-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtrain.py
37 lines (30 loc) · 1000 Bytes
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from keras.losses import mean_squared_error
from keras import optimizers
from util import load
from Models import SimpleCNN
import numpy as np
batch_size = 128
epochs = 25
print('Reading Train Data')
X, y,cols_names = load()
print("X.shape == {}; X.min == {:.3f}; X.max == {:.3f}".format(
X.shape, X.min(), X.max()))
print("y.shape == {}; y.min == {:.3f}; y.max == {:.3f}".format(
y.shape, y.min(), y.max()))
feature_name2KeypointIdx = {}
for idx,feature_name in enumerate(cols_names):
feature_name2KeypointIdx[feature_name] = idx
np.save('feature2kpId.npy',feature_name2KeypointIdx)
model = SimpleCNN()
sgd = optimizers.SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)
print('Start Training')
model.compile(loss=mean_squared_error,
optimizer=sgd,
metrics=['accuracy'])
model.fit(X, y,
batch_size=batch_size,
epochs=epochs,
verbose=1)
print('Done Training')
print('Saving Weights')
model.save_weights("FKD_weights.h5")