From 64f4d9850c171c499638d23ee06cb5146e8ac321 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 6 Dec 2024 21:04:48 +0000 Subject: [PATCH] support OffloadedWeightsLoader --- src/compressed_tensors/utils/offload.py | 7 ++++++- tests/test_utils/test_offload.py | 2 +- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/compressed_tensors/utils/offload.py b/src/compressed_tensors/utils/offload.py index 02f2c442..6c20da08 100644 --- a/src/compressed_tensors/utils/offload.py +++ b/src/compressed_tensors/utils/offload.py @@ -195,7 +195,11 @@ def update_offload_parameter( if key in dataset else next(iter(dataset.values())).device ) - dataset[key] = data.to(device=offload_device) + + if isinstance(dataset, OffloadedWeightsLoader): + dataset.state_dict[key] = data.to(device=offload_device) + else: + dataset[key] = data.to(device=offload_device) elif isinstance(weights_map, dict): offload_device = ( @@ -250,6 +254,7 @@ def delete_offload_parameter(module: torch.nn.Module, name: str): @contextlib.contextmanager def disable_hf_hook(module: torch.nn.Module, recurse: bool = False): hooks = {} + def collect_hooks(module): nonlocal hooks if hasattr(module, "_hf_hook"): diff --git a/tests/test_utils/test_offload.py b/tests/test_utils/test_offload.py index c779c0d9..46fe316e 100644 --- a/tests/test_utils/test_offload.py +++ b/tests/test_utils/test_offload.py @@ -180,4 +180,4 @@ def test_disable_hf_hook_model_recurse(): assert hasattr(module0, "_hf_hook") assert hasattr(module1, "_hf_hook") - assert hasattr(module2, "_hf_hook") \ No newline at end of file + assert hasattr(module2, "_hf_hook")