-
Notifications
You must be signed in to change notification settings - Fork 0
/
trainEval.py
69 lines (55 loc) · 2.18 KB
/
trainEval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
def train_and_evaluate(args):
"""Trains and evaluates the Keras model.
Uses the Keras model defined in model.py and trains on data loaded and
preprocessed in util.py. Saves the trained model in TensorFlow SavedModel
format to the path defined in part by the --job-dir argument.
Args:
args: dictionary of arguments - see get_args() for details
"""
train_x, train_y, eval_x, eval_y = util.load_data()
# dimensions
num_train_examples, input_dim = train_x.shape
num_eval_examples = eval_x.shape[0]
# Create the Keras Model
keras_model = model.create_keras_model(
input_dim=input_dim, learning_rate=args.learning_rate)
# Pass a numpy array by passing DataFrame.values
training_dataset = model.input_fn(
features=train_x.values,
labels=train_y,
shuffle=True,
num_epochs=args.num_epochs,
batch_size=args.batch_size)
# Pass a numpy array by passing DataFrame.values
validation_dataset = model.input_fn(
features=eval_x.values,
labels=eval_y,
shuffle=False,
num_epochs=args.num_epochs,
batch_size=num_eval_examples)
# Setup Learning Rate decay.
lr_decay_cb = tf.keras.callbacks.LearningRateScheduler(
lambda epoch: args.learning_rate + 0.02 * (0.5 ** (1 + epoch)),
verbose=True)
# Setup TensorBoard callback.
tensorboard_cb = tf.keras.callbacks.TensorBoard(
os.path.join(args.job_dir, 'keras_tensorboard'),
histogram_freq=1)
# Train model
keras_model.fit(
training_dataset,
steps_per_epoch=int(num_train_examples / args.batch_size),
epochs=args.num_epochs,
validation_data=validation_dataset,
validation_steps=1,
verbose=1,
callbacks=[lr_decay_cb, tensorboard_cb])
export_path = os.path.join(args.job_dir, 'keras_export')
tf.keras.models.save_model(keras_model, export_path)
print('Model exported to: {}'.format(export_path))
if __name__ == '__main__':
strategy = tf.distribute.MirroredStrategy()
with strategy.scope():
args = get_args()
tf.compat.v1.logging.set_verbosity(args.verbosity)
train_and_evaluate(args)