From df3e1860956455ec478337a90f7a8ce8a0450954 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 19 Dec 2024 20:51:14 +0000 Subject: [PATCH] use apply rather than recursion Signed-off-by: Kyle Sayers --- src/compressed_tensors/utils/offload.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/compressed_tensors/utils/offload.py b/src/compressed_tensors/utils/offload.py index 3d92b9ba..aec8eeb8 100644 --- a/src/compressed_tensors/utils/offload.py +++ b/src/compressed_tensors/utils/offload.py @@ -235,7 +235,7 @@ def delete_offload_parameter(module: torch.nn.Module, name: str): @check_accelerate(fallback=contextlib.nullcontext()) @contextlib.contextmanager -def disable_hf_hook(module: torch.nn.Module, recurse: bool = False): +def disable_hf_hook(module: torch.nn.Module): hooks = {} def collect_hooks(module): @@ -244,10 +244,7 @@ def collect_hooks(module): hooks[module] = module._hf_hook remove_hook_from_module(module) - for submodule in module.children(): - collect_hooks(submodule) - - collect_hooks(module) + module.apply(collect_hooks) yield