From fc66e060264d50162a902fa794b2ba774c58f5ae Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Thu, 9 Jan 2025 16:06:05 +0100 Subject: [PATCH] free mem --- .../exp2024_04_23_baselines/ctc_recog_ext.py | 10 +++++++ users/zeyer/utils/lru_cache.py | 28 ++++++++++++++++--- 2 files changed, 34 insertions(+), 4 deletions(-) diff --git a/users/zeyer/experiments/exp2024_04_23_baselines/ctc_recog_ext.py b/users/zeyer/experiments/exp2024_04_23_baselines/ctc_recog_ext.py index 1ac22e442..c9171422a 100644 --- a/users/zeyer/experiments/exp2024_04_23_baselines/ctc_recog_ext.py +++ b/users/zeyer/experiments/exp2024_04_23_baselines/ctc_recog_ext.py @@ -518,6 +518,8 @@ def __init__(self): self._count_recalc_whole_seq = 0 self._recent_debug_log_time = -sys.maxsize + # Use LRU cache. Note that additionally to the max_size here, + # we free more when we run out of CUDA memory. @lru_cache(maxsize=1024) def _calc_next_lm_state(self, state: LMState) -> Tuple[Any, torch.Tensor]: """ @@ -529,6 +531,14 @@ def _calc_next_lm_state(self, state: LMState) -> Tuple[Any, torch.Tensor]: prev_lm_state = lm_initial_state else: prev_lm_state, _ = self._calc_next_lm_state.cache_peek(state_.prev_state, fallback=(None, None)) + + if dev.type == "cuda": + # Maybe check if we should free some more memory. + while self._calc_next_lm_state.cache_len() > 0: + free, total = torch.cuda.mem_get_info(dev) + if free / total < 0.2: + self._calc_next_lm_state.cache_pop_oldest() + if prev_lm_state is not None or lm_initial_state is None: # We have the prev state, or there is no state at all. # So we can do a single step. diff --git a/users/zeyer/utils/lru_cache.py b/users/zeyer/utils/lru_cache.py index e4ec7c952..5b3c16bd9 100644 --- a/users/zeyer/utils/lru_cache.py +++ b/users/zeyer/utils/lru_cache.py @@ -20,11 +20,13 @@ def lru_cache(maxsize: int = 128, typed: bool = False): Arguments to the cached function must be hashable. + Use f.cache_len() to see the current size of the cache. Use f.cache_peek(*args, update_statistics=False, fallback=None, **kwargs) to peek the cache, without ever calling the user function. View the cache statistics named tuple (hits, misses, maxsize, currsize) with f.cache_info(). Clear the cache and statistics with f.cache_clear(). + Remove the oldest entry from the cache with f.cache_pop_oldest(). Access the underlying function with f.__wrapped__. See: https://en.wikipedia.org/wiki/Cache_replacement_policies#Least_recently_used_(LRU) @@ -109,7 +111,6 @@ def wrapper(*args, **kwds): # still adjusting the links. root = oldroot[NEXT] oldkey = root[KEY] - oldresult = root[RESULT] root[KEY] = root[RESULT] = None # Now update the cache dictionary. del cache[oldkey] @@ -150,19 +151,38 @@ def cache_peek(*args, update_statistics: bool = True, fallback=None, **kwargs): with lock: link = cache_get(key) if link is not None: - # Move the link to the front of the circular queue - _, _, _, result = link if update_statistics: hits += 1 - return result + return link[RESULT] if update_statistics: misses += 1 return fallback + not_specified = object() + + def cache_pop_oldest(*, fallback=not_specified): + nonlocal root + with lock: + if not cache: + if fallback is not_specified: + raise KeyError("cache is empty") + return fallback + assert cache + oldroot = root + root = oldroot[NEXT] + oldkey = root[KEY] + oldvalue = root[RESULT] + del cache[oldkey] + root[KEY] = root[RESULT] = None + oldroot[PREV][NEXT] = root + return oldvalue + wrapper.cache_info = cache_info wrapper.cache_clear = cache_clear wrapper.cache_parameters = cache_parameters wrapper.cache_peek = cache_peek + wrapper.cache_len = cache_len + wrapper.cache_pop_oldest = cache_pop_oldest update_wrapper(wrapper, user_function)