-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
291 lines (240 loc) · 10.3 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
import torch
import json
from dataset import WordleDataset
import numpy as np
def get_default_features() -> torch.Tensor:
"""
Returns the default features.
The features are:
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] x 26 rows
They correspond to the following for each alphabet:
1) Not guessed
2) Absent in word
3) Correct Position - 1
4) Correct Position - 2
5) Correct Position - 3
6) Correct Position - 4
7) Correct Position - 5
8) Incorrect Position - 1
9) Incorrect Position - 2
10) Incorrect Position - 3
11) Incorrect Position - 4
12) Incorrect Position - 5
These are all binary features.
"""
one_through_11 = torch.zeros((26, 11)).float()
zero = torch.ones((26, 1)).float()
return torch.hstack((zero, one_through_11))
def get_label_tensor(word : str) -> torch.Tensor:
"""
Given a word, we need to create labels for that word.
Each label is the offset from 'a'.
E.G: d will have label 3, a has label 0, z has label 25.
Since each word has 5 characters, the resulting tensor has size 5
"""
output = torch.empty(5, dtype=torch.long)
for i, k in enumerate(word):
output[i] = ord(k) - ord('a')
return output
def get_feedback(guessed_word : str, correct_word : str) -> list:
"""
This functio goes through the guessed word and correct word to produce
wordle style feedback.
Edge cases to watch out for have been outlined in tests/test_feedback.py
The implementation is the similar with a difference that rather that returning
a list of strings from {'red', 'yellow', 'green'} we return {-1, 0, 1}.
"""
correct_word_counter = { c : 0 for c in correct_word }
for k in correct_word:
correct_word_counter[k] += 1
guessed_word_counter = { c : 0 for c in guessed_word }
for k in guessed_word:
guessed_word_counter[k] += 1
feedback_counter = {}
for k in correct_word:
if k in guessed_word:
feedback_counter[k] = min(guessed_word_counter[k], correct_word_counter[k])
else:
feedback_counter[k] = correct_word_counter[k]
feedback = [-1 for k in guessed_word]
for i, k in enumerate(guessed_word):
if correct_word[i] == k:
if feedback_counter[k]:
feedback[i] = 1
feedback_counter[k] -= 1
for i, k in enumerate(guessed_word):
if k in correct_word and feedback[i] == -1:
if feedback_counter[k]:
feedback[i] = 0
feedback_counter[k] -= 1
return feedback
def get_updated_features(features : torch.Tensor, feedback : list, guessed_word : str) -> torch.Tensor:
"""
This function updates the features based on the feedback and the guessed_word.
Arguments:
`feautures`: Features should be either a default feature set created from get_default_features() or
an updated feature set that comes from this function.
`feedback`: Feedback is a list of integers from {-1, 0, 1}. This comes from the output of get_feedback().
`guessed_word`: THe word that is guessed by the model.
"""
for i, k in enumerate(guessed_word):
row_idx = ord(k) - ord('a')
if feedback[i] == 0:
col_idx = 7 + i
elif feedback[i] == 1:
col_idx = 2 + i
elif feedback[i] == -1:
col_idx = 1
else:
raise ValueError
features[row_idx][col_idx] = 1
features[row_idx][0] = 0
return features
def get_word(outputs : torch.Tensor) -> str:
"""
To convert the output of our model to a word that can be made sense of, we use this function.
Basically take the argmax for each of the output (1, 26) and add the offset for ord('a') to get the character.
Arguments:
`outputs`: The output from the model. Should be of the shape [5, 26].
"""
word = ""
for o in outputs:
word += chr(torch.argmax(o) + ord('a'))
return word
def get_wordlist(wordlist_path : str) -> list:
"""
Reads the words from the file at the path, removes ending '\n' characters, and lowercases all the words.
Arguments:
`wordlist_path`: Needs to be the full path to the wordlist to use. Usually word lists are under data/ subdirectory.
"""
words = []
with open(wordlist_path, 'r') as f:
words = f.readlines()
words = [word.strip() for word in words]
words = [word.lower() for word in words]
return words
def get_wordset(wordlist_path : str) -> set:
"""
Reads the words from the file at the path, removes ending '\n' characters, and lowercases all the words.
In addition, it converts the list to a set, useful for constant time lookup to check if word is in the list of words
Arguments:
`wordlist_path`: Needs to be the full path to the wordlist to use. Usually word lists are under data/ subdirectory.
"""
words = get_wordlist(wordlist_path)
return set(words)
def get_dataset(root_dir : str) -> WordleDataset:
"""
Creates and returns a WordleDataset that uses the root_dir as the file from which the words are to be read.
Arguments:
`root_dir`: The full path to the word list to be used. Usually found under data/ subdirectory.
"""
dataset = WordleDataset(root_dir)
return dataset
def get_split_dataset(dataset : WordleDataset, splits : list) -> dict:
"""
Given a dataset, creates 3 splits with the ratios specified in splits.
Arguments:
`dataset`: Should be an instance of the WordleDataset or any other Dataset.
`splits`: Should be a list of 3 floats, with the last one being 0. We calculate the last one
using magic.
**The magic being finding the number of data points in the train and val set and then subtracting from
the count of all the datapoints to get the number in test. It is just easier to have the last value as 0.
Return:
`datasets`: A dict with 'train', 'val', 'test' as keys.
{
'train': Training_dataset,
'val: Validation_dataset,
'test': Testing_dataset,
}
"""
total_count = len(dataset)
splits = [int(total_count * ratio) for ratio in splits]
splits[-1] = total_count - sum(splits)
train_set, val_set, test_set = torch.utils.data.random_split(dataset, splits, generator=torch.Generator().manual_seed(42))
datasets = {
'train': train_set,
'test': test_set,
'val': val_set,
}
return datasets
def save_model(model : torch.nn.Module, model_name : str) -> None:
"""
Save the model with the given name under the models/ subdirectory
"""
torch.save(model, f"models/{model_name}")
def save_history(history : dict, file_name : str) -> None:
"""
Save the interaction_history as a json file with the chosen file_name under the interactions/ subdirectory
"""
json_file = json.dumps(history)
with open(f"interactions/{file_name}", "w") as f:
f.write(json_file)
def save_loss(loss : list, file_name : str):
"""
Save the loss, that is acutally an np.arry but I can't find the typing for that, in a npz file under the plots/ subdirectory.
"""
np.save(f"plots/{file_name}", loss)
def get_mask_tree(wordlist_path : str) -> dict:
"""
Reads the words from the file at the path, and creates a word mask tree.
A mask tree specifies the masks for each possible part of the word.
This way we can find what are the possible next characters for any part of a word.
ex:
mask_tree[1] has all the masks from 'a' - 'z'
mask_tree[1]['a'] has the masks for 'a****' specifying what are the characters that can be
at the second place.
mask_tree[1]['art'] has the masks for 'art**' specifying what are the characters that can be
at the third place.
Arguments:
`wordlist_path`: Needs to be the full path to the wordlist to use. Usually word lists are under data/ subdirectory.
"""
words = get_wordlist(wordlist_path)
mask_tree = {
0: [0 for _ in range(26)],
}
for word in words:
idx = ord(word[0]) - ord('a')
mask_tree[0][idx] = 1
for word in words:
for pass_len in range(1, 5):
key = word[:pass_len]
idx = ord(word[pass_len]) - ord('a')
if pass_len not in mask_tree:
mask_tree[pass_len] = {}
if key not in mask_tree[pass_len]:
mask_tree[pass_len][key] = [0 for _ in range(26)]
mask_tree[pass_len][key][idx] = 1
return mask_tree
def get_word_beam_search(outputs : torch.Tensor, mask_tree : dict, k : int = 3) -> str:
"""
To convert the output of our model to a word that can be made sense of, we use this function.
Rather than taking the argmax independent of the underlying word distribution, we carry out a beam search to find the optimal
word in our dictionary.
Arguments:
`outputs`: The output from the model. Should be of the shape [5, 26].
`mask_tree`: The mask tree to be used. This is created using the get_word_tree() function
"""
# initialize
soft_outputs = torch.nn.functional.softmax(outputs, dim=1)
mask = mask_tree[0]
mask = torch.tensor(mask)
mask = mask * soft_outputs[0]
values, indices = torch.topk(mask, k=k)
characters = [ chr(i + ord('a')) for i in indices ]
for i, output in enumerate(soft_outputs[1:]):
new_output = torch.tensor([])
for j, c in enumerate(characters):
mask = mask_tree[i + 1][c]
mask = torch.tensor(mask)
mask = values[j] * mask
mask = mask * output
new_output = torch.hstack((new_output, mask))
# update for new iterration
values, indices = torch.topk(new_output, k=k)
temp_char = [ch for ch in characters]
for i, idx in enumerate(indices):
old_ch = characters[torch.div(idx, 26, rounding_mode='trunc')]
new_ch = old_ch + chr(idx % 26 + ord('a'))
temp_char[i] = new_ch
characters = temp_char
return characters[0]