-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathget_data.py
90 lines (66 loc) · 3.02 KB
/
get_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import dspy
import requests
import pickle
import json
import random
from collections import defaultdict
import pandas as pd
# this is the one that they sampled 100 existing OpenToM plots to produce "extra long" narratives
# URL = "https://raw.githubusercontent.com/SeacowX/OpenToM/main/data/opentom_long.json"
URL = "https://raw.githubusercontent.com/SeacowX/OpenToM/main/data/opentom.json"
def default_factory():
return []
def load_dataset():
response = requests.get(URL).json()
df = pd.DataFrame(response)
# Extract 'type' and 'answer' into separate columns
df["type"] = df["question"].apply(lambda x: x["type"])
df["answer"] = df["question"].apply(lambda x: x["answer"])
unique_answers_by_type = df.groupby("type")["answer"].unique()
# convert the dataset to what DSPy expects (list of Example objects)
dataset = []
for index, row in df.iterrows():
context = row["narrative"]
question = row["question"]["question"]
answer = row["question"]["answer"]
type = row["question"]["type"]
plot_info = json.dumps(row["plot_info"]) # Keeping each example field as a string might be a good idea
# update the type value if location is coarse or fine
if "location" in type:
location_granularity = "fine" if answer.lower().strip() != "yes" and answer.lower().strip() != "no" else "coarse"
type = f"{type}-{location_granularity}"
# Answer choices
if "location" in type and (
answer.lower().strip() != "yes" and answer.lower().strip() != "no"
): # don't provide answer choices for fine grained location questions
answer_choices = "n/a, list a specific location"
elif "location" in type:
answer_choices = "No, Yes"
else:
answer_choices = ", ".join(unique_answers_by_type[type])
dataset.append(
dspy.Example(
context=context, question=question, answer=answer, type=type, plot_info=plot_info, answer_choices=answer_choices
).with_inputs("context", "question", "answer_choices")
)
# split datasets by question types
datasets = defaultdict(default_factory)
for example in dataset:
datasets[example.type].append(example)
datasets.keys()
[len(dataset) for dataset in datasets.values()]
# create train test split
for question_type, dataset in datasets.items():
random.shuffle(dataset)
datasets[question_type] = {
"train": dataset[int(len(dataset) * 0.8) :], # 80% test, 20% train
"test": dataset[: int(len(dataset) * 0.8)],
}
print(f"Train {question_type}: {len(datasets[question_type]['train'])}")
print(f"Test {question_type}: {len(datasets[question_type]['test'])}")
# Serialize and save the datasets object to a file
with open("datasets.pkl", "wb") as file:
pickle.dump(datasets, file)
print("🫡 Datasets object has been saved to 'datasets.pkl' 🫡")
if __name__ == "__main__":
load_dataset()