diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index 5aecae0d..8dd8fc51 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -29,7 +29,11 @@ from compressed_tensors.quantization.quant_config import QuantizationStatus from compressed_tensors.quantization.quant_scheme import QuantizationScheme from compressed_tensors.quantization.utils import is_kv_cache_quant_scheme -from compressed_tensors.utils import get_execution_device, is_module_offloaded +from compressed_tensors.utils import ( + disable_hf_hook, + has_offloaded_params, + register_offload_parameter, +) from torch.nn import Module, Parameter @@ -112,43 +116,10 @@ def initialize_module_for_quantization( module.quantization_scheme = scheme module.quantization_status = QuantizationStatus.INITIALIZED - offloaded = False - # What is this doing/why isn't this in the attn case? - if is_module_offloaded(module): - try: - from accelerate.hooks import add_hook_to_module, remove_hook_from_module - from accelerate.utils import PrefixedDataset - except ModuleNotFoundError: - raise ModuleNotFoundError( - "Offloaded model detected. To use CPU offloading with " - "compressed-tensors the `accelerate` package must be installed, " - "run `pip install compressed-tensors[accelerate]`" - ) - - offloaded = True - hook = module._hf_hook - prefix_dict = module._hf_hook.weights_map - new_prefix = {} - - # recreate the prefix dict (since it is immutable) - # and add quantization parameters - for key, data in module.named_parameters(): - if key not in prefix_dict: - new_prefix[f"{prefix_dict.prefix}{key}"] = data - else: - new_prefix[f"{prefix_dict.prefix}{key}"] = prefix_dict[key] - new_prefix_dict = PrefixedDataset(new_prefix, prefix_dict.prefix) - remove_hook_from_module(module) - - # wrap forward call of module to perform - # quantized actions based on calltime status - wrap_module_forward_quantized(module, scheme) - - if offloaded: - # we need to re-add the hook for offloading now that we've wrapped forward - add_hook_to_module(module, hook) - if prefix_dict is not None: - module._hf_hook.weights_map = new_prefix_dict + with disable_hf_hook(module): + # wrap forward call of module to perform + # quantized actions based on calltime status + wrap_module_forward_quantized(module, scheme) def is_attention_module(module: Module): @@ -169,9 +140,11 @@ def _initialize_scale_zero_point( if quantization_args.dynamic: return - device = next(module.parameters()).device - if is_module_offloaded(module): - device = get_execution_device(module) + # begin on the same device as other parameters or cpu if offloaded. + # in the offloaded case, there's no point moving tensors to the execution device + # if they're going to be immediately offloaded by `register_offload_parameter` + params_device = next(module.parameters()).device + device = "cpu" if has_offloaded_params(module) else params_device # infer expected scale/zero point shape if quantization_args.strategy == QuantizationStrategy.TOKEN: @@ -196,7 +169,7 @@ def _initialize_scale_zero_point( torch.empty(expected_shape, dtype=scale_dtype, device=device), requires_grad=False, ) - module.register_parameter(f"{base_name}_scale", init_scale) + register_offload_parameter(module, f"{base_name}_scale", init_scale) if force_zero_point or not quantization_args.symmetric: zp_dtype = quantization_args.pytorch_dtype() @@ -204,7 +177,7 @@ def _initialize_scale_zero_point( torch.zeros(expected_shape, device=device, dtype=zp_dtype), requires_grad=False, ) - module.register_parameter(f"{base_name}_zero_point", init_zero_point) + register_offload_parameter(module, f"{base_name}_zero_point", init_zero_point) # only grouped activation ordering has g_idx if quantization_args.actorder == ActivationOrdering.GROUP: @@ -214,7 +187,7 @@ def _initialize_scale_zero_point( torch.full(g_idx_shape, -1, device=device, dtype=g_idx_dtype), requires_grad=False, ) - module.register_parameter(f"{base_name}_g_idx", init_g_idx) + register_offload_parameter(module, f"{base_name}_g_idx", init_g_idx) def _initialize_attn_scales(module: Module) -> None: diff --git a/src/compressed_tensors/utils/helpers.py b/src/compressed_tensors/utils/helpers.py index db77bccb..910436eb 100644 --- a/src/compressed_tensors/utils/helpers.py +++ b/src/compressed_tensors/utils/helpers.py @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional +import warnings +from functools import wraps +from typing import Any, Callable, Dict, Optional import torch from transformers import AutoConfig @@ -24,6 +26,8 @@ "tensor_follows_mask_structure", "replace_module", "is_compressed_tensors_config", + "getattr_chain", + "deprecated", "Aliasable", ] @@ -122,6 +126,65 @@ def is_compressed_tensors_config(compression_config: Any) -> bool: return False +def getattr_chain(obj: Any, chain_str: str, *args, **kwargs) -> Any: + """ + Chain multiple getattr calls, separated by `.` + + :param obj: base object whose attributes are being retrieved + :param chain_str: attribute names separated by `.` + :param default: default value, throw error otherwise + """ + if len(args) >= 1: + has_default = True + default = args[0] + elif "default" in kwargs: + has_default = True + default = kwargs["default"] + else: + has_default = False + + attr_names = chain_str.split(".") + + res = obj + for attr_name in attr_names: + if not hasattr(res, attr_name): + if has_default: + return default + else: + raise AttributeError(f"{res} object has no attribute {attr_name}") + res = getattr(res, attr_name) + + return res + + +def deprecated(future_name: Optional[str] = None, message: Optional[str] = None): + """ + Decorator to mark functions as deprecated + + :param new_function: Function called in place of depreciated function + :param message: Depreciation message, replaces default depreciation message + """ + + def decorator(func: Callable[[Any], Any]): + nonlocal message + + if message is None: + message = ( + f"{func.__name__} is deprecated and will be removed in a future release" + ) + if future_name is not None: + message += f". Please use {future_name} instead." + + @wraps(func) + def wrapped(*args, **kwargs): + warnings.warn(message, DeprecationWarning, stacklevel=2) + return func(*args, **kwargs) + + return wrapped + + return decorator + + class Aliasable: """ A mixin for enums to allow aliasing of enum members diff --git a/src/compressed_tensors/utils/offload.py b/src/compressed_tensors/utils/offload.py index 9dd7b22d..b3c77c58 100644 --- a/src/compressed_tensors/utils/offload.py +++ b/src/compressed_tensors/utils/offload.py @@ -11,9 +11,48 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +""" +Utilities associated with offloading functionality provided by `accelerate`. + +| ----------------------------------------------------------------------------------------------------- | # noqa: E501 +| Operation | Without offloading support | With offloading support | # noqa: E501 +| --------- | -------------------------------------- | ------------------------------------------------ | # noqa: E501 +| Add | module.register_parameter(name, param) | register_offload_parameter(module, name, param) | # noqa: E501 +| Check | N/A | has_offloaded_params(module) | # noqa: E501 +| Onload | N/A | with align_module_device(module) | # noqa: E501 +| Update | module.name.data.copy_(new_data) | update_offload_parameter(module, name, new_data) | # noqa: E501 +| Delete | del module.name | delete_offload_parameter(module, name) | # noqa: E501 +| ----------------------------------------------------------------------------------------------------- | # noqa: E501 +""" + +import contextlib +from functools import wraps +from typing import Any, Callable, Dict, Literal, Optional, Union import torch -from torch.nn import Module + + +try: + from accelerate.hooks import ( + AlignDevicesHook, + add_hook_to_module, + remove_hook_from_module, + ) + from accelerate.utils import ( + OffloadedWeightsLoader, + PrefixedDataset, + set_module_tensor_to_device, + ) + + _has_accelerate = True +except ImportError: + _has_accelerate = False + AlignDevicesHook = None + add_hook_to_module = None + remove_hook_from_module = None + OffloadedWeightsLoader = None + PrefixedDataset = None + set_module_tensor_to_device = None __all__ = [ @@ -22,23 +61,44 @@ "get_offloaded_device", "update_prefix_dict", "update_parameter_data", + "register_offload_parameter", + "update_offload_parameter", + "delete_offload_parameter", + "has_offloaded_params", + "disable_hf_hook", + "align_module_device", ] -def is_module_offloaded(module: Module) -> bool: - """ - :param module: layer to check - :return: True if layer is offloaded from GPU, False otherwise - """ - return hasattr(module, "_hf_hook") and module._hf_hook.offload +def check_accelerate(fallback: Any): + def decorator(func: Callable[[Any], Any]): + if not _has_accelerate: + + @wraps(func) + def fallback_fn(*args, **kwargs): + return fallback + + return fallback_fn + + return func + return decorator -def get_execution_device(module: Module) -> torch.device: + +""" Candidates for Depreciation """ + + +@check_accelerate(fallback=False) +def is_module_offloaded(module: torch.nn.Module) -> bool: + return has_offloaded_params(module) + + +def get_execution_device(module: torch.nn.Module) -> torch.device: """ - :param module: layer to check - :return: device layer is loaded onto during forward pass + :param module: module to check + :return: device module is loaded onto during forward pass """ - if is_module_offloaded(module): + if has_offloaded_params(module): return module._hf_hook.execution_device device = next(module.parameters()).device @@ -49,68 +109,296 @@ def get_execution_device(module: Module) -> torch.device: return device -def get_offloaded_device(module: Module) -> torch.device: +def get_offloaded_device(module: torch.nn.Module) -> torch.device: """ - :param module: layer to check - :return: device layer is offloaded to onto after forward pass + :param module: module to check + :return: device module is offloaded to onto after forward pass """ - if is_module_offloaded(module): + if has_offloaded_params(module): first_key = list(module._hf_hook.weights_map.keys())[0] prefix_dataset = module._hf_hook.weights_map.dataset return prefix_dataset[first_key].device return next(module.parameters()).device -def update_prefix_dict(module: Module, key: str, data: torch.Tensor): +@check_accelerate(fallback=None) +def update_prefix_dict(module: torch.nn.Module, key: str, data: torch.Tensor): """ Updates the offloaded state dict for a given module. Parameter named key is replaced by data. This is neccesary because parameter updates for offloaded modules do not persist automatically between loads. This function only affects the offloaded state dict and not the current state of the loaded module. - :param module: layer containing the parameter to update + :param module: module containing the parameter to update :param key: name of parameter to update :param data: tensor to update parameter with in the offloaded state dict """ - if not is_module_offloaded(module): + if not has_offloaded_params(module): raise ValueError("Prefix dict is only applicable to offloaded modules") - prefix_dict = module._hf_hook.weights_map - prefix_dict.dataset[f"{prefix_dict.prefix}{key}"] = data + + weights_map = module._hf_hook.weights_map + offload_to_weights_map(weights_map, key, data) def update_parameter_data( - module: Module, new_param_data: torch.Tensor, param_name: str + module: torch.nn.Module, new_param_data: torch.Tensor, param_name: str ): """ - Updates the paramter value named param_name for a given module. This function - updates both the current loaded module state and the offloaded state dict if - the module is offloaded. This is neccesary because parameter updates for offloaded - modules do not persist automatically between loads. + Update the data of an existing parameter and its offload dict. Supports both + parameters of offloaded modules and non-offloaded modules - :param module: layer containing the parameter to update + :param module: module containing the parameter to update :param new_param_data: tensor to update parameter with - :param param_name: name of layer parameter to update + :param param_name: name of module parameter to update """ - if not hasattr(module, param_name): - return + update_offload_parameter(module, param_name, new_param_data) + + +""" Candidates for Upstreaming """ + + +def register_offload_parameter( + module: torch.nn.Module, + name: str, + parameter: torch.nn.Parameter, + offload_device: Optional[Union[torch.device, Literal["disk"]]] = None, +): + """ + Register a parameter to the given module which may be offloaded + + :param module: maybe offloaded module + :param name: name of newly registered parameter + :param parameter: parameter being registered + :param offload_device: device on which weight will be offloaded to. If None is + provided, then infer device from parameters on module + """ + has_onload = any(p.device != torch.device("meta") for p in module.parameters()) + module.register_parameter(name, parameter) + + if has_offloaded_params(module): + weights_map = module._hf_hook.weights_map + offload_to_weights_map(weights_map, name, parameter.data, offload_device) + if not has_onload: + set_module_tensor_to_device(module, name, "meta") + + +def update_offload_parameter( + module: torch.nn.Module, + name: str, + data: Optional[torch.Tensor], + offload_device: Optional[Union[torch.device, Literal["disk"]]] = None, +): + """ + Update the data of an existing parameter and its offload dict. Supports both + parameters of offloaded modules and non-offloaded modules + + :param module: module containing the parameter to update + :param name: name of module parameter to update + :param data: tensor to update parameter with + :param offload_device: device on which weight will be offloaded to. If None is + provided, then infer device from parameters on module + """ + param = getattr(module, name) + data = data.to(param.dtype) + + # copy data into onloaded parameter if applicable + if param.device != "meta": + param.data.copy_(data) + + # update offload dict + if has_offloaded_params(module): + weights_map = module._hf_hook.weights_map + offload_to_weights_map(weights_map, name, data, offload_device) + + +def delete_offload_parameter(module: torch.nn.Module, name: str): + """ + Delete a parameter from a module which may be offloaded + + :param module: maybe offloaded module + :param name: name of parameter being deleted + """ + delattr(module, name) + + if has_offloaded_params(module): + weights_map = module._hf_hook.weights_map + delete_from_weights_map(weights_map, name) - device = next(module.parameters()).device - offloaded = False - if is_module_offloaded(module): - offload_device = get_offloaded_device(module) - offloaded = True +@check_accelerate(fallback=contextlib.nullcontext()) +@contextlib.contextmanager +def disable_hf_hook(module: torch.nn.Module): + hooks = {} - parameter = getattr(module, param_name, None) - if parameter is None: - raise ValueError("Attempted to update uninitialized parameter") + def collect_hooks(module): + nonlocal hooks + if hasattr(module, "_hf_hook"): + hooks[module] = module._hf_hook + remove_hook_from_module(module) - dtype = parameter.dtype - parameter.data = new_param_data.to(device).to(dtype) + module.apply(collect_hooks) - if offloaded: - prefix_dict = module._hf_hook.weights_map.dataset - prefix = module._hf_hook.weights_map.prefix - prefix_dict[f"{prefix}{param_name}"] = new_param_data.to(offload_device).to( - dtype + yield + + for submodule, hook in hooks.items(): + add_hook_to_module(submodule, hook) + + +@check_accelerate(fallback=None) +def offload_to_weights_map( + weights_map: Union[PrefixedDataset, Dict, OffloadedWeightsLoader], + key: str, + value: torch.Tensor, + offload_device: Optional[Union[torch.device, Literal["disk"]]] = None, +): + """ + Helper function which implements offloaded item assignment for PrefixedDataset, + OffloadedWeightsLoader, and Dict types. + + :param weights_map: weight map to be updated with offload information + :param key: key used to identify weight location + :param value: weight being offloaded + :param offload_device: device on which weight will be offloaded to. If None is + provided, then infer device from parameters in weights_map + """ + if isinstance(weights_map, PrefixedDataset): + if offload_device == "disk": + raise ValueError(f"Cannot offload to disk with type {type(weights_map)}") + + dataset = weights_map.dataset + key = f"{weights_map.prefix}{key}" + offload_to_weights_map(dataset, key, value, offload_device) + + elif isinstance(weights_map, OffloadedWeightsLoader): + if key not in weights_map.all_keys: + weights_map.all_keys.append(key) + + if len(weights_map.index) <= 0 and offload_device != "disk": + offload_to_weights_map(weights_map.state_dict, key, value, offload_device) + + else: + raise NotImplementedError( + "Updating weights_map with disk offloading is not implemented yet" + ) + + elif isinstance(weights_map, dict): + if offload_device == "disk": + raise ValueError(f"Cannot offload to disk with type {type(weights_map)}") + + # infer offload device + if offload_device is None: + if key in weights_map: + offload_device = weights_map[key].device + else: + tens = next(iter(weights_map.values()), None) + if tens is None: + raise ValueError( + "Cannot infer offload device from empty weights_map" + ) + offload_device = tens.device + + weights_map[key] = value.to(device=offload_device) + + else: + raise NotImplementedError( + "Updating offload data not implemented for weights_map of type " + f"{type(weights_map)}" + ) + + +@check_accelerate(fallback=None) +def delete_from_weights_map( + weights_map: Union[PrefixedDataset, Dict, OffloadedWeightsLoader], + key: str, +): + if isinstance(weights_map, PrefixedDataset): + dataset = weights_map.dataset + key = f"{weights_map.prefix}{key}" + delete_from_weights_map(dataset, key) + + elif isinstance(weights_map, OffloadedWeightsLoader): + if len(weights_map.index) <= 0: + delete_from_weights_map(weights_map.state_dict, key) + + else: + raise NotImplementedError( + "Delete from weights_map with disk offloading is not implemented yet" + ) + + elif isinstance(weights_map, dict): + del weights_map[key] + + else: + raise NotImplementedError( + "Updating offload data not implemented for weights_map of type " + f"{type(weights_map)}" ) + + +""" Upstreamed Functions """ + + +# introduced in accelerate v1.1.0 +@check_accelerate(fallback=False) +def has_offloaded_params(module: torch.nn.Module) -> bool: + """ + Checks if a module has offloaded parameters by checking if the given module has a + AlignDevicesHook attached with offloading enabled + + Args: + module (`torch.nn.Module`): The module to check for an offload hook. + + Returns: + bool: `True` if the module has an offload hook and offloading is enabled, + `False` otherwise. + """ + return ( + hasattr(module, "_hf_hook") + and isinstance(module._hf_hook, AlignDevicesHook) + and module._hf_hook.offload + ) + + +# introduced in accelerate v1.1.0 +@check_accelerate(fallback=contextlib.nullcontext()) +@contextlib.contextmanager +def align_module_device( + module: torch.nn.Module, execution_device: Optional[torch.device] = None +): + """ + Context manager that moves a module's parameters to the specified execution device. + + Args: + module (`torch.nn.Module`): + Module with parameters to align. + execution_device (`torch.device`, *optional*): + If provided, overrides the module's execution device within the context. + Otherwise, use hook execution device or pass + """ + if has_offloaded_params(module): + if execution_device is not None: + original_device = module._hf_hook.execution_device + module._hf_hook.execution_device = execution_device + + try: + module._hf_hook.pre_forward(module) + yield + finally: + module._hf_hook.post_forward(module, None) + if execution_device is not None: + module._hf_hook.execution_device = original_device + + elif execution_device is not None: + devices = { + name: param.device for name, param in module.named_parameters(recurse=False) + } + try: + for name in devices: + set_module_tensor_to_device(module, name, execution_device) + yield + finally: + for name, device in devices.items(): + set_module_tensor_to_device(module, name, device) + + else: + yield diff --git a/tests/test_quantization/lifecycle/test_initialize.py b/tests/test_quantization/lifecycle/test_initialize.py index 215f2130..80a1629d 100644 --- a/tests/test_quantization/lifecycle/test_initialize.py +++ b/tests/test_quantization/lifecycle/test_initialize.py @@ -24,6 +24,7 @@ from compressed_tensors.quantization.lifecycle.initialize import ( initialize_module_for_quantization, ) +from tests.testing_utils import requires_accelerate from torch.nn import Linear @@ -35,6 +36,11 @@ } +@pytest.fixture +def layer(): + return Linear(4, 4) + + @pytest.mark.parametrize( "weights,input_activations", [ @@ -53,14 +59,13 @@ ], ) def test_initialize_module_for_quantization( - create_quantization_scheme, weights, input_activations + create_quantization_scheme, weights, input_activations, layer ): quantization_scheme = create_quantization_scheme( targets=["*"], weights=weights, input_activations=input_activations, ) - layer = Linear(4, 4) assert not hasattr(layer, "quantization_scheme") assert not hasattr(layer, "quantization_status") @@ -89,6 +94,39 @@ def test_initialize_module_for_quantization( assert layer.quantization_status == QuantizationStatus.INITIALIZED +@requires_accelerate() +@pytest.mark.parametrize( + "weights,input_activations", + [ + ( + QuantizationArgs(num_bits=NUM_BITS, symmetric=True), + None, + ), + ( + None, + QuantizationArgs(num_bits=NUM_BITS, symmetric=True), + ), + ( + QuantizationArgs(num_bits=NUM_BITS, symmetric=True), + QuantizationArgs(num_bits=NUM_BITS, symmetric=True), + ), + ], +) +def test_initialize_module_for_quantization_offloaded( + create_quantization_scheme, weights, input_activations, layer +): + from accelerate.hooks import attach_align_device_hook + + attach_align_device_hook(layer, offload=True) + + test_initialize_module_for_quantization( + create_quantization_scheme, + weights, + input_activations, + layer, + ) + + @pytest.mark.parametrize( "weights,input_activations", [ diff --git a/tests/test_utils/test_offload.py b/tests/test_utils/test_offload.py new file mode 100644 index 00000000..1002a4f5 --- /dev/null +++ b/tests/test_utils/test_offload.py @@ -0,0 +1,255 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +import torch +from compressed_tensors.utils import ( + align_module_device, + delete_offload_parameter, + disable_hf_hook, + has_offloaded_params, + register_offload_parameter, + update_offload_parameter, +) +from compressed_tensors.utils.offload import offload_to_weights_map +from tests.testing_utils import requires_accelerate + + +class ExampleModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.a = torch.nn.Parameter(torch.tensor(0).float()) + self.b = torch.nn.Parameter(torch.tensor(0).float()) + + def forward(self, x): + return x * self.a + self.b + + +@requires_accelerate() +def test_has_offloaded_params(): + from accelerate.big_modeling import cpu_offload_with_hook + from accelerate.hooks import attach_align_device_hook, remove_hook_from_module + + module = ExampleModule() + assert not has_offloaded_params(module) + + attach_align_device_hook(module, offload=False) + assert not has_offloaded_params(module) + + remove_hook_from_module(module) + module, _ = cpu_offload_with_hook(module) + assert not has_offloaded_params(module) + + remove_hook_from_module(module) + attach_align_device_hook(module, offload=True, weights_map=module.state_dict()) + assert has_offloaded_params(module) + + +@requires_accelerate() +def test_register_offload_parameter(): + from accelerate.hooks import attach_align_device_hook + + module = ExampleModule() + parameter = torch.nn.Parameter(torch.tensor(1.0)) + + # register a param prior to offloading + register_offload_parameter(module, "c", parameter) + assert hasattr(module, "c") and module.c == parameter + + # offloading, check that added param was offloaded + attach_align_device_hook(module, offload=True, weights_map=module.state_dict()) + assert "c" in module._hf_hook.weights_map + + # register a param after offloading, check that added param was offloaded + register_offload_parameter(module, "d", parameter) + assert hasattr(module, "d") and module.d.device == torch.device("meta") + assert module._hf_hook.weights_map["d"].device == torch.device("cpu") + + # added parameters can be onloaded and offloaded + with align_module_device(module, execution_device="cpu"): + assert module.c.device == torch.device("cpu") + assert module.d.device == torch.device("cpu") + assert module.c.device == torch.device("meta") + assert module.d.device == torch.device("meta") + + # parameters can be added during onload + with align_module_device(module, execution_device="cpu"): + register_offload_parameter(module, "e", parameter) + assert module.e.device == torch.device("cpu") + + # parameters can be added before onload and with explicit offload + register_offload_parameter(module, "f", parameter, offload_device="cpu") + assert module._hf_hook.weights_map["f"].device == torch.device("cpu") + with align_module_device(module, execution_device="cpu"): + assert module.f.device == torch.device("cpu") + assert module._hf_hook.weights_map["f"].device == torch.device("cpu") + + +@requires_accelerate() +def test_update_offload_parameter(): + from accelerate.hooks import attach_align_device_hook + + module = ExampleModule() + param_a = torch.nn.Parameter(torch.tensor(1.0)) + param_b = torch.nn.Parameter(torch.tensor(2.0)) + + # can update modules which are not offloaded + update_offload_parameter(module, "a", param_a) + assert module.a == param_a + + # can update modules which are offloaded + attach_align_device_hook(module, offload=True, weights_map=module.state_dict()) + update_offload_parameter(module, "b", param_b) + assert module.b.device == torch.device("meta") + assert module._hf_hook.weights_map["b"] == param_b.data + + # data persists across onloading + with align_module_device(module, execution_device="cpu"): + assert module.a == param_a + assert module.b == param_b + assert module._hf_hook.weights_map["a"] == param_a.data + assert module._hf_hook.weights_map["b"] == param_b.data + + # data persists across offloading + assert module.a.device == torch.device("meta") + assert module.b.device == torch.device("meta") + assert module._hf_hook.weights_map["a"] == param_a.data + assert module._hf_hook.weights_map["b"] == param_b.data + + +@requires_accelerate() +def test_delete_offload_parameter(): + from accelerate.hooks import attach_align_device_hook + + module = ExampleModule() + param_c = torch.nn.Parameter(torch.tensor(1.0)) + param_d = torch.nn.Parameter(torch.tensor(2.0)) + register_offload_parameter(module, "c", param_c) + register_offload_parameter(module, "d", param_d) + + # parameters are deleted + delete_offload_parameter(module, "a") + delete_offload_parameter(module, "c") + assert not hasattr(module, "a") + assert hasattr(module, "b") + assert not hasattr(module, "c") + assert hasattr(module, "d") + + # parameters and their offload are deleted + attach_align_device_hook(module, offload=True, weights_map=module.state_dict()) + delete_offload_parameter(module, "b") + delete_offload_parameter(module, "d") + assert not hasattr(module, "a") + assert not hasattr(module, "b") + assert not hasattr(module, "c") + assert not hasattr(module, "d") + assert "a" not in module._hf_hook.weights_map + assert "b" not in module._hf_hook.weights_map + assert "c" not in module._hf_hook.weights_map + assert "d" not in module._hf_hook.weights_map + + +@requires_accelerate() +def test_disable_hf_hook(): + from accelerate.hooks import attach_align_device_hook + + module = ExampleModule() + + def custom_forward(): + pass + + attach_align_device_hook(module, offload=True, weights_map=module.state_dict()) + with disable_hf_hook(module): + assert not hasattr(module, "_hf_hook") + module.forward = custom_forward + + assert hasattr(module, "_hf_hook") + assert module._old_forward == custom_forward + + +@requires_accelerate() +def test_disable_hf_hook_model_recurse(): + from accelerate.hooks import attach_align_device_hook + + module0 = ExampleModule() + module1 = ExampleModule() + module2 = ExampleModule() + model = torch.nn.Sequential(module0, torch.nn.Sequential(module1, module2)) + attach_align_device_hook(model, offload=True, weights_map=model.state_dict()) + + with disable_hf_hook(model): + assert not hasattr(module0, "_hf_hook") + assert not hasattr(module1, "_hf_hook") + assert not hasattr(module2, "_hf_hook") + + assert hasattr(module0, "_hf_hook") + assert hasattr(module1, "_hf_hook") + assert hasattr(module2, "_hf_hook") + + +@requires_accelerate() +def test_offload_to_weights_map(): + from accelerate.utils import OffloadedWeightsLoader, PrefixedDataset + + name = "name" + old_value = torch.tensor(0.0) + new_value = torch.tensor(1.0) + prefix = "prefix" + + # Dict empty + weights_map = {} + with pytest.raises(ValueError): + offload_to_weights_map(weights_map, name, new_value) + offload_to_weights_map(weights_map, name, new_value, offload_device="cpu") + assert weights_map[name] == new_value + + # Dict populated + weights_map = {name: old_value} + offload_to_weights_map(weights_map, name, new_value) + assert weights_map[name] == new_value + + # OffloadedWeightsLoader[Dict] empty + weights_map = OffloadedWeightsLoader({}) + with pytest.raises(ValueError): + offload_to_weights_map(weights_map, name, new_value) + offload_to_weights_map(weights_map, name, new_value, offload_device="cpu") + assert weights_map[name] == new_value + + # OffloadedWeightsLoader[Dict] populated + weights_map = OffloadedWeightsLoader({name: old_value}) + offload_to_weights_map(weights_map, name, new_value) + assert weights_map[name] == new_value + + # PrefixedDataset[Dict] empty + weights_map = PrefixedDataset({}, prefix) + with pytest.raises(ValueError): + offload_to_weights_map(weights_map, name, new_value) + offload_to_weights_map(weights_map, name, new_value, offload_device="cpu") + assert weights_map[name] == new_value + + # PrefixedDataset[Dict] populated + weights_map = PrefixedDataset({name: old_value}, prefix) + offload_to_weights_map(weights_map, name, new_value) + assert weights_map[name] == new_value + + # PrefixedDataset[OffloadedWeightsLoader[Dict]] empty + weights_map = PrefixedDataset(OffloadedWeightsLoader({}), prefix) + with pytest.raises(ValueError): + offload_to_weights_map(weights_map, name, new_value) + offload_to_weights_map(weights_map, name, new_value, offload_device="cpu") + assert weights_map[name] == new_value + + # PrefixedDataset[OffloadedWeightsLoader[Dict]] populated + weights_map = PrefixedDataset(OffloadedWeightsLoader({name: old_value}), prefix) + offload_to_weights_map(weights_map, name, new_value) + assert weights_map[name] == new_value