-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathprocess_embedding.py
60 lines (48 loc) · 1.82 KB
/
process_embedding.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import os
import tqdm
import torch
import numpy as np
import argparse
def load_vocab_dict_from_file(dict_file):
with open(dict_file) as f:
words = [w.strip() for w in f.readlines()]
vocab_dict = {words[n]: n for n in range(len(words))}
print("words:", len(words))
return vocab_dict
def generate_emb(dataset):
glove_file = 'data/glove.840B.300d.txt'
#vocab_file = '../data/vocabulary_spacy_{}.txt'.format(dataset)
npy_file = 'data/{}/glove_emb.npy'.format(dataset)
#vocab_dict = load_vocab_dict_from_file(vocab_file)
vocab_dict = torch.load('data/{}/corpus.pth'.format(dataset)).dictionary.word2idx
vocab_size = len(vocab_dict)
emb_size = 300
emb_dict = {}
with open(glove_file, 'r') as glo:
print("Loading Glove File...")
all_emb = glo.readlines()
print("Loaded Gloves.")
for line in tqdm.tqdm(all_emb):
items = line.strip().split(' ')
word, emb = items[0], items[1:]
if word in vocab_dict.keys():
emb_dict[word] = np.array(emb, np.float32)
# build vocab npy
vocab_emb_array = np.zeros((vocab_size, emb_size), dtype=np.float32)
print("shape:", vocab_emb_array.shape)
no = open('data/{}/not_exist.txt'.format(dataset), 'w')
for w, idx in vocab_dict.items():
if w in emb_dict.keys():
vocab_emb_array[idx, :] =emb_dict[w]
else:
no.write(w+'\n')
vocab_emb_array[idx, :] = np.random.normal(loc=0.0, scale=1.0, size=(emb_size, ))
#vocab_emb_array[idx, :] = np.zeros(emb_size)
no.close()
np.save(npy_file, vocab_emb_array)
print("Saving to npy file.")
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-d', type=str, default='referit') # or 'Gref'
args = parser.parse_args()
generate_emb(args.d)