-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathscheduler.py
47 lines (37 loc) · 1.58 KB
/
scheduler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
from PIL import *
from torch.optim.lr_scheduler import _LRScheduler
class CustomScheduler(_LRScheduler):
def __init__(self, optimizer, lr_start=5e-6, lr_max=1e-5,
lr_min=1e-6, lr_ramp_ep=5, lr_sus_ep=0, lr_decay=0.8,
last_epoch=-1):
self.lr_start = lr_start
self.lr_max = lr_max
self.lr_min = lr_min
self.lr_ramp_ep = lr_ramp_ep
self.lr_sus_ep = lr_sus_ep
self.lr_decay = lr_decay
super(CustomScheduler, self).__init__(optimizer, last_epoch)
def get_lr(self):
if not self._get_lr_called_within_step:
warnings.warn("To get the last learning rate computed by the scheduler, "
"please use `get_last_lr()`.", UserWarning)
if self.last_epoch == 0:
self.last_epoch += 1
return [self.lr_start for _ in self.optimizer.param_groups]
lr = self._compute_lr_from_epoch()
self.last_epoch += 1
return [lr for _ in self.optimizer.param_groups]
def _get_closed_form_lr(self):
return self.base_lrs
def _compute_lr_from_epoch(self):
if self.last_epoch < self.lr_ramp_ep:
lr = ((self.lr_max - self.lr_start) /
self.lr_ramp_ep * self.last_epoch +
self.lr_start)
elif self.last_epoch < self.lr_ramp_ep + self.lr_sus_ep:
lr = self.lr_max
else:
lr = ((self.lr_max - self.lr_min) * self.lr_decay **
(self.last_epoch - self.lr_ramp_ep - self.lr_sus_ep) +
self.lr_min)
return lr