-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathOneCycle.py
84 lines (75 loc) · 3.77 KB
/
OneCycle.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
class OneCycle(object):
"""
In paper (https://arxiv.org/pdf/1803.09820.pdf), author suggests to do one cycle during
whole run with 2 steps of equal length. During first step, increase the learning rate
from lower learning rate to higher learning rate. And in second step, decrease it from
higher to lower learning rate. This is Cyclic learning rate policy. Author suggests one
addition to this. - During last few hundred/thousand iterations of cycle reduce the
learning rate to 1/100th or 1/1000th of the lower learning rate.
Also, Author suggests that reducing momentum when learning rate is increasing. So, we make
one cycle of momentum also with learning rate - Decrease momentum when learning rate is
increasing and increase momentum when learning rate is decreasing.
Args:
nb Total number of iterations including all epochs
max_lr The optimum learning rate. This learning rate will be used as highest
learning rate. The learning rate will fluctuate between max_lr to
max_lr/div and then (max_lr/div)/div.
momentum_vals The maximum and minimum momentum values between which momentum will
fluctuate during cycle.
Default values are (0.95, 0.85)
prcnt The percentage of cycle length for which we annihilate learning rate
way below the lower learnig rate.
The default value is 10
div The division factor used to get lower boundary of learning rate. This
will be used with max_lr value to decide lower learning rate boundary.
This value is also used to decide how much we annihilate the learning
rate below lower learning rate.
The default value is 10.
"""
def __init__(self, nb, max_lr, momentum_vals=(0.95, 0.85), prcnt= 10 , div=10):
self.nb = nb
self.div = div
self.step_len = int(self.nb * (1- prcnt/100)/2)
self.high_lr = max_lr
self.low_mom = momentum_vals[1]
self.high_mom = momentum_vals[0]
self.prcnt = prcnt
self.iteration = 0
self.lrs = []
self.moms = []
def calc(self):
self.iteration += 1
lr = self.calc_lr()
mom = self.calc_mom()
return (lr, mom)
def calc_lr(self):
if self.iteration==self.nb:
self.iteration = 0
self.lrs.append(self.high_lr/self.div)
return self.high_lr/self.div
if self.iteration > 2 * self.step_len:
ratio = (self.iteration - 2 * self.step_len) / (self.nb - 2 * self.step_len)
lr = self.high_lr * ( 1 - 0.99 * ratio)/self.div
elif self.iteration > self.step_len:
ratio = 1- (self.iteration -self.step_len)/self.step_len
lr = self.high_lr * (1 + ratio * (self.div - 1)) / self.div
else :
ratio = self.iteration/self.step_len
lr = self.high_lr * (1 + ratio * (self.div - 1)) / self.div
self.lrs.append(lr)
return lr
def calc_mom(self):
if self.iteration==self.nb:
self.iteration = 0
self.moms.append(self.high_mom)
return self.high_mom
if self.iteration > 2 * self.step_len:
mom = self.high_mom
elif self.iteration > self.step_len:
ratio = (self.iteration -self.step_len)/self.step_len
mom = self.low_mom + ratio * (self.high_mom - self.low_mom)
else :
ratio = self.iteration/self.step_len
mom = self.high_mom - ratio * (self.high_mom - self.low_mom)
self.moms.append(mom)
return mom