-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathexp_mgmt.py
125 lines (112 loc) · 4.58 KB
/
exp_mgmt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import os
import util
import datetime
import shutil
import mlconfig
import torch
import json
import misc
from collections import OrderedDict
if torch.cuda.is_available():
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
device = torch.device('cuda')
else:
device = torch.device('cpu')
class ExperimentManager():
def __init__(self, exp_name, exp_path, config_file_path, eval_mode=False):
if exp_name == '' or exp_name is None:
exp_name = 'exp_at' + datetime.datetime.now()
exp_path = os.path.join(exp_path, exp_name)
checkpoint_path = os.path.join(exp_path, 'checkpoints')
log_filepath = os.path.join(exp_path, exp_name) + ".log"
stas_hist_path = os.path.join(exp_path, 'stats')
stas_eval_path = os.path.join(exp_path, 'stats_eval')
if misc.get_rank() == 0 and not eval_mode:
util.build_dirs(exp_path)
util.build_dirs(checkpoint_path)
util.build_dirs(stas_hist_path)
util.build_dirs(stas_eval_path)
if config_file_path is not None:
dst = os.path.join(exp_path, exp_name+'.yaml')
if dst != config_file_path and misc.get_rank() == 0 and not eval_mode:
shutil.copyfile(config_file_path, dst)
config = mlconfig.load(config_file_path)
config.set_immutable()
else:
config = None
self.exp_name = exp_name
self.exp_path = exp_path
self.checkpoint_path = checkpoint_path
self.log_filepath = log_filepath
self.stas_hist_path = stas_hist_path
self.stas_eval_path = stas_eval_path
self.config = config
self.logger = None
self.eval_mode = eval_mode
if misc.get_rank() == 0:
self.logger = util.setup_logger(name=self.exp_path, log_file=self.log_filepath,
ddp=misc.get_world_size() > 1)
def save_eval_stats(self, exp_stats, name):
filename = '%s_exp_stats_eval.json' % name
filename = os.path.join(self.stas_eval_path, filename)
with open(filename, 'w') as outfile:
json.dump(exp_stats, outfile)
return
def load_eval_stats(self, name):
filename = '%s_exp_stats_eval.json' % name
filename = os.path.join(self.stas_eval_path, filename)
if os.path.exists(filename):
with open(filename, 'r') as json_file:
data = json.load(json_file)
return data
else:
return None
def save_epoch_stats(self, epoch, exp_stats):
filename = 'exp_stats_epoch_%d.json' % epoch
filename = os.path.join(self.stas_hist_path, filename)
with open(filename, 'w') as outfile:
json.dump(exp_stats, outfile)
return
def load_epoch_stats(self, epoch=None):
if epoch is not None:
filename = 'exp_stats_epoch_%d.json' % epoch
filename = os.path.join(self.stas_hist_path, filename)
with open(filename, 'r') as json_file:
data = json.load(json_file)
return data
else:
epoch = self.config.epochs
filename = 'exp_stats_epoch_%d.json' % epoch
filename = os.path.join(self.stas_hist_path, filename)
while not os.path.exists(filename) and epoch >= 0:
epoch -= 1
filename = 'exp_stats_epoch_%d.json' % epoch
filename = os.path.join(self.stas_hist_path, filename)
if not os.path.exists(filename):
return None
with open(filename, 'rb') as json_file:
data = json.load(json_file)
return data
return None
def save_state(self, target, name):
if isinstance(target, torch.nn.DataParallel):
target = target.module
filename = os.path.join(self.checkpoint_path, name) + '.pt'
torch.save(target.state_dict(), filename)
if misc.get_rank() == 0:
self.logger.info('%s saved at %s' % (name, filename))
return
def load_state(self, target, name, strict=True):
filename = os.path.join(self.checkpoint_path, name) + '.pt'
d = torch.load(filename, map_location=device)
keys = []
for k, v in d.items():
if 'total_ops' in k or 'total_params' in k:
keys.append(k)
for k in keys:
del d[k]
target.load_state_dict(d)
if misc.get_rank() == 0 and not self.eval_mode:
self.logger.info('%s loaded from %s' % (name, filename))
return target