From 1702c20f4932e112b03039b1ddccb81a9387379c Mon Sep 17 00:00:00 2001 From: marvin84 Date: Wed, 8 Jan 2025 19:41:17 +0100 Subject: [PATCH] latest after ilm --- .../configs/LFR_factored/baseline/config.py | 10 ++++ .../decoder/BASE_factored_hybrid_search.py | 17 ++---- .../setups/common/helpers/decode/iLM.py | 56 +++++++++++++++++++ .../decoder/LBS_factored_hybrid_search.py | 12 ++-- 4 files changed, 77 insertions(+), 18 deletions(-) create mode 100644 users/raissi/setups/common/helpers/decode/iLM.py diff --git a/users/raissi/experiments/librispeech/configs/LFR_factored/baseline/config.py b/users/raissi/experiments/librispeech/configs/LFR_factored/baseline/config.py index e1c9b9282..e36712110 100644 --- a/users/raissi/experiments/librispeech/configs/LFR_factored/baseline/config.py +++ b/users/raissi/experiments/librispeech/configs/LFR_factored/baseline/config.py @@ -90,6 +90,16 @@ out_joint_diphone="output/output_batch_major", ) +CONF_FH_TRIPHONE_FS_DECODING_TENSOR_CONFIG_V2_ILM = dataclasses.replace( + DecodingTensorMap.default(), + in_encoder_output="conformer_12_output/add", + out_encoder_output="encoder__output/output_batch_major", + out_right_context="right__output/output_batch_major", + out_left_context="left__output/output_batch_major", + out_center_state="output_sub_iLM/output_batch_major", + out_joint_diphone="output/output_batch_major", +) + BLSTM_FH_DECODING_TENSOR_CONFIG = dataclasses.replace( CONF_FH_DECODING_TENSOR_CONFIG, in_encoder_output="concat_lstm_fwd_6_lstm_bwd_6/concat_sources/concat", diff --git a/users/raissi/setups/common/decoder/BASE_factored_hybrid_search.py b/users/raissi/setups/common/decoder/BASE_factored_hybrid_search.py index c151243e8..c75a2159e 100644 --- a/users/raissi/setups/common/decoder/BASE_factored_hybrid_search.py +++ b/users/raissi/setups/common/decoder/BASE_factored_hybrid_search.py @@ -613,7 +613,6 @@ def recognize_count_lm( rerun_after_opt_lm=False, name_override: Union[str, None] = None, name_prefix: str = "", - gpu: Optional[bool] = None, cpu_rqmt: Optional[int] = None, mem_rqmt: Optional[int] = None, crp_update: Optional[Callable[[rasr.RasrConfig], Any]] = None, @@ -633,7 +632,6 @@ def recognize_count_lm( num_encoder_output=num_encoder_output, search_parameters=search_parameters, calculate_stats=calculate_stats, - gpu=gpu, cpu_rqmt=cpu_rqmt, mem_rqmt=mem_rqmt, is_min_duration=is_min_duration, @@ -690,6 +688,7 @@ def recognize( create_lattice: bool = True, adv_search_extra_config: Optional[rasr.RasrConfig] = None, adv_search_extra_post_config: Optional[rasr.RasrConfig] = None, + lm_lookahead_options: Optional = {}, search_rqmt_update=None, cpu_omp_thread=2, separate_lm_image_gc_generation: bool = False, @@ -876,7 +875,7 @@ def recognize( la_options = self.get_lookahead_options(clow=0, chigh=10) name += "-cheating" else: - la_options = self.get_lookahead_options() + la_options = self.get_lookahead_options(**lm_lookahead_options) adv_search_extra_config = ( copy.deepcopy(adv_search_extra_config) if adv_search_extra_config is not None else rasr.RasrConfig() ) @@ -926,6 +925,7 @@ def recognize( else "decoding" ) + search = recog.AdvancedTreeSearchJob( crp=search_crp, feature_flow=self.feature_scorer_flow, @@ -934,8 +934,8 @@ def recognize( lm_lookahead=True, lookahead_options=la_options, eval_best_in_lattice=True, - use_gpu=gpu if gpu is not None else self.gpu, - rtf=rtf_gpu if rtf_gpu is not None and gpu else rtf_cpu if rtf_cpu is not None else rqms["rtf"], + use_gpu=self.gpu, + rtf=rtf_gpu if rtf_gpu is not None else rtf_cpu if rtf_cpu is not None else rqms["rtf"], mem=rqms["mem"] if mem_rqmt is None else mem_rqmt, cpu=2 if cpu_rqmt is None else cpu_rqmt, lmgc_scorer=rasr.DiagonalMaximumScorer(self.mixtures) if self.lm_gc_simple_hash else None, @@ -1107,7 +1107,6 @@ def recognize_optimize_scales( altas_value=14.0, altas_beam=14.0, keep_value=10, - gpu: Optional[bool] = None, cpu_rqmt: Optional[int] = None, mem_rqmt: Optional[int] = None, crp_update: Optional[Callable[[rasr.RasrConfig], Any]] = None, @@ -1145,7 +1144,6 @@ def recognize_optimize_scales( calculate_stats=False, cpu_rqmt=cpu_rqmt, crp_update=crp_update, - gpu=gpu, is_min_duration=False, keep_value=keep_value, label_info=label_info, @@ -1170,7 +1168,6 @@ def recognize_optimize_scales( calculate_stats=False, cpu_rqmt=cpu_rqmt, crp_update=crp_update, - gpu=gpu, is_min_duration=False, keep_value=keep_value, label_info=label_info, @@ -1288,7 +1285,6 @@ def recognize_optimize_scales_v2( altas_value=14.0, altas_beam=14.0, keep_value=10, - gpu: Optional[bool] = None, cpu_rqmt: Optional[int] = None, mem_rqmt: Optional[int] = None, crp_update: Optional[Callable[[rasr.RasrConfig], Any]] = None, @@ -1329,7 +1325,6 @@ def recognize_optimize_scales_v2( calculate_stats=False, cpu_rqmt=cpu_rqmt, crp_update=crp_update, - gpu=gpu, is_min_duration=False, keep_value=keep_value, label_info=label_info, @@ -1358,7 +1353,6 @@ def recognize_optimize_scales_v2( calculate_stats=False, cpu_rqmt=cpu_rqmt, crp_update=crp_update, - gpu=gpu, is_min_duration=False, keep_value=keep_value, label_info=label_info, @@ -1496,7 +1490,6 @@ def recognize_optimize_transtition_values( calculate_stats=False, cpu_rqmt=cpu_rqmt, crp_update=crp_update, - gpu=gpu, is_min_duration=False, keep_value=keep_value, label_info=label_info, diff --git a/users/raissi/setups/common/helpers/decode/iLM.py b/users/raissi/setups/common/helpers/decode/iLM.py new file mode 100644 index 000000000..fcabb556d --- /dev/null +++ b/users/raissi/setups/common/helpers/decode/iLM.py @@ -0,0 +1,56 @@ +from dataclasses import dataclass +from typing import Dict + +from i6_experiments.users.raissi.setups.common.data.factored_label import PhoneticContext, LabelInfo +from i6_experiments.users.raissi.setups.common.helpers.network import add_mlp + +@dataclass +class ILMRenormINFO: + renormalize: bool + label_start_idx: int + label_end_idx: int + + +def add_zero_ilm_to_returnn_dict(network: Dict, context_type: PhoneticContext, label_info: LabelInfo, ilm_renorm_info: ILMRenormINFO, ilm_scale: float): + + assert context_type in [PhoneticContext.diphone, PhoneticContext.triphone_forward], "Zero iLM can be done only for factored context-dependent models" + + network["zero_enc"] = {"class": "eval", "from": "encoder-output", "eval": "source(0) * 0"} + if context_type == PhoneticContext.diphone: + + network["input-ilm"] = { + "class": "copy", + "from": ["zero_enc", "pastEmbed"], + } + + ilm_ff_layer = add_mlp(network=network, layer_name="input-ilm", source_layer="input-ilm", size=network["linear1-diphone"]["n_out"], n_layers=2) + network[ilm_ff_layer.replace("2", "1")]["reuse_params"] = "linear1-diphone" + network[ilm_ff_layer]["reuse_params"] = "linear2-diphone" + + ilm_layer = "iLM" + network[ilm_layer] = { + "class": "linear", + "from": ilm_ff_layer, + "activation": "log_softmax", + "n_out": label_info.get_n_state_classes(), + "reuse_params": "center-output", + } + + if ilm_renorm_info.renormalize: + start = ilm_renorm_info.label_start_idx + end = ilm_renorm_info.label_end_idx + network["iLM-renorm"] = { + "class": "eval", + "from": [ilm_layer], + "eval": f"tf.concat([source(0)[:, :81] - tf.math.log(1.0 - tf.exp(source(0)[:, 81:82])), tf.zeros(tf.shape(source(0)[:, 81:82])), source(0)[:, 82:] - tf.math.log(1.0 - tf.exp(source(0)[:, 81:82]))], axis=1)", + } + ilm_layer = "iLM-renorm" + + network["output_sub_iLM"] = { + "class": "eval", + "from": ["center-output", ilm_layer], + "eval": f"tf.exp(safe_log(source(0)) - {ilm_scale} * source(1))", + "is_output_layer": True + } + + return network \ No newline at end of file diff --git a/users/raissi/setups/librispeech/decoder/LBS_factored_hybrid_search.py b/users/raissi/setups/librispeech/decoder/LBS_factored_hybrid_search.py index b444bea16..be0ce59c0 100644 --- a/users/raissi/setups/librispeech/decoder/LBS_factored_hybrid_search.py +++ b/users/raissi/setups/librispeech/decoder/LBS_factored_hybrid_search.py @@ -94,11 +94,11 @@ def get_kazuki_lstm_config( #model and graph info trafo_config.loader.type = "meta" trafo_config.loader.meta_graph_file = tk.Path("/u/raissi/Desktop/debug/lstm_lm/mini2.3.graph.meta", cached=True) - trafo_config.loader.saved_model_file = DelayedFormat("//u/raissi/Desktop/debug/lstm_lm/models/net-model-mini2.2/network.030") + trafo_config.loader.saved_model_file = DelayedFormat("/u/raissi/Desktop/debug/lstm_lm/models/net-model-mini2.3/network.050") trafo_config.loader.required_libraries = self.library_path trafo_config.type = "tfrnn" - trafo_config.vocab_file = "/u/zhou/asr-exps/librispeech/dependencies/kazuki_lstmlm_27062019/vocabulary" + trafo_config.vocab_file = "/work/asr4/rossenbach/custom_projects/kazuki_replicate_lm_training/vocab.word.freq_sorted.200k.alternative.txt" trafo_config.transform_output_negate = True trafo_config.vocab_unknown_word = "" @@ -356,7 +356,6 @@ def recognize_ls_trafo_lm( rerun_after_opt_lm=False, name_override: Union[str, None] = None, name_prefix: str = "", - gpu: Optional[bool] = None, cpu_rqmt: Optional[int] = None, mem_rqmt: Optional[int] = None, crp_update: Optional[Callable[[rasr.RasrConfig], Any]] = None, @@ -369,7 +368,6 @@ def recognize_ls_trafo_lm( return self.recognize( add_sis_alias_and_output=add_sis_alias_and_output, calculate_stats=calculate_stats, - gpu=gpu, cpu_rqmt=cpu_rqmt, mem_rqmt=mem_rqmt, is_min_duration=is_min_duration, @@ -413,20 +411,21 @@ def recognize_lstm_lm( rerun_after_opt_lm=False, name_override: Union[str, None] = None, name_prefix: str = "", - gpu: Optional[bool] = None, cpu_rqmt: Optional[int] = None, mem_rqmt: Optional[int] = None, crp_update: Optional[Callable[[rasr.RasrConfig], Any]] = None, rtf_gpu: Optional[float] = None, rtf_cpu: Optional[float] = None, create_lattice: bool = True, + lm_lookahead_options: Optional = None, adv_search_extra_config: Optional[rasr.RasrConfig] = None, adv_search_extra_post_config: Optional[rasr.RasrConfig] = None, ) -> DecodingJobs: + if lm_lookahead_options is None: + lm_lookahead_options = {"clow": 2000, "chigh": 3000} return self.recognize( add_sis_alias_and_output=add_sis_alias_and_output, calculate_stats=calculate_stats, - gpu=gpu, cpu_rqmt=cpu_rqmt, mem_rqmt=mem_rqmt, is_min_duration=is_min_duration, @@ -447,6 +446,7 @@ def recognize_lstm_lm( crp_update=crp_update, rtf_cpu=rtf_cpu, rtf_gpu=rtf_gpu, + lm_lookahead_options=lm_lookahead_options, create_lattice=create_lattice, adv_search_extra_config=adv_search_extra_config, adv_search_extra_post_config=adv_search_extra_post_config,