-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.py
42 lines (30 loc) · 1.29 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import tensorflow as tf
from tensorflow.keras.layers import LSTM, Dense
class LSTMOpt(tf.keras.models.Model):
"""
The optimizer defined in the paper with a few changes.
Architecture is slightly different.
"""
def __init__(self, **kwargs):
super(LSTMOpt, self).__init__(**kwargs)
self.lstm1 = LSTM(20,
return_state=True,
return_sequences=True)
self.lstm2 = LSTM(20,
return_state=True)
self.dense1 = Dense(10,
name='10 dim update')
def call(self, gradients, states):
if states == None:
h1,c1,h2,c2 = tf.zeros([128,20]), tf.zeros([128,20]), tf.zeros([128,20]), tf.zeros([128,20])
else:
h1,c1,h2,c2 = states
if gradients == None:
gradients = tf.zeros([128,1,10])
else:
gradients = tf.reshape(gradients, [128,1,10])
seq, h1, c1 = self.lstm1(gradients, initial_state = [h1, c1])
_, h2,c2 = self.lstm2(seq, initial_state = [h2,c2])
update = self.dense1(h2)
states = [h1, c1, h2, c2]
return update[..., None], states