Skip to content

Commit

Permalink
revert get_offloaded_device
Browse files Browse the repository at this point in the history
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
  • Loading branch information
kylesayrs committed Nov 18, 2024
1 parent cb70047 commit 98a2889
Showing 1 changed file with 5 additions and 9 deletions.
14 changes: 5 additions & 9 deletions src/compressed_tensors/utils/offload.py
Original file line number Diff line number Diff line change
@@ -88,15 +88,11 @@ def get_offloaded_device(module: torch.nn.Module) -> torch.device:
:param module: module to check
:return: device module is offloaded to onto after forward pass
"""
if not has_offloaded_params(module):
raise ValueError("Cannot infer offload device from non-offloaded module")

first_key = next(module._hf_hook.weights_map.keys(), None)
if first_key is None:
raise ValueError("Cannot infer offload device from empty weights map")

prefix_dataset = module._hf_hook.weights_map.dataset
return prefix_dataset[first_key].device
if has_offloaded_params(module):
first_key = list(module._hf_hook.weights_map.keys())[0]
prefix_dataset = module._hf_hook.weights_map.dataset
return prefix_dataset[first_key].device
return next(module.parameters()).device


@check_accelerate(fallback=None)

0 comments on commit 98a2889

Please sign in to comment.