-
Notifications
You must be signed in to change notification settings - Fork 63
/
Copy pathmain.py
165 lines (135 loc) · 6.47 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
####################################################################################
# Implementation of the following paper: https://arxiv.org/pdf/1703.07015.pdf #
# #
# Modeling Long- and Short-Term Temporal Patterns with Deep Neural Networks #
####################################################################################
# This must be set in the beggining because in model_util, we import it
logger_name = "lstnet"
# Path appended in order to import from util
import sys
sys.path.append('..')
from util.model_util import LoadModel, SaveModel, SaveResults, SaveHistory
from util.Msglog import LogInit
from datetime import datetime
from lstnet_util import GetArguments, LSTNetInit
from lstnet_datautil import DataUtil
from lstnet_model import PreSkipTrans, PostSkipTrans, PreARTrans, PostARTrans, LSTNetModel, ModelCompile
from lstnet_plot import AutoCorrelationPlot, PlotHistory, PlotPrediction
import tensorflow as tf
custom_objects = {
'PreSkipTrans': PreSkipTrans,
'PostSkipTrans': PostSkipTrans,
'PreARTrans': PreARTrans,
'PostARTrans': PostARTrans
}
def train(model, data, init, tensorboard = None):
if init.validate == True:
val_data = (data.valid[0], data.valid[1])
else:
val_data = None
start_time = datetime.now()
history = model.fit(
x = data.train[0],
y = data.train[1],
epochs = init.epochs,
batch_size = init.batchsize,
validation_data = val_data,
callbacks = [tensorboard] if tensorboard else None
)
end_time = datetime.now()
log.info("Training time took: %s", str(end_time - start_time))
return history
if __name__ == '__main__':
try:
args = GetArguments()
except SystemExit as err:
print("Error reading arguments")
exit(0)
test_result = None
# Initialise parameters
lstnet_init = LSTNetInit(args)
# Initialise logging
log = LogInit(logger_name, lstnet_init.logfilename, lstnet_init.debuglevel, lstnet_init.log)
log.info("Python version: %s", sys.version)
log.info("Tensorflow version: %s", tf.__version__)
log.info("Keras version: %s ... Using tensorflow embedded keras", tf.keras.__version__)
# Dumping configuration
lstnet_init.dump()
# Reading data
Data = DataUtil(lstnet_init.data,
lstnet_init.trainpercent,
lstnet_init.validpercent,
lstnet_init.horizon,
lstnet_init.window,
lstnet_init.normalise)
# If file does not exist, then Data will not have attribute 'data'
if hasattr(Data, 'data') is False:
log.critical("Could not load data!! Exiting")
exit(1)
log.info("Training shape: X:%s Y:%s", str(Data.train[0].shape), str(Data.train[1].shape))
log.info("Validation shape: X:%s Y:%s", str(Data.valid[0].shape), str(Data.valid[1].shape))
log.info("Testing shape: X:%s Y:%s", str(Data.test[0].shape), str(Data.test[1].shape))
if lstnet_init.plot == True and lstnet_init.autocorrelation is not None:
AutoCorrelationPlot(Data, lstnet_init)
# If --load is set, load model from file, otherwise create model
if lstnet_init.load is not None:
log.info("Load model from %s", lstnet_init.load)
lstnet = LoadModel(lstnet_init.load, custom_objects)
else:
log.info("Creating model")
lstnet = LSTNetModel(lstnet_init, Data.train[0].shape)
if lstnet is None:
log.critical("Model could not be loaded or created ... exiting!!")
exit(1)
# Compile model
lstnet_tensorboard = ModelCompile(lstnet, lstnet_init)
if lstnet_tensorboard is not None:
log.info("Model compiled ... Open tensorboard in order to visualise it!")
else:
log.info("Model compiled ... No tensorboard visualisation is available")
# Model Training
if lstnet_init.train is True:
# Train the model
log.info("Training model ... ")
h = train(lstnet, Data, lstnet_init, lstnet_tensorboard)
# Plot training metrics
if lstnet_init.plot is True:
PlotHistory(h.history, ['loss', 'rse', 'corr'], lstnet_init)
# Saving model if lstnet_init.save is not None.
# There's no reason to save a model if lstnet_init.train == False
SaveModel(lstnet, lstnet_init.save)
if lstnet_init.saveresults == True:
SaveResults(lstnet, lstnet_init, h.history, test_result, ['loss', 'rse', 'corr'])
if lstnet_init.savehistory == True:
SaveHistory(lstnet_init.save, h.history)
# Validation
if lstnet_init.train is False and lstnet_init.validate is True:
loss, rse, corr = lstnet.evaluate(Data.valid[0], Data.valid[1])
log.info("Validation on the validation set returned: Loss:%f, RSE:%f, Correlation:%f", loss, rse, corr)
elif lstnet_init.validate == True:
log.info("Validation on the validation set returned: Loss:%f, RSE:%f, Correlation:%f",
h.history['val_loss'][-1], h.history['val_rse'][-1], h.history['val_corr'][-1])
# Testing evaluation
if lstnet_init.evaltest is True:
loss, rse, corr = lstnet.evaluate(Data.test[0], Data.test[1])
log.info("Validation on the test set returned: Loss:%f, RSE:%f, Correlation:%f", loss, rse, corr)
test_result = {'loss': loss, 'rse': rse, 'corr': corr}
# Prediction
if lstnet_init.predict is not None:
if lstnet_init.predict == 'trainingdata' or lstnet_init.predict == 'all':
log.info("Predict training data")
trainPredict = lstnet.predict(Data.train[0])
else:
trainPredict = None
if lstnet_init.predict == 'validationdata' or lstnet_init.predict == 'all':
log.info("Predict validation data")
validPredict = lstnet.predict(Data.valid[0])
else:
validPredict = None
if lstnet_init.predict == 'testingdata' or lstnet_init.predict == 'all':
log.info("Predict testing data")
testPredict = lstnet.predict(Data.test[0])
else:
testPredict = None
if lstnet_init.plot is True:
PlotPrediction(Data, lstnet_init, trainPredict, validPredict, testPredict)