Skip to content

Commit

Permalink
free mem
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Jan 9, 2025
1 parent 4927c0a commit fc66e06
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 4 deletions.
10 changes: 10 additions & 0 deletions users/zeyer/experiments/exp2024_04_23_baselines/ctc_recog_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,8 @@ def __init__(self):
self._count_recalc_whole_seq = 0
self._recent_debug_log_time = -sys.maxsize

# Use LRU cache. Note that additionally to the max_size here,
# we free more when we run out of CUDA memory.
@lru_cache(maxsize=1024)
def _calc_next_lm_state(self, state: LMState) -> Tuple[Any, torch.Tensor]:
"""
Expand All @@ -529,6 +531,14 @@ def _calc_next_lm_state(self, state: LMState) -> Tuple[Any, torch.Tensor]:
prev_lm_state = lm_initial_state
else:
prev_lm_state, _ = self._calc_next_lm_state.cache_peek(state_.prev_state, fallback=(None, None))

if dev.type == "cuda":
# Maybe check if we should free some more memory.
while self._calc_next_lm_state.cache_len() > 0:
free, total = torch.cuda.mem_get_info(dev)
if free / total < 0.2:
self._calc_next_lm_state.cache_pop_oldest()

if prev_lm_state is not None or lm_initial_state is None:
# We have the prev state, or there is no state at all.
# So we can do a single step.
Expand Down
28 changes: 24 additions & 4 deletions users/zeyer/utils/lru_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@ def lru_cache(maxsize: int = 128, typed: bool = False):
Arguments to the cached function must be hashable.
Use f.cache_len() to see the current size of the cache.
Use f.cache_peek(*args, update_statistics=False, fallback=None, **kwargs)
to peek the cache, without ever calling the user function.
View the cache statistics named tuple (hits, misses, maxsize, currsize)
with f.cache_info().
Clear the cache and statistics with f.cache_clear().
Remove the oldest entry from the cache with f.cache_pop_oldest().
Access the underlying function with f.__wrapped__.
See: https://en.wikipedia.org/wiki/Cache_replacement_policies#Least_recently_used_(LRU)
Expand Down Expand Up @@ -109,7 +111,6 @@ def wrapper(*args, **kwds):
# still adjusting the links.
root = oldroot[NEXT]
oldkey = root[KEY]
oldresult = root[RESULT]
root[KEY] = root[RESULT] = None
# Now update the cache dictionary.
del cache[oldkey]
Expand Down Expand Up @@ -150,19 +151,38 @@ def cache_peek(*args, update_statistics: bool = True, fallback=None, **kwargs):
with lock:
link = cache_get(key)
if link is not None:
# Move the link to the front of the circular queue
_, _, _, result = link
if update_statistics:
hits += 1
return result
return link[RESULT]
if update_statistics:
misses += 1
return fallback

not_specified = object()

def cache_pop_oldest(*, fallback=not_specified):
nonlocal root
with lock:
if not cache:
if fallback is not_specified:
raise KeyError("cache is empty")
return fallback
assert cache
oldroot = root
root = oldroot[NEXT]
oldkey = root[KEY]
oldvalue = root[RESULT]
del cache[oldkey]
root[KEY] = root[RESULT] = None
oldroot[PREV][NEXT] = root
return oldvalue

wrapper.cache_info = cache_info
wrapper.cache_clear = cache_clear
wrapper.cache_parameters = cache_parameters
wrapper.cache_peek = cache_peek
wrapper.cache_len = cache_len
wrapper.cache_pop_oldest = cache_pop_oldest

update_wrapper(wrapper, user_function)

Expand Down

0 comments on commit fc66e06

Please sign in to comment.