-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathcache_manager.py
46 lines (35 loc) · 889 Bytes
/
cache_manager.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import gc
from collections import UserDict
from contextlib import ContextDecorator
from functools import wraps
import torch
model_cache = UserDict(
dict(
preloaded_model_id=None,
preloaded_model=None,
preloaded_tokenizer=None,
preloaded_streamer=None,
preloaded_device=None,
))
def torch_gc():
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
def clear_cache():
gc.collect()
torch_gc()
def clear_cache_decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
clear_cache()
res = func(*args, **kwargs)
clear_cache()
return res
return wrapper
class ClearCacheContext(ContextDecorator):
def __enter__(self):
clear_cache()
return self
def __exit__(self, *exc):
clear_cache()
return False