From 8cd69ef42b97f58bc3a0eddbabd66ee0f08d8809 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 19 Nov 2024 00:05:49 +0000 Subject: [PATCH] update to align_module_device Signed-off-by: Kyle Sayers --- src/compressed_tensors/utils/offload.py | 62 +++++++++++-------------- 1 file changed, 26 insertions(+), 36 deletions(-) diff --git a/src/compressed_tensors/utils/offload.py b/src/compressed_tensors/utils/offload.py index 6c2a6e1c..5bc00c05 100644 --- a/src/compressed_tensors/utils/offload.py +++ b/src/compressed_tensors/utils/offload.py @@ -30,9 +30,6 @@ _has_accelerate = True except ImportError: _has_accelerate = False - AlignDevicesHook = None - OffloadedWeightsLoader = None - PrefixedDataset = None __all__ = [ @@ -45,7 +42,7 @@ "update_offload_data", "delete_offload_parameter", "has_offloaded_params", - "align_module", + "align_module_device", ] @@ -251,52 +248,45 @@ def has_offloaded_params(module: torch.nn.Module) -> bool: # introduced in accelerate v1.1.0 +@check_accelerate(fallback=contextlib.nullcontext()) @contextlib.contextmanager -def align_module( +def align_module_device( module: torch.nn.Module, execution_device: Optional[torch.device] = None ): """ - Moves a module's parameters to the specified execution device. + Context manager that moves a module's parameters to the specified execution device. Args: - module (torch.nn.Module): Module with parameters to align. - execution_device (Optional[torch.device]): If provided, overrides the - module's execution device within the context. - - Yields: - None: Yields control while the module's parameters are aligned to the execution - device. + module (`torch.nn.Module`): + Module with parameters to align. + execution_device (`torch.device`, *optional*): + If provided, overrides the module's execution device within the context. + Otherwise, use hook execution device or pass """ if has_offloaded_params(module): if execution_device is not None: original_device = module._hf_hook.execution_device module._hf_hook.execution_device = execution_device - module._hf_hook.pre_forward(module) - yield - module._hf_hook.post_forward(module, None) - - if execution_device is not None: - module._hf_hook.execution_device = original_device + try: + module._hf_hook.pre_forward(module) + yield + finally: + module._hf_hook.post_forward(module, None) + if execution_device is not None: + module._hf_hook.execution_device = original_device elif execution_device is not None: - devices = {} - for name, param in module.named_parameters(): - devices[name] = param.device - set_module_tensor_to_device( - module, - name, - execution_device, - ) - - yield - - for name, param in module.named_parameters(): - set_module_tensor_to_device( - module, - name, - devices[name], - ) + devices = { + name: param.device for name, param in module.named_parameters(recurse=False) + } + try: + for name in devices: + set_module_tensor_to_device(module, name, execution_device) + yield + finally: + for name, device in devices.items(): + set_module_tensor_to_device(module, name, device) else: yield