Skip to content

Commit

Permalink
unify the api used in search and perplexity
Browse files Browse the repository at this point in the history
  • Loading branch information
liyucheng09 committed Nov 6, 2023
1 parent 8bdb734 commit 73cfb41
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 76 deletions.
98 changes: 27 additions & 71 deletions perplexity.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,44 +9,15 @@
import datasets
import numpy as np
import time
from utils import (
Column_to_check,
Hf_Name_and_Split,
prepare_query,
prepare_dataset,
)

WIKI_API_ENDPOINT = "https://en.wikipedia.org/w/api.php"
np.random.seed(42)

MEMORISED = {
'wiki': 'RealTimeData/wikitext_alltime',
'bbc': 'RealTimeData/bbc_alltime'
}

CLEAN = {
'wiki': 'RealTimeData/wikitext_latest',
'bbc': 'RealTimeData/bbc_latest'
}

# the column name of the main text in the dataset, usually passage or context
COLUMNS = {
'RealTimeData/wikitext_alltime': 'text',
'RealTimeData/wikitext_latest': 'text',
'RealTimeData/bbc_latest': 'content',
'RealTimeData/bbc_alltime': 'content',
'iohadrubin/mini_xsum': 'document',
'quac': 'context',
'boolq': 'passage',
'squad_v2': 'context',
}

# which split you want to analyze, how you want to call it
SPLITS = {
'RealTimeData/wikitext_alltime': 'train',
'RealTimeData/wikitext_latest': 'train',
'RealTimeData/bbc_latest': 'train',
'RealTimeData/bbc_alltime': 'train',
'iohadrubin/mini_xsum': 'validation',
'quac': 'validation',
'boolq': 'validation',
'squad_v2': 'validation',
}

def self_info(text, model, tokenizer, merge = False):
def merge_sub_tokens(log_probs, word_ids):
# merge log probs of sub_tokens
Expand Down Expand Up @@ -103,23 +74,6 @@ def select_token_window(text, token_count=400):
tokens = tokens[ramdom_start:ramdom_start + token_count]
return ' '.join(tokens)

def prepare_data(dataset, column, split, config = None, num_samples=200, token_count=300):
# This function is used to prepare the data to analyze
# it takes dataset_name as input as return a list of strings as output
# Now it main support the downloading datasets from huggingface hub
# you could easily extend it to support other datasets

if config is None:
ds = datasets.load_dataset(dataset, split=split)
else:
ds = datasets.load_dataset(dataset, config, split=split)

ds = ds.select(np.random.choice(len(ds), num_samples))
ds = ds[column]

texts = [select_token_window(text, token_count=token_count) for text in ds]
return texts

def load_model_and_tokenizer(model_name):

if 'GPTQ' in model_name:
Expand All @@ -137,6 +91,9 @@ def load_model_and_tokenizer(model_name):
elif 'gpt2' == model_name.lower():
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
else:
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", trust_remote_code=True)

return model, tokenizer

Expand All @@ -147,33 +104,32 @@ def load_model_and_tokenizer(model_name):
# default is to use the val or test split, if you want to use the train split, set this to True
use_train_split = False

num_token = 200
num_token = 300
num_samples = 300
model_names = ['gpt2', 'TheBloke/Llama-2-13B-GPTQ', 'facebook/opt-6.7b']

evaluation_datasets = ['quac', 'boolq', 'squad_v2']
output_file = f'reports/perplexity_report.json'
model_names = ['Qwen/Qwen-7B', 'baichuan-inc/Baichuan2-7B-Base']

datasets_and_texts = {}
# Prepare evaluation datasets
for ds in evaluation_datasets:
# default is to use the val or test split
datasets_and_texts[ds] = prepare_data(ds, COLUMNS[ds], SPLITS[ds], num_samples = num_samples, token_count = num_token)
evaluation_datasets = ['mmlu', 'mmlu_train', 'ceval', 'ceval_train']
datasets_to_check = prepare_dataset(evaluation_datasets, n = num_samples)

if use_train_split:
datasets_and_texts[f'{ds}_train'] = prepare_data(ds, COLUMNS[ds], 'train', num_samples = num_samples, token_count = num_token)

if doing_contamination_test:
# What is the source of the evaluation?
# We support wikipedia and bbc in the current version.
evaluation_base = 'wiki'
memorised_time = '2022-8'

