-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy patharguments.py
135 lines (123 loc) · 6.85 KB
/
arguments.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import argparse
class ArgParser(object):
def __init__(self):
parser = argparse.ArgumentParser()
# Model related arguments
parser.add_argument('--lamda', default=1, type=float,
help="lamda for mask")
parser.add_argument('--id', default='',
help="a name for identifying the model")
parser.add_argument('--num_mix', default=2, type=int,
help="number of sounds to mix")
parser.add_argument('--num_class', default=11, type=int,
help="number of classes in the dataset")
parser.add_argument('--arch_sound', default='unet7',
help="architecture of net_sound")
parser.add_argument('--arch_frame', default='resnet18dilated',
help="architecture of net_frame")
parser.add_argument('--arch_avol', default='AVOL',
help="architecture of net_avol")
parser.add_argument('--weights_model', default='',
help="weights to finetune whole model")
parser.add_argument('--weights_sound', default='',
help="weights to finetune net_sound")
parser.add_argument('--weights_frame', default='',
help="weights to finetune net_frame")
parser.add_argument('--weights_avol', default='',
help="weights to finetune net_avol")
parser.add_argument('--num_channels', default=32, type=int,
help='number of channels')
parser.add_argument('--num_frames', default=1, type=int,
help='number of frames')
parser.add_argument('--stride_frames', default=1, type=int,
help='sampling stride of frames')
parser.add_argument('--img_pool', default='maxpool',
help="avg or max pool image features")
parser.add_argument('--img_activation', default='sigmoid',
help="activation on the image features")
parser.add_argument('--sound_activation', default='no',
help="activation on the sound features")
parser.add_argument('--output_activation', default='sigmoid',
help="activation on the output")
parser.add_argument('--binary_mask', default=1, type=int,
help="whether to use bianry masks")
parser.add_argument('--mask_thres', default=0.5, type=float,
help="threshold in the case of binary masks")
parser.add_argument('--loss', default='l1',
help="loss function to use")
parser.add_argument('--weighted_loss', default=0, type=int,
help="weighted loss")
parser.add_argument('--log_freq', default=1, type=int,
help="log frequency scale")
# Data related arguments
parser.add_argument('--num_gpus', default=1, type=int,
help='number of gpus to use')
parser.add_argument('--batch_size_per_gpu', default=32, type=int,
help='input batch size')
parser.add_argument('--workers', default=32, type=int,
help='number of data loading workers')
parser.add_argument('--num_val', default=-1, type=int,
help='number of images to evalutate')
parser.add_argument('--num_vis', default=40, type=int,
help='number of images to evalutate')
parser.add_argument('--audLen', default=65535, type=int,
help='sound length')
parser.add_argument('--audRate', default=11025, type=int,
help='sound sampling rate')
parser.add_argument('--stft_frame', default=1022, type=int,
help="stft frame length")
parser.add_argument('--stft_hop', default=256, type=int,
help="stft hop length")
parser.add_argument('--imgSize', default=224, type=int,
help='size of input frame')
parser.add_argument('--frameRate', default=8, type=float,
help='video frame sampling rate')
parser.add_argument('--dataset', default='MUSIC', type=str,
help='Used dataset (MUSIC | audioset )')
parser.add_argument('--best_err', default=300, type=float,
help='best err')
parser.add_argument('--best_sdr', default=-100, type=float,
help='best sdr')
# Misc arguments
parser.add_argument('--seed', default=1234, type=int,
help='manual seed')
parser.add_argument('--ckpt', default='./myckpt',
help='folder to output checkpoints')
parser.add_argument('--disp_iter', type=int, default=20,
help='frequency to display')
parser.add_argument('--eval_epoch', type=int, default=1,
help='frequency to evaluate')
self.parser = parser
def add_train_arguments(self):
parser = self.parser
parser.add_argument('--mode', default='train',
help="train/eval")
parser.add_argument('--list_train',
default='data/train.csv')
parser.add_argument('--list_val',
default='data/val.csv')
parser.add_argument('--dup_trainset', default=100, type=int,
help='duplicate so that one epoch has more iters')
# optimization related arguments
parser.add_argument('--num_epoch', default=100, type=int,
help='epochs to train for')
parser.add_argument('--lr_frame', default=1e-4, type=float, help='LR')
parser.add_argument('--lr_sound', default=1e-3, type=float, help='LR')
parser.add_argument('--lr_avol', default=1e-3, type=float, help='LR')
parser.add_argument('--lr_steps',
nargs='+', type=int, default=[40, 60],
help='steps to drop LR in epochs')
parser.add_argument('--beta1', default=0.9, type=float,
help='momentum for sgd, beta1 for adam')
parser.add_argument('--weight_decay', default=1e-4, type=float,
help='weights regularizer')
self.parser = parser
def print_arguments(self, args):
print("Input arguments:")
for key, val in vars(args).items():
print("{:16} {}".format(key, val))
def parse_train_arguments(self):
self.add_train_arguments()
args = self.parser.parse_args()
self.print_arguments(args)
return args