Skip to content

Commit

Permalink
some cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Jan 10, 2025
1 parent 2b143ef commit 3075f5b
Showing 1 changed file with 11 additions and 7 deletions.
18 changes: 11 additions & 7 deletions users/zeyer/experiments/exp2024_04_23_baselines/ctc_recog_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,14 @@ def __init__(self):
self._recent_debug_log_time = -sys.maxsize
self._max_used_mem_fraction = max_used_mem_fraction

def reset(self):
self.mapping_states.clear()
self._count_recalc_whole_seq = 0
self._recent_debug_log_time = -sys.maxsize
self._max_used_mem_fraction = max_used_mem_fraction
self._calc_next_lm_state.cache_clear()
self._calc_next_lm_state.cache_set_maxsize(start_lru_cache_size)

@lru_cache(maxsize=start_lru_cache_size)
def _calc_next_lm_state(self, state: LMState) -> Tuple[Any, torch.Tensor]:
"""
Expand Down Expand Up @@ -612,12 +620,7 @@ def start(self, start_with_nothing: bool):
start_with_nothing (bool): whether or not to start sentence with sil token.
"""
start_with_nothing # noqa # not sure how to handle this?
self.mapping_states.clear()
self._count_recalc_whole_seq = 0
self._recent_debug_log_time = -sys.maxsize
self._max_used_mem_fraction = max_used_mem_fraction
self._calc_next_lm_state.cache_clear()
self._calc_next_lm_state.cache_set_maxsize(start_lru_cache_size)
self.reset()
state = LMState()
self.mapping_states[state] = FlashlightLMState(label_seq=[model.bos_idx], prev_state=state)
return state
Expand Down Expand Up @@ -725,7 +728,7 @@ def finish(self, state: LMState):
f" best seq {_format_align_label_seq(results[0].tokens, model.wb_target_dim)},"
f" worst score: {scores_per_batch[-1]},"
f" LM cache info {fl_lm._calc_next_lm_state.cache_info()},"
f" LM recalc whole seq count {fl_lm._count_recalc_whole_seq}"
f" LM recalc whole seq count {fl_lm._count_recalc_whole_seq},"
f" mem usage {dev_s}: {' '.join(_collect_mem_stats())}"
)
assert all(
Expand All @@ -742,6 +745,7 @@ def finish(self, state: LMState):
assert all(len(hyp) == max_seq_len for hyp in hyps_per_batch)
hyps.append(hyps_per_batch)
scores.append(scores_per_batch)
fl_lm.reset()
hyps_pt = torch.tensor(hyps, dtype=torch.int32)
assert hyps_pt.shape == (batch_size, n_best, max_seq_len)
scores_pt = torch.tensor(scores, dtype=torch.float32)
Expand Down

0 comments on commit 3075f5b

Please sign in to comment.