From bddc83c50dbbd746e673441cbac89ef491b79b15 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 21 Oct 2024 18:42:53 +0000 Subject: [PATCH 01/28] wip --- .../quantization/observers/helpers.py | 2 +- src/compressed_tensors/utils/helpers.py | 33 ++++ src/compressed_tensors/utils/offload.py | 144 +++++++++++++----- 3 files changed, 137 insertions(+), 42 deletions(-) diff --git a/src/compressed_tensors/quantization/observers/helpers.py b/src/compressed_tensors/quantization/observers/helpers.py index 875a05b3..ec474303 100644 --- a/src/compressed_tensors/quantization/observers/helpers.py +++ b/src/compressed_tensors/quantization/observers/helpers.py @@ -13,7 +13,7 @@ # limitations under the License. from collections import Counter -from typing import Optional, Tuple +from typing import Tuple import torch from compressed_tensors.quantization.quant_args import ( diff --git a/src/compressed_tensors/utils/helpers.py b/src/compressed_tensors/utils/helpers.py index e1587ada..82e11ccf 100644 --- a/src/compressed_tensors/utils/helpers.py +++ b/src/compressed_tensors/utils/helpers.py @@ -24,6 +24,7 @@ "tensor_follows_mask_structure", "replace_module", "is_compressed_tensors_config", + "getattr_chain", ] FSDP_WRAPPER_NAME = "_fsdp_wrapped_module" @@ -119,3 +120,35 @@ def is_compressed_tensors_config(compression_config: Any) -> bool: return isinstance(compression_config, CompressedTensorsConfig) except ImportError: return False + + +def getattr_chain(obj: Any, chain_str: str, *args, **kwargs) -> Any: + """ + Chain multiple getattr calls, separated by `.` + + :param obj: base object whose attributes are being retrieved + :param chain_str: attribute names separated by `.` + :param default: default value, throw error otherwise + + """ + if len(args) >= 1: + has_default = True + default = args[0] + elif "default" in kwargs: + has_default = True + default = kwargs["default"] + else: + has_default = False + + attr_names = chain_str.split(".") + + res = obj + for attr_name in attr_names: + if not hasattr(res, attr_name): + if has_default: + return default + else: + raise AttributeError(f"{res} object has no attribute {attr_name}") + res = getattr(res, attr_name) + + return res diff --git a/src/compressed_tensors/utils/offload.py b/src/compressed_tensors/utils/offload.py index 9dd7b22d..51be2c29 100644 --- a/src/compressed_tensors/utils/offload.py +++ b/src/compressed_tensors/utils/offload.py @@ -12,8 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +import contextlib +from typing import Optional + import torch -from torch.nn import Module +from compressed_tensors.utils.helpers import getattr_chain + + +try: + from accelerate.hooks import AlignDevicesHook +except ImportError: + AlignDevicesHook = None __all__ = [ @@ -25,18 +34,32 @@ ] -def is_module_offloaded(module: Module) -> bool: +# upstream candidate +def can_offload(module: torch.nn.Module) -> bool: """ - :param module: layer to check - :return: True if layer is offloaded from GPU, False otherwise + :param module: module to check + :return: True if module has offloading capabilities """ - return hasattr(module, "_hf_hook") and module._hf_hook.offload + return ( + hasattr(module, "_hf_hook") + and isinstance(module._hf_hook, AlignDevicesHook) + and module._hf_hook.offload # offload after forward pass + ) + + +# backwards compatibility, optional package checking +def is_module_offloaded(module: torch.nn.Module) -> bool: + if AlignDevicesHook is None: + return False + + return can_offload(module) -def get_execution_device(module: Module) -> torch.device: +# depreciation candidate +def get_execution_device(module: torch.nn.Module) -> torch.device: """ - :param module: layer to check - :return: device layer is loaded onto during forward pass + :param module: module to check + :return: device module is loaded onto during forward pass """ if is_module_offloaded(module): return module._hf_hook.execution_device @@ -49,10 +72,11 @@ def get_execution_device(module: Module) -> torch.device: return device -def get_offloaded_device(module: Module) -> torch.device: +# depreciation candidate +def get_offloaded_device(module: torch.nn.Module) -> torch.device: """ - :param module: layer to check - :return: device layer is offloaded to onto after forward pass + :param module: module to check + :return: device module is offloaded to onto after forward pass """ if is_module_offloaded(module): first_key = list(module._hf_hook.weights_map.keys())[0] @@ -61,14 +85,15 @@ def get_offloaded_device(module: Module) -> torch.device: return next(module.parameters()).device -def update_prefix_dict(module: Module, key: str, data: torch.Tensor): +# depreciation candidate +def update_prefix_dict(module: torch.nn.Module, key: str, data: torch.Tensor): """ Updates the offloaded state dict for a given module. Parameter named key is replaced by data. This is neccesary because parameter updates for offloaded modules do not persist automatically between loads. This function only affects the offloaded state dict and not the current state of the loaded module. - :param module: layer containing the parameter to update + :param module: module containing the parameter to update :param key: name of parameter to update :param data: tensor to update parameter with in the offloaded state dict """ @@ -78,39 +103,76 @@ def update_prefix_dict(module: Module, key: str, data: torch.Tensor): prefix_dict.dataset[f"{prefix_dict.prefix}{key}"] = data -def update_parameter_data( - module: Module, new_param_data: torch.Tensor, param_name: str +# upstream candidate +def update_offload_parameter( + module: torch.nn.Module, + name: str, + data: torch.Tensor, + offload_device: Optional[torch.device] = None, ): """ - Updates the paramter value named param_name for a given module. This function - updates both the current loaded module state and the offloaded state dict if - the module is offloaded. This is neccesary because parameter updates for offloaded - modules do not persist automatically between loads. - - :param module: layer containing the parameter to update - :param new_param_data: tensor to update parameter with - :param param_name: name of layer parameter to update + :param module: module containing the parameter to update + :param name: name of module parameter to update + :param data: tensor to update parameter with + :param offload_device: new offload device for parameter, otherwise default to + using the existing offload device """ - if not hasattr(module, param_name): - return + param = getattr(module, name) + param.data = data - device = next(module.parameters()).device + prefix_dict = getattr_chain(module, "module._hf_hook.weights_map.dataset", None) + if prefix_dict is not None: + prefix = module._hf_hook.weights_map.prefix + key = f"{prefix}{name}" - offloaded = False - if is_module_offloaded(module): - offload_device = get_offloaded_device(module) - offloaded = True + if offload_device is None: + if key not in prefix_dict: + raise ValueError( + "Cannot initialize new offload parameter without specifying " + "offload_device" + ) + offload_device = prefix_dict[key].device - parameter = getattr(module, param_name, None) - if parameter is None: - raise ValueError("Attempted to update uninitialized parameter") + prefix_dict[key] = data.to(device=offload_device) - dtype = parameter.dtype - parameter.data = new_param_data.to(device).to(dtype) - if offloaded: - prefix_dict = module._hf_hook.weights_map.dataset - prefix = module._hf_hook.weights_map.prefix - prefix_dict[f"{prefix}{param_name}"] = new_param_data.to(offload_device).to( - dtype - ) +# backwards compatibility +def update_parameter_data( + module: torch.nn.Module, new_param_data: torch.Tensor, param_name: str +): + param = getattr(module, param_name) + new_param_data = new_param_data.to(device=param.device, dtype=param.dtype) + update_offload_parameter(module, param_name, new_param_data) + + +# upstream candidate +@contextlib.contextmanager +def align_module(module: torch.nn.Module, device: Optional[torch.device] = None): + """ + Move an offloaded module's parameters to device or module execution device + + :param module: module with parameters to align + :param device: optional device to move parameters to, if None is provided then + module execution device will be used + """ + if device is not None: + original_device = module._hf_hook.execution_device + module._hf_hook.execution_device = device + + module._hf_hook.pre_forward(module) + yield + module._hf_hook.post_forward(module, torch.tensor([])) + + if device is not None: + module._hf_hook.execution_device = original_device + + +# upstream candidate +def register_offload_parameter( + module: torch.nn.Module, + name: str, + data: torch.Tensor, + offload_device: Optional[torch.device], +): + module.register_parameter(name, torch.nn.Parameter(data)) + update_offload_parameter(module, name, data, offload_device) From 94d8c565e9a7a40d81d444bba7c8cb6fde368479 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 23 Oct 2024 05:05:19 +0000 Subject: [PATCH 02/28] add modify_offload_module --- src/compressed_tensors/utils/offload.py | 98 ++++++++++++++++--------- 1 file changed, 64 insertions(+), 34 deletions(-) diff --git a/src/compressed_tensors/utils/offload.py b/src/compressed_tensors/utils/offload.py index 51be2c29..d630d14e 100644 --- a/src/compressed_tensors/utils/offload.py +++ b/src/compressed_tensors/utils/offload.py @@ -13,6 +13,7 @@ # limitations under the License. import contextlib +from functools import wraps from typing import Optional import torch @@ -34,25 +35,32 @@ ] -# upstream candidate -def can_offload(module: torch.nn.Module) -> bool: +def has_offloaded_params(module: torch.nn.Module) -> bool: """ - :param module: module to check - :return: True if module has offloading capabilities + Checks if a module has offloaded parameters by checking if the given module + has a AlignDevicesHook attached with offloading enabled + + Args: + module (`torch.nn.Module`): The module to check for an offload hook. + + Returns: + bool: `True` if the module has an offload hook and offloading is enabled, + `False` otherwise. """ return ( - hasattr(module, "_hf_hook") - and isinstance(module._hf_hook, AlignDevicesHook) - and module._hf_hook.offload # offload after forward pass + hasattr(module, "_hf_hook") and + isinstance(module._hf_hook, AlignDevicesHook) and + module._hf_hook.offload ) -# backwards compatibility, optional package checking +# depreciation candidate +@wraps(has_offloaded_params) def is_module_offloaded(module: torch.nn.Module) -> bool: if AlignDevicesHook is None: return False - return can_offload(module) + return has_offloaded_params(module) # depreciation candidate @@ -108,14 +116,13 @@ def update_offload_parameter( module: torch.nn.Module, name: str, data: torch.Tensor, - offload_device: Optional[torch.device] = None, + init_device: Optional[torch.device] = torch.device("cpu"), ): """ :param module: module containing the parameter to update :param name: name of module parameter to update :param data: tensor to update parameter with - :param offload_device: new offload device for parameter, otherwise default to - using the existing offload device + :param init_device: offload device for newly registered parameters """ param = getattr(module, name) param.data = data @@ -125,18 +132,11 @@ def update_offload_parameter( prefix = module._hf_hook.weights_map.prefix key = f"{prefix}{name}" - if offload_device is None: - if key not in prefix_dict: - raise ValueError( - "Cannot initialize new offload parameter without specifying " - "offload_device" - ) - offload_device = prefix_dict[key].device - + offload_device = prefix_dict[key].device if key in prefix_dict else init_device prefix_dict[key] = data.to(device=offload_device) -# backwards compatibility +# depreciation candidate def update_parameter_data( module: torch.nn.Module, new_param_data: torch.Tensor, param_name: str ): @@ -147,27 +147,57 @@ def update_parameter_data( # upstream candidate @contextlib.contextmanager -def align_module(module: torch.nn.Module, device: Optional[torch.device] = None): +def align_module(module: torch.nn.Module, execution_device: Optional[torch.device] = None): """ Move an offloaded module's parameters to device or module execution device - :param module: module with parameters to align - :param device: optional device to move parameters to, if None is provided then - module execution device will be used + :param execution_device: optional device to move parameters to, if None is + provided then default module execution device will be used """ - if device is not None: - original_device = module._hf_hook.execution_device - module._hf_hook.execution_device = device + if is_module_offloaded(module): + if execution_device is not None: + original_device = module._hf_hook.execution_device + module._hf_hook.execution_device = original_device - module._hf_hook.pre_forward(module) - yield - module._hf_hook.post_forward(module, torch.tensor([])) + module._hf_hook.pre_forward(module) + yield + module._hf_hook.post_forward(module, None) - if device is not None: - module._hf_hook.execution_device = original_device + if execution_device is not None: + module._hf_hook.execution_device = original_device + elif execution_device is not None: + devices = {} + for name, param in module.named_parameters(): + devices[name] = param.device + setattr(module, name, param.to(execution_device)) -# upstream candidate + yield + + for name, param_device in module.named_parameters: + setattr(module, name, param.to(param_device)) + + else: + yield + + +@contextlib.contextmanager +def modify_offload_module( + module: torch.nn.Module, + execution_device: Optional[torch.device] = None, + offload_device: Optional[torch.device] = None, +): + with align_module(module, execution_device): + yield + + # there is little performance gain from checking if a parameter's data + # has been modified before copying since the new data must be copied + # to the offload device anyways; just update all module parameters + for name, param in module.named_parameters(): + update_offload_parameter(module, name, param.data, offload_device) + + +# upstream candidate? def register_offload_parameter( module: torch.nn.Module, name: str, From f939e987315f2a4dbb2f13386c42e2b5f9cbae8d Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 23 Oct 2024 13:37:47 +0000 Subject: [PATCH 03/28] update docs --- src/compressed_tensors/utils/offload.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/compressed_tensors/utils/offload.py b/src/compressed_tensors/utils/offload.py index d630d14e..6656e9bc 100644 --- a/src/compressed_tensors/utils/offload.py +++ b/src/compressed_tensors/utils/offload.py @@ -35,6 +35,7 @@ ] +# upstream candidate def has_offloaded_params(module: torch.nn.Module) -> bool: """ Checks if a module has offloaded parameters by checking if the given module @@ -149,12 +150,13 @@ def update_parameter_data( @contextlib.contextmanager def align_module(module: torch.nn.Module, execution_device: Optional[torch.device] = None): """ - Move an offloaded module's parameters to device or module execution device + Move a module's parameters to the execution device + :param module: module with parameters to align - :param execution_device: optional device to move parameters to, if None is - provided then default module execution device will be used + :param execution_device: if provided, overrides module execution device + within the context """ - if is_module_offloaded(module): + if has_offloaded_params(module): if execution_device is not None: original_device = module._hf_hook.execution_device module._hf_hook.execution_device = original_device From 167e74101f57aad2edac5872943218c0c50bd3f9 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 31 Oct 2024 20:10:36 +0000 Subject: [PATCH 04/28] WIP --- src/compressed_tensors/utils/offload.py | 130 ++++++++++++++++++------ 1 file changed, 99 insertions(+), 31 deletions(-) diff --git a/src/compressed_tensors/utils/offload.py b/src/compressed_tensors/utils/offload.py index 6656e9bc..b5a5fe0c 100644 --- a/src/compressed_tensors/utils/offload.py +++ b/src/compressed_tensors/utils/offload.py @@ -17,13 +17,19 @@ from typing import Optional import torch +import warnings from compressed_tensors.utils.helpers import getattr_chain try: from accelerate.hooks import AlignDevicesHook + from accelerate.utils import OffloadedWeightsLoader, PrefixedDataset, set_module_tensor_to_device + _has_accelerate = True except ImportError: + _has_accelerate = False AlignDevicesHook = None + OffloadedWeightsLoader = None + PrefixedDataset = None __all__ = [ @@ -58,7 +64,7 @@ def has_offloaded_params(module: torch.nn.Module) -> bool: # depreciation candidate @wraps(has_offloaded_params) def is_module_offloaded(module: torch.nn.Module) -> bool: - if AlignDevicesHook is None: + if not _has_accelerate: return False return has_offloaded_params(module) @@ -81,17 +87,25 @@ def get_execution_device(module: torch.nn.Module) -> torch.device: return device +# upstream candidate +def _infer_offload_device(module: torch.nn.Module) -> torch.device: + if not has_offloaded_params(module): + raise ValueError("Cannot infer offload device from non-offloaded module") + + first_key = next(module._hf_hook.weights_map.keys(), None) + if first_key is None: + raise ValueError("Cannot infer offload device from empty weights map") + + prefix_dataset = module._hf_hook.weights_map.dataset + return prefix_dataset[first_key].device + # depreciation candidate def get_offloaded_device(module: torch.nn.Module) -> torch.device: """ :param module: module to check :return: device module is offloaded to onto after forward pass """ - if is_module_offloaded(module): - first_key = list(module._hf_hook.weights_map.keys())[0] - prefix_dataset = module._hf_hook.weights_map.dataset - return prefix_dataset[first_key].device - return next(module.parameters()).device + return _infer_offload_device(module) # depreciation candidate @@ -112,30 +126,51 @@ def update_prefix_dict(module: torch.nn.Module, key: str, data: torch.Tensor): prefix_dict.dataset[f"{prefix_dict.prefix}{key}"] = data -# upstream candidate +# upstream candidate? def update_offload_parameter( module: torch.nn.Module, name: str, - data: torch.Tensor, - init_device: Optional[torch.device] = torch.device("cpu"), + data: Optional[torch.Tensor] = None, + offload_device: Optional[torch.device] = None, ): """ :param module: module containing the parameter to update :param name: name of module parameter to update :param data: tensor to update parameter with - :param init_device: offload device for newly registered parameters + :param offload_device: offload device for newly registered parameters """ param = getattr(module, name) - param.data = data - - prefix_dict = getattr_chain(module, "module._hf_hook.weights_map.dataset", None) - if prefix_dict is not None: - prefix = module._hf_hook.weights_map.prefix - key = f"{prefix}{name}" + if param.device == "meta" or data is not None and data.device == "meta": + raise ValueError("Cannot copy data to/from meta device. Consider calling with align_module(module)") + + if data is not None: + if param.data.dtype != data.dtype: + warnings.warn("TODO") - offload_device = prefix_dict[key].device if key in prefix_dict else init_device - prefix_dict[key] = data.to(device=offload_device) + param.data.copy_(data) + if has_offloaded_params(module): + weights_map = module._hf_hook.weights_map + + # for upstreaming, probably better to modify the weight map types so that they can be written to? + if isinstance(weights_map, PrefixedDataset): + prefix_dict = getattr_chain(module, "module._hf_hook.weights_map.dataset", None) + if prefix_dict is not None: + prefix = module._hf_hook.weights_map.prefix + key = f"{prefix}{name}" + + offload_device = ( + prefix_dict[key].device if key in prefix_dict + else offload_device if offload_device is not None + else _infer_offload_device(module) + ) + prefix_dict[key] = param.data.to(device=offload_device) + + if isinstance(weights_map, OffloadedWeightsLoader): + raise NotImplementedError() + + else: + raise NotImplementedError() # depreciation candidate def update_parameter_data( @@ -146,20 +181,23 @@ def update_parameter_data( update_offload_parameter(module, param_name, new_param_data) -# upstream candidate @contextlib.contextmanager def align_module(module: torch.nn.Module, execution_device: Optional[torch.device] = None): """ - Move a module's parameters to the execution device + Moves a module's parameters to the specified execution device. - :param module: module with parameters to align - :param execution_device: if provided, overrides module execution device - within the context + Args: + module (torch.nn.Module): Module with parameters to align. + execution_device (Optional[torch.device]): If provided, overrides the + module's execution device within the context. + + Yields: + None: Yields control while the module's parameters are aligned to the execution device. """ if has_offloaded_params(module): if execution_device is not None: original_device = module._hf_hook.execution_device - module._hf_hook.execution_device = original_device + module._hf_hook.execution_device = execution_device module._hf_hook.pre_forward(module) yield @@ -172,17 +210,26 @@ def align_module(module: torch.nn.Module, execution_device: Optional[torch.devic devices = {} for name, param in module.named_parameters(): devices[name] = param.device - setattr(module, name, param.to(execution_device)) + set_module_tensor_to_device( + module, + name, + execution_device, + ) yield - for name, param_device in module.named_parameters: - setattr(module, name, param.to(param_device)) + for name, param in module.named_parameters(): + set_module_tensor_to_device( + module, + name, + devices[name], + ) else: yield + @contextlib.contextmanager def modify_offload_module( module: torch.nn.Module, @@ -203,8 +250,29 @@ def modify_offload_module( def register_offload_parameter( module: torch.nn.Module, name: str, - data: torch.Tensor, - offload_device: Optional[torch.device], + parameter: torch.nn.Parameter, + offload_device: Optional[torch.device] = None, ): - module.register_parameter(name, torch.nn.Parameter(data)) - update_offload_parameter(module, name, data, offload_device) + module.register_parameter(name, parameter) + update_offload_parameter(module, name, parameter.data, offload_device) + + +# upstream candidate? +def delete_offload_parameter(module: torch.nn.Module, name: str): + delattr(module, name) + + if has_offloaded_params(module): + weights_map = module._hf_hook.weights_map + + # for upstreaming, probably better to modify the weight map types so that they can be written to? + if isinstance(weights_map, PrefixedDataset): + dataset = weights_map.dataset + prefix = weights_map.prefix + if dataset is not None: + del dataset[f"{prefix}{name}"] + + elif isinstance(weights_map, OffloadedWeightsLoader): + raise NotImplementedError() + + elif weights_map is not None: + raise NotImplementedError(f"Cannot delete parameter from weights_map of type {type(weights_map)}") \ No newline at end of file From cb6edb13ac7c6a19da046054cb75d75931795703 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 18 Nov 2024 23:55:03 +0000 Subject: [PATCH 05/28] cleanup functions, begin depreciation Signed-off-by: Kyle Sayers --- .../quantization/lifecycle/calibration.py | 4 +- .../quantization/lifecycle/initialize.py | 16 +- src/compressed_tensors/utils/helpers.py | 31 +- src/compressed_tensors/utils/offload.py | 264 ++++++++++-------- 4 files changed, 186 insertions(+), 129 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/calibration.py b/src/compressed_tensors/quantization/lifecycle/calibration.py index d444694d..c67844fa 100644 --- a/src/compressed_tensors/quantization/lifecycle/calibration.py +++ b/src/compressed_tensors/quantization/lifecycle/calibration.py @@ -16,7 +16,7 @@ import logging from compressed_tensors.quantization.quant_config import QuantizationStatus -from compressed_tensors.utils import is_module_offloaded, update_parameter_data +from compressed_tensors.utils import has_offloaded_params, update_parameter_data from torch.nn import Module @@ -56,7 +56,7 @@ def set_module_for_calibration(module: Module, quantize_weights_upfront: bool = observer = module.weight_observer g_idx = getattr(module, "weight_g_idx", None) - offloaded = is_module_offloaded(module) + offloaded = has_offloaded_params(module) if offloaded: module._hf_hook.pre_forward(module) diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index 9b98da33..27b6a803 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -30,7 +30,7 @@ from compressed_tensors.quantization.quant_config import QuantizationStatus from compressed_tensors.quantization.quant_scheme import QuantizationScheme from compressed_tensors.quantization.utils import is_kv_cache_quant_scheme -from compressed_tensors.utils import get_execution_device, is_module_offloaded +from compressed_tensors.utils import has_offloaded_params, register_offload_parameter from torch.nn import Module, Parameter @@ -109,7 +109,7 @@ def initialize_module_for_quantization( module.quantization_status = QuantizationStatus.INITIALIZED offloaded = False - if is_module_offloaded(module): + if has_offloaded_params(module): try: from accelerate.hooks import add_hook_to_module, remove_hook_from_module from accelerate.utils import PrefixedDataset @@ -164,9 +164,9 @@ def _initialize_scale_zero_point_observer( if quantization_args.dynamic: return - device = next(module.parameters()).device - if is_module_offloaded(module): - device = get_execution_device(module) + # begin on the same device as other parameters or cpu if offloaded + params_device = next(module.parameters()).device + device = "cpu" if has_offloaded_params(module) else params_device # infer expected scale/zero point shape expected_shape = 1 # per tensor @@ -188,7 +188,7 @@ def _initialize_scale_zero_point_observer( torch.empty(expected_shape, dtype=scale_dtype, device=device), requires_grad=False, ) - module.register_parameter(f"{base_name}_scale", init_scale) + register_offload_parameter(module, f"{base_name}_scale", init_scale) if force_zero_point or not quantization_args.symmetric: zp_dtype = quantization_args.pytorch_dtype() @@ -196,7 +196,7 @@ def _initialize_scale_zero_point_observer( torch.zeros(expected_shape, device=device, dtype=zp_dtype), requires_grad=False, ) - module.register_parameter(f"{base_name}_zero_point", init_zero_point) + register_offload_parameter(module, f"{base_name}_zero_point", init_zero_point) # only grouped activation ordering has g_idx if quantization_args.actorder == ActivationOrdering.GROUP: @@ -206,7 +206,7 @@ def _initialize_scale_zero_point_observer( torch.full(g_idx_shape, -1, device=device, dtype=g_idx_dtype), requires_grad=False, ) - module.register_parameter(f"{base_name}_g_idx", init_g_idx) + register_offload_parameter(module, f"{base_name}_g_idx", init_g_idx) def is_attention_module(module: Module): diff --git a/src/compressed_tensors/utils/helpers.py b/src/compressed_tensors/utils/helpers.py index 82e11ccf..8468bf0d 100644 --- a/src/compressed_tensors/utils/helpers.py +++ b/src/compressed_tensors/utils/helpers.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional +import warnings +from typing import Any, Callable, Optional import torch from transformers import AutoConfig @@ -25,6 +26,7 @@ "replace_module", "is_compressed_tensors_config", "getattr_chain", + "deprecated", ] FSDP_WRAPPER_NAME = "_fsdp_wrapped_module" @@ -152,3 +154,30 @@ def getattr_chain(obj: Any, chain_str: str, *args, **kwargs) -> Any: res = getattr(res, attr_name) return res + + +def deprecated(future_name: Optional[str] = None, message: Optional[str] = None): + """ + Decorator to mark functions as deprecated + + :param new_function: Function called in place of depreciated function + :param message: Depreciation message, replaces default depreciation message + """ + + def decorator(func: Callable[[Any], Any]): + nonlocal message + + if message is None: + message = ( + f"{func.__name__} is deprecated and will be removed in a future release" + ) + if future_name is not None: + message += f". Please use {future_name} instead." + + def wrapped(*args, **kwargs): + warnings.warn(message, DeprecationWarning, stacklevel=2) + return func(*args, **kwargs) + + return wrapped + + return decorator diff --git a/src/compressed_tensors/utils/offload.py b/src/compressed_tensors/utils/offload.py index b5a5fe0c..071c0ee1 100644 --- a/src/compressed_tensors/utils/offload.py +++ b/src/compressed_tensors/utils/offload.py @@ -13,17 +13,20 @@ # limitations under the License. import contextlib -from functools import wraps -from typing import Optional +from typing import Any, Callable, Optional import torch -import warnings from compressed_tensors.utils.helpers import getattr_chain try: from accelerate.hooks import AlignDevicesHook - from accelerate.utils import OffloadedWeightsLoader, PrefixedDataset, set_module_tensor_to_device + from accelerate.utils import ( + OffloadedWeightsLoader, + PrefixedDataset, + set_module_tensor_to_device, + ) + _has_accelerate = True except ImportError: _has_accelerate = False @@ -38,45 +41,38 @@ "get_offloaded_device", "update_prefix_dict", "update_parameter_data", + "register_offload_parameter", + "update_offload_data", + "delete_offload_parameter", + "has_offloaded_params", + "align_module", ] -# upstream candidate -def has_offloaded_params(module: torch.nn.Module) -> bool: - """ - Checks if a module has offloaded parameters by checking if the given module - has a AlignDevicesHook attached with offloading enabled +def check_accelerate(fallback: Any): + def decorator(func: Callable[[Any], Any]): + if not _has_accelerate: + return lambda *args, **kwargs: fallback - Args: - module (`torch.nn.Module`): The module to check for an offload hook. + return func - Returns: - bool: `True` if the module has an offload hook and offloading is enabled, - `False` otherwise. - """ - return ( - hasattr(module, "_hf_hook") and - isinstance(module._hf_hook, AlignDevicesHook) and - module._hf_hook.offload - ) + return decorator -# depreciation candidate -@wraps(has_offloaded_params) -def is_module_offloaded(module: torch.nn.Module) -> bool: - if not _has_accelerate: - return False +""" Candidates for Depreciation """ + +@check_accelerate(fallback=False) +def is_module_offloaded(module: torch.nn.Module) -> bool: return has_offloaded_params(module) -# depreciation candidate def get_execution_device(module: torch.nn.Module) -> torch.device: """ :param module: module to check :return: device module is loaded onto during forward pass """ - if is_module_offloaded(module): + if has_offloaded_params(module): return module._hf_hook.execution_device device = next(module.parameters()).device @@ -87,11 +83,14 @@ def get_execution_device(module: torch.nn.Module) -> torch.device: return device -# upstream candidate -def _infer_offload_device(module: torch.nn.Module) -> torch.device: +def get_offloaded_device(module: torch.nn.Module) -> torch.device: + """ + :param module: module to check + :return: device module is offloaded to onto after forward pass + """ if not has_offloaded_params(module): raise ValueError("Cannot infer offload device from non-offloaded module") - + first_key = next(module._hf_hook.weights_map.keys(), None) if first_key is None: raise ValueError("Cannot infer offload device from empty weights map") @@ -99,16 +98,8 @@ def _infer_offload_device(module: torch.nn.Module) -> torch.device: prefix_dataset = module._hf_hook.weights_map.dataset return prefix_dataset[first_key].device -# depreciation candidate -def get_offloaded_device(module: torch.nn.Module) -> torch.device: - """ - :param module: module to check - :return: device module is offloaded to onto after forward pass - """ - return _infer_offload_device(module) - -# depreciation candidate +@check_accelerate(fallback=None) def update_prefix_dict(module: torch.nn.Module, key: str, data: torch.Tensor): """ Updates the offloaded state dict for a given module. Parameter named key is replaced @@ -120,69 +111,154 @@ def update_prefix_dict(module: torch.nn.Module, key: str, data: torch.Tensor): :param key: name of parameter to update :param data: tensor to update parameter with in the offloaded state dict """ - if not is_module_offloaded(module): + if not has_offloaded_params(module): raise ValueError("Prefix dict is only applicable to offloaded modules") prefix_dict = module._hf_hook.weights_map prefix_dict.dataset[f"{prefix_dict.prefix}{key}"] = data -# upstream candidate? -def update_offload_parameter( +def update_parameter_data( + module: torch.nn.Module, new_param_data: torch.Tensor, param_name: str +): + """ + Update the data of an existing parameter and its offload dict. Supports both + parameters of offloaded modules and non-offloaded modules + + :param module: module containing the parameter to update + :param new_param_data: tensor to update parameter with + :param param_name: name of module parameter to update + """ + update_offload_data(module, param_name, new_param_data) + + +""" Candidates for Upstreaming """ + + +def register_offload_parameter( + module: torch.nn.Module, + name: str, + parameter: torch.nn.Parameter, +): + """ + Register a parameter to the given module which may be offloaded + + :param module: maybe offloaded module + :param name: name of newly registered parameter + :param parameter: parameter being registered + """ + if has_offloaded_params(module): + module.register_parameter(name, parameter) + update_offload_data(module, name, parameter.data) + set_module_tensor_to_device(module, name, "meta") + else: + device = next(module.parameters()).device + parameter = parameter.to(device) + module.register_parameter(name, parameter) + + +def update_offload_data( module: torch.nn.Module, name: str, - data: Optional[torch.Tensor] = None, - offload_device: Optional[torch.device] = None, + data: Optional[torch.Tensor], ): """ + Update the data of an existing parameter and its offload dict. Supports both + parameters of offloaded modules and non-offloaded modules + :param module: module containing the parameter to update :param name: name of module parameter to update :param data: tensor to update parameter with - :param offload_device: offload device for newly registered parameters """ param = getattr(module, name) - if param.device == "meta" or data is not None and data.device == "meta": - raise ValueError("Cannot copy data to/from meta device. Consider calling with align_module(module)") - - if data is not None: - if param.data.dtype != data.dtype: - warnings.warn("TODO") + # copy data into onloaded parameter if applicable + if param.device != "meta": param.data.copy_(data) + # update offload dict if has_offloaded_params(module): weights_map = module._hf_hook.weights_map - # for upstreaming, probably better to modify the weight map types so that they can be written to? + # for upstreaming, better to add write capabilities to weight map classes first if isinstance(weights_map, PrefixedDataset): - prefix_dict = getattr_chain(module, "module._hf_hook.weights_map.dataset", None) - if prefix_dict is not None: + dataset = getattr_chain(module, "module._hf_hook.weights_map.dataset", None) + if dataset is not None: prefix = module._hf_hook.weights_map.prefix key = f"{prefix}{name}" + breakpoint() + offload_device = ( - prefix_dict[key].device if key in prefix_dict - else offload_device if offload_device is not None - else _infer_offload_device(module) + dataset[key].device + if key in dataset + else next(dataset.values()).device ) - prefix_dict[key] = param.data.to(device=offload_device) - + dataset[key] = param.data.to(device=offload_device) + if isinstance(weights_map, OffloadedWeightsLoader): raise NotImplementedError() - + else: raise NotImplementedError() -# depreciation candidate -def update_parameter_data( - module: torch.nn.Module, new_param_data: torch.Tensor, param_name: str -): - param = getattr(module, param_name) - new_param_data = new_param_data.to(device=param.device, dtype=param.dtype) - update_offload_parameter(module, param_name, new_param_data) +def delete_offload_parameter(module: torch.nn.Module, name: str): + """ + Delete a module from a module which may be offloaded + + :param module: maybe offloaded module + :param name: name of parameter being deleted + """ + delattr(module, name) + + if has_offloaded_params(module): + weights_map = module._hf_hook.weights_map + + # for upstreaming, better to add write capabilities to weight map classes first + if isinstance(weights_map, PrefixedDataset): + dataset = weights_map.dataset + prefix = weights_map.prefix + if dataset is not None: + del dataset[f"{prefix}{name}"] + + elif isinstance(weights_map, OffloadedWeightsLoader): + raise NotImplementedError() + + elif weights_map is not None: + raise NotImplementedError( + f"Cannot delete parameter from weights_map of type {type(weights_map)}" + ) + + +""" Upstreamed Functions """ + +# introduced in accelerate v1.1.0 +@check_accelerate(fallback=False) +def has_offloaded_params(module: torch.nn.Module) -> bool: + """ + Checks if a module has offloaded parameters by checking if the given module has a + AlignDevicesHook attached with offloading enabled + + Args: + module (`torch.nn.Module`): The module to check for an offload hook. + + Returns: + bool: `True` if the module has an offload hook and offloading is enabled, + `False` otherwise. + """ + return ( + hasattr(module, "_hf_hook") + and isinstance(module._hf_hook, AlignDevicesHook) + and module._hf_hook.offload + ) + + +# introduced in accelerate v1.1.0 @contextlib.contextmanager -def align_module(module: torch.nn.Module, execution_device: Optional[torch.device] = None): +def align_module( + module: torch.nn.Module, execution_device: Optional[torch.device] = None +): """ Moves a module's parameters to the specified execution device. @@ -192,7 +268,8 @@ def align_module(module: torch.nn.Module, execution_device: Optional[torch.devic module's execution device within the context. Yields: - None: Yields control while the module's parameters are aligned to the execution device. + None: Yields control while the module's parameters are aligned to the execution + device. """ if has_offloaded_params(module): if execution_device is not None: @@ -227,52 +304,3 @@ def align_module(module: torch.nn.Module, execution_device: Optional[torch.devic else: yield - - - -@contextlib.contextmanager -def modify_offload_module( - module: torch.nn.Module, - execution_device: Optional[torch.device] = None, - offload_device: Optional[torch.device] = None, -): - with align_module(module, execution_device): - yield - - # there is little performance gain from checking if a parameter's data - # has been modified before copying since the new data must be copied - # to the offload device anyways; just update all module parameters - for name, param in module.named_parameters(): - update_offload_parameter(module, name, param.data, offload_device) - - -# upstream candidate? -def register_offload_parameter( - module: torch.nn.Module, - name: str, - parameter: torch.nn.Parameter, - offload_device: Optional[torch.device] = None, -): - module.register_parameter(name, parameter) - update_offload_parameter(module, name, parameter.data, offload_device) - - -# upstream candidate? -def delete_offload_parameter(module: torch.nn.Module, name: str): - delattr(module, name) - - if has_offloaded_params(module): - weights_map = module._hf_hook.weights_map - - # for upstreaming, probably better to modify the weight map types so that they can be written to? - if isinstance(weights_map, PrefixedDataset): - dataset = weights_map.dataset - prefix = weights_map.prefix - if dataset is not None: - del dataset[f"{prefix}{name}"] - - elif isinstance(weights_map, OffloadedWeightsLoader): - raise NotImplementedError() - - elif weights_map is not None: - raise NotImplementedError(f"Cannot delete parameter from weights_map of type {type(weights_map)}") \ No newline at end of file From cb70047efe4ab266ca73a652a5cb69d5693194d2 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 18 Nov 2024 23:56:48 +0000 Subject: [PATCH 06/28] remove extra space Signed-off-by: Kyle Sayers --- src/compressed_tensors/utils/helpers.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/compressed_tensors/utils/helpers.py b/src/compressed_tensors/utils/helpers.py index 8468bf0d..2f2e745e 100644 --- a/src/compressed_tensors/utils/helpers.py +++ b/src/compressed_tensors/utils/helpers.py @@ -131,7 +131,6 @@ def getattr_chain(obj: Any, chain_str: str, *args, **kwargs) -> Any: :param obj: base object whose attributes are being retrieved :param chain_str: attribute names separated by `.` :param default: default value, throw error otherwise - """ if len(args) >= 1: has_default = True From 98a2889456361e1eb8045bf8654781e5778047ae Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 18 Nov 2024 23:59:30 +0000 Subject: [PATCH 07/28] revert get_offloaded_device Signed-off-by: Kyle Sayers --- src/compressed_tensors/utils/offload.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/src/compressed_tensors/utils/offload.py b/src/compressed_tensors/utils/offload.py index 071c0ee1..6c2a6e1c 100644 --- a/src/compressed_tensors/utils/offload.py +++ b/src/compressed_tensors/utils/offload.py @@ -88,15 +88,11 @@ def get_offloaded_device(module: torch.nn.Module) -> torch.device: :param module: module to check :return: device module is offloaded to onto after forward pass """ - if not has_offloaded_params(module): - raise ValueError("Cannot infer offload device from non-offloaded module") - - first_key = next(module._hf_hook.weights_map.keys(), None) - if first_key is None: - raise ValueError("Cannot infer offload device from empty weights map") - - prefix_dataset = module._hf_hook.weights_map.dataset - return prefix_dataset[first_key].device + if has_offloaded_params(module): + first_key = list(module._hf_hook.weights_map.keys())[0] + prefix_dataset = module._hf_hook.weights_map.dataset + return prefix_dataset[first_key].device + return next(module.parameters()).device @check_accelerate(fallback=None) From 8cd69ef42b97f58bc3a0eddbabd66ee0f08d8809 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 19 Nov 2024 00:05:49 +0000 Subject: [PATCH 08/28] update to align_module_device Signed-off-by: Kyle Sayers --- src/compressed_tensors/utils/offload.py | 62 +++++++++++-------------- 1 file changed, 26 insertions(+), 36 deletions(-) diff --git a/src/compressed_tensors/utils/offload.py b/src/compressed_tensors/utils/offload.py index 6c2a6e1c..5bc00c05 100644 --- a/src/compressed_tensors/utils/offload.py +++ b/src/compressed_tensors/utils/offload.py @@ -30,9 +30,6 @@ _has_accelerate = True except ImportError: _has_accelerate = False - AlignDevicesHook = None - OffloadedWeightsLoader = None - PrefixedDataset = None __all__ = [ @@ -45,7 +42,7 @@ "update_offload_data", "delete_offload_parameter", "has_offloaded_params", - "align_module", + "align_module_device", ] @@ -251,52 +248,45 @@ def has_offloaded_params(module: torch.nn.Module) -> bool: # introduced in accelerate v1.1.0 +@check_accelerate(fallback=contextlib.nullcontext()) @contextlib.contextmanager -def align_module( +def align_module_device( module: torch.nn.Module, execution_device: Optional[torch.device] = None ): """ - Moves a module's parameters to the specified execution device. + Context manager that moves a module's parameters to the specified execution device. Args: - module (torch.nn.Module): Module with parameters to align. - execution_device (Optional[torch.device]): If provided, overrides the - module's execution device within the context. - - Yields: - None: Yields control while the module's parameters are aligned to the execution - device. + module (`torch.nn.Module`): + Module with parameters to align. + execution_device (`torch.device`, *optional*): + If provided, overrides the module's execution device within the context. + Otherwise, use hook execution device or pass """ if has_offloaded_params(module): if execution_device is not None: original_device = module._hf_hook.execution_device module._hf_hook.execution_device = execution_device - module._hf_hook.pre_forward(module) - yield - module._hf_hook.post_forward(module, None) - - if execution_device is not None: - module._hf_hook.execution_device = original_device + try: + module._hf_hook.pre_forward(module) + yield + finally: + module._hf_hook.post_forward(module, None) + if execution_device is not None: + module._hf_hook.execution_device = original_device elif execution_device is not None: - devices = {} - for name, param in module.named_parameters(): - devices[name] = param.device - set_module_tensor_to_device( - module, - name, - execution_device, - ) - - yield - - for name, param in module.named_parameters(): - set_module_tensor_to_device( - module, - name, - devices[name], - ) + devices = { + name: param.device for name, param in module.named_parameters(recurse=False) + } + try: + for name in devices: + set_module_tensor_to_device(module, name, execution_device) + yield + finally: + for name, device in devices.items(): + set_module_tensor_to_device(module, name, device) else: yield From 0d2318363ba467f20db7d185a5fbfa1ae74c4055 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 19 Nov 2024 00:19:10 +0000 Subject: [PATCH 09/28] add requires skip for accelerate --- .../test_quantization/lifecycle/test_apply.py | 2 ++ tests/testing_utils.py | 23 ++++++++++++++++++- 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/tests/test_quantization/lifecycle/test_apply.py b/tests/test_quantization/lifecycle/test_apply.py index 5f0bd093..511cfdf7 100644 --- a/tests/test_quantization/lifecycle/test_apply.py +++ b/tests/test_quantization/lifecycle/test_apply.py @@ -29,6 +29,7 @@ apply_quantization_status, ) from compressed_tensors.quantization.utils import iter_named_leaf_modules +from tests.testing_utils import requires_accelerate from transformers import AutoModelForCausalLM @@ -226,6 +227,7 @@ def get_sample_tinyllama_quant_config(status: str = "frozen"): return QuantizationConfig.parse_obj(config_dict) +@requires_accelerate() @pytest.mark.parametrize( "ignore,should_raise_warning", [ diff --git a/tests/testing_utils.py b/tests/testing_utils.py index 2e9be7cf..e446cad3 100644 --- a/tests/testing_utils.py +++ b/tests/testing_utils.py @@ -26,8 +26,29 @@ def compressed_tensors_config_available(): return False +def accelerate_availabe(): + try: + import accelerate # noqa: F401 + + return True + + except ImportError: + return False + + +_is_compressed_tensors_config_available = compressed_tensors_config_available() +_is_accelerate_available = accelerate_availabe() + + def requires_hf_quantizer(): return pytest.mark.skipif( - not compressed_tensors_config_available(), + not _is_compressed_tensors_config_available, reason="requires transformers>=4.45 to support CompressedTensorsHfQuantizer", ) + + +def requires_accelerate(): + return pytest.mark.skipif( + not _is_accelerate_available, + reason="requires accelerate", + ) From 0b0d8b67e38df0344a0ae9c001f6907333d94fe9 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 19 Nov 2024 00:44:39 +0000 Subject: [PATCH 10/28] fix per token initialization --- src/compressed_tensors/quantization/lifecycle/initialize.py | 5 ++++- src/compressed_tensors/utils/offload.py | 2 -- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index 2a1efccb..009215e0 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -173,7 +173,10 @@ def _initialize_scale_zero_point( device = "cpu" if has_offloaded_params(module) else params_device # infer expected scale/zero point shape - expected_shape = 1 # per tensor + if quantization_args.strategy == QuantizationStrategy.TOKEN: + expected_shape = (1, 1) + else: + expected_shape = 1 if base_name == "weight" and weight_shape is not None: if quantization_args.strategy == QuantizationStrategy.CHANNEL: diff --git a/src/compressed_tensors/utils/offload.py b/src/compressed_tensors/utils/offload.py index 5bc00c05..29970bd9 100644 --- a/src/compressed_tensors/utils/offload.py +++ b/src/compressed_tensors/utils/offload.py @@ -179,8 +179,6 @@ def update_offload_data( prefix = module._hf_hook.weights_map.prefix key = f"{prefix}{name}" - breakpoint() - offload_device = ( dataset[key].device if key in dataset From 95e59075feaf215f29c2fd7bb82e23a0762d0083 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 19 Nov 2024 02:31:23 +0000 Subject: [PATCH 11/28] remove align_module_device --- src/compressed_tensors/utils/offload.py | 47 ------------------------- 1 file changed, 47 deletions(-) diff --git a/src/compressed_tensors/utils/offload.py b/src/compressed_tensors/utils/offload.py index 29970bd9..0df00ec2 100644 --- a/src/compressed_tensors/utils/offload.py +++ b/src/compressed_tensors/utils/offload.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import contextlib from typing import Any, Callable, Optional import torch @@ -42,7 +41,6 @@ "update_offload_data", "delete_offload_parameter", "has_offloaded_params", - "align_module_device", ] @@ -243,48 +241,3 @@ def has_offloaded_params(module: torch.nn.Module) -> bool: and isinstance(module._hf_hook, AlignDevicesHook) and module._hf_hook.offload ) - - -# introduced in accelerate v1.1.0 -@check_accelerate(fallback=contextlib.nullcontext()) -@contextlib.contextmanager -def align_module_device( - module: torch.nn.Module, execution_device: Optional[torch.device] = None -): - """ - Context manager that moves a module's parameters to the specified execution device. - - Args: - module (`torch.nn.Module`): - Module with parameters to align. - execution_device (`torch.device`, *optional*): - If provided, overrides the module's execution device within the context. - Otherwise, use hook execution device or pass - """ - if has_offloaded_params(module): - if execution_device is not None: - original_device = module._hf_hook.execution_device - module._hf_hook.execution_device = execution_device - - try: - module._hf_hook.pre_forward(module) - yield - finally: - module._hf_hook.post_forward(module, None) - if execution_device is not None: - module._hf_hook.execution_device = original_device - - elif execution_device is not None: - devices = { - name: param.device for name, param in module.named_parameters(recurse=False) - } - try: - for name in devices: - set_module_tensor_to_device(module, name, execution_device) - yield - finally: - for name, device in devices.items(): - set_module_tensor_to_device(module, name, device) - - else: - yield From 81a1eabe0196ce97905dfc724a04453e5ef57a12 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 6 Dec 2024 03:59:10 +0000 Subject: [PATCH 12/28] respond to nits Signed-off-by: Kyle Sayers --- src/compressed_tensors/utils/helpers.py | 2 ++ src/compressed_tensors/utils/offload.py | 10 ++++++++-- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/src/compressed_tensors/utils/helpers.py b/src/compressed_tensors/utils/helpers.py index fe034126..910436eb 100644 --- a/src/compressed_tensors/utils/helpers.py +++ b/src/compressed_tensors/utils/helpers.py @@ -13,6 +13,7 @@ # limitations under the License. import warnings +from functools import wraps from typing import Any, Callable, Dict, Optional import torch @@ -174,6 +175,7 @@ def decorator(func: Callable[[Any], Any]): if future_name is not None: message += f". Please use {future_name} instead." + @wraps(func) def wrapped(*args, **kwargs): warnings.warn(message, DeprecationWarning, stacklevel=2) return func(*args, **kwargs) diff --git a/src/compressed_tensors/utils/offload.py b/src/compressed_tensors/utils/offload.py index 0df00ec2..0d7b0bbe 100644 --- a/src/compressed_tensors/utils/offload.py +++ b/src/compressed_tensors/utils/offload.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from functools import wraps from typing import Any, Callable, Optional import torch @@ -47,7 +48,12 @@ def check_accelerate(fallback: Any): def decorator(func: Callable[[Any], Any]): if not _has_accelerate: - return lambda *args, **kwargs: fallback + + @wraps(func) + def fallback_fn(*args, **kwargs): + return fallback + + return fallback_fn return func @@ -193,7 +199,7 @@ def update_offload_data( def delete_offload_parameter(module: torch.nn.Module, name: str): """ - Delete a module from a module which may be offloaded + Delete a parameter from a module which may be offloaded :param module: maybe offloaded module :param name: name of parameter being deleted From e7e1d81dbcfbb13745167114fdffa1e830b86227 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 6 Dec 2024 01:39:51 -0500 Subject: [PATCH 13/28] Accelerate Utilities Follow-up (#224) --- .../quantization/lifecycle/initialize.py | 46 +---- src/compressed_tensors/utils/offload.py | 94 +++++++++- .../lifecycle/test_initialize.py | 43 ++++- tests/test_utils/test_offload.py | 163 ++++++++++++++++++ 4 files changed, 300 insertions(+), 46 deletions(-) create mode 100644 tests/test_utils/test_offload.py diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index 009215e0..9cb99b2e 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -29,7 +29,11 @@ from compressed_tensors.quantization.quant_config import QuantizationStatus from compressed_tensors.quantization.quant_scheme import QuantizationScheme from compressed_tensors.quantization.utils import is_kv_cache_quant_scheme -from compressed_tensors.utils import has_offloaded_params, register_offload_parameter +from compressed_tensors.utils import ( + disable_hf_hook, + has_offloaded_params, + register_offload_parameter, +) from torch.nn import Module, Parameter @@ -112,42 +116,10 @@ def initialize_module_for_quantization( module.quantization_scheme = scheme module.quantization_status = QuantizationStatus.INITIALIZED - offloaded = False - if has_offloaded_params(module): - try: - from accelerate.hooks import add_hook_to_module, remove_hook_from_module - from accelerate.utils import PrefixedDataset - except ModuleNotFoundError: - raise ModuleNotFoundError( - "Offloaded model detected. To use CPU offloading with " - "compressed-tensors the `accelerate` package must be installed, " - "run `pip install compressed-tensors[accelerate]`" - ) - - offloaded = True - hook = module._hf_hook - prefix_dict = module._hf_hook.weights_map - new_prefix = {} - - # recreate the prefix dict (since it is immutable) - # and add quantization parameters - for key, data in module.named_parameters(): - if key not in prefix_dict: - new_prefix[f"{prefix_dict.prefix}{key}"] = data - else: - new_prefix[f"{prefix_dict.prefix}{key}"] = prefix_dict[key] - new_prefix_dict = PrefixedDataset(new_prefix, prefix_dict.prefix) - remove_hook_from_module(module) - - # wrap forward call of module to perform - # quantized actions based on calltime status - wrap_module_forward_quantized(module, scheme) - - if offloaded: - # we need to re-add the hook for offloading now that we've wrapped forward - add_hook_to_module(module, hook) - if prefix_dict is not None: - module._hf_hook.weights_map = new_prefix_dict + with disable_hf_hook(module): + # wrap forward call of module to perform + # quantized actions based on calltime status + wrap_module_forward_quantized(module, scheme) def is_attention_module(module: Module): diff --git a/src/compressed_tensors/utils/offload.py b/src/compressed_tensors/utils/offload.py index 0d7b0bbe..7b7cc864 100644 --- a/src/compressed_tensors/utils/offload.py +++ b/src/compressed_tensors/utils/offload.py @@ -12,15 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. +import contextlib from functools import wraps from typing import Any, Callable, Optional import torch -from compressed_tensors.utils.helpers import getattr_chain try: - from accelerate.hooks import AlignDevicesHook + from accelerate.hooks import ( + AlignDevicesHook, + add_hook_to_module, + remove_hook_from_module, + ) from accelerate.utils import ( OffloadedWeightsLoader, PrefixedDataset, @@ -42,6 +46,8 @@ "update_offload_data", "delete_offload_parameter", "has_offloaded_params", + "disable_hf_hook", + "align_module_device", ] @@ -167,6 +173,7 @@ def update_offload_data( :param data: tensor to update parameter with """ param = getattr(module, name) + data = data.to(param.dtype) # copy data into onloaded parameter if applicable if param.device != "meta": @@ -178,7 +185,7 @@ def update_offload_data( # for upstreaming, better to add write capabilities to weight map classes first if isinstance(weights_map, PrefixedDataset): - dataset = getattr_chain(module, "module._hf_hook.weights_map.dataset", None) + dataset = getattr(weights_map, "dataset", None) if dataset is not None: prefix = module._hf_hook.weights_map.prefix key = f"{prefix}{name}" @@ -186,15 +193,26 @@ def update_offload_data( offload_device = ( dataset[key].device if key in dataset - else next(dataset.values()).device + else next(iter(dataset.values())).device ) - dataset[key] = param.data.to(device=offload_device) + dataset[key] = data.to(device=offload_device) + + elif isinstance(weights_map, dict): + offload_device = ( + weights_map[name].device + if name in weights_map + else next(iter(weights_map.values())).device + ) + weights_map[name] = data.to(device=offload_device) - if isinstance(weights_map, OffloadedWeightsLoader): + elif isinstance(weights_map, OffloadedWeightsLoader): raise NotImplementedError() else: - raise NotImplementedError() + raise NotImplementedError( + "Updating offload data not implemented for weights_map of type " + f"{type(weights_map)}" + ) def delete_offload_parameter(module: torch.nn.Module, name: str): @@ -216,6 +234,9 @@ def delete_offload_parameter(module: torch.nn.Module, name: str): if dataset is not None: del dataset[f"{prefix}{name}"] + elif isinstance(weights_map, dict): + del weights_map[name] + elif isinstance(weights_map, OffloadedWeightsLoader): raise NotImplementedError() @@ -225,6 +246,20 @@ def delete_offload_parameter(module: torch.nn.Module, name: str): ) +@check_accelerate(fallback=contextlib.nullcontext()) +@contextlib.contextmanager +def disable_hf_hook(module: torch.nn.Module, recurse: bool = False): + offloaded = has_offloaded_params(module) + if offloaded: + hook = module._hf_hook + remove_hook_from_module(module, recurse=recurse) + + yield + + if offloaded: + add_hook_to_module(module, hook) + + """ Upstreamed Functions """ @@ -247,3 +282,48 @@ def has_offloaded_params(module: torch.nn.Module) -> bool: and isinstance(module._hf_hook, AlignDevicesHook) and module._hf_hook.offload ) + + +# introduced in accelerate v1.1.0 +@check_accelerate(fallback=contextlib.nullcontext()) +@contextlib.contextmanager +def align_module_device( + module: torch.nn.Module, execution_device: Optional[torch.device] = None +): + """ + Context manager that moves a module's parameters to the specified execution device. + + Args: + module (`torch.nn.Module`): + Module with parameters to align. + execution_device (`torch.device`, *optional*): + If provided, overrides the module's execution device within the context. + Otherwise, use hook execution device or pass + """ + if has_offloaded_params(module): + if execution_device is not None: + original_device = module._hf_hook.execution_device + module._hf_hook.execution_device = execution_device + + try: + module._hf_hook.pre_forward(module) + yield + finally: + module._hf_hook.post_forward(module, None) + if execution_device is not None: + module._hf_hook.execution_device = original_device + + elif execution_device is not None: + devices = { + name: param.device for name, param in module.named_parameters(recurse=False) + } + try: + for name in devices: + set_module_tensor_to_device(module, name, execution_device) + yield + finally: + for name, device in devices.items(): + set_module_tensor_to_device(module, name, device) + + else: + yield diff --git a/tests/test_quantization/lifecycle/test_initialize.py b/tests/test_quantization/lifecycle/test_initialize.py index 987b2ae2..8252b545 100644 --- a/tests/test_quantization/lifecycle/test_initialize.py +++ b/tests/test_quantization/lifecycle/test_initialize.py @@ -19,12 +19,18 @@ ) from compressed_tensors.quantization.quant_args import QuantizationArgs from compressed_tensors.quantization.quant_config import QuantizationStatus +from tests.testing_utils import requires_accelerate from torch.nn import Linear NUM_BITS = 8 +@pytest.fixture +def layer(): + return Linear(4, 4) + + @pytest.mark.parametrize( "weights,input_activations", [ @@ -43,14 +49,13 @@ ], ) def test_initialize_module_for_quantization( - create_quantization_scheme, weights, input_activations + create_quantization_scheme, weights, input_activations, layer ): quantization_scheme = create_quantization_scheme( targets=["*"], weights=weights, input_activations=input_activations, ) - layer = Linear(4, 4) assert not hasattr(layer, "quantization_scheme") assert not hasattr(layer, "quantization_status") @@ -77,3 +82,37 @@ def test_initialize_module_for_quantization( assert hasattr(layer, "quantization_status") assert layer.quantization_status == QuantizationStatus.INITIALIZED + + +@requires_accelerate() +@pytest.mark.parametrize( + "weights,input_activations", + [ + ( + QuantizationArgs(num_bits=NUM_BITS, symmetric=True), + None, + ), + ( + None, + QuantizationArgs(num_bits=NUM_BITS, symmetric=True), + ), + ( + QuantizationArgs(num_bits=NUM_BITS, symmetric=True), + QuantizationArgs(num_bits=NUM_BITS, symmetric=True), + ), + ], +) +def test_initialize_module_for_quantization_offloaded( + create_quantization_scheme, weights, input_activations +): + from accelerate.hooks import attach_align_device_hook + + layer = Linear(4, 4) + attach_align_device_hook(layer, offload=True) + + test_initialize_module_for_quantization( + create_quantization_scheme, + weights, + input_activations, + layer, + ) diff --git a/tests/test_utils/test_offload.py b/tests/test_utils/test_offload.py new file mode 100644 index 00000000..c127dd98 --- /dev/null +++ b/tests/test_utils/test_offload.py @@ -0,0 +1,163 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +from compressed_tensors.utils import ( + align_module_device, + delete_offload_parameter, + disable_hf_hook, + has_offloaded_params, + register_offload_parameter, + update_offload_data, +) +from tests.testing_utils import requires_accelerate + + +class ExampleModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.a = torch.nn.Parameter(torch.tensor(0).float()) + self.b = torch.nn.Parameter(torch.tensor(0).float()) + + def forward(self, x): + return x * self.a + self.b + + +@requires_accelerate() +def test_has_offloaded_params(): + from accelerate.big_modeling import cpu_offload_with_hook + from accelerate.hooks import attach_align_device_hook, remove_hook_from_module + + module = ExampleModule() + assert not has_offloaded_params(module) + + attach_align_device_hook(module, offload=False) + assert not has_offloaded_params(module) + + remove_hook_from_module(module) + module, _ = cpu_offload_with_hook(module) + assert not has_offloaded_params(module) + + remove_hook_from_module(module) + attach_align_device_hook(module, offload=True, weights_map=module.state_dict()) + assert has_offloaded_params(module) + + +@requires_accelerate() +def test_register_offload_parameter(): + from accelerate.hooks import attach_align_device_hook + + module = ExampleModule() + parameter = torch.nn.Parameter(torch.tensor(1.0)) + + # register a param prior to offloading + register_offload_parameter(module, "c", parameter) + assert hasattr(module, "c") and module.c == parameter + + # offloading, check that added param was offloaded + attach_align_device_hook(module, offload=True, weights_map=module.state_dict()) + assert "c" in module._hf_hook.weights_map + + # register a param after offloading, check that added param was offloaded + register_offload_parameter(module, "d", parameter) + assert hasattr(module, "d") and module.d.device == torch.device("meta") + assert "d" in module._hf_hook.weights_map + + # added parameters can be onloaded and offloaded + with align_module_device(module, execution_device="cpu"): + assert module.c.device == torch.device("cpu") + assert module.d.device == torch.device("cpu") + assert module.c.device == torch.device("meta") + assert module.d.device == torch.device("meta") + + +@requires_accelerate() +def test_update_offload_data(): + from accelerate.hooks import attach_align_device_hook + + module = ExampleModule() + param_a = torch.nn.Parameter(torch.tensor(1.0)) + param_b = torch.nn.Parameter(torch.tensor(2.0)) + + # can update modules which are not offloaded + update_offload_data(module, "a", param_a) + assert module.a == param_a + + # can update modules which are offloaded + attach_align_device_hook(module, offload=True, weights_map=module.state_dict()) + update_offload_data(module, "b", param_b) + assert module.b.device == torch.device("meta") + assert module._hf_hook.weights_map["b"] == param_b.data + + # data persists across onloading + with align_module_device(module, execution_device="cpu"): + assert module.a == param_a + assert module.b == param_b + assert module._hf_hook.weights_map["a"] == param_a.data + assert module._hf_hook.weights_map["b"] == param_b.data + + # data persists across offloading + assert module.a.device == torch.device("meta") + assert module.b.device == torch.device("meta") + assert module._hf_hook.weights_map["a"] == param_a.data + assert module._hf_hook.weights_map["b"] == param_b.data + + +@requires_accelerate() +def test_delete_offload_parameter(): + from accelerate.hooks import attach_align_device_hook + + module = ExampleModule() + param_c = torch.nn.Parameter(torch.tensor(1.0)) + param_d = torch.nn.Parameter(torch.tensor(2.0)) + register_offload_parameter(module, "c", param_c) + register_offload_parameter(module, "d", param_d) + + # parameters are deleted + delete_offload_parameter(module, "a") + delete_offload_parameter(module, "c") + assert not hasattr(module, "a") + assert hasattr(module, "b") + assert not hasattr(module, "c") + assert hasattr(module, "d") + + # parameters and their offload are deleted + attach_align_device_hook(module, offload=True, weights_map=module.state_dict()) + delete_offload_parameter(module, "b") + delete_offload_parameter(module, "d") + assert not hasattr(module, "a") + assert not hasattr(module, "b") + assert not hasattr(module, "c") + assert not hasattr(module, "d") + assert "a" not in module._hf_hook.weights_map + assert "b" not in module._hf_hook.weights_map + assert "c" not in module._hf_hook.weights_map + assert "d" not in module._hf_hook.weights_map + + +@requires_accelerate() +def test_disable_hf_hook(): + from accelerate.hooks import attach_align_device_hook + + module = ExampleModule() + + def custom_forward(): + pass + + attach_align_device_hook(module, offload=True, weights_map=module.state_dict()) + with disable_hf_hook(module): + assert not hasattr(module, "_hf_hook") + module.forward = custom_forward + + assert hasattr(module, "_hf_hook") + assert module._old_forward == custom_forward From 9af736f533534f0590bc312e0981e8e1a1ab97cd Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 6 Dec 2024 07:15:04 +0000 Subject: [PATCH 14/28] rename --- src/compressed_tensors/utils/offload.py | 8 ++++---- tests/test_utils/test_offload.py | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/compressed_tensors/utils/offload.py b/src/compressed_tensors/utils/offload.py index 7b7cc864..ee59933b 100644 --- a/src/compressed_tensors/utils/offload.py +++ b/src/compressed_tensors/utils/offload.py @@ -43,7 +43,7 @@ "update_prefix_dict", "update_parameter_data", "register_offload_parameter", - "update_offload_data", + "update_offload_parameter", "delete_offload_parameter", "has_offloaded_params", "disable_hf_hook", @@ -131,7 +131,7 @@ def update_parameter_data( :param new_param_data: tensor to update parameter with :param param_name: name of module parameter to update """ - update_offload_data(module, param_name, new_param_data) + update_offload_parameter(module, param_name, new_param_data) """ Candidates for Upstreaming """ @@ -151,7 +151,7 @@ def register_offload_parameter( """ if has_offloaded_params(module): module.register_parameter(name, parameter) - update_offload_data(module, name, parameter.data) + update_offload_parameter(module, name, parameter.data) set_module_tensor_to_device(module, name, "meta") else: device = next(module.parameters()).device @@ -159,7 +159,7 @@ def register_offload_parameter( module.register_parameter(name, parameter) -def update_offload_data( +def update_offload_parameter( module: torch.nn.Module, name: str, data: Optional[torch.Tensor], diff --git a/tests/test_utils/test_offload.py b/tests/test_utils/test_offload.py index c127dd98..80cb55f7 100644 --- a/tests/test_utils/test_offload.py +++ b/tests/test_utils/test_offload.py @@ -18,7 +18,7 @@ disable_hf_hook, has_offloaded_params, register_offload_parameter, - update_offload_data, + update_offload_parameter, ) from tests.testing_utils import requires_accelerate @@ -82,7 +82,7 @@ def test_register_offload_parameter(): @requires_accelerate() -def test_update_offload_data(): +def test_update_offload_parameter(): from accelerate.hooks import attach_align_device_hook module = ExampleModule() @@ -90,12 +90,12 @@ def test_update_offload_data(): param_b = torch.nn.Parameter(torch.tensor(2.0)) # can update modules which are not offloaded - update_offload_data(module, "a", param_a) + update_offload_parameter(module, "a", param_a) assert module.a == param_a # can update modules which are offloaded attach_align_device_hook(module, offload=True, weights_map=module.state_dict()) - update_offload_data(module, "b", param_b) + update_offload_parameter(module, "b", param_b) assert module.b.device == torch.device("meta") assert module._hf_hook.weights_map["b"] == param_b.data From 35fa1cd17f0883746b8565518350914adfe1fb78 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 6 Dec 2024 07:40:23 +0000 Subject: [PATCH 15/28] implement recursive case --- src/compressed_tensors/utils/offload.py | 20 ++++++++++++++------ tests/test_utils/test_offload.py | 20 ++++++++++++++++++++ 2 files changed, 34 insertions(+), 6 deletions(-) diff --git a/src/compressed_tensors/utils/offload.py b/src/compressed_tensors/utils/offload.py index ee59933b..40f11353 100644 --- a/src/compressed_tensors/utils/offload.py +++ b/src/compressed_tensors/utils/offload.py @@ -249,15 +249,23 @@ def delete_offload_parameter(module: torch.nn.Module, name: str): @check_accelerate(fallback=contextlib.nullcontext()) @contextlib.contextmanager def disable_hf_hook(module: torch.nn.Module, recurse: bool = False): - offloaded = has_offloaded_params(module) - if offloaded: - hook = module._hf_hook - remove_hook_from_module(module, recurse=recurse) + hooks = {} + def collect_hooks(module): + nonlocal hooks + if hasattr(module, "_hf_hook"): + hooks[module] = module._hf_hook + remove_hook_from_module(module) + + for submodule in module.children(): + print(submodule) + collect_hooks(submodule) + + collect_hooks(module) yield - if offloaded: - add_hook_to_module(module, hook) + for submodule, hook in hooks.items(): + add_hook_to_module(submodule, hook) """ Upstreamed Functions """ diff --git a/tests/test_utils/test_offload.py b/tests/test_utils/test_offload.py index 80cb55f7..c779c0d9 100644 --- a/tests/test_utils/test_offload.py +++ b/tests/test_utils/test_offload.py @@ -161,3 +161,23 @@ def custom_forward(): assert hasattr(module, "_hf_hook") assert module._old_forward == custom_forward + + +@requires_accelerate() +def test_disable_hf_hook_model_recurse(): + from accelerate.hooks import attach_align_device_hook + + module0 = ExampleModule() + module1 = ExampleModule() + module2 = ExampleModule() + model = torch.nn.Sequential(module0, torch.nn.Sequential(module1, module2)) + attach_align_device_hook(model, offload=True, weights_map=model.state_dict()) + + with disable_hf_hook(model): + assert not hasattr(module0, "_hf_hook") + assert not hasattr(module1, "_hf_hook") + assert not hasattr(module2, "_hf_hook") + + assert hasattr(module0, "_hf_hook") + assert hasattr(module1, "_hf_hook") + assert hasattr(module2, "_hf_hook") \ No newline at end of file From 38765bd0e77b2cf42b7d85b7c03cd6d468489064 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 6 Dec 2024 07:41:16 +0000 Subject: [PATCH 16/28] remove print --- src/compressed_tensors/utils/offload.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/compressed_tensors/utils/offload.py b/src/compressed_tensors/utils/offload.py index 40f11353..02f2c442 100644 --- a/src/compressed_tensors/utils/offload.py +++ b/src/compressed_tensors/utils/offload.py @@ -257,7 +257,6 @@ def collect_hooks(module): remove_hook_from_module(module) for submodule in module.children(): - print(submodule) collect_hooks(submodule) collect_hooks(module) From 64f4d9850c171c499638d23ee06cb5146e8ac321 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 6 Dec 2024 21:04:48 +0000 Subject: [PATCH 17/28] support OffloadedWeightsLoader --- src/compressed_tensors/utils/offload.py | 7 ++++++- tests/test_utils/test_offload.py | 2 +- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/compressed_tensors/utils/offload.py b/src/compressed_tensors/utils/offload.py index 02f2c442..6c20da08 100644 --- a/src/compressed_tensors/utils/offload.py +++ b/src/compressed_tensors/utils/offload.py @@ -195,7 +195,11 @@ def update_offload_parameter( if key in dataset else next(iter(dataset.values())).device ) - dataset[key] = data.to(device=offload_device) + + if isinstance(dataset, OffloadedWeightsLoader): + dataset.state_dict[key] = data.to(device=offload_device) + else: + dataset[key] = data.to(device=offload_device) elif isinstance(weights_map, dict): offload_device = ( @@ -250,6 +254,7 @@ def delete_offload_parameter(module: torch.nn.Module, name: str): @contextlib.contextmanager def disable_hf_hook(module: torch.nn.Module, recurse: bool = False): hooks = {} + def collect_hooks(module): nonlocal hooks if hasattr(module, "_hf_hook"): diff --git a/tests/test_utils/test_offload.py b/tests/test_utils/test_offload.py index c779c0d9..46fe316e 100644 --- a/tests/test_utils/test_offload.py +++ b/tests/test_utils/test_offload.py @@ -180,4 +180,4 @@ def test_disable_hf_hook_model_recurse(): assert hasattr(module0, "_hf_hook") assert hasattr(module1, "_hf_hook") - assert hasattr(module2, "_hf_hook") \ No newline at end of file + assert hasattr(module2, "_hf_hook") From b8ae38702cc252ee99129f05ca44e9097d015915 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 10 Dec 2024 18:59:59 +0000 Subject: [PATCH 18/28] add lifecycle docstring --- src/compressed_tensors/utils/offload.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/compressed_tensors/utils/offload.py b/src/compressed_tensors/utils/offload.py index 6c20da08..f99e2b91 100644 --- a/src/compressed_tensors/utils/offload.py +++ b/src/compressed_tensors/utils/offload.py @@ -11,6 +11,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +""" +Utilities associated with offloading functionality provided by `accelerate`. + +| ----------------------------------------------------------------------------------------------------- | # noqa: E501 +| Operation | Without offloading support | With offloading support | # noqa: E501 +| --------- | -------------------------------------- | ------------------------------------------------ | # noqa: E501 +| Add | module.register_parameter(name, param) | register_offload_parameter(module, name, param) | # noqa: E501 +| Check | N/A | has_offloaded_params(module) | # noqa: E501 +| Onload | N/A | with align_module_device(module) | # noqa: E501 +| Update | module.name.data.copy_(new_data) | update_offload_parameter(module, name, new_data) | # noqa: E501 +| Delete | del module.name | delete_offload_parameter(module, name) | # noqa: E501 +| ----------------------------------------------------------------------------------------------------- | # noqa: E501 +""" import contextlib from functools import wraps From 870095e00d26cbcb02633319d9bfb13abcf6c251 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 16 Dec 2024 14:01:45 -0500 Subject: [PATCH 19/28] implement offload_to_weights_map with recursive definition Signed-off-by: Kyle Sayers --- src/compressed_tensors/utils/offload.py | 112 ++++++++++++++++-------- tests/test_utils/test_offload.py | 51 +++++++++++ 2 files changed, 127 insertions(+), 36 deletions(-) diff --git a/src/compressed_tensors/utils/offload.py b/src/compressed_tensors/utils/offload.py index f99e2b91..f6ef4830 100644 --- a/src/compressed_tensors/utils/offload.py +++ b/src/compressed_tensors/utils/offload.py @@ -27,7 +27,7 @@ import contextlib from functools import wraps -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, Union import torch @@ -47,6 +47,12 @@ _has_accelerate = True except ImportError: _has_accelerate = False + AlignDevicesHook = None + add_hook_to_module = None + remove_hook_from_module = None + OffloadedWeightsLoader = None + PrefixedDataset = None + set_module_tensor_to_device = None __all__ = [ @@ -195,41 +201,7 @@ def update_offload_parameter( # update offload dict if has_offloaded_params(module): weights_map = module._hf_hook.weights_map - - # for upstreaming, better to add write capabilities to weight map classes first - if isinstance(weights_map, PrefixedDataset): - dataset = getattr(weights_map, "dataset", None) - if dataset is not None: - prefix = module._hf_hook.weights_map.prefix - key = f"{prefix}{name}" - - offload_device = ( - dataset[key].device - if key in dataset - else next(iter(dataset.values())).device - ) - - if isinstance(dataset, OffloadedWeightsLoader): - dataset.state_dict[key] = data.to(device=offload_device) - else: - dataset[key] = data.to(device=offload_device) - - elif isinstance(weights_map, dict): - offload_device = ( - weights_map[name].device - if name in weights_map - else next(iter(weights_map.values())).device - ) - weights_map[name] = data.to(device=offload_device) - - elif isinstance(weights_map, OffloadedWeightsLoader): - raise NotImplementedError() - - else: - raise NotImplementedError( - "Updating offload data not implemented for weights_map of type " - f"{type(weights_map)}" - ) + offload_to_weights_map(weights_map, name, data) def delete_offload_parameter(module: torch.nn.Module, name: str): @@ -285,6 +257,74 @@ def collect_hooks(module): add_hook_to_module(submodule, hook) +def offload_to_weights_map( + weights_map: Union[PrefixedDataset, dict, OffloadedWeightsLoader], + key: str, + value: torch.Tensor, + default_device: torch.device = torch.device("cpu"), +): + if isinstance(weights_map, PrefixedDataset): + dataset = weights_map.dataset + key = f"{weights_map.prefix}{key}" + offload_to_weights_map(dataset, key, value) + + elif isinstance(weights_map, OffloadedWeightsLoader): + if key not in weights_map.all_keys: + weights_map.all_keys.append(key) + + if len(weights_map.index) <= 0: + offload_to_weights_map(weights_map.state_dict, key, value) + + else: + raise NotImplementedError( + "Updating weights_map with disk offloading is not implemented yet" + ) + # TODO: below may not be correct and has not been tested + # FUTURE: upstream as OffloadedWeightsLoader.__set_item__ + # use_index = "safetensors_file" in next(iter(weights_map.values())) + # if use_index: + # if key not in weights_map: + # weights_map.index[key] = { + # "safetensors_file": ???, + # "weight_name": key, + # "dtype": str(value.dtype) + # } + + # weight_info = weights_map.index[key] + # file_path = weight_info["safetensors_file"] + # with safetensors.create_file(file_path) as file: + # file.write(value) + + # else: + # assert self.save_folder is not None + # weight_file = os.path.join(self.save_folder, f"{key}.dat") + # need_index_update = not os.path.exists(weight_file) + # offload_weight( + # value, + # key, + # weights_map.save_folder, + # weights_map.index + # ) + + # if need_index_update: + # save_offload_index(weights_map.index, weights_map.save_folder) + + elif isinstance(weights_map, dict): + if key in weights_map: + offload_device = weights_map[key].device + else: + tens = next(iter(weights_map.values()), None) + offload_device = tens.device if tens is not None else default_device + + weights_map[key] = value.to(device=offload_device) + + else: + raise NotImplementedError( + "Updating offload data not implemented for weights_map of type " + f"{type(weights_map)}" + ) + + """ Upstreamed Functions """ diff --git a/tests/test_utils/test_offload.py b/tests/test_utils/test_offload.py index 46fe316e..befeeb84 100644 --- a/tests/test_utils/test_offload.py +++ b/tests/test_utils/test_offload.py @@ -20,6 +20,7 @@ register_offload_parameter, update_offload_parameter, ) +from compressed_tensors.utils.offload import offload_to_weights_map from tests.testing_utils import requires_accelerate @@ -181,3 +182,53 @@ def test_disable_hf_hook_model_recurse(): assert hasattr(module0, "_hf_hook") assert hasattr(module1, "_hf_hook") assert hasattr(module2, "_hf_hook") + + +@requires_accelerate() +def test_offload_to_weights_map(): + from accelerate.utils import OffloadedWeightsLoader, PrefixedDataset + + name = "name" + old_value = torch.tensor(0.0) + new_value = torch.tensor(1.0) + prefix = "prefix" + + # Dict empty + weights_map = {} + offload_to_weights_map(weights_map, name, new_value) + assert weights_map[name] == new_value + + # Dict populated + weights_map = {name: old_value} + offload_to_weights_map(weights_map, name, new_value) + assert weights_map[name] == new_value + + # OffloadedWeightsLoader[Dict] empty + weights_map = OffloadedWeightsLoader({}) + offload_to_weights_map(weights_map, name, new_value) + assert weights_map[name] == new_value + + # OffloadedWeightsLoader[Dict] populated + weights_map = OffloadedWeightsLoader({name: old_value}) + offload_to_weights_map(weights_map, name, new_value) + assert weights_map[name] == new_value + + # PrefixedDataset[Dict] empty + weights_map = PrefixedDataset({}, prefix) + offload_to_weights_map(weights_map, name, new_value) + assert weights_map[name] == new_value + + # PrefixedDataset[Dict] populated + weights_map = PrefixedDataset({name: old_value}, prefix) + offload_to_weights_map(weights_map, name, new_value) + assert weights_map[name] == new_value + + # PrefixedDataset[OffloadedWeightsLoader[Dict]] empty + weights_map = PrefixedDataset(OffloadedWeightsLoader({}), prefix) + offload_to_weights_map(weights_map, name, new_value) + assert weights_map[name] == new_value + + # PrefixedDataset[OffloadedWeightsLoader[Dict]] populated + weights_map = PrefixedDataset(OffloadedWeightsLoader({name: old_value}), prefix) + offload_to_weights_map(weights_map, name, new_value) + assert weights_map[name] == new_value From 77411ca653aba0acde18dda46a1ff3d09d44f25c Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 16 Dec 2024 14:12:34 -0500 Subject: [PATCH 20/28] add docstring Signed-off-by: Kyle Sayers --- src/compressed_tensors/utils/offload.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/compressed_tensors/utils/offload.py b/src/compressed_tensors/utils/offload.py index f6ef4830..7110a250 100644 --- a/src/compressed_tensors/utils/offload.py +++ b/src/compressed_tensors/utils/offload.py @@ -263,6 +263,17 @@ def offload_to_weights_map( value: torch.Tensor, default_device: torch.device = torch.device("cpu"), ): + """ + Helper function which implements offloaded item assignment for PrefixedDataset, + OffloadedWeightsLoader, and Dict types. + + :param weights_map: weight map to be updated with offload information + :param key: key used to identify weight location + :param value: weight being offloaded + :param default_device: in the event that the weights_map does already contain + offloaded weights or use disk offloading, the weight will be offloaded to the + `default_device` + """ if isinstance(weights_map, PrefixedDataset): dataset = weights_map.dataset key = f"{weights_map.prefix}{key}" From a5b1792189656f192d75dfc0b6ef242baf43c9a8 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 16 Dec 2024 14:13:58 -0500 Subject: [PATCH 21/28] fix type hint --- src/compressed_tensors/utils/offload.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/compressed_tensors/utils/offload.py b/src/compressed_tensors/utils/offload.py index 7110a250..7b312552 100644 --- a/src/compressed_tensors/utils/offload.py +++ b/src/compressed_tensors/utils/offload.py @@ -27,7 +27,7 @@ import contextlib from functools import wraps -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Dict, Optional, Union import torch @@ -258,7 +258,7 @@ def collect_hooks(module): def offload_to_weights_map( - weights_map: Union[PrefixedDataset, dict, OffloadedWeightsLoader], + weights_map: Union[PrefixedDataset, Dict, OffloadedWeightsLoader], key: str, value: torch.Tensor, default_device: torch.device = torch.device("cpu"), From ed9ee4e32e59384d97584864cea956ee2bfc4a46 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 16 Dec 2024 14:15:14 -0500 Subject: [PATCH 22/28] add check_accelerate guard Signed-off-by: Kyle Sayers --- src/compressed_tensors/utils/offload.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/compressed_tensors/utils/offload.py b/src/compressed_tensors/utils/offload.py index 7b312552..b2038c62 100644 --- a/src/compressed_tensors/utils/offload.py +++ b/src/compressed_tensors/utils/offload.py @@ -257,6 +257,7 @@ def collect_hooks(module): add_hook_to_module(submodule, hook) +@check_accelerate(fallback=None) def offload_to_weights_map( weights_map: Union[PrefixedDataset, Dict, OffloadedWeightsLoader], key: str, From 1632cc3e610ffa9e5fae3fb587e4c4a04c7a6076 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 16 Dec 2024 16:01:07 -0500 Subject: [PATCH 23/28] make device used by clearer Signed-off-by: Kyle Sayers --- .../quantization/lifecycle/initialize.py | 4 +++- src/compressed_tensors/utils/offload.py | 7 ++----- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index 9cb99b2e..8dd8fc51 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -140,7 +140,9 @@ def _initialize_scale_zero_point( if quantization_args.dynamic: return - # begin on the same device as other parameters or cpu if offloaded + # begin on the same device as other parameters or cpu if offloaded. + # in the offloaded case, there's no point moving tensors to the execution device + # if they're going to be immediately offloaded by `register_offload_parameter` params_device = next(module.parameters()).device device = "cpu" if has_offloaded_params(module) else params_device diff --git a/src/compressed_tensors/utils/offload.py b/src/compressed_tensors/utils/offload.py index b2038c62..43ec43ce 100644 --- a/src/compressed_tensors/utils/offload.py +++ b/src/compressed_tensors/utils/offload.py @@ -168,14 +168,11 @@ def register_offload_parameter( :param name: name of newly registered parameter :param parameter: parameter being registered """ + module.register_parameter(name, parameter) + if has_offloaded_params(module): - module.register_parameter(name, parameter) update_offload_parameter(module, name, parameter.data) set_module_tensor_to_device(module, name, "meta") - else: - device = next(module.parameters()).device - parameter = parameter.to(device) - module.register_parameter(name, parameter) def update_offload_parameter( From 1c55a1097704cb656d949b9daef30c5e13d5437f Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 16 Dec 2024 20:34:03 -0500 Subject: [PATCH 24/28] update update_prefix_dict Signed-off-by: Kyle Sayers --- src/compressed_tensors/utils/offload.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/compressed_tensors/utils/offload.py b/src/compressed_tensors/utils/offload.py index 43ec43ce..3d92b9ba 100644 --- a/src/compressed_tensors/utils/offload.py +++ b/src/compressed_tensors/utils/offload.py @@ -135,8 +135,9 @@ def update_prefix_dict(module: torch.nn.Module, key: str, data: torch.Tensor): """ if not has_offloaded_params(module): raise ValueError("Prefix dict is only applicable to offloaded modules") - prefix_dict = module._hf_hook.weights_map - prefix_dict.dataset[f"{prefix_dict.prefix}{key}"] = data + + weights_map = module._hf_hook.weights_map + offload_to_weights_map(weights_map, key, data) def update_parameter_data( From 91776504d9e0084632f1785821e5446a667b4b39 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 17 Dec 2024 19:08:25 +0000 Subject: [PATCH 25/28] reuse fixture Signed-off-by: Kyle Sayers --- tests/test_quantization/lifecycle/test_initialize.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/test_quantization/lifecycle/test_initialize.py b/tests/test_quantization/lifecycle/test_initialize.py index 8252b545..56df92f3 100644 --- a/tests/test_quantization/lifecycle/test_initialize.py +++ b/tests/test_quantization/lifecycle/test_initialize.py @@ -103,11 +103,10 @@ def test_initialize_module_for_quantization( ], ) def test_initialize_module_for_quantization_offloaded( - create_quantization_scheme, weights, input_activations + create_quantization_scheme, weights, input_activations, layer ): from accelerate.hooks import attach_align_device_hook - layer = Linear(4, 4) attach_align_device_hook(layer, offload=True) test_initialize_module_for_quantization( From df3e1860956455ec478337a90f7a8ce8a0450954 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 19 Dec 2024 20:51:14 +0000 Subject: [PATCH 26/28] use apply rather than recursion Signed-off-by: Kyle Sayers --- src/compressed_tensors/utils/offload.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/compressed_tensors/utils/offload.py b/src/compressed_tensors/utils/offload.py index 3d92b9ba..aec8eeb8 100644 --- a/src/compressed_tensors/utils/offload.py +++ b/src/compressed_tensors/utils/offload.py @@ -235,7 +235,7 @@ def delete_offload_parameter(module: torch.nn.Module, name: str): @check_accelerate(fallback=contextlib.nullcontext()) @contextlib.contextmanager -def disable_hf_hook(module: torch.nn.Module, recurse: bool = False): +def disable_hf_hook(module: torch.nn.Module): hooks = {} def collect_hooks(module): @@ -244,10 +244,7 @@ def collect_hooks(module): hooks[module] = module._hf_hook remove_hook_from_module(module) - for submodule in module.children(): - collect_hooks(submodule) - - collect_hooks(module) + module.apply(collect_hooks) yield From 665c98777b3e433680e8d349d6480de0ebf10d4f Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 19 Dec 2024 20:59:49 +0000 Subject: [PATCH 27/28] clearer delete_from_weights_map --- src/compressed_tensors/utils/offload.py | 77 ++++++++++--------------- 1 file changed, 30 insertions(+), 47 deletions(-) diff --git a/src/compressed_tensors/utils/offload.py b/src/compressed_tensors/utils/offload.py index aec8eeb8..620292cf 100644 --- a/src/compressed_tensors/utils/offload.py +++ b/src/compressed_tensors/utils/offload.py @@ -213,24 +213,7 @@ def delete_offload_parameter(module: torch.nn.Module, name: str): if has_offloaded_params(module): weights_map = module._hf_hook.weights_map - - # for upstreaming, better to add write capabilities to weight map classes first - if isinstance(weights_map, PrefixedDataset): - dataset = weights_map.dataset - prefix = weights_map.prefix - if dataset is not None: - del dataset[f"{prefix}{name}"] - - elif isinstance(weights_map, dict): - del weights_map[name] - - elif isinstance(weights_map, OffloadedWeightsLoader): - raise NotImplementedError() - - elif weights_map is not None: - raise NotImplementedError( - f"Cannot delete parameter from weights_map of type {type(weights_map)}" - ) + delete_from_weights_map(weights_map, name) @check_accelerate(fallback=contextlib.nullcontext()) @@ -286,35 +269,6 @@ def offload_to_weights_map( raise NotImplementedError( "Updating weights_map with disk offloading is not implemented yet" ) - # TODO: below may not be correct and has not been tested - # FUTURE: upstream as OffloadedWeightsLoader.__set_item__ - # use_index = "safetensors_file" in next(iter(weights_map.values())) - # if use_index: - # if key not in weights_map: - # weights_map.index[key] = { - # "safetensors_file": ???, - # "weight_name": key, - # "dtype": str(value.dtype) - # } - - # weight_info = weights_map.index[key] - # file_path = weight_info["safetensors_file"] - # with safetensors.create_file(file_path) as file: - # file.write(value) - - # else: - # assert self.save_folder is not None - # weight_file = os.path.join(self.save_folder, f"{key}.dat") - # need_index_update = not os.path.exists(weight_file) - # offload_weight( - # value, - # key, - # weights_map.save_folder, - # weights_map.index - # ) - - # if need_index_update: - # save_offload_index(weights_map.index, weights_map.save_folder) elif isinstance(weights_map, dict): if key in weights_map: @@ -332,6 +286,35 @@ def offload_to_weights_map( ) +@check_accelerate(fallback=None) +def delete_from_weights_map( + weights_map: Union[PrefixedDataset, Dict, OffloadedWeightsLoader], + key: str, +): + if isinstance(weights_map, PrefixedDataset): + dataset = weights_map.dataset + key = f"{weights_map.prefix}{key}" + delete_from_weights_map(dataset, key) + + elif isinstance(weights_map, OffloadedWeightsLoader): + if len(weights_map.index) <= 0: + delete_from_weights_map(weights_map.state_dict, key) + + else: + raise NotImplementedError( + "Delete from weights_map with disk offloading is not implemented yet" + ) + + elif isinstance(weights_map, dict): + del weights_map[key] + + else: + raise NotImplementedError( + "Updating offload data not implemented for weights_map of type " + f"{type(weights_map)}" + ) + + """ Upstreamed Functions """ From 0f4760a79451b6cb94142e6d2b2b3c454307e5f7 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 19 Dec 2024 17:09:16 -0500 Subject: [PATCH 28/28] add offload_device argument (#228) Signed-off-by: Kyle Sayers --- src/compressed_tensors/utils/offload.py | 52 +++++++++++++++++-------- tests/test_utils/test_offload.py | 31 ++++++++++++--- 2 files changed, 62 insertions(+), 21 deletions(-) diff --git a/src/compressed_tensors/utils/offload.py b/src/compressed_tensors/utils/offload.py index 620292cf..b3c77c58 100644 --- a/src/compressed_tensors/utils/offload.py +++ b/src/compressed_tensors/utils/offload.py @@ -27,7 +27,7 @@ import contextlib from functools import wraps -from typing import Any, Callable, Dict, Optional, Union +from typing import Any, Callable, Dict, Literal, Optional, Union import torch @@ -161,6 +161,7 @@ def register_offload_parameter( module: torch.nn.Module, name: str, parameter: torch.nn.Parameter, + offload_device: Optional[Union[torch.device, Literal["disk"]]] = None, ): """ Register a parameter to the given module which may be offloaded @@ -168,18 +169,24 @@ def register_offload_parameter( :param module: maybe offloaded module :param name: name of newly registered parameter :param parameter: parameter being registered + :param offload_device: device on which weight will be offloaded to. If None is + provided, then infer device from parameters on module """ + has_onload = any(p.device != torch.device("meta") for p in module.parameters()) module.register_parameter(name, parameter) if has_offloaded_params(module): - update_offload_parameter(module, name, parameter.data) - set_module_tensor_to_device(module, name, "meta") + weights_map = module._hf_hook.weights_map + offload_to_weights_map(weights_map, name, parameter.data, offload_device) + if not has_onload: + set_module_tensor_to_device(module, name, "meta") def update_offload_parameter( module: torch.nn.Module, name: str, data: Optional[torch.Tensor], + offload_device: Optional[Union[torch.device, Literal["disk"]]] = None, ): """ Update the data of an existing parameter and its offload dict. Supports both @@ -188,6 +195,8 @@ def update_offload_parameter( :param module: module containing the parameter to update :param name: name of module parameter to update :param data: tensor to update parameter with + :param offload_device: device on which weight will be offloaded to. If None is + provided, then infer device from parameters on module """ param = getattr(module, name) data = data.to(param.dtype) @@ -199,7 +208,7 @@ def update_offload_parameter( # update offload dict if has_offloaded_params(module): weights_map = module._hf_hook.weights_map - offload_to_weights_map(weights_map, name, data) + offload_to_weights_map(weights_map, name, data, offload_device) def delete_offload_parameter(module: torch.nn.Module, name: str): @@ -240,7 +249,7 @@ def offload_to_weights_map( weights_map: Union[PrefixedDataset, Dict, OffloadedWeightsLoader], key: str, value: torch.Tensor, - default_device: torch.device = torch.device("cpu"), + offload_device: Optional[Union[torch.device, Literal["disk"]]] = None, ): """ Helper function which implements offloaded item assignment for PrefixedDataset, @@ -249,21 +258,23 @@ def offload_to_weights_map( :param weights_map: weight map to be updated with offload information :param key: key used to identify weight location :param value: weight being offloaded - :param default_device: in the event that the weights_map does already contain - offloaded weights or use disk offloading, the weight will be offloaded to the - `default_device` + :param offload_device: device on which weight will be offloaded to. If None is + provided, then infer device from parameters in weights_map """ if isinstance(weights_map, PrefixedDataset): + if offload_device == "disk": + raise ValueError(f"Cannot offload to disk with type {type(weights_map)}") + dataset = weights_map.dataset key = f"{weights_map.prefix}{key}" - offload_to_weights_map(dataset, key, value) + offload_to_weights_map(dataset, key, value, offload_device) elif isinstance(weights_map, OffloadedWeightsLoader): if key not in weights_map.all_keys: weights_map.all_keys.append(key) - if len(weights_map.index) <= 0: - offload_to_weights_map(weights_map.state_dict, key, value) + if len(weights_map.index) <= 0 and offload_device != "disk": + offload_to_weights_map(weights_map.state_dict, key, value, offload_device) else: raise NotImplementedError( @@ -271,11 +282,20 @@ def offload_to_weights_map( ) elif isinstance(weights_map, dict): - if key in weights_map: - offload_device = weights_map[key].device - else: - tens = next(iter(weights_map.values()), None) - offload_device = tens.device if tens is not None else default_device + if offload_device == "disk": + raise ValueError(f"Cannot offload to disk with type {type(weights_map)}") + + # infer offload device + if offload_device is None: + if key in weights_map: + offload_device = weights_map[key].device + else: + tens = next(iter(weights_map.values()), None) + if tens is None: + raise ValueError( + "Cannot infer offload device from empty weights_map" + ) + offload_device = tens.device weights_map[key] = value.to(device=offload_device) diff --git a/tests/test_utils/test_offload.py b/tests/test_utils/test_offload.py index befeeb84..1002a4f5 100644 --- a/tests/test_utils/test_offload.py +++ b/tests/test_utils/test_offload.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import pytest import torch from compressed_tensors.utils import ( align_module_device, @@ -72,7 +73,7 @@ def test_register_offload_parameter(): # register a param after offloading, check that added param was offloaded register_offload_parameter(module, "d", parameter) assert hasattr(module, "d") and module.d.device == torch.device("meta") - assert "d" in module._hf_hook.weights_map + assert module._hf_hook.weights_map["d"].device == torch.device("cpu") # added parameters can be onloaded and offloaded with align_module_device(module, execution_device="cpu"): @@ -81,6 +82,18 @@ def test_register_offload_parameter(): assert module.c.device == torch.device("meta") assert module.d.device == torch.device("meta") + # parameters can be added during onload + with align_module_device(module, execution_device="cpu"): + register_offload_parameter(module, "e", parameter) + assert module.e.device == torch.device("cpu") + + # parameters can be added before onload and with explicit offload + register_offload_parameter(module, "f", parameter, offload_device="cpu") + assert module._hf_hook.weights_map["f"].device == torch.device("cpu") + with align_module_device(module, execution_device="cpu"): + assert module.f.device == torch.device("cpu") + assert module._hf_hook.weights_map["f"].device == torch.device("cpu") + @requires_accelerate() def test_update_offload_parameter(): @@ -195,7 +208,9 @@ def test_offload_to_weights_map(): # Dict empty weights_map = {} - offload_to_weights_map(weights_map, name, new_value) + with pytest.raises(ValueError): + offload_to_weights_map(weights_map, name, new_value) + offload_to_weights_map(weights_map, name, new_value, offload_device="cpu") assert weights_map[name] == new_value # Dict populated @@ -205,7 +220,9 @@ def test_offload_to_weights_map(): # OffloadedWeightsLoader[Dict] empty weights_map = OffloadedWeightsLoader({}) - offload_to_weights_map(weights_map, name, new_value) + with pytest.raises(ValueError): + offload_to_weights_map(weights_map, name, new_value) + offload_to_weights_map(weights_map, name, new_value, offload_device="cpu") assert weights_map[name] == new_value # OffloadedWeightsLoader[Dict] populated @@ -215,7 +232,9 @@ def test_offload_to_weights_map(): # PrefixedDataset[Dict] empty weights_map = PrefixedDataset({}, prefix) - offload_to_weights_map(weights_map, name, new_value) + with pytest.raises(ValueError): + offload_to_weights_map(weights_map, name, new_value) + offload_to_weights_map(weights_map, name, new_value, offload_device="cpu") assert weights_map[name] == new_value # PrefixedDataset[Dict] populated @@ -225,7 +244,9 @@ def test_offload_to_weights_map(): # PrefixedDataset[OffloadedWeightsLoader[Dict]] empty weights_map = PrefixedDataset(OffloadedWeightsLoader({}), prefix) - offload_to_weights_map(weights_map, name, new_value) + with pytest.raises(ValueError): + offload_to_weights_map(weights_map, name, new_value) + offload_to_weights_map(weights_map, name, new_value, offload_device="cpu") assert weights_map[name] == new_value # PrefixedDataset[OffloadedWeightsLoader[Dict]] populated