This repository has been archived by the owner on Jan 23, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathevaluator.py
132 lines (99 loc) · 4.48 KB
/
evaluator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
#!/usr/bin/env python3
import csv
from typing import *
import logging
import sys
import json
EXIT_STATUS_ANSWERS_MALFORMED = 1
EXIT_STATUS_PREDICTIONS_MALFORMED = 2
EXIT_STATUS_PREDICTIONS_EXTRA = 3
EXIT_STATUS_PREDICTION_MISSING = 4
def calculate_accuracy(question_answers: Dict[str, str], predictions: Dict[str, List[str]]) -> float:
score = 0.0
for question_id, answer in question_answers.items():
try:
predictions_for_q = predictions[question_id]
except KeyError:
logging.error("Missing prediction for question '%s'.", question_id)
sys.exit(EXIT_STATUS_PREDICTION_MISSING)
if answer in predictions_for_q:
score += 1.0 / len(predictions_for_q)
del predictions[question_id]
if len(predictions) > 0:
logging.error("Found %d extra predictions, for example: %s", len(predictions),
", ".join(list(predictions.keys())[:3]))
sys.exit(EXIT_STATUS_PREDICTIONS_EXTRA)
return score / len(question_answers)
def read_answers(filename: str) -> Dict[str, str]:
answers = {}
with open(filename, "rt", encoding="UTF-8", errors="replace") as f:
for line in f:
line = line.strip()
try:
record = json.loads(line)
except ValueError as e:
logging.error("Error while reading file %s: %s", filename, e)
sys.exit(EXIT_STATUS_ANSWERS_MALFORMED)
question_id = record["id"]
answer = record["answerKey"]
if question_id in answers:
logging.error("Key %s repeated in %s", question_id, filename)
sys.exit(EXIT_STATUS_ANSWERS_MALFORMED)
answers[question_id] = answer
if len(answers) == 0:
logging.error("No answers found in file %s", filename)
sys.exit(EXIT_STATUS_ANSWERS_MALFORMED)
return answers
def read_predictions(filename: str) -> Dict[str, List[str]]:
predictions = {}
with open(filename, "rt", encoding="UTF-8", errors="replace") as f:
reader = csv.reader(f)
try:
for row in reader:
try:
question_id = row[0]
prediction_raw = row[1]
except IndexError as e:
logging.error("Error reading value from CSV file %s on line %d: %s", filename, reader.line_num, e)
sys.exit(EXIT_STATUS_PREDICTIONS_MALFORMED)
if question_id in predictions:
logging.error("Key %s repeated in file %s on line %d", question_id, filename, reader.line_num)
sys.exit(EXIT_STATUS_PREDICTIONS_MALFORMED)
if question_id == "":
logging.error("Key is empty in file %s on line %d", filename, reader.line_num)
sys.exit(EXIT_STATUS_PREDICTIONS_MALFORMED)
prediction = prediction_raw.split(";")
# prediction labels cannot be empty strings
for p in prediction:
if p == "":
logging.error("Key %s has empty labels for prediction in file %s on line %d",
question_id, filename, reader.line_num)
sys.exit(EXIT_STATUS_PREDICTIONS_MALFORMED)
predictions[question_id] = prediction
except csv.Error as e:
logging.error('file %s, line %d: %s', filename, reader.line_num, e)
sys.exit(EXIT_STATUS_PREDICTIONS_MALFORMED)
return predictions
def main():
import argparse
parser = argparse.ArgumentParser(description='Evaluate leaderboard predictions for questions.')
parser.add_argument(
'--question-answers', '-qa',
help='Filename of the question answers to read. Expects a JSONL file with documents that have field "id" and "answerKey".',
required=True)
parser.add_argument(
'--predictions', '-p',
help="Filename of the leaderboard predictions, in CSV format.",
required=True)
parser.add_argument(
'--output', '-o',
help='Output results to this file.',
required=True)
args = parser.parse_args()
question_answers = read_answers(args.question_answers)
predictions = read_predictions(args.predictions)
accuracy = calculate_accuracy(question_answers, predictions)
with open(args.output, "wt", encoding="UTF-8") as output:
output.write(json.dumps({"accuracy": accuracy}))
if __name__ == '__main__':
main()