-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathcustom_trainer.py
53 lines (42 loc) · 1.95 KB
/
custom_trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import logging
import sys
import torch
import transformers
from transformers import Seq2SeqTrainer
logger = logging.getLogger('custom_trainer')
logger.setLevel(logging.INFO)
class Seq2SeqTrainerCustomLinearScheduler(Seq2SeqTrainer):
"""
Custom trainer to initialize Learning Rate Scheduler
and define the learning rate in the end of a training.
"""
@staticmethod
def scheduler_n_steps_for_fixed_lr_in_end(lr_max, lr_end, num_train_steps, num_warmup_steps) -> int:
assert lr_end < lr_max
return num_warmup_steps + (num_train_steps - num_warmup_steps) * lr_max / (lr_max - lr_end)
def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optimizer = None):
use_custom_scheduler = False
try:
# if learning_rate_end was passed as an argument
learning_rate_end = self.args.learning_rate_end
use_custom_scheduler = True
logger.info('TrainerCustomLinearScheduler.create_scheduler(). '
f'initializing custom linear scheduler using learning_rate_end={learning_rate_end}')
except:
logger.info('TrainerCustomLinearScheduler.create_scheduler(). '
'learning_rate_end was not set. fallback to a default behavior')
if use_custom_scheduler is True:
scheduler_num_steps = self.scheduler_n_steps_for_fixed_lr_in_end(
lr_max=self.args.learning_rate,
lr_end=learning_rate_end,
num_train_steps=num_training_steps,
num_warmup_steps=self.args.warmup_steps
)
self.lr_scheduler = transformers.get_scheduler(
'linear', optimizer=optimizer,
num_warmup_steps=self.args.warmup_steps,
num_training_steps=scheduler_num_steps
)
return self.lr_scheduler
else:
return super().create_scheduler(num_training_steps, optimizer)