-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_end2end.py
49 lines (36 loc) · 1.5 KB
/
train_end2end.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
#----------------------------------------
#--------- OS related imports -----------
#----------------------------------------
import os
import argparse
import subprocess
import wandb
#----------------------------------------
#--------- Torch related imports --------
#----------------------------------------
import torch
#----------------------------------------
#--------- Config and training imports --
#----------------------------------------
from functions.config import config, update_config
from functions.train import train_net
def parse_args():
parser = argparse.ArgumentParser('Train Cognition Network')
parser.add_argument('--cfg', type=str, help='path to config file')
parser.add_argument('--dist', help='whether to use distributed training', default=False, action='store_true')
parser.add_argument('--data_parallel', help='whether to use data parallel or not', default=False, action='store_true')
parser.add_argument('--cudnn-off', help='disable cudnn', default=False, action='store_true')
parser.add_argument('--do_test', help='testing', default=False, action='store_true')
args = parser.parse_args()
if args.cfg is not None:
update_config(args.cfg)
return args, config
def main():
args, config = parse_args()
# initialize wandb
wandb.init(project=config.PROJECT, name=config.VERSION, config=config)
train_net(args, config)
if args.do_test and (rank is None or rank == 0):
test_net(args, config)
if __name__ == '__main__':
main()