-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
82 lines (65 loc) · 2.45 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import config
import os
import time
import argparse
import numpy as np
from model import get_model, load_model
from utils import generate_from_file_list, random_shuffle, augmentation
# args
parser = argparse.ArgumentParser()
parser.add_argument('-i', '--input', help='input path for model')
parser.add_argument('-o', '--output', help='output path for model',
default='{}.h5'.format(time.time()))
parser.add_argument('-m', '--model', help='model type', default='CNN')
parser.add_argument(
'-e', '--epochs', help='number of training epochs', default=1500, type=int)
args = parser.parse_args()
print('\nReading Data...\n')
# Training Data
pos_list = config.pos_list
neg_list = config.neg_list
pos_train = generate_from_file_list(pos_list)
neg_train = generate_from_file_list(neg_list)
np.random.shuffle(pos_train)
np.random.shuffle(neg_train)
if not config.data_balance == 0:
neg_train = neg_train[np.random.choice(len(neg_train), int(
len(pos_train) * config.data_balance))]
print('\nPos: {} Neg: {}\n'.format(len(pos_train), len(neg_train)))
x_train = np.concatenate((pos_train, neg_train), axis=0)
y_train = np.hstack((np.ones(len(pos_train)), np.zeros(len(neg_train))))
# shuffle and augmentation
x_train, y_train = random_shuffle(x_train, y_train)
x_train = augmentation(x_train)
# Test Data
pos_test = generate_from_file_list([config.pos_test])
neg_test = generate_from_file_list([config.neg_test])
pos_y = np.ones(len(pos_test))
neg_y = np.zeros(len(neg_test))
print('\nReading Data Done.\n')
# Train
# Load Model
print('\nTrainging Begin\n')
print('Loading Model...')
pad = config.pad
if args.input:
model = load_model(args.input)
else:
model = get_model(input_shape=(2 * pad, 2 * pad),
output_shape=2, model_type=args.model)
# Train
print('Trainging...')
model.fit(x_train, y_train, epochs=args.epochs, batch_size=256,
validation_data=(pos_test, pos_y), validation_freq=100, verbose=1)
# Evaluate
print('Evaluating...')
model.evaluate(pos_test, pos_y)
model.evaluate(neg_test, neg_y)
model.evaluate(np.concatenate((pos_test, neg_test), axis=0),
np.concatenate((pos_y, neg_y), axis=0))
# Save Model
print('Saving Model...')
model_path = os.path.join(config.model_path, args.output)
model.save(model_path)
print('Model has Saved in {} \n Dataset Pos: {} Neg:{}'.format(
model_path, len(pos_train), len(neg_train)))