-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathself_p_zhou.py
52 lines (40 loc) · 2.02 KB
/
self_p_zhou.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import tensorflow as tf
def self_attention_by_peng_zhou(inputs):
"""
Attention method proposed by:
Title: Attention-Based Bidirectional Long Short-Term Memory Networks for Relation Classification
Authors: Peng Zhou, Wei Shi, Jun Tian, Zhenyu Qi, Bingchen Li, Hongwei Hao, Bo Xu
Paper: https://www.aclweb.org/anthology/P16-2034
Code author: SeoSangwoo (c), https://github.com/SeoSangwoo
Code: https://github.com/SeoSangwoo/Attention-Based-BiLSTM-relation-extraction
inputs: Tensor
tensor of shape (B, T, D), where
B -- batch, T -- time-series (context terms), D -- data or hidden embedding.
"""
assert(isinstance(inputs, tf.Tensor))
# Trainable parameters
hidden_size = inputs.shape[2].value
u_omega = tf.compat.v1.get_variable(name="u_omega",
shape=[hidden_size],
initializer=tf.keras.initializers.glorot_normal())
with tf.name_scope('v'):
v = tf.tanh(inputs)
# For each of the timestamps its vector of size A from `v` is reduced with `u` vector
vu = tf.tensordot(v, u_omega, axes=1, name='vu') # (B,T) shape
alphas = tf.nn.softmax(vu, name='alphas') # (B,T) shape
return alphas
def calculate_sequential_attentive_weights_by_peng_zhou(inputs, alphas):
""" Authors: Peng Zhou, Wei Shi, Jun Tian, Zhenyu Qi, Bingchen Li, Hongwei Hao, Bo Xu
Paper: https://www.aclweb.org/anthology/P16-2034
Proposes to reduce the time-series parameter
inputs: Tensor
tensor of shape (B, T, D), where
B -- batch, T -- time-series (context terms), D -- data or hidden embedding.
alphas: Tensor
tensor of shape (B, T), which corresponds to weights of the particular time-series item>
"""
# Output of (Bi-)RNN is reduced with attention vector; the result has (B,D) shape
output = tf.reduce_sum(inputs * tf.expand_dims(alphas, -1), 1)
# Final output with tanh
output = tf.tanh(output)
return output