-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtrain.py
113 lines (94 loc) · 3.47 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import argparse
from siren.data_loader import get_data_loader_cls_by_type
from siren.optimizer import JaxOptimizer
from siren.model import get_model_cls_by_type
from util.log import Logger
from util.timer import Timer
def parse_args():
parser = argparse.ArgumentParser(description="Train SirenHighres")
parser.add_argument("--file", type=str, help="location of the file", required=True)
parser.add_argument(
"--nc",
type=int,
default=3,
help="number of channels of input image. if the source is color (3) and --nc is 1, then the source is converted to gray scale",
)
parser.add_argument(
"--type",
type=str,
default="normal",
choices=["normal", "gradient", "laplacian", "combined"],
help="training image type",
)
parser.add_argument(
"--size",
type=int,
default=256,
help="resize the image to this (squre) shape. 0 if not goint go resize",
)
parser.add_argument(
"--batch_size",
type=int,
default=0,
help="the size of batches. 0 for single batch",
)
parser.add_argument("--epoch", type=int, default=10000, help="number of epochs")
parser.add_argument("--lr", type=float, default=0.0001, help="learning rate")
parser.add_argument(
"--print_iter", type=int, default=200, help="when to print intermediate info"
)
parser.add_argument(
"--layers",
type=str,
default="256,256,256",
help="layers of multi layer perceptron",
)
parser.add_argument("--omega", type=float, default=30, help="omega value of Siren")
args = parser.parse_args()
return args
def main(args):
layers = [int(l) for l in args.layers.split(",")]
Model = get_model_cls_by_type(args.type)
DataLoader = get_data_loader_cls_by_type(args.type)
data_loader = DataLoader(args.file, args.nc, args.size, args.batch_size)
model = Model(layers, args.nc, args.omega)
optimizer = JaxOptimizer("adam", model, args.lr)
name = args.file.split(".")[0]
logger = Logger(name)
logger.save_option(vars(args))
gt_img = data_loader.get_ground_truth_image()
logger.save_image("original", data_loader.original_pil_img)
logger.save_image("gt", gt_img)
iter_timer = Timer()
iter_timer.start()
def interm_callback(i, data, params):
log = {}
loss = model.loss_func(params, data)
log["loss"] = float(loss)
log["iter"] = i
log["duration_per_iter"] = iter_timer.get_dt() / args.print_iter
logger.save_log(log)
print(log)
print("Training Start")
print(vars(args))
total_timer = Timer()
total_timer.start()
last_data = None
for _ in range(args.epoch):
data_loader = DataLoader(args.file, args.nc, args.size, args.batch_size)
for data in data_loader:
optimizer.step(data)
last_data = data
if optimizer.iter_cnt % args.print_iter == 0:
interm_callback(
optimizer.iter_cnt, data, optimizer.get_optimized_params()
)
if not optimizer.iter_cnt % args.print_iter == 0:
interm_callback(optimizer.iter_cnt, data, optimizer.get_optimized_params())
train_duration = total_timer.get_dt()
print("Training Duration: {} sec".format(train_duration))
logger.save_net_params(optimizer.get_optimized_params())
logger.save_losses_plot()
if __name__ == "__main__":
args = parse_args()
main(args)