-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrainer.py
153 lines (121 loc) · 5.19 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import torch
from torch import nn
from torch.utils.data import Dataset
import os
class Trainer():
def __init__(
self,
model: nn.Module,
loss_function,
optimizer,
label_vocab: dict,
device,
log_steps:int=500,
log_level:int=2):
"""
Args:
- model: the model we want to train.
- loss_function: the loss_function to minimize.
- optimizer: the optimizer used to minimize the loss_function.
- label_vocab (dictionary): vocabulary for the labels
- log_steps (int): Number of iterations that we use to observe the loss function trend.
- log_level (int): Always use to observe the loss function trend
"""
self.model = model
self.loss_function = loss_function
self.optimizer = optimizer
self.device = device
self.label_vocab = label_vocab
self.log_steps = log_steps
self.log_level = log_level
self.label_vocab = label_vocab
def train(self, train_dataset:Dataset,
valid_dataset:Dataset,
epochs:int=1):
"""
Args:
- train_dataset: a Dataset or DatasetLoader instance containing
the training instances.
- valid_dataset: a Dataset or DatasetLoader instance used to evaluate
learning progress.
- epochs: the number of times to iterate over train_dataset.
Returns:
- avg_train_loss: the average training loss on train_dataset over epochs.
"""
assert epochs > 1 and isinstance(epochs, int)
train_loss = 0.0
patience = 0
best_loss = 30000
for epoch in range(epochs):
if self.log_level > 0:
print(' Epoch {:03d}'.format(epoch + 1))
epoch_loss = 0.0
self.model.train()
for step, sample in enumerate(train_dataset):
inputs = sample['inputs'].to(self.device)
labels = sample['outputs'].to(self.device)
chars = sample['char'].to(self.device)
self.optimizer.zero_grad()
predictions = self.model(inputs,chars)
predictions = predictions.view(-1, predictions.shape[-1])
labels = labels.view(-1)
sample_loss = self.loss_function(predictions, labels)
sample_loss.backward()
self.optimizer.step()
epoch_loss += sample_loss.tolist()
if self.log_level > 1 and step % self.log_steps == self.log_steps - 1:
print('\t[E: {:2d} @ step {}] current avg loss = {:0.4f}'.format(epoch, step, epoch_loss / (step + 1)))
avg_epoch_loss = epoch_loss / len(train_dataset)
train_loss += avg_epoch_loss
if self.log_level > 0:
print('\t[E: {:2d}] train loss = {:0.4f}'.format(epoch, avg_epoch_loss))
valid_loss, patience, best_loss = self.evaluate(valid_dataset, patience, best_loss)
if patience > 5:
print(f"\033[No Improvement for 5 steps \033[0m")
break
if self.log_level > 0:
print(' [E: {:2d}] valid loss = {:0.4f}'.format(epoch, valid_loss))
if self.log_level > 0:
print('... Done!')
avg_epoch_loss = train_loss / epochs
return avg_epoch_loss
def evaluate(self, valid_dataset, patience ,best_loss):
"""
Args:
- valid_dataset: the dataset to use to evaluate the model.
Returns:
- avg_valid_loss: the average validation loss over valid_dataset.
"""
valid_loss = 0.0
self.model.eval()
with torch.no_grad():
for sample in valid_dataset:
inputs = sample['inputs'].to(self.device)
labels = sample['outputs'].to(self.device)
chars = sample['char'].to(self.device)
predictions = self.model(inputs,chars)
predictions = predictions.view(-1, predictions.shape[-1])
labels = labels.view(-1)
sample_loss = self.loss_function(predictions, labels)
valid_loss += sample_loss.tolist()
if (valid_loss / len(valid_dataset)) < best_loss:
best_loss = valid_loss / len(valid_dataset)
patience = 0
torch.save(self.model.state_dict(), 'models/best_model.pt')
print(f"\033[Improvement perfomances, model saved\033[0m")
else:
patience += 1
return valid_loss / len(valid_dataset), patience, best_loss
def predict(self, x):
"""
Args:
- x: a tensor of indices.
Returns:
- A list containing the predicted NER tag for each token in the
input sentences.
"""
self.model.eval()
with torch.no_grad():
logits = self.model(x)
predictions = torch.argmax(logits, -1)
return logits, predictions