-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathrun.py
92 lines (73 loc) · 3.77 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import os
import argparse
from work_module import prepare_data, main_dnn, evaluate
CURRENT_PATH = os.path.split(os.getcwd())[-1]
MINIDATA = True #MINIDATA
if MINIDATA == True:
WORKSPACE = os.path.join("D:/Python_output",CURRENT_PATH+"_MINIDATA")
TR_SPEECH_DIR="mini_data/train_speech"
TR_NOISE_DIR="mini_data/train_noise"
TE_SPEECH_DIR="mini_data/test_speech"
TE_NOISE_DIR="mini_data/test_noise"
else:
WORKSPACE = "D:/Python_output/"+CURRENT_PATH
TR_SPEECH_DIR="D:/train/speech"
TR_NOISE_DIR="D:/noise"
TE_SPEECH_DIR="D:/test"
TE_NOISE_DIR="D:/noise"
def get_args():
parser = argparse.ArgumentParser(description="Speech Enhancement using DNN.")
parser.add_argument('-sr', '--sample_rate', default=8000, type=int,
help="target sampling rate of audio")
parser.add_argument('--fft', default=256, type=int,
help="FFT size")
parser.add_argument('--window', default=256, type=int,
help="window size")
parser.add_argument('--overlap', default=192, type=int,
help="overlap size of spectrogram")
parser.add_argument('--n_concat', default=7, type=int,
help="number of frames to concatentate")
parser.add_argument('--tr_snr', default=0, type=int,
help="SNR of training data")
parser.add_argument('--te_snr', default=0, type=int,
help="SNR of test data")
parser.add_argument('--iter', default=10000, type=int,
help="number of iteration for training")
parser.add_argument('--debug_inter', default=1000, type=int,
help="Interval to debug model")
parser.add_argument('--save_inter', default=5000, type=int,
help="Interval to save model")
parser.add_argument('-b', '--batch_size', default=32, type=int)
parser.add_argument('--lr', default=0.0001, type=float,
help="Initial learning rate")
parser.add_argument('-visual', '--visualize', default=1, type=int, choices=[0,1],
help="If value is 1, visualization of result of inference")
parser.add_argument('--train', default=1, type=int, choices=[0,1],
help="If the value is 1, run training")
parser.add_argument('--test', default=1, type=int, choices=[0,1],
help="If the value is 1, run test")
args = parser.parse_args()
return args
if __name__ == '__main__':
args = get_args()
DIRECTORY = {}
DIRECTORY['WORKSPACE'] = WORKSPACE
DIRECTORY['TR_SPEECH_DIR'] = TR_SPEECH_DIR
DIRECTORY['TR_NOISE_DIR'] = TR_NOISE_DIR
DIRECTORY['TE_SPEECH_DIR'] = TE_SPEECH_DIR
DIRECTORY['TE_NOISE_DIR'] = TE_NOISE_DIR
assert args.iter >= args.debug_inter and args.iter >=args.save_inter, "Number of training iterations should greater than or equal to the debugging and store interval"
prepare_data.create_mixture_csv(DIRECTORY, args, mode='train')
prepare_data.create_mixture_csv(DIRECTORY, args, mode='test')
prepare_data.calculate_mixture_features(DIRECTORY, args, mode='train')
prepare_data.calculate_mixture_features(DIRECTORY, args, mode='test')
prepare_data.pack_features(DIRECTORY, args, mode='train')
prepare_data.pack_features(DIRECTORY, args, mode='test')
if args.train==1:
prepare_data.compute_scaler(DIRECTORY, args, mode='train')
main_dnn.train(DIRECTORY, args)
evaluate.plot_training_stat(DIRECTORY, args, bgn_iter=0, fin_iter=args.iter, interval_iter=args.debug_inter)
if args.test==1:
main_dnn.inference(DIRECTORY, args)
evaluate.calculate_pesq(DIRECTORY,args)
evaluate.get_stats(DIRECTORY,args)