-
Notifications
You must be signed in to change notification settings - Fork 0
/
trainer.py
41 lines (28 loc) · 1.17 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
# Prepocess dataset and train a model
import numpy as np
from tensorflow.keras import models,layers
import tensorflow as tf
# Time Series Preprocessing
from preprocess import truncate
x_train,y_train,x_val,y_val,x_test,y_test=truncate(data,lookback=168,forecast=48,len_test=10000,len_val=10000,target='full')
print(x_train.shape,y_train.shape,x_val.shape,y_val.shape,x_test.shape,y_test.shape)
# -> (47800, 168, 10) (47800, 48, 3) (10000, 168, 10) (10000, 48, 3) (10000, 168, 10) (10000, 48, 3)
# Model Generation (Basic LSTM)
model=models.Sequential()
model.add(layers.Input(shape=(168,10))
model.add(layers.LSTM(16,dropout=drop,return_sequences=False))
model.add(layers.Dense(48*3))
model.add(layers.Reshape(target_shape=(48,3)))
#Training
from tensorflow.keras.callbacks import Tensorboard
name='model_demo'
path='a_path'
tensorboard = TensorBoard(logdir+"logs/{}".format(name),histogram_freq=1)
model.compile(loss='mse',optimizer=tf.keras.optimizers.RMSprop(0.001),metrics=['mae'])
model.fit(x_train,
y_train,
epochs=30,
batch_size=128,
validation_data=(x_val,y_val),
callbacks=[tensorboard])
model.save(path+name+'.h5',overwrite=True)