-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
29 lines (22 loc) · 895 Bytes
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from __future__ import print_function, division
import torch
import torchvision
import Functions.Pipeline as pipe
print("PyTorch Version: ", torch.__version__)
print("Torchvision Version: ", torchvision.__version__)
def main(*pipeparts):
if 'A' in pipeparts:
pipe.A_Folderize(force=False)
if 'B' in pipeparts:
model, optimizer, scheduler = pipe.B_InitModel()
if 'C' in pipeparts:
loaders = pipe.C_PrepareData()
if 'D' in pipeparts:
model, val_acc_history = pipe.D_TrainModel(model, optimizer, scheduler, loaders)
if 'E' in pipeparts:
predictions = pipe.E_PredictModel(model, loaders['test'])
return predictions
if __name__ == '__main__':
# Define which parts of the pipeline to execute (include 'A' in first execution, then 'A' does not need to be run)
pipeparts = ['A' ,'B', 'C', 'D', 'E']
main(*pipeparts)