-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathearlystop.py
133 lines (116 loc) · 5.32 KB
/
earlystop.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import torch
import os
import torch.distributed as dist
class EarlyStopping:
def __init__(
self,
path,
patience=7,
verbose=False,
min_lr=1e-6,
early_stopping_enabled=True,
best_score=None,
counter=0,
args=None,
):
self.patience = patience
self.verbose = verbose
self.best_score = best_score
self.counter = counter
self.early_stop = False
self.path = path
self.early_stopping_enabled = early_stopping_enabled
self.best_epochs = []
self.last_epochs = []
self.min_lr = min_lr
self.model_name = args.model_name
def __call__(self, val_accuracy, model, optimizer, epoch):
score = val_accuracy
if self.early_stopping_enabled:
if self.best_score is None:
self.best_score = score
self.save_best_model(model, optimizer, epoch)
elif score < self.best_score + 0.001:
self.counter += 1
if dist.get_rank() == 0:
print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
if self.counter >= self.patience:
for param_group in optimizer.param_groups:
if param_group['lr'] > self.min_lr:
if dist.get_rank() == 0:
print(f'Reducing learning rate from {param_group["lr"]} to {param_group["lr"] * 0.1}')
param_group['lr'] *= 0.1
self.counter = 0 # reset the counter
else:
self.early_stop = True
else:
if self.verbose and dist.get_rank() == 0:
print(f'Validation accuracy increased ({self.best_score:.4f} --> {score:.4f}). Saving model ...')
self.best_score = score
self.save_best_model(model, optimizer, epoch)
self.counter = 0
self.save_last_epochs(model, optimizer, epoch, index=1, laststop=True)
def save_best_model(self, model, optimizer, epoch):
# Check if the model is CLIPModel and save only fc layer's state_dict
if self.model_name == 'clip':
model_state_dict = model.module.fc.state_dict()
else:
model_state_dict = model.state_dict()
if dist.get_rank() == 0:
state = {
'epoch': epoch,
'counter': self.counter,
'best_score': self.best_score,
'model_state_dict': model_state_dict,
'optimizer_state_dict': optimizer.state_dict(),
}
torch.save(state, self.path + '.pth')
self.save_best_epochs(model, optimizer, epoch, index=1, earlystop=True)
def save_best_epochs(self, model, optimizer, epoch, index=3, earlystop=False):
# Check if the model is CLIPModel and save only fc layer's state_dict
if self.model_name == 'clip':
model_state_dict = model.module.fc.state_dict()
else:
model_state_dict = model.state_dict()
self.best_epochs.append((epoch, model_state_dict, optimizer.state_dict()))
earlystop = '_best' if earlystop else ''
# Keep only the latest 3 models
while len(self.best_epochs) > index:
oldest_epoch, _, _ = self.best_epochs.pop(0)
if dist.get_rank() == 0:
os.remove(f"{self.path}{earlystop}_ep{oldest_epoch}.pth")
# Save the latest {index} models
for saved_epoch, model_state_dict, optimizer_state_dict in self.best_epochs[-index:]:
if dist.get_rank() == 0:
state = {
'epoch': saved_epoch,
'counter': self.counter,
'best_score': self.best_score,
'model_state_dict': model_state_dict,
'optimizer_state_dict': optimizer_state_dict,
}
torch.save(state, f"{self.path}{earlystop}_ep{saved_epoch}.pth")
def save_last_epochs(self, model, optimizer, epoch, index=3, laststop=False):
# Check if the model is CLIPModel and save only fc layer's state_dict
if self.model_name == 'clip':
model_state_dict = model.module.fc.state_dict()
else:
model_state_dict = model.state_dict()
self.last_epochs.append((epoch, model_state_dict, optimizer.state_dict()))
laststop = '_last' if laststop else ''
# Keep only the latest 3 models
while len(self.last_epochs) > index:
oldest_epoch, _, _ = self.last_epochs.pop(0)
if dist.get_rank() == 0:
os.remove(f"{self.path}{laststop}_ep{oldest_epoch}.pth")
# Save the latest {index} models
for saved_epoch, model_state_dict, optimizer_state_dict in self.last_epochs[-index:]:
if dist.get_rank() == 0:
state = {
'epoch': saved_epoch,
'counter': self.counter,
'best_score': self.best_score,
'model_state_dict': model_state_dict,
'optimizer_state_dict': optimizer_state_dict,
}
torch.save(state, f"{self.path}{laststop}_ep{saved_epoch}.pth")