From 95e59075feaf215f29c2fd7bb82e23a0762d0083 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 19 Nov 2024 02:31:23 +0000 Subject: [PATCH] remove align_module_device --- src/compressed_tensors/utils/offload.py | 47 ------------------------- 1 file changed, 47 deletions(-) diff --git a/src/compressed_tensors/utils/offload.py b/src/compressed_tensors/utils/offload.py index 29970bd9..0df00ec2 100644 --- a/src/compressed_tensors/utils/offload.py +++ b/src/compressed_tensors/utils/offload.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import contextlib from typing import Any, Callable, Optional import torch @@ -42,7 +41,6 @@ "update_offload_data", "delete_offload_parameter", "has_offloaded_params", - "align_module_device", ] @@ -243,48 +241,3 @@ def has_offloaded_params(module: torch.nn.Module) -> bool: and isinstance(module._hf_hook, AlignDevicesHook) and module._hf_hook.offload ) - - -# introduced in accelerate v1.1.0 -@check_accelerate(fallback=contextlib.nullcontext()) -@contextlib.contextmanager -def align_module_device( - module: torch.nn.Module, execution_device: Optional[torch.device] = None -): - """ - Context manager that moves a module's parameters to the specified execution device. - - Args: - module (`torch.nn.Module`): - Module with parameters to align. - execution_device (`torch.device`, *optional*): - If provided, overrides the module's execution device within the context. - Otherwise, use hook execution device or pass - """ - if has_offloaded_params(module): - if execution_device is not None: - original_device = module._hf_hook.execution_device - module._hf_hook.execution_device = execution_device - - try: - module._hf_hook.pre_forward(module) - yield - finally: - module._hf_hook.post_forward(module, None) - if execution_device is not None: - module._hf_hook.execution_device = original_device - - elif execution_device is not None: - devices = { - name: param.device for name, param in module.named_parameters(recurse=False) - } - try: - for name in devices: - set_module_tensor_to_device(module, name, execution_device) - yield - finally: - for name, device in devices.items(): - set_module_tensor_to_device(module, name, device) - - else: - yield