-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
166 lines (127 loc) · 6.04 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
"""
This file is used to run the project.
Notes:
- The structure of this file (and the entire project in general) is made with emphasis on flexibility for research
purposes, and the pipelining is done in a python file such that newcomers can easily use and understand the code.
- Remember that relative paths in Python are always relative to the current working directory.
Hence, if you look at the functions in make_dataset.py, the file paths are relative to the path of
this file (main.py)
"""
__author__ = "Simon Leminen Madsen"
__email__ = "slm@eng.au.dk"
import os
import argparse
import datetime
import src.utils as utils
from src.data import dataset_manager
from src.models.BasicModel import BasicModel
from src.models.logreg_example import logreg_example
from src.visualization import visualize
"""parsing and configuration"""
def parse_args():
# ----------------------------------------------------------------------------------------------------------------------
# Define default pipeline
# ----------------------------------------------------------------------------------------------------------------------
desc = "Pipeline for running Tensorflow implementation of infoGAN"
parser = argparse.ArgumentParser(description=desc)
parser.add_argument('--make_dataset',
action='store_true',
help = 'Fetch dataset from remote source into /data/raw/. Or generate raw dataset [Defaults to False if argument is omitted]')
parser.add_argument('--process_dataset',
action='store_true',
help = 'Run preprocessing of raw data. [Defaults to False if argument is omitted]')
parser.add_argument('--train_model',
action='store_true',
help = 'Run configuration and training network [Defaults to False if argument is omitted]')
parser.add_argument('--evaluate_model',
action='store_true',
help = 'Run evaluation of the model by computing and visualizing the results [Defaults to False if argument is omitted]')
parser.add_argument('--visualize',
action='store_true',
help = 'Run visualization of results [Defaults to False if argument is omitted]')
# ----------------------------------------------------------------------------------------------------------------------
# Define the arguments used in the entire pipeline
# ----------------------------------------------------------------------------------------------------------------------
parser.add_argument('--model',
type=str,
default='BasicModel',
choices=['BasicModel',
'LogReg_example'],
#required = True,
help='The name of the network model')
parser.add_argument('--dataset',
type=str, default='MNIST',
choices=['MNIST',
'PSD_Nonsegmented',
'PSD_Segmented'],
#required = True,
help='The name of dataset')
# ----------------------------------------------------------------------------------------------------------------------
# Define the arguments for the training
# ----------------------------------------------------------------------------------------------------------------------
parser.add_argument('--id',
type= str,
default = None,
help = 'Optional ID, to distinguise experiments')
parser.add_argument('--hparams',
type=str, default = '',
help='CLI arguments for the model wrapped in a string')
return check_args(parser.parse_args())
"""checking arguments"""
def check_args(args):
# Assert if training parameters are provided, when training is selected
# if args.train_model:
# try:
# assert args.hparams is ~None
# except:
# print('hparams not provided for training')
# exit()
return args
"""main"""
def main():
# parse arguments
args = parse_args()
if args is None:
exit()
# Make dataset
if args.make_dataset:
utils.show_message('Fetching raw dataset: {0}'.format(args.dataset), lvl = 1)
dataset_manager.make_dataset(args.dataset)
# Make dataset
if args.process_dataset:
utils.show_message('Processing raw dataset: {0}'.format(args.dataset), lvl = 1)
dataset_manager.process_dataset(args.dataset)
# Build and train model
if args.train_model:
utils.show_message('Configuring and Training Network: {0}'.format(args.model), lvl = 1)
if args.model == 'BasicModel':
model = BasicModel(
dataset = args.dataset,
id = args.id)
model.train(hparams_string = args.hparams)
elif args.model == 'LogReg_example':
model = logreg_example(
dataset = args.dataset,
id = args.id)
model.train(hparams_string = args.hparams)
# Evaluate model
if args.evaluate_model:
utils.show_message('Evaluating Network: {0}'.format(args.model), lvl = 1)
if args.model == 'BasicModel':
model = BasicModel(
dataset = args.dataset,
id = args.id)
model.evaluate(hparams_string = args.hparams)
elif args.model == 'LogReg_example':
model = logreg_example(
dataset = args.dataset,
id = args.id)
model.evaluate(hparams_string = args.hparams)
# Visualize results
if args.visualize:
print('Visualizing Results')
#################################
####### To Be Implemented #######
#################################
if __name__ == '__main__':
main()