-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain.py
144 lines (120 loc) · 4.38 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import os
import time
import shutil
import click
import numpy as np
from keras import callbacks, optimizers
from IPython import embed
from model import deconv_frontend, dilated_frontend, add_softmax
from image_reader import RandomTransformer, SegmentationDataGenerator
train_list_fname = '/home/v-yurzho/benchmark_RELEASE/dataset/train.txt'
val_list_fname = '/home/v-yurzho/benchmark_RELEASE/dataset/val.txt'
img_root = '/home/v-yurzho/benchmark_RELEASE/dataset/img'
mask_root = '/home/v-yurzho/benchmark_RELEASE/dataset/pngs'
weights_path = '/home/v-yurzho/FCN-for-Semantic-Segmentation/trained_log/2017-10-27 17:29-lr1e-04-bs001/ep42-vl1.0240.hdf5'
batch_size = 4
learning_rate = 2e-7
modeltype = "deconv"
def load_weights(model, weights_path):
print("*********** load weights ***********")
def load_tf_weights():
""" Load pretrained weights converted from Caffe to TF. """
# 'latin1' enables loading .npy files created with python2
weights_data = np.load(weights_path, encoding='latin1').item()
for layer in model.layers:
if layer.name in weights_data.keys():
layer_weights = weights_data[layer.name]
layer.set_weights((layer_weights['weights'], layer_weights['biases']))
def load_keras_weights():
""" Load a Keras checkpoint. """
model.load_weights(weights_path)
if weights_path.endswith('.npy'):
load_tf_weights()
elif weights_path.endswith('.hdf5'):
load_keras_weights()
else:
raise Exception("Unknown weights format.")
def build_abs_paths(basenames):
global img_root
global mask_root
img_fnames = [os.path.join(img_root, f) + '.jpg' for f in basenames]
mask_fnames = [os.path.join(mask_root, f) + '.png' for f in basenames]
return img_fnames, mask_fnames
def train():
global train_list_fname
global val_list_fname
global img_root
global mask_root
global weights_path
global batch_size
global learning_rate
global modeltype
train_data_gen = SegmentationDataGenerator(
RandomTransformer(horizontal_flip = True, vertical_flip = True))
val_data_gen = SegmentationDataGenerator(
RandomTransformer(horizontal_flip = True, vertical_flip = True))
trained_log = '{}-lr{:.0e}-bs{:03d}'.format(
time.strftime("%Y-%m-%d %H:%M"),
learning_rate,
batch_size)
checkpoints_folder = 'trained_log/' + trained_log
try:
os.makedirs(checkpoints_folder)
except OSError:
shutil.rmtree(checkpoints_folder, ignore_errors=True)
os.makedirs(checkpoints_folder)
model_checkpoint = callbacks.ModelCheckpoint(
checkpoints_folder + '/ep{epoch:02d}-vl{val_loss:.4f}.hdf5', monitor='loss')
model_tensorboard = callbacks.TensorBoard(
log_dir='{}/tboard'.format(checkpoints_folder),
histogram_freq=0,
write_graph=False,
write_images=False)
model_csvlogger = callbacks.CSVLogger(
'{}/history.log'.format(checkpoints_folder))
model_reducelr = callbacks.ReduceLROnPlateau(
monitor='val_loss',
factor=0.2,
patience=5,
verbose=1,
min_lr=0.05 * learning_rate)
model = add_softmax(dilated_frontend(500, 500))
#load_weights(model, weights_path)
model.compile(
optimizer=optimizers.SGD(lr=learning_rate, momentum=0.9),
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
train_basenames = [l.strip() for l in open(train_list_fname).readlines()]
val_basenames = [l.strip() for l in open(val_list_fname).readlines()][:500]
train_img_fnames, train_mask_fnames = build_abs_paths(train_basenames)
val_img_fnames, val_mask_fnames = build_abs_paths(val_basenames)
model_skipped = callbacks.LambdaCallback(
on_epoch_end=lambda a, b: open(
'{}/skipped.txt'.format(checkpoints_folder), 'a').write(
'{}\n'.format(train_data_gen.skipped_count)))
model.fit_generator(
train_data_gen.flow_from_list(
train_img_fnames,
train_mask_fnames,
shuffle=True,
batch_size=batch_size,
img_target_size=(500,500),
mask_target_size=(16, 16)),
steps_per_epoch=(len(train_basenames)/batch_size),
epochs=50,
validation_data=val_data_gen.flow_from_list(
val_img_fnames,
val_mask_fnames,
batch_size=8,
img_target_size=(500,500),
mask_target_size=(16,16)),
validation_steps=(len(val_basenames)/8),
callbacks=[
model_checkpoint,
model_tensorboard,
model_csvlogger,
model_reducelr,
model_skipped
])
if __name__ == '__main__':
train()