diff --git a/users/zeyer/experiments/exp2024_04_23_baselines/ctc_recog_ext.py b/users/zeyer/experiments/exp2024_04_23_baselines/ctc_recog_ext.py index 122ae7e99..fab07b8f1 100644 --- a/users/zeyer/experiments/exp2024_04_23_baselines/ctc_recog_ext.py +++ b/users/zeyer/experiments/exp2024_04_23_baselines/ctc_recog_ext.py @@ -457,7 +457,6 @@ def model_recog_flashlight( out_spatial_dim, final beam_dim """ - import gc from dataclasses import dataclass import torch from flashlight.lib.text.decoder import LM, LMState @@ -598,7 +597,9 @@ def _cache_maybe_free_memory(self): if used_mem / total_mem < self._max_used_mem_fraction: break # Check again after trying to empty the cache. - gc.collect() + # Note: gc.collect() is problematic here because of how Flashlight handles the states: + # We have millions of Python objects in the mapping_states dict, + # which takes a very long time to go through. torch.cuda.empty_cache() used_mem = torch.cuda.memory_reserved(dev) if used_mem / total_mem < self._max_used_mem_fraction: