This repository has been archived by the owner on Apr 1, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 31
/
Copy pathmain_single.py
191 lines (157 loc) · 7.46 KB
/
main_single.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
import argparse
import importlib
import os
import sys
import shutil
import json
import tensorflow as tf
from time import gmtime, strftime
src_dir = './src_single'
def launch_training(**kwargs):
# Deal with file and paths
appendix = kwargs["resume_from"]
if appendix is None or appendix == '':
cur_time = strftime("%Y-%m-%d-%H-%M-%S", gmtime())
log_dir = './log_skgan_' + cur_time
ckpt_dir = './ckpt_skgan_' + cur_time
if not os.path.isdir(log_dir) and not os.path.exists(log_dir):
os.makedirs(log_dir)
if not os.path.isdir(ckpt_dir) and not os.path.exists(ckpt_dir):
os.makedirs(ckpt_dir)
# copy current script in src folder to log dir for record
if not os.path.exists(src_dir) or not os.path.isdir(src_dir):
print("src folder does not exist.")
return
else:
for file in os.listdir(src_dir):
if file.endswith(".py"):
shutil.copy(os.path.join(src_dir, file), log_dir)
kwargs['log_dir'] = log_dir
kwargs['ckpt_dir'] = ckpt_dir
appendix = cur_time
kwargs["resume_from"] = appendix
kwargs["iter_from"] = 0
# Save parameters
with open(os.path.join(log_dir, 'param_%d.json' % 0), 'w') as fp:
json.dump(kwargs, fp, indent=4)
sys.path.append(src_dir)
entry_point_module = kwargs['entry_point']
from config import Config
Config.set_from_dict(kwargs)
print("Launching new train: %s" % cur_time)
else:
if len(appendix.split('-')) != 6:
print("Invalid resume folder")
return
log_dir = './log_skgan_' + appendix
ckpt_dir = './ckpt_skgan_' + appendix
# Get last parameters (recover entry point module name)
json_files = [f for f in os.listdir(log_dir) if
os.path.isfile(os.path.join(log_dir, f)) and os.path.splitext(f)[1] == '.json']
iter_starts = max([int(os.path.splitext(filename)[0].split('_')[1]) for filename in json_files])
with open(os.path.join(log_dir, 'param_%d.json' % iter_starts), 'r') as fp:
params = json.load(fp)
entry_point_module = params['entry_point']
# Recover parameters
_ignored = ['num_gpu', 'iter_from']
for k, v in params.items():
if k not in _ignored:
kwargs[k] = v
sys.path.append(log_dir)
# Get latest checkpoint filename
# if stage == 1:
# ckpt_file = tf.train.latest_checkpoint(stage_1_ckpt_dir)
# elif stage == 2:
ckpt_file = tf.train.latest_checkpoint(ckpt_dir)
if ckpt_file is None:
raise RuntimeError
else:
iter_from = int(os.path.split(ckpt_file)[1].split('-')[1]) + 1
kwargs['log_dir'] = log_dir
kwargs['ckpt_dir'] = ckpt_dir
kwargs['iter_from'] = iter_from
# Save new set of parameters
with open(os.path.join(log_dir, 'param_%d.json' % iter_from), 'w') as fp:
kwargs['entry_point'] = entry_point_module
json.dump(kwargs, fp, indent=4)
from config import Config
Config.set_from_dict(kwargs)
print("Launching train from checkpoint: %s" % appendix)
# Launch train
train_module = importlib.import_module(entry_point_module)
# from train_paired_aug_multi_gpu import train
status = train_module.train(**kwargs)
return status, appendix
def launch_test(**kwargs):
# Deal with file and paths
appendix = kwargs["resume_from"]
if appendix is None or appendix == '' or len(appendix.split('-')) != 6:
print("Invalid resume folder")
return
log_dir = './log_skgan_' + appendix
ckpt_dir = './ckpt_skgan_' + appendix
sys.path.append(log_dir)
# Get latest checkpoint filename
kwargs['log_dir'] = log_dir
kwargs['ckpt_dir'] = ckpt_dir
# Get last parameters (recover entry point module name)
# Assuming last json file
json_files = [f for f in os.listdir(log_dir) if
os.path.isfile(os.path.join(log_dir, f)) and os.path.splitext(f)[1] == '.json']
iter_starts = max([int(os.path.splitext(filename)[0].split('_')[1]) for filename in json_files])
with open(os.path.join(log_dir, 'param_%d.json' % iter_starts), 'r') as fp:
params = json.load(fp)
entry_point_module = params['entry_point']
# Recover parameters
_ignored = ["num_gpu", 'iter_from']
for k, v in params.items():
if k not in _ignored:
kwargs[k] = v
from config import Config
Config.set_from_dict(kwargs)
print("Launching test from checkpoint: %s" % appendix)
# Launch test
train_module = importlib.import_module(entry_point_module)
train_module.test(**kwargs)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Train or Test model')
parser.add_argument('--mode', type=str, default="train", help="train or test")
parser.add_argument('--resume_from', type=str, default='', help="Whether resume last checkpoint from a past run. Notice: you only need to fill in the string after skgan_, i.e. the part with yyyy-mm-dd-hr-min-sec")
parser.add_argument('--entry_point', type=str, default='train_single', help="name of the training .py file")
parser.add_argument('--batch_size', default=16, type=int, help='Batch size per gpu')
parser.add_argument('--max_iter_step', default=300000, type=int, help="Max number of iterations")
parser.add_argument('--disc_iterations', default=1, type=int, help="Number of discriminator iterations")
parser.add_argument('--ld', default=10, type=float, help="Gradient penalty lambda hyperparameter")
parser.add_argument('--optimizer', type=str, default="Adam", help="Optimizer for the graph")
parser.add_argument('--lr_G', type=float, default=2e-4, help="learning rate for the generator")
parser.add_argument('--lr_D', type=float, default=4e-4, help="learning rate for the discriminator")
parser.add_argument('--num_gpu', default=2, type=int, help="Number of GPUs to use")
parser.add_argument('--distance_map', default=1, type=int, help="Whether using distance maps for sketches")
parser.add_argument('--small_img', default=1, type=int, help="Whether using 64x64 instead of 256x256")
parser.add_argument('--extra_info', default="", type=str, help="Extra information saved for record")
args = parser.parse_args()
assert args.optimizer in ["RMSprop", "Adam", "AdaDelta", "AdaGrad"], "Unsupported optimizer"
# Set default params
d_params = {"resume_from": args.resume_from,
"entry_point": args.entry_point,
"batch_size": args.batch_size,
"max_iter_step": args.max_iter_step,
"disc_iterations": args.disc_iterations,
"ld": args.ld,
"optimizer": args.optimizer,
"lr_G": args.lr_G,
"lr_D": args.lr_D,
"num_gpu": args.num_gpu,
"distance_map": args.distance_map,
"small_img": args.small_img,
"extra_info": args.extra_info,
}
if args.mode == 'train':
# Launch training
status, appendix = launch_training(**d_params)
while status == -1: # NaN during training
print("Training ended with status -1. Restarting..")
d_params["resume_from"] = appendix
status = launch_training(**d_params)
elif args.mode == 'test':
launch_test(**d_params)