-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathself_att_bilstm.py
199 lines (147 loc) · 7.7 KB
/
self_att_bilstm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
import tensorflow as tf
from arenets.attention import common
from arenets.context.configurations.self_att_bilstm import SelfAttentionBiLSTMConfig
from arenets.sample import InputSample
from arenets.arekit.common.data_type import DataType
from arenets.tf_helpers import sequence
from arenets.context.architectures.base.base import SingleInstanceNeuralNetwork
class SelfAttentionBiLSTM(SingleInstanceNeuralNetwork):
"""
A Structured Self-attentive Sentence Embedding (ICLR 2017)
Paper: https://arxiv.org/pdf/1703.03130.pdf
Code Author: roomylee, https://github.com/roomylee (C)
Code: https://github.com/roomylee/self-attentive-emb-tf
"""
def __init__(self):
super(SelfAttentionBiLSTM, self).__init__()
# hidden
self.__A = None
self.__avg_by_r_A = None
self.__W_s1 = None
self.__W_s2 = None
self.__W_output = None
self.__b_output = None
self.__dropout_rnn_keep_prob = None
# region properties
@property
def ContextEmbeddingSize(self):
"""
Returns: flattened M
return r * 2u, where u is an output of a single direction in bilstm.
"""
return self.Config.RSize * 2 * self.Config.HiddenSize
# endregion
# region public 'set' methods
def set_input_rnn_keep_prob(self, value):
self.__dropout_rnn_keep_prob = value
# endregion
# region public 'init' methods
def init_input(self):
super(SelfAttentionBiLSTM, self).init_input()
self.__dropout_rnn_keep_prob = tf.compat.v1.placeholder(dtype=tf.float32,
name="ctx_dropout_rnn_keep_prob")
def init_context_embedding(self, embedded_terms):
assert(isinstance(self.Config, SelfAttentionBiLSTMConfig))
# Bidirectional(Left&Right) Recurrent Structure
with tf.name_scope("bi-lstm"):
x_length = sequence.calculate_sequence_length(self.get_input_parameter(InputSample.I_X_INDS))
s_length = tf.cast(x=tf.maximum(x_length, 1), dtype=tf.int32)
fw_cell = sequence.get_cell(hidden_size=self.Config.HiddenSize,
cell_type=self.Config.CellType,
dropout_rnn_keep_prob=self.__dropout_rnn_keep_prob)
bw_cell = sequence.get_cell(hidden_size=self.Config.HiddenSize,
cell_type=self.Config.CellType,
dropout_rnn_keep_prob=self.__dropout_rnn_keep_prob)
(self.output_fw, self.output_bw), states = sequence.bidirectional_rnn(cell_fw=fw_cell,
cell_bw=bw_cell,
inputs=embedded_terms,
sequence_length=s_length,
dtype=tf.float32)
H = tf.concat([self.output_fw, self.output_bw], axis=2)
H_reshape = tf.reshape(H, [-1, 2 * self.Config.HiddenSize])
with tf.name_scope("self-attention"):
_H_s1 = tf.nn.tanh(tf.matmul(H_reshape, self.__W_s1))
_H_s2 = tf.matmul(_H_s1, self.__W_s2)
_H_s2_reshape = tf.transpose(tf.reshape(_H_s2, [-1, self.Config.TermsPerContext, self.Config.RSize]),
perm=[0, 2, 1])
self.__A = tf.nn.softmax(_H_s2_reshape, name="attention")
self.__avg_by_r_A = tf.reduce_mean(self.__A, axis=-2)
with tf.name_scope("sentence-embedding"):
# M shape (r, 2u)
M = tf.matmul(self.__A, H)
# M_flat (batch_size, r * 2u)
return tf.reshape(M, shape=[-1, self.ContextEmbeddingSize])
def init_body_dependent_hidden_states(self):
assert(isinstance(self.Config, SelfAttentionBiLSTMConfig))
self.__W_s1 = tf.compat.v1.get_variable(
name="W_s1",
shape=[2 * self.Config.HiddenSize, self.Config.DASize],
regularizer=self.Config.LayerRegularizer,
initializer=self.Config.WeightInitializer)
self.__W_s2 = tf.compat.v1.get_variable(
name="W_s2",
shape=[self.Config.DASize, self.Config.RSize],
regularizer=self.Config.LayerRegularizer,
initializer=self.Config.WeightInitializer)
def init_logits_hidden_states(self):
assert(isinstance(self.Config, SelfAttentionBiLSTMConfig))
self.__W_output = tf.compat.v1.get_variable(
name="W_output",
shape=[self.Config.FullyConnectionSize, self.Config.ClassesCount],
regularizer=self.Config.LayerRegularizer,
initializer=self.Config.WeightInitializer)
self.__b_output = tf.compat.v1.get_variable(
name="b_output",
shape=[self.Config.ClassesCount],
regularizer=self.Config.LayerRegularizer,
initializer=self.Config.BiasInitializer)
self.__W_fc = tf.compat.v1.get_variable(
name="W_fc",
shape=[2 * self.Config.HiddenSize * self.Config.RSize, self.Config.FullyConnectionSize],
regularizer=self.Config.LayerRegularizer,
initializer=self.Config.WeightInitializer)
self.__b_fc = tf.compat.v1.get_variable(
name="b_fc",
shape=[self.Config.FullyConnectionSize],
regularizer=self.Config.LayerRegularizer,
initializer=self.Config.BiasInitializer)
def init_logits_unscaled(self, context_embedding):
"""
context_embedding: M_flat parameter of init_context_embedding
M_flat shape (r * 2u)
"""
with tf.name_scope("fully-connected"):
fc = tf.nn.relu(tf.compat.v1.nn.xw_plus_b(context_embedding, self.__W_fc, self.__b_fc), name="fc")
with tf.name_scope("output"):
logits = tf.compat.v1.nn.xw_plus_b(x=fc, weights=self.__W_output, biases=self.__b_output, name="logits")
return logits, tf.nn.dropout(logits, self.DropoutKeepProb)
def init_cost(self, logits_unscaled_dropped):
loss = super(SelfAttentionBiLSTM, self).init_cost(logits_unscaled_dropped)
with tf.name_scope("penalization"):
AA_T = tf.matmul(self.__A, tf.transpose(self.__A, [0, 2, 1]))
I = tf.reshape(tensor=tf.tile(tf.eye(self.Config.RSize), [tf.shape(self.__A)[0], 1]),
shape=[-1, self.Config.RSize, self.Config.RSize])
P = tf.square(tf.norm(AA_T - I, axis=[-2, -1], ord="fro"))
return loss + tf.reduce_mean(P * self.Config.PenaltizationTermCoef)
# endregion
# region public 'create' methods
def create_feed_dict(self, input, data_type):
feed_dict = super(SelfAttentionBiLSTM, self).create_feed_dict(input=input, data_type=data_type)
feed_dict[self.__dropout_rnn_keep_prob] = self.Config.DropoutRNNKeepProb if data_type == DataType.Train else 1.0
return feed_dict
# endregion
# region public 'iter' methods
def iter_input_dependent_hidden_parameters(self):
for name, value in super(SelfAttentionBiLSTM, self).iter_input_dependent_hidden_parameters():
yield name, value
yield common.ATTENTION_WEIGHTS_LOG_PARAMETER, self.__avg_by_r_A
def iter_hidden_parameters(self):
if self.__W_s1 is not None:
yield ("W_s1", self.__W_s1)
if self.__W_s2 is not None:
yield ("W_s2", self.__W_s2)
if self.__W_output is not None:
yield ("W_output", self.__W_output)
if self.__b_output is not None:
yield ("b_output", self.__b_output)
# endregion