# Prepare two baselines
datasets_and_texts['memorised'] = prepare_data(MEMORISED[evaluation_base], COLUMNS[MEMORISED[evaluation_base]], SPLITS[MEMORISED[evaluation_base]], \
config = memorised_time, num_samples = num_samples, token_count = num_token)
datasets_and_texts['clean'] = prepare_data(CLEAN[evaluation_base], COLUMNS[CLEAN[evaluation_base]], SPLITS[CLEAN[evaluation_base]], \
num_samples = num_samples, token_count = num_token)
baseline_datasets = [f'{evaluation_base}_clean', f'{evaluation_base}_all']
datasets_to_check.update(prepare_dataset(baseline_datasets, n = 500, config = memorised_time))

datasets_and_texts = {}
for dataset_name, ds in datasets_to_check.items():

all_texts = []
for i, row in tqdm(enumerate(ds), desc=f'Processing {dataset_name}'):
# query is the verbatized test sample
query = prepare_query(dataset_name, row)
if query['query'] is None: continue
query_chunked = select_token_window(query['query'], token_count=num_token)

datasets_and_texts.setdefault(dataset_name, []).append(query_chunked)

results = {}
for model_name in model_names:
Expand Down
52 changes: 47 additions & 5 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,40 @@
Column_to_check = {
'winogrande': {'input': 'sentence', 'label': lambda x: x[f'option{x["answer"]}'], 'id': 'id'},
'ceval': {'input': 'question', 'label': lambda x: x[x['answer']], 'id': 'id'},
'ceval_train': {'input': 'question', 'label': lambda x: x[x['answer']], 'id': 'id'},
'mmlu': {'input': 'question', 'label': lambda x: x[x['answer']], 'id': 'id'},
'mmlu_train': {'input': 'question', 'label': lambda x: x['choices'][x['answer']], 'id': 'id'},
'hellaswag': {'input': 'ctx', 'label': lambda x: x['endings'][int(x['label'])], 'id': 'ind'},
'ARC': {'input': 'question', 'label': lambda x: x['choices']['text'][x['choices']['label'].index(x['answerKey'])], 'id': 'id'},
'commonsense_qa': {'input': 'question', 'label': lambda x: x['choices']['text'][x['choices']['label'].index(x['answerKey'])], 'id': 'id'}
'commonsense_qa': {'input': 'question', 'label': lambda x: x['choices']['text'][x['choices']['label'].index(x['answerKey'])], 'id': 'id'},
'squad_v2': {'passage': 'context', 'question': 'question', 'label': lambda x:x['answers']['text'], 'id': 'id'},
'squad_v2_train': {'passage': 'context', 'question': 'question', 'label': lambda x:x['answers']['text'], 'id': 'id'},
'quac': {'passage': 'background', 'id': 'dialogue_id'},
'boolq': {'passage': 'passage', 'id': 'question'},
'wiki_clean': {'passage': 'text', 'id': 'title'},
'bbc_clean': {'passage': 'content', 'id': 'link'},
'wiki_all': {'passage': 'text', 'id': 'title'},
'bbc_all': {'passage': 'content', 'id': 'link'},
}

# The name of benchmarks on the Huggingface Hub, and the split to be used
Hf_Name_and_Split = {
'winogrande': {'hf_name': 'liyucheng/winogrande_val', 'split': 'validation'},
'ceval': {'hf_name': 'liyucheng/ceval_all', 'split': 'val'},
'ceval_train': {'hf_name': 'liyucheng/ceval_all_dev', 'split': 'dev'},
'mmlu': {'hf_name': 'liyucheng/mmlu_test', 'split': 'train'},
'mmlu_train': {'hf_name': 'liyucheng/mmlu_train', 'split': 'train'},
'hellaswag': {'hf_name': 'Rowan/hellaswag', 'split': 'validation'},
'ARC': {'hf_name': 'liyucheng/arc_test', 'split': 'test'},
'commonsense_qa': {'hf_name': 'commonsense_qa', 'split': 'validation'},
'quac': {'hf_name': 'quac', 'split': 'validation'},
'boolq': {'hf_name': 'boolq', 'split': 'validation'},
'squad_v2': {'hf_name': 'squad_v2', 'split': 'validation'},
'squad_v2_train': {'hf_name': 'squad_v2', 'split': 'train'},
'wiki_clean': {'hf_name': 'RealTimeData/wikitext_latest', 'split': 'train'},
'bbc_clean': {'hf_name': 'RealTimeData/bbc_latest', 'split': 'train'},
'wiki_all': {'hf_name': 'RealTimeData/wikitext_alltime', 'split': 'train'},
'bbc_all': {'hf_name': 'RealTimeData/bbc_alltime', 'split': 'train'},
}

