Skip to content

Commit

Permalink
support OffloadedWeightsLoader
Browse files Browse the repository at this point in the history
  • Loading branch information
kylesayrs committed Dec 6, 2024
1 parent 38765bd commit 64f4d98
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 2 deletions.
7 changes: 6 additions & 1 deletion src/compressed_tensors/utils/offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,11 @@ def update_offload_parameter(
if key in dataset
else next(iter(dataset.values())).device
)
dataset[key] = data.to(device=offload_device)

if isinstance(dataset, OffloadedWeightsLoader):
dataset.state_dict[key] = data.to(device=offload_device)
else:
dataset[key] = data.to(device=offload_device)

elif isinstance(weights_map, dict):
offload_device = (
Expand Down Expand Up @@ -250,6 +254,7 @@ def delete_offload_parameter(module: torch.nn.Module, name: str):
@contextlib.contextmanager
def disable_hf_hook(module: torch.nn.Module, recurse: bool = False):
hooks = {}

def collect_hooks(module):
nonlocal hooks
if hasattr(module, "_hf_hook"):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_utils/test_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,4 +180,4 @@ def test_disable_hf_hook_model_recurse():

assert hasattr(module0, "_hf_hook")
assert hasattr(module1, "_hf_hook")
assert hasattr(module2, "_hf_hook")
assert hasattr(module2, "_hf_hook")

0 comments on commit 64f4d98

Please sign in to comment.