-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathevaluate.py
111 lines (97 loc) · 3.77 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import os
import sys
import time
from collections import OrderedDict
from time import strftime, gmtime
from dataset import AsrDataset, DataLoader, AsrCollator
from megengine.data import SequentialSampler, RandomSampler, DataLoader
from models.transformer import Model
import hparams as hp
import argparse
from tqdm import tqdm
import difflib
import megengine as mge
import megengine.functional as F
class Session:
def __init__(self, args):
with open(os.path.join(hp.dataset_root, "vocab.txt")) as f:
self.vocab = [w.strip() for w in f.readlines()]
self.vocab = ["<pad>"] + self.vocab
print(f"Vocab Size: {len(self.vocab)}")
self.pad_id = 0
self.model = Model(hp.num_mels, len(self.vocab)) # .eval()
ckpt = mge.load(args.model_path)
self.model.load_state_dict(ckpt["model"], strict=False)
self.model.eval()
self.numerator = 0
self.denominator = 0
self.error_number = 0
def reset_evaluate(self):
self.numerator = 0
self.denominator = 0
self.error_number = 0
def get_evaluate_acc(self):
return (
self.numerator / self.denominator,
self.error_number / self.denominator,
self.numerator,
self.error_number,
self.denominator,
)
def GetEditDistance(self, str1, str2):
leven_cost = 0
s = difflib.SequenceMatcher(None, str1, str2)
for tag, i1, i2, j1, j2 in s.get_opcodes():
# print('{:7} a[{}: {}] --> b[{}: {}] {} --> {}'.format(tag, i1, i2, j1, j2, str1[i1: i2], str2[j1: j2]))
if tag == "replace":
leven_cost += max(i2 - i1, j2 - j1)
elif tag == "insert":
leven_cost += j2 - j1
elif tag == "delete":
leven_cost += i2 - i1
return leven_cost
def evaluate(self, data):
text_input, text_output, mel, pos_text, pos_mel, text_length, mel_length = data
ys = self.model.forward(mel, mel_length, text_input, text_length, evaluate=True)
mask = text_output != self.pad_id
ys_str = ""
text_output_str = ""
for idex in ys[mask]:
if self.vocab[idex] == "<eos>":
ys_str += "."
else:
ys_str += self.vocab[idex]
for idex in text_output[mask]:
if self.vocab[idex] == "<eos>":
text_output_str += "."
else:
text_output_str += self.vocab[idex]
numerator = F.sum(ys[mask] == text_output[mask])
denominator = F.sum(mask)
edit_distance = self.GetEditDistance(ys_str, text_output_str)
if edit_distance <= denominator.item() - numerator.item():
self.error_number += edit_distance
else:
self.error_number += denominator.item() - numerator.item()
self.numerator += numerator.item()
self.denominator += denominator.item()
def main():
os.makedirs(hp.checkpoint_path, exist_ok=True)
parser = argparse.ArgumentParser()
parser.add_argument("--model-path", required=True)
parser.add_argument("--dataset", default="dev", choices=["dev", "test", "train"])
args = parser.parse_args()
dataset = AsrDataset(args.dataset)
sess = Session(args)
val_sampler = SequentialSampler(dataset=dataset, batch_size=32)
dataloader = DataLoader(
dataset=dataset, sampler=val_sampler, collator=AsrCollator()
)
sess.reset_evaluate()
for idx, data in enumerate(tqdm(dataloader)):
text_input, text_output, mel, pos_text, pos_mel, text_length, mel_length = data
sess.evaluate(data)
acc, wer, numerator, error_number, denominator = sess.get_evaluate_acc()
print("ACC: ", acc, "CER", wer)
if __name__ == "__main__":
main()