-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathwandb_fast_reid.py
145 lines (115 loc) · 3.83 KB
/
wandb_fast_reid.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import argparse
import os
from pathlib import Path
import sys
from fastreid.config import get_cfg
from fastreid.engine import DefaultTrainer, default_setup, launch
from fastreid.utils.checkpoint import Checkpointer
sys.path.append('.')
def increment_path(path, exist_ok=False, sep='', mkdir=False):
path = Path(path) # os-agnostic
if path.exists() and not exist_ok:
path, suffix = (path.with_suffix(''), path.suffix) if path.is_file() else (path, '')
# Method 1
for n in range(2, 9999):
p = f'{path}{sep}{n}{suffix}' # increment path
if not os.path.exists(p): #
break
path = Path(p)
if mkdir:
path.mkdir(parents=True, exist_ok=True) # make directory
return path
def setup(args):
"""
Create configs and perform basic setups.
"""
cfg = get_cfg()
cfg.merge_from_file(args.config_file)
cfg.merge_from_list(args.opts)
cfg.freeze()
default_setup(cfg, args)
return cfg
def main(args):
cfg = setup(args)
if args.eval_only:
cfg.defrost()
cfg.MODEL.BACKBONE.PRETRAIN = False
model = DefaultTrainer.build_model(cfg)
Checkpointer(model).load(cfg.MODEL.WEIGHTS) # load trained model
res = DefaultTrainer.test(cfg, model)
return res
trainer = DefaultTrainer(cfg)
trainer.resume_or_load(resume=args.resume)
return trainer.train()
def default_argument_parser():
"""
Create a parser with some common arguments used by fastreid users.
Returns:
argparse.ArgumentParser:
"""
parser = argparse.ArgumentParser(description="fastreid Training")
parser.add_argument("--config-file", default="", metavar="FILE", help="path to config file")
parser.add_argument(
"--resume",
action="store_true",
help="whether to attempt to resume from the checkpoint directory",
)
parser.add_argument("--eval-only", action="store_true", help="perform evaluation only")
parser.add_argument("--num-gpus", type=int, default=1, help="number of gpus *per machine*")
parser.add_argument("--num-machines", type=int, default=1, help="total number of machines")
parser.add_argument(
"--machine-rank", type=int, default=0, help="the rank of this machine (unique per machine)"
)
# PyTorch still may leave orphan processes in multi-gpu training.
# Therefore we use a deterministic way to obtain port,
# so that users are aware of orphan processes by seeing the port occupied.
port = 2 ** 15 + 2 ** 14 + hash(os.getuid() if sys.platform != "win32" else 1) % 2 ** 14
parser.add_argument("--dist-url", default="tcp://127.0.0.1:{}".format(port))
parser.add_argument(
"opts",
help="Modify config options using the command-line",
default=None,
nargs=argparse.REMAINDER,
)
return parser
FLOAT_OPTS_LIST = [
'PROB',
'BRIGHTNESS',
'CONTRAST',
'HUE',
'SATURATION',
'ALPHA',
'EPSILON',
'MODEL.LOSSES.CE.SCALE',
'MODEL.LOSSES.TRI.SCALE',
'MARGIN',
'BIAS_LR_FACTOR',
'CLIP_VALUE',
'NORM_TYPE',
'GAMMA',
'HEADS_LR_FACTOR',
'MOMENTUM',
'WARMUP_FACTOR',
'WEIGHT_DECAY',
'WEIGHT_DECAY_BIAS',
'WEIGHT_DECAY_NORM'
]
if __name__ == "__main__":
args = default_argument_parser().parse_args()
aux = [x.split("=") for x in args.opts]
args.opts = []
for x in aux:
if x[0] == "OUTPUT_DIR":
x[1] = str(increment_path(x[1], mkdir=True))
if any([k in x[0] for k in FLOAT_OPTS_LIST]):
x[1] = str(float(x[1]))
args.opts += x
print("\n\nCommand Line Args:", args)
launch(
main,
args.num_gpus,
num_machines=args.num_machines,
machine_rank=args.machine_rank,
dist_url=args.dist_url,
args=(args,),
)