Skip to content

Commit

Permalink
implement recursive case
Browse files Browse the repository at this point in the history
  • Loading branch information
kylesayrs committed Dec 6, 2024
1 parent 9af736f commit 35fa1cd
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 6 deletions.
20 changes: 14 additions & 6 deletions src/compressed_tensors/utils/offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,15 +249,23 @@ def delete_offload_parameter(module: torch.nn.Module, name: str):
@check_accelerate(fallback=contextlib.nullcontext())
@contextlib.contextmanager
def disable_hf_hook(module: torch.nn.Module, recurse: bool = False):
offloaded = has_offloaded_params(module)
if offloaded:
hook = module._hf_hook
remove_hook_from_module(module, recurse=recurse)
hooks = {}
def collect_hooks(module):
nonlocal hooks
if hasattr(module, "_hf_hook"):
hooks[module] = module._hf_hook
remove_hook_from_module(module)

for submodule in module.children():
print(submodule)
collect_hooks(submodule)

collect_hooks(module)

yield

if offloaded:
add_hook_to_module(module, hook)
for submodule, hook in hooks.items():
add_hook_to_module(submodule, hook)


""" Upstreamed Functions """
Expand Down
20 changes: 20 additions & 0 deletions tests/test_utils/test_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,3 +161,23 @@ def custom_forward():

assert hasattr(module, "_hf_hook")
assert module._old_forward == custom_forward


@requires_accelerate()
def test_disable_hf_hook_model_recurse():
from accelerate.hooks import attach_align_device_hook

module0 = ExampleModule()
module1 = ExampleModule()
module2 = ExampleModule()
model = torch.nn.Sequential(module0, torch.nn.Sequential(module1, module2))
attach_align_device_hook(model, offload=True, weights_map=model.state_dict())

with disable_hf_hook(model):
assert not hasattr(module0, "_hf_hook")
assert not hasattr(module1, "_hf_hook")
assert not hasattr(module2, "_hf_hook")

assert hasattr(module0, "_hf_hook")
assert hasattr(module1, "_hf_hook")
assert hasattr(module2, "_hf_hook")

0 comments on commit 35fa1cd

Please sign in to comment.