Skip to content

Commit

Permalink
latest after ilm
Browse files Browse the repository at this point in the history
  • Loading branch information
Marvin84 committed Jan 8, 2025
1 parent c663e3b commit 1702c20
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,16 @@
out_joint_diphone="output/output_batch_major",
)

CONF_FH_TRIPHONE_FS_DECODING_TENSOR_CONFIG_V2_ILM = dataclasses.replace(
DecodingTensorMap.default(),
in_encoder_output="conformer_12_output/add",
out_encoder_output="encoder__output/output_batch_major",
out_right_context="right__output/output_batch_major",
out_left_context="left__output/output_batch_major",
out_center_state="output_sub_iLM/output_batch_major",
out_joint_diphone="output/output_batch_major",
)

BLSTM_FH_DECODING_TENSOR_CONFIG = dataclasses.replace(
CONF_FH_DECODING_TENSOR_CONFIG,
in_encoder_output="concat_lstm_fwd_6_lstm_bwd_6/concat_sources/concat",
Expand Down
17 changes: 5 additions & 12 deletions users/raissi/setups/common/decoder/BASE_factored_hybrid_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,7 +613,6 @@ def recognize_count_lm(
rerun_after_opt_lm=False,
name_override: Union[str, None] = None,
name_prefix: str = "",
gpu: Optional[bool] = None,
cpu_rqmt: Optional[int] = None,
mem_rqmt: Optional[int] = None,
crp_update: Optional[Callable[[rasr.RasrConfig], Any]] = None,
Expand All @@ -633,7 +632,6 @@ def recognize_count_lm(
num_encoder_output=num_encoder_output,
search_parameters=search_parameters,
calculate_stats=calculate_stats,
gpu=gpu,
cpu_rqmt=cpu_rqmt,
mem_rqmt=mem_rqmt,
is_min_duration=is_min_duration,
Expand Down Expand Up @@ -690,6 +688,7 @@ def recognize(
create_lattice: bool = True,
adv_search_extra_config: Optional[rasr.RasrConfig] = None,
adv_search_extra_post_config: Optional[rasr.RasrConfig] = None,
lm_lookahead_options: Optional = {},
search_rqmt_update=None,
cpu_omp_thread=2,
separate_lm_image_gc_generation: bool = False,
Expand Down Expand Up @@ -876,7 +875,7 @@ def recognize(
la_options = self.get_lookahead_options(clow=0, chigh=10)
name += "-cheating"
else:
la_options = self.get_lookahead_options()
la_options = self.get_lookahead_options(**lm_lookahead_options)
adv_search_extra_config = (
copy.deepcopy(adv_search_extra_config) if adv_search_extra_config is not None else rasr.RasrConfig()
)
Expand Down Expand Up @@ -926,6 +925,7 @@ def recognize(
else "decoding"
)


search = recog.AdvancedTreeSearchJob(
crp=search_crp,
feature_flow=self.feature_scorer_flow,
Expand All @@ -934,8 +934,8 @@ def recognize(
lm_lookahead=True,
lookahead_options=la_options,
eval_best_in_lattice=True,
use_gpu=gpu if gpu is not None else self.gpu,
rtf=rtf_gpu if rtf_gpu is not None and gpu else rtf_cpu if rtf_cpu is not None else rqms["rtf"],
use_gpu=self.gpu,
rtf=rtf_gpu if rtf_gpu is not None else rtf_cpu if rtf_cpu is not None else rqms["rtf"],
mem=rqms["mem"] if mem_rqmt is None else mem_rqmt,
cpu=2 if cpu_rqmt is None else cpu_rqmt,
lmgc_scorer=rasr.DiagonalMaximumScorer(self.mixtures) if self.lm_gc_simple_hash else None,
Expand Down Expand Up @@ -1107,7 +1107,6 @@ def recognize_optimize_scales(
altas_value=14.0,
altas_beam=14.0,
keep_value=10,
gpu: Optional[bool] = None,
cpu_rqmt: Optional[int] = None,
mem_rqmt: Optional[int] = None,
crp_update: Optional[Callable[[rasr.RasrConfig], Any]] = None,
Expand Down Expand Up @@ -1145,7 +1144,6 @@ def recognize_optimize_scales(
calculate_stats=False,
cpu_rqmt=cpu_rqmt,
crp_update=crp_update,
gpu=gpu,
is_min_duration=False,
keep_value=keep_value,
label_info=label_info,
Expand All @@ -1170,7 +1168,6 @@ def recognize_optimize_scales(
calculate_stats=False,
cpu_rqmt=cpu_rqmt,
crp_update=crp_update,
gpu=gpu,
is_min_duration=False,
keep_value=keep_value,
label_info=label_info,
Expand Down Expand Up @@ -1288,7 +1285,6 @@ def recognize_optimize_scales_v2(
altas_value=14.0,
altas_beam=14.0,
keep_value=10,
gpu: Optional[bool] = None,
cpu_rqmt: Optional[int] = None,
mem_rqmt: Optional[int] = None,
crp_update: Optional[Callable[[rasr.RasrConfig], Any]] = None,
Expand Down Expand Up @@ -1329,7 +1325,6 @@ def recognize_optimize_scales_v2(
calculate_stats=False,
cpu_rqmt=cpu_rqmt,
crp_update=crp_update,
gpu=gpu,
is_min_duration=False,
keep_value=keep_value,
label_info=label_info,
Expand Down Expand Up @@ -1358,7 +1353,6 @@ def recognize_optimize_scales_v2(
calculate_stats=False,
cpu_rqmt=cpu_rqmt,
crp_update=crp_update,
gpu=gpu,
is_min_duration=False,
keep_value=keep_value,
label_info=label_info,
Expand Down Expand Up @@ -1496,7 +1490,6 @@ def recognize_optimize_transtition_values(
calculate_stats=False,
cpu_rqmt=cpu_rqmt,
crp_update=crp_update,
gpu=gpu,
is_min_duration=False,
keep_value=keep_value,
label_info=label_info,
Expand Down
56 changes: 56 additions & 0 deletions users/raissi/setups/common/helpers/decode/iLM.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from dataclasses import dataclass
from typing import Dict

from i6_experiments.users.raissi.setups.common.data.factored_label import PhoneticContext, LabelInfo
from i6_experiments.users.raissi.setups.common.helpers.network import add_mlp

@dataclass
class ILMRenormINFO:
renormalize: bool
label_start_idx: int
label_end_idx: int


def add_zero_ilm_to_returnn_dict(network: Dict, context_type: PhoneticContext, label_info: LabelInfo, ilm_renorm_info: ILMRenormINFO, ilm_scale: float):

assert context_type in [PhoneticContext.diphone, PhoneticContext.triphone_forward], "Zero iLM can be done only for factored context-dependent models"

network["zero_enc"] = {"class": "eval", "from": "encoder-output", "eval": "source(0) * 0"}
if context_type == PhoneticContext.diphone:

network["input-ilm"] = {
"class": "copy",
"from": ["zero_enc", "pastEmbed"],
}

ilm_ff_layer = add_mlp(network=network, layer_name="input-ilm", source_layer="input-ilm", size=network["linear1-diphone"]["n_out"], n_layers=2)
network[ilm_ff_layer.replace("2", "1")]["reuse_params"] = "linear1-diphone"
network[ilm_ff_layer]["reuse_params"] = "linear2-diphone"

ilm_layer = "iLM"
network[ilm_layer] = {
"class": "linear",
"from": ilm_ff_layer,
"activation": "log_softmax",
"n_out": label_info.get_n_state_classes(),
"reuse_params": "center-output",
}

if ilm_renorm_info.renormalize:
start = ilm_renorm_info.label_start_idx
end = ilm_renorm_info.label_end_idx
network["iLM-renorm"] = {
"class": "eval",
"from": [ilm_layer],
"eval": f"tf.concat([source(0)[:, :81] - tf.math.log(1.0 - tf.exp(source(0)[:, 81:82])), tf.zeros(tf.shape(source(0)[:, 81:82])), source(0)[:, 82:] - tf.math.log(1.0 - tf.exp(source(0)[:, 81:82]))], axis=1)",
}
ilm_layer = "iLM-renorm"

network["output_sub_iLM"] = {
"class": "eval",
"from": ["center-output", ilm_layer],
"eval": f"tf.exp(safe_log(source(0)) - {ilm_scale} * source(1))",
"is_output_layer": True
}

return network
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,11 @@ def get_kazuki_lstm_config(
#model and graph info
trafo_config.loader.type = "meta"
trafo_config.loader.meta_graph_file = tk.Path("/u/raissi/Desktop/debug/lstm_lm/mini2.3.graph.meta", cached=True)
trafo_config.loader.saved_model_file = DelayedFormat("//u/raissi/Desktop/debug/lstm_lm/models/net-model-mini2.2/network.030")
trafo_config.loader.saved_model_file = DelayedFormat("/u/raissi/Desktop/debug/lstm_lm/models/net-model-mini2.3/network.050")
trafo_config.loader.required_libraries = self.library_path

trafo_config.type = "tfrnn"
trafo_config.vocab_file = "/u/zhou/asr-exps/librispeech/dependencies/kazuki_lstmlm_27062019/vocabulary"
trafo_config.vocab_file = "/work/asr4/rossenbach/custom_projects/kazuki_replicate_lm_training/vocab.word.freq_sorted.200k.alternative.txt"
trafo_config.transform_output_negate = True
trafo_config.vocab_unknown_word = "<UNK>"

Expand Down Expand Up @@ -356,7 +356,6 @@ def recognize_ls_trafo_lm(
rerun_after_opt_lm=False,
name_override: Union[str, None] = None,
name_prefix: str = "",
gpu: Optional[bool] = None,
cpu_rqmt: Optional[int] = None,
mem_rqmt: Optional[int] = None,
crp_update: Optional[Callable[[rasr.RasrConfig], Any]] = None,
Expand All @@ -369,7 +368,6 @@ def recognize_ls_trafo_lm(
return self.recognize(
add_sis_alias_and_output=add_sis_alias_and_output,
calculate_stats=calculate_stats,
gpu=gpu,
cpu_rqmt=cpu_rqmt,
mem_rqmt=mem_rqmt,
is_min_duration=is_min_duration,
Expand Down Expand Up @@ -413,20 +411,21 @@ def recognize_lstm_lm(
rerun_after_opt_lm=False,
name_override: Union[str, None] = None,
name_prefix: str = "",
gpu: Optional[bool] = None,
cpu_rqmt: Optional[int] = None,
mem_rqmt: Optional[int] = None,
crp_update: Optional[Callable[[rasr.RasrConfig], Any]] = None,
rtf_gpu: Optional[float] = None,
rtf_cpu: Optional[float] = None,
create_lattice: bool = True,
lm_lookahead_options: Optional = None,
adv_search_extra_config: Optional[rasr.RasrConfig] = None,
adv_search_extra_post_config: Optional[rasr.RasrConfig] = None,
) -> DecodingJobs:
if lm_lookahead_options is None:
lm_lookahead_options = {"clow": 2000, "chigh": 3000}
return self.recognize(
add_sis_alias_and_output=add_sis_alias_and_output,
calculate_stats=calculate_stats,
gpu=gpu,
cpu_rqmt=cpu_rqmt,
mem_rqmt=mem_rqmt,
is_min_duration=is_min_duration,
Expand All @@ -447,6 +446,7 @@ def recognize_lstm_lm(
crp_update=crp_update,
rtf_cpu=rtf_cpu,
rtf_gpu=rtf_gpu,
lm_lookahead_options=lm_lookahead_options,
create_lattice=create_lattice,
adv_search_extra_config=adv_search_extra_config,
adv_search_extra_post_config=adv_search_extra_post_config,
Expand Down

0 comments on commit 1702c20

Please sign in to comment.