-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathrun_nnlm_lrs2_lstm.sh
executable file
·136 lines (127 loc) · 5.95 KB
/
run_nnlm_lrs2_lstm.sh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
#!/usr/bin/env bash
# Copyright 2012 Johns Hopkins University (author: Daniel Povey)
# 2020 Ke Li
# This script trains an RNN (including LSTM and GRU) or Transformer-based language model with PyTorch and performs N-best rescoring
stage=0
gpu=1
LM=ami_fsh.o3g.kn.pr1-7 # 4gram ami.o3g.kn.pr1-7
#ac_model_dir=exp/chain/tdnn7q_sp
#ac_model_dir=data/pytorchnn_ami/rescore/exp/$mic
#decode_dir_suffix=
#pytorch_path=exp/pytorch_lstm_bz32_hdim1024_ami+fisher+swbd
#pytorch_path=exp/pytorch_lstm_bz32_hdim1024_ami
#nn_model=$pytorch_path/model.pt
model_type=LSTM # LSTM, GRU or Transformer
embedding_dim=1024 # 512 for Transformer (to reproduce the perplexities and WERs above)
hidden_dim=1024 # 512 for Transformer
nlayers=2 # 6 for Transformer
nhead=8 # for Transformer
learning_rate=5 # 0.1 for Transformer
seq_len=100
uncertainty=Gaussian # none for baseline, options: Bayesian, Gaussian
L_bayes_pos=0 # LSTM Bayesian position: [0: standard | 1: input_gate | 2: forget_gate | 3: cell_gate | 4: output_gate]
L_gauss_pos=00 # LSTM Gaussian Activation position: [0: None | 1: input_gate | 2: forget_gate | 3: cell_gate | 4: output_gate | 5: cell | 6: hidden | 7: inputs]
prior_path=steps/pytorchnn/prior/lstm # load pretrained prior model
prior=False # load a pretrained model
mark=marks # save_path disctinct to uncover
inter_flag=0
inter_alpha=0.8
L_v_pos=00
##################################################################################################
dropout=0.2 # baseline 0.2 | bayesian initial 0.0
##################################################################################################
itpr=1.0
. ./cmd.sh
. ./path.sh
. ./utils/parse_options.sh
set -e
export CUDA_VISIBLE_DEVICES=$gpu
ac_model_dir=exp/chain/lrs2/
if [ "$uncertainty" == "Bayesian" ]; then
pytorch_path=lrs2/pytorch-${model_type}-emb${embedding_dim}_hid${hidden_dim}_nly${nlayers}-${dropout}-${uncertainty}-${L_bayes_pos}-pre${prior}-${mark}
nn_model=$pytorch_path/model.pt
decode_dir_suffix=pytorch-${model_type}-emb${embedding_dim}_hid${hidden_dim}_nly${nlayers}-${dropout}-${uncertainty}-${L_bayes_pos}-pre${prior}-${mark}-itpr${itpr}-ib${inter_flag}-${inter_alpha}
elif [ "$uncertainty" == "Gaussian" ]; then
pytorch_path=lrs2/pytorch-${model_type}-emb${embedding_dim}_hid${hidden_dim}_nly${nlayers}-${dropout}-${uncertainty}-GP${L_gauss_pos}-pre${prior}-${mark}
nn_model=$pytorch_path/model.pt
decode_dir_suffix=pytorch-${model_type}-emb${embedding_dim}_hid${hidden_dim}_nly${nlayers}-${dropout}-${uncertainty}-GP${L_gauss_pos}-pre${prior}-${mark}-itpr${itpr}-ib${inter_flag}-${inter_alpha}
else
pytorch_path=lrs2/pytorch-${model_type}-emb${embedding_dim}_hid${hidden_dim}_nly${nlayers}-${dropout}-${uncertainty}-${L_v_pos}-pre${prior}-${mark}
nn_model=$pytorch_path/model.pt
#decode_dir_suffix=pytorch-${model_type}-emb${embedding_dim}_hid${hidden_dim}_nly${nlayers}-${lmdata}-${dropout}-${uncertainty}-${L_bayes_pos}-pre${prior}-${mark}-itpr${itpr}
decode_dir_suffix=pytorch-${model_type}-emb${embedding_dim}_hid${hidden_dim}_nly${nlayers}-${dropout}-${uncertainty}-${L_v_pos}-pre${prior}-${mark}-itpr${itpr}-ib${inter_flag}-${inter_alpha}
#data_dir=data/pytorchnn_ami/ami
fi
data_dir=data_lrs2/
#mkdir -p $data_dir
mkdir -p $pytorch_path
# Check if PyTorch is installed to use with python
if python steps/pytorchnn/check_py.py 2>/dev/null; then
echo PyTorch is ready to use on the python side. This is good.
else
echo PyTorch not found on the python side.
echo Please install PyTorch first. For example, you can install it with conda:
echo "conda install pytorch torchvision cudatoolkit=10.2 -c pytorch", or
echo with pip: "pip install torch torchvision". If you already have PyTorch
echo installed somewhere else, you need to add it to your PATH.
echo Note: you need to install higher version than PyTorch 1.1 to train Transformer models
exit 1
fi
#if [ $stage -le 0 ]; then
# local/pytorchnn/data_prep.sh $data_dir
#fi
if [ $stage -le 1 ]; then
# Train a PyTorch neural network language model.
echo "Start neural network language model training."
#$cuda_cmd $pytorch_path/log/train.log utils/parallel/limit_num_gpus.sh \
python steps/pytorchnn/train.py --data $data_dir \
--model $model_type \
--emsize $embedding_dim \
--nhid $hidden_dim \
--nlayers $nlayers \
--nhead $nhead \
--lr $learning_rate \
--dropout $dropout \
--seq_len $seq_len \
--clip 1.0 \
--batch-size 32 \
--epoch 32 \
--save $nn_model \
--uncertainty $uncertainty \
--L_bayes_pos $L_bayes_pos \
--L_gauss_pos $L_gauss_pos \
--L_v_pos $L_v_pos \
--prior $prior \
--prior_path $prior_path \
--tied \
--cuda > $pytorch_path/train.log
fi
#LM=ami_fsh.o3g.kn.pr1-7 # Using the 4-gram const arpa file as old lm
#LM=ami.o3g.kn.pr1-7
if [ $stage -le 2 ]; then
echo "$0: Perform nbest-rescoring on $ac_model_dir with a PyTorch trained $model_type LM."
for decode_set in clean_avsr TF MVDR FAS; do
decode_dir=${ac_model_dir}/decode_${decode_set}
steps/pytorchnn/lmrescore_nbest_pytorchnn_jwyu.sh \
--stage 1 \
--cmd "$decode_cmd --mem 4G" \
--N 20 \
--model-type $model_type \
--embedding_dim $embedding_dim \
--hidden_dim $hidden_dim \
--nlayers $nlayers \
--nhead $nhead \
--uncertainty $uncertainty \
--L_bayes_pos $L_bayes_pos \
--L_gauss_pos $L_gauss_pos \
--L_v_pos $L_v_pos \
--interpolation_flag $inter_flag \
--inter_alpha $inter_alpha \
$itpr data-lrs2-clean/lang/ $nn_model $data_dir/words.txt \
data-lrs2-clean/test/ ${decode_dir} \
${decode_dir}_${decode_dir_suffix}
done
fi
exit 0
# for i in data/pytorchnn_ami/rescore/exp/ihm/tdnnli_sp_bi/ihm/decode_dev; do grep Sum $i/*sco*/*ys | ./utils/best_wer.sh ;done
#cat exp/chain/lrs2//decode_clean_avsr/wer_* | ./utils/best_wer.sh