This repository has been archived by the owner on Jul 26, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
127 lines (105 loc) · 4.27 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
from collections import namedtuple
from pprint import pprint
from dotenv import load_dotenv
# needs to happen as very first thing, otherwise HF ignores env vars
load_dotenv()
import os
import pandas as pd
from dataclasses import dataclass, field
from typing import Dict, cast, List
from datasets import DatasetDict, load_dataset
from src.readers.base_reader import Reader
from src.evaluation import evaluate
from src.readers.dpr_reader import DprReader
from src.readers.longformer_reader import LongformerReader
from src.retrievers.base_retriever import Retriever
from src.retrievers.es_retriever import ESRetriever
from src.retrievers.faiss_retriever import (
FaissRetriever,
FaissRetrieverOptions
)
from src.utils.log import logger
from src.utils.preprocessing import context_to_reader_input
from src.utils.timing import get_times, timeit
ExperimentResult = namedtuple('ExperimentResult', ['correct', 'given'])
@dataclass
class Experiment:
retriever: Retriever
reader: Reader
lm: str
results: List[ExperimentResult] = field(default_factory=list)
if __name__ == '__main__':
dataset_name = "GroNLP/ik-nlp-22_slp"
paragraphs = cast(DatasetDict, load_dataset(
"GroNLP/ik-nlp-22_slp", "paragraphs"))
questions = cast(DatasetDict, load_dataset(dataset_name, "questions"))
# Only doing a few questions for speed
subset_idx = len(questions["test"])
questions_test = questions["test"][:subset_idx]
experiments: Dict[str, Experiment] = {
"faiss_dpr": Experiment(
retriever=FaissRetriever(
paragraphs,
FaissRetrieverOptions.dpr("./src/models/dpr.faiss")),
reader=DprReader(),
lm="dpr"
),
"faiss_longformer": Experiment(
retriever=FaissRetriever(
paragraphs,
FaissRetrieverOptions.longformer("./src/models/longformer.faiss")),
reader=LongformerReader(),
lm="longformer"
),
"es_dpr": Experiment(
retriever=ESRetriever(paragraphs),
reader=DprReader(),
lm="dpr"
),
"es_longformer": Experiment(
retriever=ESRetriever(paragraphs),
reader=LongformerReader(),
lm="longformer"
),
}
for experiment_name, experiment in experiments.items():
logger.info(f"Running experiment {experiment_name}...")
for idx in range(subset_idx):
question = questions_test["question"][idx]
answer = questions_test["answer"][idx]
# workaround so we can use the decorator with a dynamic name for
# time recording
retrieve_timer = timeit(f"{experiment_name}.retrieve")
t_retrieve = retrieve_timer(experiment.retriever.retrieve)
read_timer = timeit(f"{experiment_name}.read")
t_read = read_timer(experiment.reader.read)
print(f"\x1b[1K\r[{idx+1:03}] - \"{question}\"", end='')
scores, context = t_retrieve(question, 5)
reader_input = context_to_reader_input(context)
# Requesting 1 answers results in us getting the best answer
given_answer = t_read(question, reader_input, 1)[0]
# Save the results so we can evaluate laters
if experiment.lm == "longformer":
experiment.results.append(
ExperimentResult(answer, given_answer[0]))
else:
experiment.results.append(
ExperimentResult(answer, given_answer.text))
print()
if os.getenv("ENABLE_TIMING", "false").lower() == "true":
# Save times
times = get_times()
df = pd.DataFrame(times)
os.makedirs("./results/", exist_ok=True)
df.to_csv("./results/timings.csv")
f1_results = pd.DataFrame(columns=experiments.keys())
em_results = pd.DataFrame(columns=experiments.keys())
for experiment_name, experiment in experiments.items():
em, f1 = zip(*list(map(
lambda r: evaluate(r.correct, r.given), experiment.results
)))
em_results[experiment_name] = em
f1_results[experiment_name] = f1
os.makedirs("./results/", exist_ok=True)
f1_results.to_csv("./results/f1_scores.csv")
em_results.to_csv("./results/em_scores.csv")