-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun.py
113 lines (93 loc) · 4.66 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
"""
Main file for running experiments.
"""
import json
from typing import Union
import argparse
import torch
from stepback.base import Base
from stepback.log import Container
from stepback.config import ConfigManager
from stepback.defaults import DEFAULTS
parser = argparse.ArgumentParser(description='Run stepback from the terminal.')
parser.add_argument('-i', '--id', nargs='?', type=str, default='test1', help="The id of the config (its file name).")
parser.add_argument('-cdir', '--config-dir', nargs='?', type=str, default=DEFAULTS.config_dir, help="The config directory.")
parser.add_argument('-odir', '--output-dir', nargs='?', type=str, default=DEFAULTS.output_dir, help="The output directory.")
parser.add_argument('-ddir', '--data-dir', nargs='?', type=str, default=DEFAULTS.data_dir, help="The data directory.")
parser.add_argument('--device', nargs='?', type=str, default=DEFAULTS.device, help="Device to run on.")
parser.add_argument('-nw', '--num-workers', nargs='?', type=int, default=DEFAULTS.num_workers, help="Number of workers for DataLoader.")
parser.add_argument('--data-parallel', nargs='+', default=DEFAULTS.data_parallel, help='Device list for DataParallel in Pytorch.')
parser.add_argument('--verbose', action="store_true", help="Verbose mode.")
parser.add_argument('--force-deterministic', action="store_true", help="Use deterministic mode in Pytorch. Might require setting environment variables.")
def run_one(exp_id: str,
config_dir: str=DEFAULTS.config_dir,
output_dir: str=DEFAULTS.output_dir,
data_dir: str=DEFAULTS.data_dir,
device: str=DEFAULTS.device,
num_workers: int=DEFAULTS.num_workers,
data_parallel: Union[list, None]=DEFAULTS.data_parallel,
verbose: bool=DEFAULTS.verbose,
force_deterministic: bool=DEFAULTS.force_deterministic
):
"""Function for running all runs from one config file.
Default values for all arguments can be found in ``stepback/defaults.py``.
Parameters
----------
exp_id : str
The experiment ID, equal to the name of the config file.
config_dir : str, optional
Directory where config file is stored, by default DEFAULTS.config_dir
output_dir : str, optional
Directory where output is stored, by default DEFAULTS.output_dir
data_dir : str, optional
Directory where datasets can be found,, by default DEFAULTS.data_dir
device : str, optional
Device string, by default DEFAULTS.device
If 'cuda' is specified, but not available on system, it switches to CPU.
num_workers : int, optional
Number of workers for DataLoader, by default DEFAULTS.num_workers
data_parallel : Union[list, None], optional
If not None, this specifies the device ids for DataParallel mode in Pytorch.
See https://pytorch.org/docs/stable/generated/torch.nn.DataParallel.html.
verbose : bool, optional
Verbose mode flag.
If True, prints progress bars, model architecture and other useful information.
force_deterministic : bool, optional
Whether to run in Pytorch (full) deterministic mode.
Not recommended, as this leads to substantial slow down. Seeds are set also without setting this to True.
"""
# load config
Conf = ConfigManager(exp_id=exp_id, config_dir=config_dir)
exp_list = Conf.create_config_list()
print(f"Created {len(exp_list)} different configurations.")
# initialize container for storing
C = Container(name=exp_id, output_dir=output_dir, as_json=True)
if force_deterministic:
torch.use_deterministic_algorithms(True)
print("Using Pytorch deterministic mode. This might lead to substantial slowdown.")
for j, config in enumerate(exp_list):
# each run gets id, by position in the list
B = Base(name=exp_id + f'_{j}',
config=config,
device=device,
data_dir=data_dir,
num_workers=num_workers,
data_parallel=data_parallel,
verbose=verbose)
B.setup()
B.run() # train and validate
C.append(B.results).store() # store results
print("All experiments have completed.")
return
if __name__ == '__main__':
args = parser.parse_args()
print(args)
run_one(args.id,
config_dir=args.config_dir,
output_dir=args.output_dir,
data_dir=args.data_dir,
device=args.device,
num_workers=args.num_workers,
data_parallel=args.data_parallel,
verbose=args.verbose,
force_deterministic=args.force_deterministic)