-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun_verifier.py
70 lines (60 loc) · 2.36 KB
/
run_verifier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
# -*- coding: utf-8 -*-
# Copyright (c) 2021 by Phuc Phan
import json
import argparse
import collections
def get_score1(args):
cof = [1, 1]
best_cof = [1]
all_scores = collections.OrderedDict()
idx = 0
for input_file in args.input_null_files.split(","):
with open(input_file, 'r') as reader:
input_data = json.load(reader, strict=False)
for (key, score) in input_data.items():
if key not in all_scores:
all_scores[key] = []
all_scores[key].append(cof[idx] * score)
idx += 1
output_scores = {}
for (key, scores) in all_scores.items():
mean_score = 0.0
for score in scores:
mean_score += score
mean_score /= float(len(scores))
output_scores[key] = mean_score
idx = 0
all_nbest = collections.OrderedDict()
for input_file in args.input_nbest_files.split(","):
with open(input_file, "r") as reader:
input_data = json.load(reader, strict=False)
for (key, entries) in input_data.items():
if key not in all_nbest:
all_nbest[key] = collections.defaultdict(float)
for entry in entries:
all_nbest[key][entry["text"]] += best_cof[idx] * entry["probability"]
idx += 1
output_predictions = {}
for (key, entry_map) in all_nbest.items():
sorted_texts = sorted(
entry_map.keys(), key=lambda x: entry_map[x], reverse=True)
best_text = sorted_texts[0]
output_predictions[key] = best_text
best_th = args.thresh
for qid in output_predictions.keys():
if output_scores[qid] > best_th:
output_predictions[qid] = ""
output_prediction_file = "results.json"
with open(output_prediction_file, "w") as writer:
writer.write(json.dumps(output_predictions, indent=4) + "\n")
def main():
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument('--input_null_files', type=str, default="cls_score.json, null_odds.json")
parser.add_argument('--input_nbest_files', type=str, default="nbest_predictions.json")
parser.add_argument('--thresh', default=0, type=float)
parser.add_argument("--predict_file", default="data/dev-v2.0.json")
args = parser.parse_args()
get_score1(args)
if __name__ == "__main__":
main()