-
Notifications
You must be signed in to change notification settings - Fork 12
/
Copy pathtransfer_single_domain.py
79 lines (63 loc) · 2.63 KB
/
transfer_single_domain.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
''' Script for training all the transfer datasets for a single pretrained checkpoint
Sample usage:
python3 transfer_single_domain.py \
--domain=satellite_images --gpus=0 --exp_base_dir=/path/to/base_dir/ \
--transfer_data_root=/path/to/data --pretrain_exp_name=shed_librispeech_0.15_pytorch '--ckpt=epoch=293-step=99999.ckpt'
'''
import argparse
import os
import subprocess
from src.datasets.catalog import PRETRAIN_TO_TRANSFER_DICT
def parse_args():
parser = argparse.ArgumentParser()
# Required
parser.add_argument('--domain', type=str, required=True, choices=list(PRETRAIN_TO_TRANSFER_DICT.keys()))
parser.add_argument('--exp_base_dir', type=str, required=True)
parser.add_argument('--pretrain_exp_name', type=str, required=True)
parser.add_argument('--ckpt', type=str, required=True)
parser.add_argument('--data_root', type=str, required=True)
# Optional
parser.add_argument('--gpus', type=str, required=False, default='0')
parser.add_argument("--debug", required=False, action='store_true', default=False)
return parser.parse_args()
def run(command, debug=False):
'''Runs command with error catching'''
print(command)
if debug:
return
try:
subprocess.run(command, check=True, shell=True)
except subprocess.CalledProcessError as error:
print(error.output)
def main():
args = parse_args()
domain = PRETRAIN_TO_TRANSFER_DICT[args.domain]
# Print pretty domain summary
header = f'Spawn script for {args.domain.upper()} domain'
print(header)
print('=' * len(header))
print('Pretrain dataset:')
print(f' * {domain.pretrain}')
print('Transfer datasets:')
for transfer in domain.transfers:
print(f' * {transfer}')
print('=' * len(header))
for transfer in domain.transfers:
ckpt = os.path.join(args.exp_base_dir, args.pretrain_exp_name, args.ckpt)
# Name like shed_librispeech_0.15_pytorch
algorithm, _, fraction, framework = args.pretrain_exp_name.split('_')
transfer_exp_name = f"{algorithm}_{domain.pretrain}_{transfer}_{fraction}_{framework}_transfer"
print(transfer_exp_name)
command = (
'python3 transfer.py '
f'\'ckpt=\"{ckpt}\"\' '
f'dataset={transfer} '
f'exp.name={transfer_exp_name} '
f'exp.base_dir={args.exp_base_dir} '
f'data_root={args.data_root} '
f'gpus={args.gpus} '
f'framework={framework}'
)
run(command, debug=args.debug)
if __name__ == '__main__':
main()