Skip to content

Commit

Permalink
update to align_module_device
Browse files Browse the repository at this point in the history
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
  • Loading branch information
kylesayrs committed Nov 19, 2024
1 parent 98a2889 commit 8cd69ef
Showing 1 changed file with 26 additions and 36 deletions.
62 changes: 26 additions & 36 deletions src/compressed_tensors/utils/offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,6 @@
_has_accelerate = True
except ImportError:
_has_accelerate = False
AlignDevicesHook = None
OffloadedWeightsLoader = None
PrefixedDataset = None


__all__ = [
Expand All @@ -45,7 +42,7 @@
"update_offload_data",
"delete_offload_parameter",
"has_offloaded_params",
"align_module",
"align_module_device",
]


Expand Down Expand Up @@ -251,52 +248,45 @@ def has_offloaded_params(module: torch.nn.Module) -> bool:


# introduced in accelerate v1.1.0
@check_accelerate(fallback=contextlib.nullcontext())
@contextlib.contextmanager
def align_module(
def align_module_device(
module: torch.nn.Module, execution_device: Optional[torch.device] = None
):
"""
Moves a module's parameters to the specified execution device.
Context manager that moves a module's parameters to the specified execution device.
Args:
module (torch.nn.Module): Module with parameters to align.
execution_device (Optional[torch.device]): If provided, overrides the
module's execution device within the context.
Yields:
None: Yields control while the module's parameters are aligned to the execution
device.
module (`torch.nn.Module`):
Module with parameters to align.
execution_device (`torch.device`, *optional*):
If provided, overrides the module's execution device within the context.
Otherwise, use hook execution device or pass
"""
if has_offloaded_params(module):
if execution_device is not None:
original_device = module._hf_hook.execution_device
module._hf_hook.execution_device = execution_device

module._hf_hook.pre_forward(module)
yield
module._hf_hook.post_forward(module, None)

if execution_device is not None:
module._hf_hook.execution_device = original_device
try:
module._hf_hook.pre_forward(module)
yield
finally:
module._hf_hook.post_forward(module, None)
if execution_device is not None:
module._hf_hook.execution_device = original_device

elif execution_device is not None:
devices = {}
for name, param in module.named_parameters():
devices[name] = param.device
set_module_tensor_to_device(
module,
name,
execution_device,
)

yield

for name, param in module.named_parameters():
set_module_tensor_to_device(
module,
name,
devices[name],
)
devices = {
name: param.device for name, param in module.named_parameters(recurse=False)
}
try:
for name in devices:
set_module_tensor_to_device(module, name, execution_device)
yield
finally:
for name, device in devices.items():
set_module_tensor_to_device(module, name, device)

else:
yield

0 comments on commit 8cd69ef

Please sign in to comment.