-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
83 lines (62 loc) · 2.75 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
"""
Contains a template command line interface for training.
As an example, the training.pytorch_training template routine is used
in combination with the models.unet.UNET model.
The data loading is provided by a generic .npy array wrapper
in utils.datasets.NpyDataset, transformations and augmentation can be
added easily in the corresponding class.
We assume that the data is split into test and training examples
in the paths
<DATA_DIR>/train/x
<DATA_DIR>/train/y
<DATA_DIR>/test/x
<DATA_DIR>/test/y
where "x" corresponds to observation features and "y" corresponds to labels.
We additionally assume that in "x" and "y", the correspondence
is provided by files carrying the same file name.
<DATA_DIR> is specified in config.config.global_config.
We do NOT provide a full Argparse suite, since every training routine
needs its own arguments provided by an external training config file/dictionary.
An example config is provided in training.fastai_training.
"""
import argparse
#import pprint
import os
import warnings
from config.config import global_config
from utils.data import datasets
from utils.notify import smtp
from training import pytorch_training
def cli():
DESCRIPTION = """
"""
parser = argparse.ArgumentParser(description=DESCRIPTION)
parser.add_argument("-s", "--smtp", help="Send SMTP mail notification",
type=str)
parser.add_argument("-w", "--warnings", action="store_true",
help="Suppress all warnings")
return parser.parse_args()
if __name__ == "__main__":
args = cli()
if args.smtp:
notifier = smtp.SMTPNotifier(args.smtp, args.smtp)
if args.warnings:
warnings.filterwarnings("ignore")
#create datasets
train_data = datasets.NpyDataset(os.path.join(global_config["DATA_DIR"], 'train/x'),
os.path.join(global_config["DATA_DIR"], 'train/y'))
test_data = datasets.NpyDataset(os.path.join(global_config["DATA_DIR"], 'test/x'),
os.path.join(global_config["DATA_DIR"], 'test/y'))
#pass configuration and datasets to training routine
#learner, log_content, name = fastai_training.train(train_data, test_data,
#pytorch_training.train_config,
#global_config)
pytorch_training.train(train_data,
test_data,
pytorch_training.train_config,
global_config)
# only implemented for fastai_training. Maybe adding for pytorch_training too?
#if args.smtp:
# pp = pprint.PrettyPrinter(indent=4)
# content = pp.pformat(log_content)
# notifier.notify(content, subject=name)