-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.py
168 lines (152 loc) · 6.48 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
from utils import *
from embedding import embed
class rnn_encoder_decoder(nn.Module):
def __init__(self, x_cti_size, x_wti_size, y_wti_size):
super().__init__()
# architecture
self.enc = encoder(x_cti_size, x_wti_size)
self.dec = decoder(y_wti_size)
self = self.cuda() if CUDA else self
def forward(self, xc, xw, y0): # for training
b = y0.size(0) # batch size
loss = 0
self.zero_grad()
mask, lens = maskset(xw)
self.dec.hs = self.enc(b, xc, xw, lens)
self.dec.hidden = self.enc.hidden
yi = LongTensor([SOS_IDX] * b)
if self.dec.feed_input:
self.dec.attn.v = zeros(b, 1, HIDDEN_SIZE)
for t in range(y0.size(1)):
yo = self.dec(yi.unsqueeze(1), mask, t)
yi = y0[:, t] # teacher forcing
loss += F.nll_loss(yo, yi, ignore_index = PAD_IDX)
loss /= y0.size(1) # divide by senquence length
# loss /= y0.gt(0).sum().float() # divide by the number of unpadded tokens
return loss
def decode(self, x): # for inference
pass
class encoder(nn.Module):
def __init__(self, cti_size, wti_size):
super().__init__()
self.hidden = None # encoder hidden state
# architecture
self.embed = embed(ENC_EMBED, cti_size, wti_size)
self.rnn = getattr(nn, RNN_TYPE)(
input_size = self.embed.dim,
hidden_size = HIDDEN_SIZE // NUM_DIRS,
num_layers = NUM_LAYERS,
bias = True,
batch_first = True,
dropout = DROPOUT,
bidirectional = (NUM_DIRS == 2)
)
def init_state(self, b): # initialize RNN states
n = NUM_LAYERS * NUM_DIRS
h = HIDDEN_SIZE // NUM_DIRS
hs = zeros(n, b, h) # hidden state
if RNN_TYPE == "LSTM":
cs = zeros(n, b, h) # LSTM cell state
return (hs, cs)
return hs
def forward(self, b, xc, xw, lens):
self.hidden = self.init_state(b)
x = self.embed(xc, xw)
x = nn.utils.rnn.pack_padded_sequence(x, lens, batch_first = True)
h, _ = self.rnn(x, self.hidden)
h, _ = nn.utils.rnn.pad_packed_sequence(h, batch_first = True)
return h
class decoder(nn.Module):
def __init__(self, wti_size):
super().__init__()
self.hs = None # source hidden state
self.hidden = None # decoder hidden state
self.feed_input = True # input feeding
# architecture
self.embed = embed(DEC_EMBED, 0, wti_size)
self.rnn = getattr(nn, RNN_TYPE)(
input_size = self.embed.dim + (HIDDEN_SIZE if self.feed_input else 0),
hidden_size = HIDDEN_SIZE // NUM_DIRS,
num_layers = NUM_LAYERS,
bias = True,
batch_first = True,
dropout = DROPOUT,
bidirectional = (NUM_DIRS == 2)
)
self.attn = attn()
self.out = nn.Linear(HIDDEN_SIZE, wti_size)
self.softmax = nn.LogSoftmax(1)
def forward(self, y1, mask, t):
x = self.embed(None, y1)
if self.feed_input:
x = torch.cat((x, self.attn.v), 2)
h, _ = self.rnn(x, self.hidden)
if self.attn:
h = self.attn(h, self.hs, mask, t)
h = self.out(h).squeeze(1)
y = self.softmax(h)
return y
class attn(nn.Module): # attention layer (Luong et al 2015)
def __init__(self):
super().__init__()
self.type = "global" # global, local-m, local-p
self.method = "dot" # dot, general, concat
self.v = None # attention vector (for input feeding)
self.w = None # attention weights (for visualization)
# architecture
if self.type in ("local-m", "local-p"):
self.window_size = 5
if self.type[-1] == "p":
self.Wp = nn.Linear(HIDDEN_SIZE, HIDDEN_SIZE)
self.Vp = nn.Linear(HIDDEN_SIZE, 1)
if self.method == "general":
self.Wa = nn.Linear(HIDDEN_SIZE, HIDDEN_SIZE)
elif self.method == "concat":
pass # TODO
self.softmax = nn.Softmax(2)
self.Wc = nn.Linear(HIDDEN_SIZE * 2, HIDDEN_SIZE)
def window(self, ht, hs, mask, t): # for local attention
if self.type == "global":
return hs, mask, None
if self.type[-1] == "m": # local-m (monotonic)
p0 = max(0, min(t - self.window_size, hs.size(1) - self.window_size))
p1 = min(hs.size(1), t + 1 + self.window_size)
return hs[:, p0:p1], mask[0][:, p0:p1], None
if self.type[-1] == "p": # local-p (predicative)
s = Tensor(mask[1]) # source sequence length
pt = s * torch.sigmoid(self.Vp(torch.tanh(self.Wp(ht)))).view(-1) # aligned position
hs_w, mask_w = [], []
k = [] # truncated Gaussian distribution as kernel function
for i in range(BATCH_SIZE):
p, seq_len, min_len = int(s[i].item()), mask[1][i], mask[1][-1]
p0 = max(0, min(p - self.window_size, seq_len - self.window_size))
p1 = min(seq_len, p + 1 + self.window_size)
if min_len < p1 - p0:
p0, p1 = 0, min_len
hs_w.append(hs[i, p0:p1])
mask_w.append(mask[0][i, p0:p1])
sd = (p1 - p0) / 2 # standard deviation
v = [torch.exp(-(j - pt[i]).pow(2) / (2 * sd ** 2)) for j in range(p0, p1)]
k.append(torch.cat(v))
hs_w = torch.cat(hs_w).view(BATCH_SIZE, -1, HIDDEN_SIZE)
mask_w = torch.cat(mask_w).view(BATCH_SIZE, -1)
k = torch.cat(k).view(BATCH_SIZE, 1, -1)
return hs_w, mask_w, k
def align(self, ht, hs, mask, k):
# content-based function [B, 1, H] @ [B, H, L] = [B, 1, L]
if self.method == "dot":
a = ht.bmm(hs.transpose(1, 2))
elif self.method == "general":
a = ht.bmm(self.Wa(hs).transpose(1, 2))
elif self.method == "concat":
pass # TODO
a = self.softmax(a.masked_fill(mask.unsqueeze(1), -10000))
if self.type == "local-p":
a = a * k
return a # alignment vector as attention weights
def forward(self, ht, hs, mask, t):
hs, mask, k = self.window(ht, hs, mask, t)
self.w = self.align(ht, hs, mask, k)
c = self.w.bmm(hs) # context vector [B, 1, L] @ [B, L, H] = [B, 1, H]
self.v = torch.tanh(self.Wc(torch.cat((c, ht), 2)))
return self.v # attention vector as attentional hidden state