-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathsaver.py
63 lines (54 loc) · 2.24 KB
/
saver.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import os
import torch as th
import torch.nn as nn
import torch.optim as optim
from time import localtime, strftime
class Saver:
def __init__(self, directory: str = 'pytorch_model') -> None:
self.directory = directory
def save_checkpoint(self,
state,
file_name: str = 'pytorch_model.pt',
append_time=True):
os.makedirs(self.directory, exist_ok=True)
timestamp = strftime("%Y_%m_%d__%H_%M_%S", localtime())
filebasename, fileext = file_name.split('.')
if append_time:
filepath = os.path.join(self.directory, '_'.join(
[filebasename, '.'.join([timestamp, fileext])]))
else:
filepath = os.path.join(self.directory, file_name)
if isinstance(state, nn.Module):
checkpoint = {'model_dict': state.state_dict()}
th.save(checkpoint, filepath)
elif isinstance(state, dict):
th.save(state, filepath)
else:
raise TypeError('state must be a nn.Module or dict')
def load_checkpoint(self,
model: nn.Module,
optimizer: optim.Optimizer = None,
file_name: str = 'pytorch_model.pt'):
filepath = os.path.join(self.directory, file_name)
checkpoint: dict = th.load(
filepath, map_location=lambda storage, loc: storage)
model.load_state_dict(checkpoint['model_dict'])
if optimizer is not None:
optimizer.load_state_dict(checkpoint['optimizer_dict'])
hyperparam_dict = {
k: v
for k, v in checkpoint.items()
if k != 'model_dict' or k != 'optimizer_dict'
}
return model, optimizer, hyperparam_dict
def create_checkpoint(self, model: nn.Module, optimizer: optim.Optimizer,
hyperparam_dict):
model_dict = model.state_dict()
optimizer_dict = optimizer.state_dict()
state_dict = {
'model_dict': model_dict,
'optimizer_dict': optimizer_dict,
'timestamp': strftime('%l:%M%p GMT%z on %b %d, %Y', localtime())
}
checkpoint = {**state_dict, **hyperparam_dict}
return checkpoint