From 35fa1cd17f0883746b8565518350914adfe1fb78 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 6 Dec 2024 07:40:23 +0000 Subject: [PATCH] implement recursive case --- src/compressed_tensors/utils/offload.py | 20 ++++++++++++++------ tests/test_utils/test_offload.py | 20 ++++++++++++++++++++ 2 files changed, 34 insertions(+), 6 deletions(-) diff --git a/src/compressed_tensors/utils/offload.py b/src/compressed_tensors/utils/offload.py index ee59933b..40f11353 100644 --- a/src/compressed_tensors/utils/offload.py +++ b/src/compressed_tensors/utils/offload.py @@ -249,15 +249,23 @@ def delete_offload_parameter(module: torch.nn.Module, name: str): @check_accelerate(fallback=contextlib.nullcontext()) @contextlib.contextmanager def disable_hf_hook(module: torch.nn.Module, recurse: bool = False): - offloaded = has_offloaded_params(module) - if offloaded: - hook = module._hf_hook - remove_hook_from_module(module, recurse=recurse) + hooks = {} + def collect_hooks(module): + nonlocal hooks + if hasattr(module, "_hf_hook"): + hooks[module] = module._hf_hook + remove_hook_from_module(module) + + for submodule in module.children(): + print(submodule) + collect_hooks(submodule) + + collect_hooks(module) yield - if offloaded: - add_hook_to_module(module, hook) + for submodule, hook in hooks.items(): + add_hook_to_module(submodule, hook) """ Upstreamed Functions """ diff --git a/tests/test_utils/test_offload.py b/tests/test_utils/test_offload.py index 80cb55f7..c779c0d9 100644 --- a/tests/test_utils/test_offload.py +++ b/tests/test_utils/test_offload.py @@ -161,3 +161,23 @@ def custom_forward(): assert hasattr(module, "_hf_hook") assert module._old_forward == custom_forward + + +@requires_accelerate() +def test_disable_hf_hook_model_recurse(): + from accelerate.hooks import attach_align_device_hook + + module0 = ExampleModule() + module1 = ExampleModule() + module2 = ExampleModule() + model = torch.nn.Sequential(module0, torch.nn.Sequential(module1, module2)) + attach_align_device_hook(model, offload=True, weights_map=model.state_dict()) + + with disable_hf_hook(model): + assert not hasattr(module0, "_hf_hook") + assert not hasattr(module1, "_hf_hook") + assert not hasattr(module2, "_hf_hook") + + assert hasattr(module0, "_hf_hook") + assert hasattr(module1, "_hf_hook") + assert hasattr(module2, "_hf_hook") \ No newline at end of file