-
Notifications
You must be signed in to change notification settings - Fork 12
/
Copy pathtransformer.py
324 lines (268 loc) · 10.7 KB
/
transformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.autograd import Variable
import math
import copy
from load_data import EXTRA_CHARS
class OneHotEmbedding(nn.Module):
def __init__(self, alphabet_size):
super().__init__()
self.alphabet_size = alphabet_size
self.embedding = nn.Embedding.from_pretrained(torch.eye(alphabet_size))
def forward(self, x):
return self.embed(x)
class Embedding(nn.Module):
def __init__(self, alphabet_size, d_model):
super().__init__()
self.alphabet_size = alphabet_size
self.d_model = d_model
self.embed = nn.Embedding(alphabet_size, d_model)
def forward(self, x):
return self.embed(x)
class PositionalEncoder(nn.Module):
def __init__(self, d_model, max_seq_len = 6000, dropout = 0.1):
super().__init__()
self.d_model = d_model
self.dropout = nn.Dropout(p=dropout)
# create constant 'pe' matrix with values dependant on
# pos and i
pe = torch.zeros(max_seq_len, d_model)
for pos in range(max_seq_len):
for i in range(0, d_model, 2):
pe[pos, i] = \
math.sin(pos / (10000 ** ((2 * i)/d_model)))
pe[pos, i + 1] = \
math.cos(pos / (10000 ** ((2 * (i + 1))/d_model)))
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)
def forward(self, x):
# make embeddings relatively larger
x = x * math.sqrt(self.d_model)
#add constant to embedding
seq_len = x.size(1)
pe = self.pe[:,:seq_len]
pe = Variable(self.pe[:,:seq_len], requires_grad=False)
if x.is_cuda:
pe.cuda()
x = x + pe
#print(x.mean(), x)
x = self.dropout(x)
#x = F.dropout(x, p=0.1, training=self.training)
#print(x.mean(), x)
return x
class Norm(nn.Module):
def __init__(self, d_model, eps = 1e-6):
super().__init__()
self.size = d_model
# create two learnable parameters to calibrate normalisation
self.alpha = nn.Parameter(torch.ones(self.size))
self.bias = nn.Parameter(torch.zeros(self.size))
self.eps = eps
def forward(self, x):
norm = self.alpha * (x - x.mean(dim=-1, keepdim=True)) \
/ (x.std(dim=-1, keepdim=True) + self.eps) + self.bias
return norm
def attention(q, k, v, d_k, mask=None, dropout=None):
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
if mask is not None:
mask = mask.unsqueeze(1)
scores = scores.masked_fill(mask == 0, -1e9)
scores = F.softmax(scores, dim=-1)
if dropout is not None:
scores = dropout(scores)
output = torch.matmul(scores, v)
return output
class MultiHeadAttention(nn.Module):
def __init__(self, heads, d_model, dropout = 0.1):
super().__init__()
self.d_model = d_model
self.d_k = d_model // heads
self.h = heads
self.q_linear = nn.Linear(d_model, d_model)
self.v_linear = nn.Linear(d_model, d_model)
self.k_linear = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
self.out = nn.Linear(d_model, d_model)
def forward(self, q, k, v, mask=None):
bs = q.size(0)
# perform linear operation and split into N heads
k = self.k_linear(k).view(bs, -1, self.h, self.d_k)
q = self.q_linear(q).view(bs, -1, self.h, self.d_k)
v = self.v_linear(v).view(bs, -1, self.h, self.d_k)
# transpose to get dimensions bs * N * sl * d_model
k = k.transpose(1,2)
q = q.transpose(1,2)
v = v.transpose(1,2)
# calculate attention using function we will define next
scores = attention(q, k, v, self.d_k, mask, self.dropout)
# concatenate heads and put through final linear layer
concat = scores.transpose(1,2).contiguous()\
.view(bs, -1, self.d_model)
output = self.out(concat)
return output
class FeedForward(nn.Module):
def __init__(self, d_model, d_ff=2048, dropout = 0.1):
super().__init__()
# We set d_ff as a default to 2048
self.linear_1 = nn.Linear(d_model, d_ff)
self.dropout = nn.Dropout(dropout)
self.linear_2 = nn.Linear(d_ff, d_model)
def forward(self, x):
x = self.dropout(F.relu(self.linear_1(x)))
x = self.linear_2(x)
return x
class EncoderLayer(nn.Module):
def __init__(self, d_model, heads, dropout=0.1):
super().__init__()
self.norm_1 = Norm(d_model)
self.norm_2 = Norm(d_model)
self.attn = MultiHeadAttention(heads, d_model, dropout=dropout)
self.ff = FeedForward(d_model, dropout=dropout)
self.dropout_1 = nn.Dropout(dropout)
self.dropout_2 = nn.Dropout(dropout)
def forward(self, x, mask):
x2 = self.norm_1(x)
x = x + self.dropout_1(self.attn(x2,x2,x2,mask))
x2 = self.norm_2(x)
x = x + self.dropout_2(self.ff(x2))
return x
# build a decoder layer with two multi-head attention layers and
# one feed-forward layer
class DecoderLayer(nn.Module):
def __init__(self, d_model, heads, dropout=0.1):
super().__init__()
self.norm_1 = Norm(d_model)
self.norm_2 = Norm(d_model)
self.norm_3 = Norm(d_model)
self.dropout_1 = nn.Dropout(dropout)
self.dropout_2 = nn.Dropout(dropout)
self.dropout_3 = nn.Dropout(dropout)
self.attn_1 = MultiHeadAttention(heads, d_model, dropout=dropout)
self.attn_2 = MultiHeadAttention(heads, d_model, dropout=dropout)
self.ff = FeedForward(d_model, dropout=dropout)
def forward(self, x, e_outputs, src_mask, trg_mask):
x2 = self.norm_1(x)
x = x + self.dropout_1(self.attn_1(x2, x2, x2, trg_mask))
x2 = self.norm_2(x)
x = x + self.dropout_2(self.attn_2(x2, e_outputs, e_outputs, \
src_mask))
x2 = self.norm_3(x)
x = x + self.dropout_3(self.ff(x2))
return x
def get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
class Encoder(nn.Module):
def __init__(self, alphabet_size, d_model, N, heads, dropout):
super().__init__()
self.N = N
self.embed = Embedding(alphabet_size, d_model)
self.pe = PositionalEncoder(d_model, dropout=dropout)
self.layers = get_clones(EncoderLayer(d_model, heads, dropout), N)
self.norm = Norm(d_model)
def forward(self, src, mask):
x = self.embed(src)
x = self.pe(x)
for i in range(self.N):
x = self.layers[i](x, mask)
return self.norm(x)
class Decoder(nn.Module):
def __init__(self, alphabet_size, d_model, N, heads, dropout):
super().__init__()
self.N = N
self.embed = Embedding(alphabet_size, d_model)
self.pe = PositionalEncoder(d_model, dropout=dropout)
self.layers = get_clones(DecoderLayer(d_model, heads, dropout), N)
self.norm = Norm(d_model)
def forward(self, trg, e_outputs, src_mask, trg_mask):
x = self.embed(trg)
x = self.pe(x)
for i in range(self.N):
x = self.layers[i](x, e_outputs, src_mask, trg_mask)
return self.norm(x)
class Transformer(nn.Module):
def __init__(self, alphabet_size, d_model, N, heads=8, dropout=0.1):
super().__init__()
self.encoder = Encoder(alphabet_size, d_model, N, heads, dropout)
self.decoder = Decoder(alphabet_size, d_model, N, heads, dropout)
self.out = nn.Linear(d_model, alphabet_size)
def forward(self, src, trg, src_mask, trg_mask):
e_outputs = self.encoder(src, src_mask)
#print("DECODER")
d_output = self.decoder(trg, e_outputs, src_mask, trg_mask)
output = self.out(d_output)
return output
def nopeak_mask(size, device):
np_mask = torch.triu(torch.ones((size, size), dtype=torch.uint8), diagonal=1).unsqueeze(0)
np_mask = np_mask == 0
np_mask = np_mask.to(device)
return np_mask
def create_masks(src, trg=None, pad_idx=ord(EXTRA_CHARS['pad']), device=None):
src_mask = (src != pad_idx).unsqueeze(-2)
if trg is not None:
trg_mask = (trg != pad_idx).unsqueeze(-2)
size = trg.size(1) # get seq_len for matrix
np_mask = nopeak_mask(size, device)
np_mask.to(device)
trg_mask = trg_mask & np_mask
return src_mask, trg_mask
return src_mask
class CosineWithRestarts(torch.optim.lr_scheduler._LRScheduler):
"""
Cosine annealing with restarts.
Parameters
----------
optimizer : torch.optim.Optimizer
T_max : int
The maximum number of iterations within the first cycle.
eta_min : float, optional (default: 0)
The minimum learning rate.
last_epoch : int, optional (default: -1)
The index of the last epoch.
"""
def __init__(self,
optimizer,
T_max,
eta_min = 0.,
last_epoch = -1,
factor = 1.):
# pylint: disable=invalid-name
self.T_max = T_max
self.eta_min = eta_min
self.factor = factor
self._last_restart = 0
self._cycle_counter = 0
self._cycle_factor = 1.
self._updated_cycle_len = T_max
self._initialized = False
super(CosineWithRestarts, self).__init__(optimizer, last_epoch)
def get_lr(self):
"""Get updated learning rate."""
# HACK: We need to check if this is the first time get_lr() was called, since
# we want to start with step = 0, but _LRScheduler calls get_lr with
# last_epoch + 1 when initialized.
if not self._initialized:
self._initialized = True
return self.base_lrs
step = self.last_epoch + 1
self._cycle_counter = step - self._last_restart
lrs = [
(
self.eta_min + ((lr - self.eta_min) / 2) *
(
np.cos(
np.pi *
((self._cycle_counter) % self._updated_cycle_len) /
self._updated_cycle_len
) + 1
)
) for lr in self.base_lrs
]
if self._cycle_counter % self._updated_cycle_len == 0:
# Adjust the cycle length.
self._cycle_factor *= self.factor
self._cycle_counter = 0
self._updated_cycle_len = int(self._cycle_factor * self.T_max)
self._last_restart = step
return lrs