forked from yikangshen/PRPN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
model_PRPN.py
112 lines (86 loc) · 4.11 KB
/
model_PRPN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import torch
import torch.nn as nn
from ParsingNetwork import ParsingNetwork
from PredictNetwork import PredictNetwork
from ReadingNetwork import ReadingNetwork
class PRPN(nn.Module):
"""Container module with an encoder, a recurrent module, and a decoder."""
def __init__(self, ntoken, ninp, nhid, nlayers,
nslots=5, nlookback=1, resolution=0.1,
dropout=0.4, idropout=0.4, rdropout=0.1,
tie_weights=False, hard=False, res=1):
super(PRPN, self).__init__()
self.nhid = nhid
self.ninp = ninp
self.nlayers = nlayers
self.nslots = nslots
self.nlookback = nlookback
self.drop = nn.Dropout(dropout)
self.idrop = nn.Dropout(idropout)
self.rdrop = nn.Dropout(rdropout)
# Feedforward layers
self.encoder = nn.Embedding(ntoken, ninp)
self.parser = ParsingNetwork(ninp, nhid, nslots, nlookback, resolution, idropout, hard)
self.reader = nn.ModuleList([ReadingNetwork(ninp, nhid, nslots, dropout=dropout, idropout=idropout), ] +
[ReadingNetwork(nhid, nhid, nslots, dropout=idropout, idropout=idropout)
for i in range(nlayers - 1)])
self.predictor = PredictNetwork(nhid, ninp, nslots, idropout, res)
self.decoder = nn.Linear(ninp, ntoken)
if tie_weights:
self.decoder.weight = self.encoder.weight
self.attentions = None
self.gates = None
self.init_weights()
def init_weights(self):
initrange = 0.01
self.encoder.weight.data.uniform_(-initrange, initrange)
self.decoder.bias.data.fill_(0)
self.decoder.weight.data.uniform_(-initrange, initrange)
def clip_grad_norm(self, clip):
for model in self.reader:
torch.nn.utils.clip_grad_norm(model.memory_rnn.parameters(), clip)
def forward(self, input, hidden_states):
ntimestep = input.size(0)
bsz = input.size(1)
emb = self.encoder(input) # timesteps, bsz, ninp
output_h = []
output_memory = []
attentions = []
reader_state, parser_state, predictor_state = hidden_states # memory_h: bsz, nslots, nhid
(memory_gate, memory_gate_next), gate, parser_state = self.parser(emb, parser_state)
rmask = torch.autograd.Variable(torch.ones(self.nlayers, self.nhid))
if input.is_cuda: rmask = rmask.cuda()
rmask = self.rdrop(rmask)
for i in range(input.size(0)):
emb_i = emb[i] # emb_i: bsz, nhid
attention = []
attention.append(memory_gate[i])
# summarize layer
h_i = emb_i
for j in range(self.nlayers):
hidden = reader_state[j]
h_i, new_memory, attention0 = self.reader[j](h_i, hidden, memory_gate[i], rmask[j])
# updata states
attention.append(attention0)
reader_state[j] = new_memory
# predict layer
selected_memory_h, predictor_state, attention1 = self.predictor.attention(h_i, predictor_state,
gate_time=memory_gate_next[i])
output_h.append(h_i)
output_memory.append(selected_memory_h)
attention.append(memory_gate_next[i])
attention.append(attention1)
attentions.append(torch.stack(attention, dim=1))
self.attentions = torch.stack(attentions, dim=0)
self.gates = gate
output_h = torch.stack(output_h, dim=0)
output_memory = torch.stack(output_memory, dim=0)
output = self.predictor(output_h.view(-1, self.nhid), output_memory.view(-1, self.nhid))
output = self.drop(output)
decoded = self.decoder(output)
return decoded.view(ntimestep, bsz, -1), (reader_state, parser_state, predictor_state)
def init_hidden(self, bsz):
return [self.reader[i].init_hidden(bsz)
for i in range(self.nlayers)], \
self.parser.init_hidden(bsz), \
self.predictor.init_hidden(bsz)