-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathget_anuvaad_preds.py
36 lines (28 loc) · 1.06 KB
/
get_anuvaad_preds.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from anuvaad import Anuvaad
models = {
#'tam': Anuvaad('english-tamil'),
#'tel': Anuvaad('english-telugu'),
#'kan': Anuvaad('english-kannada'),
#'mal': Anuvaad('english-malayalam'),
#'mar': Anuvaad('english-marathi'),
'hin': Anuvaad('english-hindi')
}
from tqdm import tqdm
lines = open('english-telugu_tamil_hindi_kannada_malayalam_marathi.tatoeba-sentpairs.tsv').readlines()
lines = [l for l in lines if l.strip().split('\t')[3] in models]
if len(lines) >= 1000:
import random
random.seed(42)
random.shuffle(lines)
lines = lines[:1000]
of = open(list(models.keys())[0] + '-results.tsv', 'w')
of.write('SRC_ID\tTGT_ID\tSRC_LANG\tTGT_LANG\tSRC_SENT\tTGT_SENT\tPRED\n')
for line in tqdm(lines):
try:
if line.strip().split('\t')[3] not in models: continue
pred = models[line.strip().split('\t')[3]].anuvaad(line.split('\t')[4].strip())
of.write(line.strip() + '\t' + pred + '\n')
of.flush()
except Exception as ex:
print('\n', ex, '\n')
continue