# This is used to choose the right Bing market, based on the language of the dataset
Expand Down Expand Up @@ -58,13 +78,29 @@ def random_sample_ds(ds, n = 100):
return ds.select(np.random.choice(len(ds), n))

def prepare_query(dataset_name, row):
"""Here we verbalize the input and label to form a textual query so that we can send them to Bing Search.
For some benchmarks which have blanks in the input, we replace the blank with the label
"""
Here we verbalize the test instances to form a textual query so that we can send them to Bing Search or calculating ppl.
For multi-choice benchmarks, we use 'input' and 'label' to form the query.
For reading comprehension benchmarks, we use 'passage', 'question', and 'answer' to form the query.
For test samples which have blanks in the input, we replace the blank with the label.
Otherwise, we directly append the answer to the input.
"""
assert dataset_name in Column_to_check.keys(), \
f'Column_to_check for {dataset_name} is not configed in utils.py'

id_ = row[Column_to_check[dataset_name]['id']]
if Column_to_check[dataset_name].get('passage', None) is not None:
# rendering reading comprehension benchmarks
passage = row[Column_to_check[dataset_name]['passage']]
query = f'{passage}'
return {
'id': id_,
'query': query,
}

# multi-choice benchmarks are more complicated, sometimes need to fill-blanks
assert Column_to_check[dataset_name].get('input', None) is not None
input_ = row[Column_to_check[dataset_name]['input']]
label = Column_to_check[dataset_name]['label'](row)

Expand Down Expand Up @@ -100,7 +136,9 @@ def fill_blanks(question, answers, placeholder = '____'):
verbalize = {
'winogrande': lambda input_, label: fill_blanks(input_, label, '_'),
'ceval': lambda input_, label: fill_blanks(input_, label, '____'),
'ceval_train': lambda input_, label: fill_blanks(input_, label, '____'),
'mmlu': lambda input_, label: fill_blanks(input_, label),
'mmlu_train': lambda input_, label: fill_blanks(input_, label),

'hellaswag': lambda input_, label: f'{input_} {label}',
'ARC': lambda input_, label: f'{input_} {label}',
Expand All @@ -113,15 +151,19 @@ def fill_blanks(question, answers, placeholder = '____'):
'query': verbalize[dataset_name](input_, label),
}

def prepare_dataset(dataset_names, n = 500):
def prepare_dataset(dataset_names, n = 500, config = None):
"""Load the datasets from Huggingface Hub, and randomly sample n samples from each dataset.
if n == 'all', then load all samples.
"""
dses = {}
for dataset_name in dataset_names:
assert dataset_name in Hf_Name_and_Split.keys(), \
f'Hf_Name_and_Split for {dataset_name} is not defined in utils.py'
if Hf_Name_and_Split[dataset_name].get('config', None) is not None:

if dataset_name in ['wiki_clean', 'bbc_clean', 'wiki_all', 'bbc_all']:
assert config is not None, 'The config is used to set a time range, if you use wiki_clean, bbc_clean, wiki_all, bbc_all you need to set a time with config = time'
ds = load_dataset(Hf_Name_and_Split[dataset_name]['hf_name'], config, split = Hf_Name_and_Split[dataset_name]['split'])
elif Hf_Name_and_Split[dataset_name].get('config', None) is not None:
ds = load_dataset(Hf_Name_and_Split[dataset_name]['hf_name'], Hf_Name_and_Split[dataset_name]['config'], split = Hf_Name_and_Split[dataset_name]['split'])
else:
ds = load_dataset(Hf_Name_and_Split[dataset_name]['hf_name'], split = Hf_Name_and_Split[dataset_name]['split'])
Expand Down

0 comments on commit 73cfb41

Please sign in to comment.