Skip to content

Commit

Permalink
prior running mean
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Jan 3, 2025
1 parent d9447de commit 3b55bce
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 0 deletions.
24 changes: 24 additions & 0 deletions users/zeyer/experiments/exp2024_04_23_baselines/ctc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2115,6 +2115,12 @@ def __init__(
auxiliary=True,
non_critical_for_restore=True,
)
self.prior_running_mean_momentum = config.typed_value("prior_running_mean_momentum", None)
self.prior_running_mean = None
if self.prior_running_mean_momentum is not None:
self.prior_running_mean = rf.Parameter(
[self.wb_target_dim], auxiliary=True, initial=1.0 / self.wb_target_dim.dimension
)

if target_dim.vocab and not wb_target_dim.vocab:
from returnn.datasets.util.vocabulary import Vocabulary
Expand Down Expand Up @@ -2305,6 +2311,20 @@ def log_probs_wb_from_logits(self, logits: Tensor) -> Tensor:
out_dim=self.wb_target_dim,
)
log_probs.feature_dim = self.wb_target_dim

if self.prior_running_mean_momentum is not None:

def _update_running_stats():
batch_prior = rf.reduce_mean(
rf.exp(log_probs), axis=[d for d in log_probs.dims if d != self.wb_target_dim]
)
assert batch_prior.dims == (self.wb_target_dim,)
self.prior_running_mean.assign_add(
self.prior_running_mean_momentum * (batch_prior - self.prior_running_mean)
)

rf.cond(rf.get_run_ctx().train_flag, _update_running_stats, lambda: None)

log_probs = self._maybe_apply_on_log_probs(log_probs)
if self.ctc_am_scale == 1 and self.ctc_prior_scale == 0: # fast path
return log_probs
Expand All @@ -2319,6 +2339,10 @@ def log_probs_wb_from_logits(self, logits: Tensor) -> Tensor:
elif self.ctc_prior_type == "static":
log_prob_prior = self.static_prior
assert log_prob_prior.dims == (self.wb_target_dim,)
elif self.ctc_prior_type == "running_mean":
assert self.prior_running_mean is not None
log_prob_prior = rf.safe_log(self.prior_running_mean)
assert log_prob_prior.dims == (self.wb_target_dim,)
else:
raise ValueError(f"invalid ctc_prior_type {self.ctc_prior_type!r}")
log_probs -= log_prob_prior * self.ctc_prior_scale
Expand Down
50 changes: 50 additions & 0 deletions users/zeyer/experiments/exp2024_04_23_baselines/ctc_claix2023.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,6 +707,56 @@ def py():
else (),
)

# Diff am/prior scales, with downsampling 6, spm10k.
for am_scale, prior_scale, name, prior_type, extra_train_opts in [
# Baseline (1.0, 0.0, None):
(0.7, 0.0, "", None, {}),
(0.5, 0.2, "-priorBatch", "batch", {}),
(0.7, 0.2, "-priorBatch", "batch", {}),
(0.7, 0.2, "-priorRunningMean1e_3", "running_mean", {"prior_running_mean_momentum": 0.001}),
]:
ctc_train_exp(
f"n12-spm10k-am{am_scale}-prior{prior_scale}{name}-auxAED-b150k",
config_96gb_bf16_accgrad1,
model_config={
"enc_conformer_layer": rf.build_dict(
ConformerEncoderLayer,
ff=rf.build_dict(
ConformerPositionwiseFeedForward, activation=rf.build_dict(rf.relu_square), with_bias=False
),
num_heads=8,
),
"feature_batch_norm": True,
"num_enc_layers": 12,
},
config_updates={
**_get_cfg_lrlin_oclr_by_bs_nep_v3(150_000, 100, batch_size_factor=_batch_size_factor),
"optimizer.weight_decay": 1e-2,
"max_seq_length_default_target": None,
# Note on max seq len stats: Before, when we used max_seq_length_default_target=75 with bpe10k,
# out of 281241 seqs in train, we removed only 71 seqs.
# With max seq len 19.5 secs on the audio, we also remove exactly 71 seqs.
"max_seq_length_default_input": 19.5 * _raw_sample_rate,
"__train_audio_preprocess": speed_pert_librosa_config,
"speed_pert_discrete_values": [0.7, 0.8, 0.9, 1.0, 1.1],
"aux_attention_decoder": rf.build_dict(TransformerDecoder, num_layers=6), # purely used for training
# Only for training:
"ctc_am_scale": am_scale,
"ctc_prior_scale": prior_scale,
"ctc_prior_type": prior_type,
**extra_train_opts,
"use_fixed_ctc_grad": "v2",
},
post_config_updates={"log_grad_norm": True, "__multi_proc_dataset_opts": {"num_workers": 25}},
vocab="spm10k",
train_vocab_opts={"other_opts": {"class": "SamplingBytePairEncoding", "breadth_prob": 0.01}},
dataset_train_opts={"train_epoch_split": 1, "train_epoch_wise_filter": None},
# avoid OOM
env_updates={"PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True"},
)

# TODO exp joint FF, or maybe joint Conformer wo att (with shared params?), ...

# Time downsampling 6.
# Comparing different vocabs, samplings (using max_seq_length_default_input).
for vocab, sample, alpha in [
Expand Down

0 comments on commit 3b55bce

Please sign in to comment.