Skip to content

Commit

Permalink
more
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Jan 6, 2025
1 parent f83cfe4 commit 603eef5
Showing 1 changed file with 43 additions and 0 deletions.
43 changes: 43 additions & 0 deletions users/zeyer/experiments/exp2024_04_23_baselines/ctc_claix2023.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,49 @@ def py():
env_updates={"PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True"},
)

from .model_ext.ctc_sep_net import ModelSepNet, FeedForwardNet, ctc_training_with_sep_net

# Time downsampling 6 (standard), spm10k.
# Separate FF net.
ctc_train_exp(
f"n12-spm10k-sepFf_alpha05-auxAED-b150k",
config_96gb_bf16_accgrad1,
train_def=ctc_training_with_sep_net,
model_config={
"ctc_model_cls": rf.build_dict(ModelSepNet)["class"],
"separate_enc_net": rf.build_dict(FeedForwardNet),
"enc_conformer_layer": rf.build_dict(
ConformerEncoderLayer,
ff=rf.build_dict(
ConformerPositionwiseFeedForward, activation=rf.build_dict(rf.relu_square), with_bias=False
),
num_heads=8,
),
"feature_batch_norm": True,
"num_enc_layers": 12,
},
config_updates={
**_get_cfg_lrlin_oclr_by_bs_nep_v3(150_000, 100, batch_size_factor=_batch_size_factor),
"optimizer.weight_decay": 1e-2,
"max_seq_length_default_target": None,
# Note on max seq len stats: Before, when we used max_seq_length_default_target=75 with bpe10k,
# out of 281241 seqs in train, we removed only 71 seqs.
# With max seq len 19.5 secs on the audio, we also remove exactly 71 seqs.
"max_seq_length_default_input": 19.5 * _raw_sample_rate,
"__train_audio_preprocess": speed_pert_librosa_config,
"speed_pert_discrete_values": [0.7, 0.8, 0.9, 1.0, 1.1],
"aux_attention_decoder": rf.build_dict(TransformerDecoder, num_layers=6), # purely used for training
"use_fixed_ctc_grad": "v2",
"sep_net_grad_interpolate_alpha": 0.5,
},
post_config_updates={"log_grad_norm": True, "__multi_proc_dataset_opts": {"num_workers": 25}},
vocab="spm10k",
train_vocab_opts={"other_opts": {"class": "SamplingBytePairEncoding", "breadth_prob": 0.01}},
dataset_train_opts={"train_epoch_split": 1, "train_epoch_wise_filter": None},
# avoid OOM
env_updates={"PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True"},
)

from i6_experiments.common.setups import serialization
from sisyphus import gs

Expand Down

0 comments on commit 603eef5

Please sign in to comment.