Skip to content

Commit

Permalink
remove align_module_device
Browse files Browse the repository at this point in the history
  • Loading branch information
kylesayrs committed Nov 19, 2024
1 parent 0b0d8b6 commit 95e5907
Showing 1 changed file with 0 additions and 47 deletions.
47 changes: 0 additions & 47 deletions src/compressed_tensors/utils/offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import contextlib
from typing import Any, Callable, Optional

import torch
Expand Down Expand Up @@ -42,7 +41,6 @@
"update_offload_data",
"delete_offload_parameter",
"has_offloaded_params",
"align_module_device",
]


Expand Down Expand Up @@ -243,48 +241,3 @@ def has_offloaded_params(module: torch.nn.Module) -> bool:
and isinstance(module._hf_hook, AlignDevicesHook)
and module._hf_hook.offload
)


# introduced in accelerate v1.1.0
@check_accelerate(fallback=contextlib.nullcontext())
@contextlib.contextmanager
def align_module_device(
module: torch.nn.Module, execution_device: Optional[torch.device] = None
):
"""
Context manager that moves a module's parameters to the specified execution device.
Args:
module (`torch.nn.Module`):
Module with parameters to align.
execution_device (`torch.device`, *optional*):
If provided, overrides the module's execution device within the context.
Otherwise, use hook execution device or pass
"""
if has_offloaded_params(module):
if execution_device is not None:
original_device = module._hf_hook.execution_device
module._hf_hook.execution_device = execution_device

try:
module._hf_hook.pre_forward(module)
yield
finally:
module._hf_hook.post_forward(module, None)
if execution_device is not None:
module._hf_hook.execution_device = original_device

elif execution_device is not None:
devices = {
name: param.device for name, param in module.named_parameters(recurse=False)
}
try:
for name in devices:
set_module_tensor_to_device(module, name, execution_device)
yield
finally:
for name, device in devices.items():
set_module_tensor_to_device(module, name, device)

else:
yield

0 comments on commit 95e5907

Please sign in to comment.