-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathadabound.py
153 lines (125 loc) · 5.45 KB
/
adabound.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
"""AdaBound for Tensorflow."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import re
class AdaBoundOptimizer(tf.train.Optimizer):
"""Optimizer that implements the AdaBound algorithm.
See [Luo et al., 2019](https://openreview.net/forum?id=Bkg3g2R9FX)
([pdf](https://openreview.net/pdf?id=Bkg3g2R9FX)).
"""
def __init__(self,
learning_rate=0.001,
final_lr=0.1,
beta1=0.9,
beta2=0.999,
gamma=1e-3,
epsilon=1e-8,
amsbound=False,
decay=0.,
weight_decay=0.,
exclude_from_weight_decay=None,
use_locking=False, name="AdaBound"):
super(AdaBoundOptimizer, self).__init__(use_locking, name)
if final_lr <= 0.:
raise ValueError("Invalid final learning rate : {}".format(final_lr))
if not 0. <= beta1 < 1.:
raise ValueError("Invalid beta1 value : {}".format(beta1))
if not 0. <= beta2 < 1.:
raise ValueError("Invalid beta2 value : {}".format(beta2))
if not 0. <= gamma < 1.:
raise ValueError("Invalid gamma value : {}".format(gamma))
if epsilon <= 0.:
raise ValueError("Invalid epsilon value : {}".format(epsilon))
self._lr = learning_rate
self._beta1 = beta1
self._beta2 = beta2
self._final_lr = final_lr
self._gamma = gamma
self._epsilon = epsilon
self._amsbound = amsbound
self._decay = decay
self._weight_decay = weight_decay
self._exclude_from_weight_decay = exclude_from_weight_decay
self._base_lr = learning_rate
def apply_gradients(self, grads_and_vars, global_step=None, name=None):
lr = self._lr
t = tf.cast(global_step, dtype=tf.float32)
if self._decay > 0.:
lr *= (1. / (1. + self._decay * t))
t += 1
bias_correction1 = 1. - (self._beta1 ** t)
bias_correction2 = 1. - (self._beta2 ** t)
step_size = (lr * tf.sqrt(bias_correction2) / bias_correction1)
# Applies bounds on actual learning rate
# lr_scheduler cannot affect final_lr, this is a workaround to apply lr decay
final_lr = self._final_lr * lr / self._base_lr
lower_bound = final_lr * (1. - 1. / (self._gamma * t + 1.))
upper_bound = final_lr * (1. + 1. / (self._gamma * t))
assignments = []
for grad, param in grads_and_vars:
if grad is None or param is None:
continue
param_name = self._get_variable_name(param.name)
m = tf.get_variable(
name=param_name + "/adabound_m",
shape=param.shape.as_list(),
dtype=tf.float32,
trainable=False,
initializer=tf.zeros_initializer())
v = tf.get_variable(
name=param_name + "/adabound_v",
shape=param.shape.as_list(),
dtype=tf.float32,
trainable=False,
initializer=tf.zeros_initializer())
if self._amsbound:
v_hat = tf.get_variable(
name=param_name + "/adabound_v_hat",
shape=param.shape.as_list(),
dtype=tf.float32,
trainable=False,
initializer=tf.zeros_initializer())
m_t = (
tf.multiply(self._beta1, m) + tf.multiply(1. - self._beta1, grad))
v_t = (
tf.multiply(self._beta2, v) + tf.multiply(1. - self._beta2, tf.square(grad)))
if self._amsbound:
# Maintains the maximum of all 2nd moment running avg. till now
v_hat_t = tf.maximum(v_hat, v_t)
# Use the max. for normalizing running avg. of gradient
denom = (tf.sqrt(v_hat_t) + self._epsilon)
else:
denom = (tf.sqrt(v_t) + self._epsilon)
step_size_p = step_size * tf.ones_like(denom)
step_size_p_bound = step_size_p / denom
lr_t = m_t * tf.clip_by_value(t=step_size_p_bound,
clip_value_min=lower_bound,
clip_value_max=upper_bound)
p_t = param - lr_t
if self._do_use_weight_decay(param_name):
p_t += self._weight_decay * param
update_list = [param.assign(p_t), m.assign(m_t), v.assign(v_t)]
if self._amsbound:
update_list.append(v_hat.assign(v_hat_t))
assignments.extend(update_list)
# update the global step
assignments.append(global_step.assign_add(1))
return tf.group(*assignments, name=name)
def _do_use_weight_decay(self, param_name):
"""Whether to use L2 weight decay for `param_name`."""
if not self._weight_decay:
return False
if self._exclude_from_weight_decay:
for r in self.exclude_from_weight_decay:
if re.search(r, param_name) is not None:
return False
return True
@staticmethod
def _get_variable_name(param_name):
"""Get the variable name from the tensor name."""
m = re.match("^(.*):\\d+$", param_name)
if m is not None:
param_name = m.group(1)
return param_name