-
Notifications
You must be signed in to change notification settings - Fork 25
/
Copy pathutil.py
81 lines (69 loc) · 2.71 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import os
import torch
import errno
import modules
def model_parameters(model):
num_params = 0
for param in model.parameters():
num_params += param.numel()
return num_params
def symlink_force(target, link_name):
try:
os.symlink(target, link_name)
except OSError as e:
if e.errno == errno.EEXIST:
os.remove(link_name)
os.symlink(target, link_name)
else:
raise e
def save(model, optimizer, scheduler, args, epoch, module_type):
checkpoint = {
'epoch': epoch,
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'scheduler': scheduler.state_dict(),
'args': args,
}
if scheduler is not None:
checkpoint["scheduler"] = scheduler.state_dict()
if not os.path.exists(os.path.join(args.output_dir, module_type)):
os.makedirs(os.path.join(args.output_dir, module_type))
cp = os.path.join(args.output_dir, module_type, 'last.pth')
fn = os.path.join(args.output_dir, module_type, 'epoch_'+str(epoch)+'.pth')
torch.save(checkpoint, fn)
symlink_force(fn, cp)
def load(fn, module_type, device = torch.device('cuda')):
checkpoint = torch.load(fn, map_location=device)
args = checkpoint['args']
if device == torch.device('cpu'):
args.use_cuda = False
if module_type == 'fr':
model = modules.set_fr_module(args)
elif module_type == 'fc':
model = modules.set_fc_module(args)
else:
raise NotImplementedError('Module type not recognized')
model.load_state_dict(checkpoint['model'])
optimizer, scheduler = set_optim(args, model, module_type)
if checkpoint["scheduler"] is not None:
scheduler.load_state_dict(checkpoint["scheduler"])
optimizer.load_state_dict(checkpoint["optimizer"])
return model, optimizer, scheduler, args, checkpoint['epoch']
def set_optim(args, module, module_type):
if module_type == 'fr':
optimizer = torch.optim.Adam(module.parameters(), lr=args.lr_fr)
elif module_type == 'fc':
optimizer = torch.optim.Adam(module.parameters(), lr=args.lr_fc)
else:
raise(ValueError('Expected module_type to be fr or fc but got {}'.format(module_type)))
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=7, factor=0.5, verbose=True)
return optimizer, scheduler
def print_args(logger, args):
message = ''
for k, v in sorted(vars(args).items()):
message += '\n{:>30}: {:<30}'.format(str(k), str(v))
logger.info(message)
args_path = os.path.join(args.output_dir, 'run.args')
with open(args_path, 'wt') as args_file:
args_file.write(message)
args_file.write('\n')