-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
370 lines (325 loc) · 14.5 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
import unittest
from typing import List, Dict, Any
import random
import numpy as np
import torch
from torch import nn
from cl_data.src.constants import TaskTypes
from cl_pretrainer.batch_builder import BatchBuilder
from cl_pretrainer.checkpoint_manager import CheckPointManager
from lr_scheduler import NoamOpt
from response_parser.simple_response_parser import SimpleResponseParser
from transformer import Transformer
from vocabulary_builder.simple_vocabulary_builder import SimpleVocabBuilder
def train(
vocabulary: SimpleVocabBuilder,
transformer: nn.Module,
scheduler: Any,
criterion: Any,
batches: Dict[str, List[List[List[dict]]]],
masks: Dict[str, List[torch.Tensor]],
n_epochs: int,
task_type: str,
start_epoch=0,
is_training=True,
verbose_log=False,
):
"""
Main training loop
:param vocabulary: Vocabulary class instance.
:param task_type: The type of task.
:param start_epoch: From which epoch training should resume.
:param verbose_log: Log in detailed level with tgt_output and decoder_output
:param transformer: the transformer model
:param scheduler: the learning rate scheduler
:param criterion: the optimization criterion (loss function)
:param batches: aligned src and tgt batches that contain tokens ids
:param masks: source key padding mask and target future mask for each batch
:param n_epochs: the number of epochs to train the model for
:param is_training: is the model used for training or inference
:return: the accuracy and loss on the latest batch
"""
transformer.train(is_training)
if not is_training:
n_epochs = 1
num_iters = 0
for e in range(start_epoch, start_epoch + n_epochs):
for i, (src_batch, src_padding_mask, tgt_batch, tgt_future_mask) in enumerate(
zip(batches[BatchBuilder.ENCODER_IO_PARSER_OUTPUT_KEY],
masks[BatchBuilder.PADDING_MASK_KEY],
batches[BatchBuilder.DECODER_IO_PARSER_OUTPUT_KEY],
masks[BatchBuilder.FUTURE_MASK_KEY])
):
encoder_output = transformer.encoder(src_batch, task_type, src_padding_mask=src_padding_mask)
# Perform one decoder forward pass to obtain *all* next-token predictions for every index i given its
# previous *gold standard* tokens [1,..., i] (i.e. teacher forcing) in parallel/at once.
decoder_output = transformer.decoder(
tgt_batch,
task_type,
encoder_output,
src_padding_mask=src_padding_mask,
future_mask=tgt_future_mask,
) # type: ignore
# Align labels with predictions: the last decoder prediction is meaningless because we have no target token
# for it. In teacher forcing we want to force all tokens, but to force/let know decoder to learn a token,
# it has to be provided in the decoder input. If provide in decoder input then it will produce one
# decoder output, but this output is meaningless, as we don't have any target for that token.
# Decoder output also don't have BOS, as BOS is added in decoder input for the first token.
# [batch_size, sequence_length, logits]
decoder_output = decoder_output[:, :-1, :]
# print(f"after decoder_output {decoder_output.shape}")
# convert tgt_batch into integer tokens
tgt_batch = torch.tensor(vocabulary.batch_encoder(tgt_batch))
# The BOS token in the target is also not something we want to compute a loss for.
# As it's not available in Decoder output.
# But Padding and EOS is okay, as we will compute decoder output until max_length.
# Which include EOS and Padding musk tokens.
# [batch_size, sequence_length]
tgt_batch = tgt_batch[:, 1:]
# Set pad tokens in the target to -100 so they don't incur a loss
# tgt_batch[tgt_batch == transformer.padding_idx] = -100
# Compute the average cross-entropy loss over all next-token predictions at each index i given [1, ..., i]
# for the entire batch. Note that the original paper uses label smoothing (I was too lazy).
batch_loss = criterion(
decoder_output.contiguous().permute(0, 2, 1),
tgt_batch.contiguous().long(),
)
# Rough estimate of per-token accuracy in the current training batch
batch_accuracy = (
torch.sum(decoder_output.argmax(dim=-1) == tgt_batch)
) / torch.numel(tgt_batch)
if num_iters % len(batches[BatchBuilder.ENCODER_IO_PARSER_OUTPUT_KEY]) == 0 or not is_training:
print(
f"epoch: {e}, num_iters: {num_iters}, batch_loss: {batch_loss}, batch_accuracy: {batch_accuracy}"
)
if verbose_log:
# Printing predicted tokens
print("~~~Printing target batch~~~\n")
SimpleResponseParser.print_response_to_console(vocabulary.batch_decode(tgt_batch.tolist()))
print("~~~Printing decoder output batch~~~\n")
SimpleResponseParser.print_response_to_console(vocabulary.batch_decode(decoder_output.argmax(dim=-1).tolist()))
# Update parameters
if is_training:
batch_loss.backward()
scheduler.step()
scheduler.optimizer.zero_grad()
num_iters += 1
return batch_loss, batch_accuracy
class TestTransformerTraining(unittest.TestCase):
seed = 0
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
PATH = "./saved_models/model.pth"
def test_train_and_save(self):
"""
Test training by trying to (over)fit a simple copy dataset - bringing the loss to ~zero. (GPU required)
"""
device = (
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
)
batch_size = 3
n_epochs = 35
# Minimum encoding length is 16
max_encoding_length = 20
max_decoding_length = 12
task_type = TaskTypes.NL_TO_NL_TRANSLATION.value
# Construct vocabulary and create synthetic data by uniform randomly sampling tokens from it
corpus_source = [
"The sun is shining brightly in the clear blue sky",
"She studied hard for her exams and earned top grades",
"The cat chased the mouse around the house",
"He loves to play the guitar and sing songs",
"They enjoyed a delicious meal at their favorite restaurant",
"The book was so captivating that she couldn't put it down",
]
corpus_target = [
"I like chocolate ice cream with milk",
"Dogs bark loudly at snowy night",
"Raindrops fall gently from clouds",
"Each children will receive ##division(9,3) candies",
"Adding 3 plus 2 equals ##addition(3,2)",
"The result of subtracting 1 from 5 is ##subtraction(5,1)"
]
combined_list = corpus_source + corpus_target
# Creating the vocabulary
vocab = SimpleVocabBuilder(BatchBuilder.get_batch_io_parser_output(combined_list, True, 20))
vocab_size = len(list(vocab.vocab_item_to_index.keys()))
valid_tokens = list(vocab.vocab_item_to_index.keys())[3:]
print(f"Vocabulary size: {vocab_size}")
# Construct src-tgt aligned input batches (note: the original paper uses dynamic batching based on tokens)
# Creating the batch
corpus = [
{BatchBuilder.SOURCE_LANGUAGE_KEY: src, BatchBuilder.TARGET_LANGUAGE_KEY: tgt} for src, tgt in
zip(corpus_source, corpus_target)
]
batches, masks = BatchBuilder.construct_batches_for_transformer(
corpus,
batch_size=batch_size,
max_encoder_sequence_length=max_encoding_length,
max_decoder_sequence_length=max_decoding_length,
)
print(
f"valid token {len(valid_tokens)}\n"
f"corpus {len(corpus)}\n"
f"batch size: {batch_size} Number of item in batches {len(batches[BatchBuilder.ENCODER_IO_PARSER_OUTPUT_KEY])},"
f" calculated : {len(corpus)/batch_size}"
)
# Initialize transformer
transformer = Transformer(
hidden_dim=768,
batch_size=batch_size,
ff_dim=2048,
num_heads=8,
num_layers=2,
max_decoding_length=max_decoding_length,
vocab_size=vocab_size,
dropout_p=0.1,
).to(device)
# Initialize learning rate scheduler, optimizer and loss (note: the original paper uses label smoothing)
optimizer = torch.optim.Adam(
transformer.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9
)
scheduler = NoamOpt(
transformer.hidden_dim,
factor=1,
warmup=400,
optimizer=optimizer,
)
criterion = nn.CrossEntropyLoss()
# Start training and verify ~zero loss and >90% accuracy on the last batch
latest_batch_loss, latest_batch_accuracy = train(
vocabulary=vocab,
transformer=transformer,
scheduler=scheduler,
criterion=criterion,
batches=batches,
masks=masks,
n_epochs=n_epochs,
task_type=task_type,
is_training=True,
verbose_log=False,
)
print(f"batch loss {latest_batch_loss.item()}")
print(f"batch accuracy {latest_batch_accuracy}")
self.assertEqual(latest_batch_loss.item() <= 0.01, True)
self.assertEqual(latest_batch_accuracy >= 0.99, True)
CheckPointManager.save_checkpoint_map(
TestTransformerTraining.PATH,
n_epochs,
transformer,
optimizer,
)
def test_model_load(self):
device = (
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
)
batch_size = 3
n_epochs = 1
# Minimum encoding length is 16
max_encoding_length = 20
max_decoding_length = 12
task_type = TaskTypes.NL_TO_NL_TRANSLATION.value
# Construct vocabulary and create synthetic data by uniform randomly sampling tokens from it
corpus_source = [
"The sun is shining brightly in the clear blue sky",
"She studied hard for her exams and earned top grades",
"The cat chased the mouse around the house",
"He loves to play the guitar and sing songs",
"They enjoyed a delicious meal at their favorite restaurant",
"The book was so captivating that she couldn't put it down",
]
corpus_target = [
"I like chocolate ice cream with milk",
"Dogs bark loudly at snowy night",
"Raindrops fall gently from clouds",
"Each children will receive ##division(9,3) candies",
"Adding 3 plus 2 equals ##addition(3,2)",
"The result of subtracting 1 from 5 is ##subtraction(5,1)"
]
combined_list = corpus_source + corpus_target
# Creating the vocabulary
vocab = SimpleVocabBuilder(BatchBuilder.get_batch_io_parser_output(combined_list, True, None))
vocab_size = len(list(vocab.vocab_item_to_index.keys()))
valid_tokens = list(vocab.vocab_item_to_index.keys())[3:]
print(f"Vocabulary size: {vocab_size}")
# Creating the batch
corpus = [
{BatchBuilder.SOURCE_LANGUAGE_KEY: src, BatchBuilder.TARGET_LANGUAGE_KEY: tgt} for src, tgt in
zip(corpus_source, corpus_target)
]
batches, masks = BatchBuilder.construct_batches_for_transformer(
corpus,
batch_size=batch_size,
max_encoder_sequence_length=max_encoding_length,
max_decoder_sequence_length=max_decoding_length,
)
print(
f"valid token {len(valid_tokens)}\n"
f"corpus {len(corpus)}\n"
f"batch size: {batch_size} Number of item in batches {len(batches[BatchBuilder.ENCODER_IO_PARSER_OUTPUT_KEY])},"
f" calculated : {len(corpus) / batch_size}"
)
# Initialize transformer
transformer = Transformer(
hidden_dim=768,
batch_size=batch_size,
ff_dim=2048,
num_heads=8,
num_layers=2,
max_decoding_length=max_decoding_length,
vocab_size=vocab_size,
dropout_p=0.1,
).to(device)
# Initialize learning rate scheduler, optimizer and loss (note: the original paper uses label smoothing)
optimizer = torch.optim.Adam(
transformer.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9
)
# Load the model...
checkpoint_map = CheckPointManager.load_checkpoint_map(
TestTransformerTraining.PATH
)
transformer.load_saved_model_from_state_dict(
CheckPointManager.get_checkpoint_item(
checkpoint_map,
CheckPointManager.TRANSFORMER_STATE,
),
)
optimizer.load_state_dict(
CheckPointManager.get_checkpoint_item(
checkpoint_map,
CheckPointManager.OPTIM_STATE,
),
)
start_epoch = CheckPointManager.get_checkpoint_item(
checkpoint_map,
CheckPointManager.EPOCH,
)
print("Model loaded correctly...")
scheduler = NoamOpt(
transformer.hidden_dim,
factor=1,
warmup=400,
optimizer=optimizer,
)
criterion = nn.CrossEntropyLoss()
# Start training and verify ~zero loss and >90% accuracy on the last batch
latest_batch_loss, latest_batch_accuracy = train(
vocabulary=vocab,
transformer=transformer,
scheduler=scheduler,
criterion=criterion,
batches=batches,
masks=masks,
n_epochs=n_epochs,
task_type=task_type,
start_epoch=start_epoch,
is_training=False,
verbose_log=True,
)
print(f"batch loss {latest_batch_loss.item()}")
print(f"batch accuracy {latest_batch_accuracy}")
self.assertEqual(latest_batch_loss.item() <= 0.01, True)
self.assertEqual(latest_batch_accuracy >= 0.99, True)
if __name__ == "__main__":
unittest.main()