-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbuild_token_frequencies.py
44 lines (35 loc) · 1.14 KB
/
build_token_frequencies.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from transformers import AutoTokenizer
from tqdm import tqdm
import json
BATCH_SIZE=128
max_length=32
N=100000
model_name="bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
token_frequencies={}
def read_wikisent(dataset,n_lines=10):
with open(f'../{dataset}','r') as f:
sentences=[]
for i in range(n_lines):
sentences.append(f.readline())
return sentences
def update_freqs(batch):
batch = tokenizer.batch_encode_plus(
batch,
return_tensors="pt",
padding=True)
for x in batch.input_ids:
for token in tokenizer.convert_ids_to_tokens(x):
token_frequencies[token]=token_frequencies.get(token,0)+1
sentences=read_wikisent('wiki1m_for_simcse_shuf.txt',N)
batch=[]
for sentence in tqdm(sentences):
batch.append(' '.join(sentence.replace('\n', '').split()))
if len(batch) >= BATCH_SIZE:
update_freqs(batch)
batch=[]
if len(batch) >= 0:
update_freqs(batch)
token_frequencies=dict(sorted(token_frequencies.items(),key=lambda x:x[1],reverse=True))
with open('token_freqs.json','w') as f:
json.dump(token_frequencies,f)