-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathLM_API.py
41 lines (35 loc) · 1.06 KB
/
LM_API.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import pickle
S_gram_file = "./LM_results/S_gram.pkl"
Bi_gram_file = "./LM_results/Bi_gram.pkl"
Tri_gram_file = "./LM_results/Tri_gram.pkl"
STOP_TAG = '</s>'
ngram = 3
if ngram == 3:
with open(Bi_gram_file, 'rb') as f:
down_model = pickle.load(f)
with open(Tri_gram_file, 'rb') as f:
up_model = pickle.load(f)
elif ngram == 2:
with open(Bi_gram_file, 'rb') as f:
up_model = pickle.load(f)
with open(S_gram_file, 'rb') as f:
down_model = pickle.load(f)
def LM_score(x):
if not isinstance(x, list):
x = [x]
scores = []
for s in x:
score = 1
if STOP_TAG in s:
s = s.replace(STOP_TAG, '')
s = list(s) + [STOP_TAG]
else:
s = list(s)
for i in range(len(s) - ngram + 1):
up_word = ''.join(s[i:i + ngram])
down_word = ''.join(up_word[1:])
p = (up_model[up_word]+1)/(len(down_model)+down_model[down_word])
score *= p
# print(up_word, down_word, p)
scores.append(score)
return scores