From eb1c705cd0eb3a644f20fa0ab8c821aa89e0fbec Mon Sep 17 00:00:00 2001 From: Kyle Sayers <kylesayrs@gmail.com> Date: Fri, 4 Oct 2024 13:07:58 -0400 Subject: [PATCH 01/12] remove function (#156) Co-authored-by: Kyle Sayers <kyle@neuralmagic.com> --- .../quantization/lifecycle/helpers.py | 47 ------------------- 1 file changed, 47 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/helpers.py b/src/compressed_tensors/quantization/lifecycle/helpers.py index 497a9921..9d755328 100644 --- a/src/compressed_tensors/quantization/lifecycle/helpers.py +++ b/src/compressed_tensors/quantization/lifecycle/helpers.py @@ -16,62 +16,15 @@ Miscelaneous helpers for the quantization lifecycle """ -from typing import Optional - -import torch from torch.nn import Module __all__ = [ - "update_layer_weight_quant_params", "enable_quantization", "disable_quantization", ] -def update_layer_weight_quant_params( - layer: Module, - weight: Optional[torch.Tensor] = None, - g_idx: Optional[torch.Tensor] = None, - reset_obs: bool = False, -): - """ - Update quantization parameters on layer - - :param layer: input layer - :param weight: weight to update quant params with, defaults to layer weight - :param g_idx: optional mapping from column index to group index - :param reset_obs: reset the observer before calculating quant params, - defaults to False - """ - attached_weight = getattr(layer, "weight", None) - - if weight is None: - weight = attached_weight - scale = getattr(layer, "weight_scale", None) - zero_point = getattr(layer, "weight_zero_point", None) - if g_idx is None: - g_idx = getattr(layer, "weight_g_idx", None) - observer = getattr(layer, "weight_observer", None) - - if weight is None or observer is None or scale is None or zero_point is None: - # scale, zp, or observer not calibratable or weight not available - return - - if reset_obs: - observer.reset() - - if attached_weight is not None: - weight = weight.to(attached_weight.dtype) - - updated_scale, updated_zero_point = observer(weight) - - # update scale and zero point - device = next(layer.parameters()).device - scale.data = updated_scale.to(device) - zero_point.data = updated_zero_point.to(device) - - def enable_quantization(module: Module): module.quantization_enabled = True From c2455b7c5040adc0e7fc6a197144971b2597d607 Mon Sep 17 00:00:00 2001 From: Rahul Tuli <rahul@neuralmagic.com> Date: Fri, 4 Oct 2024 16:07:11 -0400 Subject: [PATCH 02/12] Revert "Ignore Dense sparsity config (#169)" (#181) --- .../compressors/model_compressors/model_compressor.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/compressed_tensors/compressors/model_compressors/model_compressor.py b/src/compressed_tensors/compressors/model_compressors/model_compressor.py index ac15fdaa..6473554d 100644 --- a/src/compressed_tensors/compressors/model_compressors/model_compressor.py +++ b/src/compressed_tensors/compressors/model_compressors/model_compressor.py @@ -242,10 +242,6 @@ def __init__( self.sparsity_compressor = None self.quantization_compressor = None - if sparsity_config and sparsity_config.format == CompressionFormat.dense.value: - # ignore dense sparsity config - self.sparsity_config = None - if sparsity_config is not None: self.sparsity_compressor = BaseCompressor.load_from_registry( sparsity_config.format, config=sparsity_config From d6d823cec0d5bde2151b98ba15b5771c92fb390e Mon Sep 17 00:00:00 2001 From: Kyle Sayers <kylesayrs@gmail.com> Date: Mon, 7 Oct 2024 12:55:10 -0400 Subject: [PATCH 03/12] Workaround HF Quantizer `apply_quantization_config` misuse (#180) * workaround hf quantizer apply none * Add usage comment --- src/compressed_tensors/quantization/lifecycle/apply.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index a66dba92..09281528 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -107,8 +107,8 @@ def load_pretrained_quantization(model: Module, model_name_or_path: str): def apply_quantization_config( - model: Module, config: QuantizationConfig, run_compressed: bool = False -) -> Dict: + model: Module, config: Union[QuantizationConfig, None], run_compressed: bool = False +) -> OrderedDict: """ Initializes the model for quantization in-place based on the given config @@ -117,6 +117,10 @@ def apply_quantization_config( :param run_compressed: Whether the model will be run in compressed mode or decompressed fully on load """ + # Workaround for when HF Quantizer passes None, see PR #180 + if config is None: + return OrderedDict() + # remove reference to the original `config` # argument. This function can mutate it, and we'd # like to keep the original `config` as it is. From b876a600cf2264e38e505cb8853b89d54110d027 Mon Sep 17 00:00:00 2001 From: dhuangnm <74931910+dhuangnm@users.noreply.github.com> Date: Tue, 8 Oct 2024 10:41:24 -0400 Subject: [PATCH 04/12] bump up version to 0.7.0 (#186) Co-authored-by: dhuangnm <dhuang@MacBook-Pro-2.local> --- src/compressed_tensors/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/compressed_tensors/version.py b/src/compressed_tensors/version.py index 17e7b592..1d241270 100644 --- a/src/compressed_tensors/version.py +++ b/src/compressed_tensors/version.py @@ -17,7 +17,7 @@ """ -version_base = "0.6.0" +version_base = "0.7.0" is_release = True # change to True to set the generated version as a release version From b2abe724a5b97821b90fdaa602eebdb1fa8a5eec Mon Sep 17 00:00:00 2001 From: Dipika Sikka <dipikasikka1@gmail.com> Date: Fri, 11 Oct 2024 14:14:53 -0400 Subject: [PATCH 05/12] [Observer Restructure]: Remove MemoryLess Observer; use helper function for dynamic quantization (#187) * remove memoryless observer; use helper function for dynamic quantization * update init * clean-up * update test case * fix arg * validation + update name * update preset schemes; swap condition check --- .../quantization/lifecycle/forward.py | 10 ++-- .../quantization/lifecycle/initialize.py | 13 +++-- .../quantization/observers/__init__.py | 1 - .../quantization/observers/helpers.py | 42 +++++++++++++- .../quantization/observers/memoryless.py | 56 ------------------- .../quantization/quant_args.py | 32 +++++++++-- .../quantization/quant_scheme.py | 3 + .../lifecycle/test_dynamic_lifecycle.py | 2 +- 8 files changed, 85 insertions(+), 74 deletions(-) delete mode 100644 src/compressed_tensors/quantization/observers/memoryless.py diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index 4dbe4a85..d685c0c0 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -18,7 +18,10 @@ import torch from compressed_tensors.quantization.cache import QuantizedKVParameterCache -from compressed_tensors.quantization.observers.helpers import calculate_range +from compressed_tensors.quantization.observers.helpers import ( + calculate_range, + compute_dynamic_scales_and_zp, +) from compressed_tensors.quantization.quant_args import ( QuantizationArgs, QuantizationStrategy, @@ -376,9 +379,8 @@ def maybe_calibrate_or_quantize( g_idx = getattr(module, "weight_g_idx", None) if args.dynamic: - # dynamic quantization - get scale and zero point directly from observer - observer = getattr(module, f"{base_name}_observer") - scale, zero_point = observer(value, g_idx=g_idx) + # dynamic quantization - no need to invoke observer + scale, zero_point = compute_dynamic_scales_and_zp(value=value, args=args) else: # static quantization - get previous scale and zero point from layer scale = getattr(module, f"{base_name}_scale") diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index 49e7b1a9..9b98da33 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -153,12 +153,16 @@ def _initialize_scale_zero_point_observer( weight_shape: Optional[torch.Size] = None, force_zero_point: bool = True, ): + # initialize observer module and attach as submodule observer = quantization_args.get_observer() - module.register_module(f"{base_name}_observer", observer) + # no need to register an observer for dynamic quantization + if observer: + module.register_module(f"{base_name}_observer", observer) + # no need to register a scale and zero point for a dynamic quantization if quantization_args.dynamic: - return # no need to register a scale and zero point for a dynamic observer + return device = next(module.parameters()).device if is_module_offloaded(module): @@ -173,10 +177,7 @@ def _initialize_scale_zero_point_observer( expected_shape = (weight_shape[0], 1) elif quantization_args.strategy == QuantizationStrategy.GROUP: num_groups = weight_shape[1] // quantization_args.group_size - expected_shape = ( - weight_shape[0], - max(num_groups, 1) - ) + expected_shape = (weight_shape[0], max(num_groups, 1)) scale_dtype = module.weight.dtype if scale_dtype not in [torch.float16, torch.bfloat16, torch.float32]: diff --git a/src/compressed_tensors/quantization/observers/__init__.py b/src/compressed_tensors/quantization/observers/__init__.py index 2cb15c96..05b6b367 100644 --- a/src/compressed_tensors/quantization/observers/__init__.py +++ b/src/compressed_tensors/quantization/observers/__init__.py @@ -17,6 +17,5 @@ from .helpers import * from .base import * -from .memoryless import * from .min_max import * from .mse import * diff --git a/src/compressed_tensors/quantization/observers/helpers.py b/src/compressed_tensors/quantization/observers/helpers.py index 13c05991..875a05b3 100644 --- a/src/compressed_tensors/quantization/observers/helpers.py +++ b/src/compressed_tensors/quantization/observers/helpers.py @@ -13,18 +13,56 @@ # limitations under the License. from collections import Counter -from typing import Tuple +from typing import Optional, Tuple import torch from compressed_tensors.quantization.quant_args import ( FP8_DTYPE, QuantizationArgs, + QuantizationStrategy, QuantizationType, ) from torch import FloatTensor, IntTensor, Tensor -__all__ = ["calculate_qparams", "get_observer_token_count", "calculate_range"] +__all__ = [ + "calculate_qparams", + "get_observer_token_count", + "calculate_range", + "compute_dynamic_scales_and_zp", +] + + +def compute_dynamic_scales_and_zp(value: Tensor, args: QuantizationArgs): + """ + Returns the computed scales and zero points for dynamic activation + qunatization. + + :param value: tensor to calculate quantization parameters for + :param args: quantization args + :param reduce_dims: optional tuple of dimensions to reduce along, + returned scale and zero point will be shaped (1,) along the + reduced dimensions + :return: tuple of scale and zero point derived from the observed tensor + """ + if args.strategy == QuantizationStrategy.TOKEN: + dim = {1, 2} + reduce_dims = tuple(idx for idx in range(value.ndim) if idx not in dim) + elif args.strategy == QuantizationStrategy.TENSOR: + reduce_dims = None + else: + raise ValueError( + f"One of {QuantizationStrategy.TOKEN} or {QuantizationStrategy.TENSOR} ", + "must be used for dynamic quantization", + ) + + if not reduce_dims: + min_val, max_val = torch.aminmax(value) + else: + min_val = torch.amin(value, dim=reduce_dims, keepdims=True) + max_val = torch.amax(value, dim=reduce_dims, keepdims=True) + + return calculate_qparams(min_val, max_val, args) def get_observer_token_count(module: torch.nn.Module) -> Counter: diff --git a/src/compressed_tensors/quantization/observers/memoryless.py b/src/compressed_tensors/quantization/observers/memoryless.py deleted file mode 100644 index ea8a7a01..00000000 --- a/src/compressed_tensors/quantization/observers/memoryless.py +++ /dev/null @@ -1,56 +0,0 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Any, Optional, Tuple - -import torch -from compressed_tensors.quantization.observers.base import Observer -from compressed_tensors.quantization.observers.helpers import calculate_qparams -from torch import FloatTensor, IntTensor, Tensor - - -__all__ = ["MemorylessObserver"] - - -@Observer.register("memoryless", alias=["dynamic"]) -class MemorylessObserver(Observer): - """ - Implements a quantization observer that sets the scale and - zero point based on the latest observed value without tracking state - """ - - def calculate_qparams( - self, - observed: Tensor, - tensor_id: Optional[Any] = None, - reduce_dims: Optional[Tuple[int]] = None, - ) -> Tuple[FloatTensor, IntTensor]: - """ - Returns the min and max values of observed tensor - - :param observed: observed tensor to calculate quantization parameters for - :param tensor_id: optional id for tensor; not used for memoryless - :param reduce_dims: optional tuple of dimensions to reduce along, - returned scale and zero point will be shaped (1,) along the - reduced dimensions - :return: tuple of scale and zero point derived from the observed tensor - """ - - if not reduce_dims: - min_val, max_val = torch.aminmax(observed) - else: - min_val = torch.amin(observed, dim=reduce_dims, keepdims=True) - max_val = torch.amax(observed, dim=reduce_dims, keepdims=True) - - return calculate_qparams(min_val, max_val, self.quantization_args) diff --git a/src/compressed_tensors/quantization/quant_args.py b/src/compressed_tensors/quantization/quant_args.py index 54805c58..c2fc0b6a 100644 --- a/src/compressed_tensors/quantization/quant_args.py +++ b/src/compressed_tensors/quantization/quant_args.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import warnings from enum import Enum from typing import Any, Dict, Optional, Union @@ -94,7 +95,7 @@ class QuantizationArgs(BaseModel, use_enum_values=True): block_structure: Optional[str] = None dynamic: bool = False actorder: Union[ActivationOrdering, bool, None] = None - observer: str = Field( + observer: Optional[str] = Field( default="minmax", description=( "The class to use to compute the quantization param - " @@ -115,10 +116,10 @@ def get_observer(self): """ from compressed_tensors.quantization.observers.base import Observer + # No observer required for the dynamic case if self.dynamic: - # override defualt observer for dynamic, you never want minmax which - # keeps state across samples for dynamic - self.observer = "memoryless" + self.observer = None + return self.observer return Observer.load_from_registry(self.observer, quantization_args=self) @@ -171,6 +172,8 @@ def validate_model_after(model: "QuantizationArgs") -> Dict[str, Any]: strategy = model.strategy group_size = model.group_size actorder = model.actorder + dynamic = model.dynamic + observer = model.observer # infer strategy if strategy is None: @@ -207,6 +210,27 @@ def validate_model_after(model: "QuantizationArgs") -> Dict[str, Any]: "activation ordering" ) + if dynamic: + if strategy not in ( + QuantizationStrategy.TOKEN, + QuantizationStrategy.TENSOR, + ): + raise ValueError( + f"One of {QuantizationStrategy.TOKEN} or " + f"{QuantizationStrategy.TENSOR} must be used for dynamic ", + "quantization", + ) + if observer is not None: + warnings.warn( + "No observer is used for dynamic quantization, setting to None" + ) + model.observer = None + + # if we have not set an observer and we + # are running static quantization, use minmax + if not observer and not dynamic: + model.observer = "minmax" + # write back modified values model.strategy = strategy return model diff --git a/src/compressed_tensors/quantization/quant_scheme.py b/src/compressed_tensors/quantization/quant_scheme.py index b41eaafb..180d0f26 100644 --- a/src/compressed_tensors/quantization/quant_scheme.py +++ b/src/compressed_tensors/quantization/quant_scheme.py @@ -122,6 +122,7 @@ def is_preset_scheme(name: str) -> bool: strategy=QuantizationStrategy.TOKEN, symmetric=True, dynamic=True, + observer=None, ), ) @@ -164,6 +165,7 @@ def is_preset_scheme(name: str) -> bool: strategy=QuantizationStrategy.TOKEN, symmetric=True, dynamic=True, + observer=None, ), ) @@ -200,6 +202,7 @@ def is_preset_scheme(name: str) -> bool: strategy=QuantizationStrategy.TOKEN, symmetric=True, dynamic=True, + observer=None, ), ) diff --git a/tests/test_quantization/lifecycle/test_dynamic_lifecycle.py b/tests/test_quantization/lifecycle/test_dynamic_lifecycle.py index be228451..45d49370 100644 --- a/tests/test_quantization/lifecycle/test_dynamic_lifecycle.py +++ b/tests/test_quantization/lifecycle/test_dynamic_lifecycle.py @@ -73,7 +73,7 @@ def _test_layer_dynamic_quantization_status( # check inputs always have an observer if quantized but never scale/zp assert not hasattr(module, "input_scale") assert not hasattr(module, "input_zero_point") - assert hasattr(module, "input_observer") == inputs + assert not hasattr(module, "input_observer") # check weights always have scale/zp and observer only if not frozen assert hasattr(module, "weight_scale") == weights From 506cd36829f6fd386a0a2c75ed8a51cc6dcab619 Mon Sep 17 00:00:00 2001 From: dhuangnm <74931910+dhuangnm@users.noreply.github.com> Date: Wed, 16 Oct 2024 15:44:43 -0400 Subject: [PATCH 06/12] bump up to 0.7.1 for patch release (#192) Co-authored-by: dhuangnm <dhuang@MacBook-Pro-2.local> --- src/compressed_tensors/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/compressed_tensors/version.py b/src/compressed_tensors/version.py index 1d241270..d0ba0363 100644 --- a/src/compressed_tensors/version.py +++ b/src/compressed_tensors/version.py @@ -17,7 +17,7 @@ """ -version_base = "0.7.0" +version_base = "0.7.1" is_release = True # change to True to set the generated version as a release version From 955e9068d4bd4cd632cde212117d1bf8c880e2ed Mon Sep 17 00:00:00 2001 From: Dipika Sikka <dipikasikka1@gmail.com> Date: Fri, 18 Oct 2024 11:15:34 -0400 Subject: [PATCH 07/12] [Observer Restructure]: Separate out scale/zp and observer init; separate out calibration from forward pass (#188) * separate out scale/zp and observer init * temporary workaround * separate out calibration from forward pass * Fix typo * fix missing import * fix tests * update all other tests * Fix an incorrectly working test * clean * update * clean-up * fix test case * remove commented code * remove TODOs * remove TODO --- src/compressed_tensors/quantization/cache.py | 1 - .../quantization/lifecycle/calibration.py | 12 +++ .../quantization/lifecycle/forward.py | 83 +++++++++++++------ .../quantization/lifecycle/frozen.py | 10 +-- .../quantization/lifecycle/initialize.py | 28 +++---- .../lifecycle/test_dynamic_lifecycle.py | 1 + .../lifecycle/test_forward.py | 50 +++++++---- .../lifecycle/test_frozen.py | 5 +- .../lifecycle/test_lifecycle.py | 5 +- 9 files changed, 127 insertions(+), 68 deletions(-) diff --git a/src/compressed_tensors/quantization/cache.py b/src/compressed_tensors/quantization/cache.py index cc33a48a..312f1c9d 100644 --- a/src/compressed_tensors/quantization/cache.py +++ b/src/compressed_tensors/quantization/cache.py @@ -28,7 +28,6 @@ class KVCacheScaleType(Enum): class QuantizedKVParameterCache(HFDyanmicCache): - """ Quantized KV cache used in the forward call based on HF's dynamic cache. Quantization strategy (tensor, group, channel) set from Quantization arg's strategy diff --git a/src/compressed_tensors/quantization/lifecycle/calibration.py b/src/compressed_tensors/quantization/lifecycle/calibration.py index d444694d..66dc35a0 100644 --- a/src/compressed_tensors/quantization/lifecycle/calibration.py +++ b/src/compressed_tensors/quantization/lifecycle/calibration.py @@ -53,7 +53,19 @@ def set_module_for_calibration(module: Module, quantize_weights_upfront: bool = if quantize_weights_upfront and module.quantization_scheme.weights is not None: # set weight scale and zero_point up front, calibration data doesn't affect it + if not hasattr(module, "weight_observer"): + from compressed_tensors.quantization.lifecycle.initialize import ( + initialize_observers, + ) + + initialize_observers( + module=module, + base_name="weight", + quantization_args=module.quantization_scheme.weights, + ) + observer = module.weight_observer + g_idx = getattr(module, "weight_g_idx", None) offloaded = is_module_offloaded(module) diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index d685c0c0..eae641f8 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -38,7 +38,8 @@ "dequantize", "fake_quantize", "wrap_module_forward_quantized", - "maybe_calibrate_or_quantize", + "forward_quantize", + "calibrate_activations", ] @@ -276,14 +277,24 @@ def wrapped_forward(self, *args, **kwargs): if scheme.input_activations is not None: # calibrate and (fake) quantize input activations when applicable - input_ = maybe_calibrate_or_quantize( - module, input_, "input", scheme.input_activations - ) + # NOTE: will be moved out of compressed-tensors + if ( + module.quantization_status == QuantizationStatus.CALIBRATION + and not scheme.input_activations.dynamic + ): + calibrate_activations( + module=module, + value=input_, + base_name="input", + quantization_args=scheme.input_activations, + ) + + input_ = forward_quantize(module, input_, "input", scheme.input_activations) if scheme.weights is not None and not compressed: # calibrate and (fake) quantize weights when applicable unquantized_weight = self.weight.data.clone() - self.weight.data = maybe_calibrate_or_quantize( + self.weight.data = forward_quantize( module, self.weight, "weight", scheme.weights ) @@ -296,7 +307,19 @@ def wrapped_forward(self, *args, **kwargs): # calibrate and (fake) quantize output activations when applicable # kv_cache scales updated on model self_attn forward call in # wrap_module_forward_quantized_attn - output = maybe_calibrate_or_quantize( + + if ( + module.quantization_status == QuantizationStatus.CALIBRATION + and not scheme.output_activations.dynamic + ): + calibrate_activations( + module=module, + value=output, + base_name="output", + quantization_args=scheme.ouput_activations, + ) + + output = forward_quantize( module, output, "output", scheme.output_activations ) @@ -356,12 +379,36 @@ def wrapped_forward(self, *args, **kwargs): setattr(module, "forward", bound_wrapped_forward) -def maybe_calibrate_or_quantize( +def calibrate_activations( + module: Module, + value: torch.Tensor, + base_name: str, + quantization_args: QuantizationArgs, +): + # If empty tensor, can't update zp/scale + # Case for MoEs + if value.numel() == 0: + return + # calibration mode - get new quant params from observer + if not hasattr(module, f"{base_name}_observer"): + from compressed_tensors.quantization.lifecycle import initialize_observers + + initialize_observers( + module=module, base_name=base_name, quantization_args=quantization_args + ) + + observer = getattr(module, f"{base_name}_observer") + + updated_scale, updated_zero_point = observer(value) + + # update scale and zero point + update_parameter_data(module, updated_scale, f"{base_name}_scale") + update_parameter_data(module, updated_zero_point, f"{base_name}_zero_point") + + +def forward_quantize( module: Module, value: torch.Tensor, base_name: str, args: "QuantizationArgs" ) -> torch.Tensor: - # don't run quantization if we haven't entered calibration mode - if module.quantization_status == QuantizationStatus.INITIALIZED: - return value # in compressed mode, the weight is already compressed and quantized so we don't # need to run fake quantization @@ -386,22 +433,6 @@ def maybe_calibrate_or_quantize( scale = getattr(module, f"{base_name}_scale") zero_point = getattr(module, f"{base_name}_zero_point", None) - if ( - module.quantization_status == QuantizationStatus.CALIBRATION - and base_name != "weight" - ): - # calibration mode - get new quant params from observer - observer = getattr(module, f"{base_name}_observer") - - updated_scale, updated_zero_point = observer(value, g_idx=g_idx) - - # update scale and zero point - update_parameter_data(module, updated_scale, f"{base_name}_scale") - update_parameter_data(module, updated_zero_point, f"{base_name}_zero_point") - - scale = updated_scale - zero_point = updated_zero_point - return fake_quantize( x=value, scale=scale, diff --git a/src/compressed_tensors/quantization/lifecycle/frozen.py b/src/compressed_tensors/quantization/lifecycle/frozen.py index 66356cb7..e4723431 100644 --- a/src/compressed_tensors/quantization/lifecycle/frozen.py +++ b/src/compressed_tensors/quantization/lifecycle/frozen.py @@ -41,15 +41,11 @@ def freeze_module_quantization(module: Module): return # delete observers from module if not dynamic - if scheme.input_activations and not scheme.input_activations.dynamic: + if hasattr(module, "input_observer") and not scheme.input_activations.dynamic: delattr(module, "input_observer") - if scheme.weights and not scheme.weights.dynamic: + if hasattr(module, "weight_observer") and not scheme.weights.dynamic: delattr(module, "weight_observer") - if ( - scheme.output_activations - and not is_kv_cache_quant_scheme(scheme) - and not scheme.output_activations.dynamic - ): + if hasattr(module, "output_observer") and not scheme.output_activations.dynamic: delattr(module, "output_observer") module.quantization_status = QuantizationStatus.FROZEN diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index 9b98da33..68157cb1 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -34,9 +34,7 @@ from torch.nn import Module, Parameter -__all__ = [ - "initialize_module_for_quantization", -] +__all__ = ["initialize_module_for_quantization", "initialize_observers"] _LOGGER = logging.getLogger(__name__) @@ -74,7 +72,7 @@ def initialize_module_for_quantization( else: if scheme.input_activations is not None: - _initialize_scale_zero_point_observer( + _initialize_scale_zero_point( module, "input", scheme.input_activations, @@ -85,7 +83,7 @@ def initialize_module_for_quantization( weight_shape = None if isinstance(module, torch.nn.Linear): weight_shape = module.weight.shape - _initialize_scale_zero_point_observer( + _initialize_scale_zero_point( module, "weight", scheme.weights, @@ -101,7 +99,7 @@ def initialize_module_for_quantization( if scheme.output_activations is not None: if not is_kv_cache_quant_scheme(scheme): - _initialize_scale_zero_point_observer( + _initialize_scale_zero_point( module, "output", scheme.output_activations ) @@ -146,21 +144,23 @@ def initialize_module_for_quantization( module._hf_hook.weights_map = new_prefix_dict -def _initialize_scale_zero_point_observer( +def initialize_observers( module: Module, base_name: str, quantization_args: QuantizationArgs, - weight_shape: Optional[torch.Size] = None, - force_zero_point: bool = True, ): - # initialize observer module and attach as submodule observer = quantization_args.get_observer() - # no need to register an observer for dynamic quantization - if observer: - module.register_module(f"{base_name}_observer", observer) + module.register_module(f"{base_name}_observer", observer) - # no need to register a scale and zero point for a dynamic quantization + +def _initialize_scale_zero_point( + module: Module, + base_name: str, + quantization_args: QuantizationArgs, + weight_shape: Optional[torch.Size] = None, + force_zero_point: bool = True, +): if quantization_args.dynamic: return diff --git a/tests/test_quantization/lifecycle/test_dynamic_lifecycle.py b/tests/test_quantization/lifecycle/test_dynamic_lifecycle.py index 45d49370..1f88626e 100644 --- a/tests/test_quantization/lifecycle/test_dynamic_lifecycle.py +++ b/tests/test_quantization/lifecycle/test_dynamic_lifecycle.py @@ -23,6 +23,7 @@ def test_apply_tinyllama_dynamic_activations(): + # NOTE: should not calibrate dynamic quant quant_config = get_sample_dynamic_tinyllama_quant_config() model = get_tinyllama_model() diff --git a/tests/test_quantization/lifecycle/test_forward.py b/tests/test_quantization/lifecycle/test_forward.py index b9ee67ff..0730c991 100644 --- a/tests/test_quantization/lifecycle/test_forward.py +++ b/tests/test_quantization/lifecycle/test_forward.py @@ -15,13 +15,18 @@ import pytest import torch +from compressed_tensors.quantization.lifecycle.calibration import ( + set_module_for_calibration, +) from compressed_tensors.quantization.lifecycle.forward import ( + calibrate_activations, dequantize, - maybe_calibrate_or_quantize, + forward_quantize, quantize, wrap_module_forward_quantized, wrap_module_forward_quantized_attn, ) +from compressed_tensors.quantization.lifecycle.frozen import freeze_module_quantization from compressed_tensors.quantization.lifecycle.initialize import ( initialize_module_for_quantization, ) @@ -58,7 +63,7 @@ def test_wrap_module_forward_quantized(create_quantization_scheme): @pytest.mark.parametrize( "quantization_status", ["initialized", "calibration", "frozen"] ) -def test_maybe_calibrate_or_quantize(create_quantization_scheme, quantization_status): +def test_forward_quantize(create_quantization_scheme, quantization_status): num_bits = 8 quantization_scheme = create_quantization_scheme( targets=["*"], @@ -72,26 +77,41 @@ def test_maybe_calibrate_or_quantize(create_quantization_scheme, quantization_st dummy_tensor = torch.randn(8, 4) # (num_tokens, num_features) layer.quantization_status = QuantizationStatus(quantization_status) - initialize_module_for_quantization(layer, quantization_scheme) - # only calibration updates the scale and zero-point if layer.quantization_status == QuantizationStatus.INITIALIZED: - out = maybe_calibrate_or_quantize( - layer, dummy_tensor, "input", quantization_args - ) - assert torch.allclose(out, dummy_tensor) + # Init zp and scales + initialize_module_for_quantization(layer, quantization_scheme) + # init weight observers; update weight scales/zp + set_module_for_calibration(layer) + # call quant/dequant on weights + out = forward_quantize(layer, layer.weight, "weight", quantization_args) + assert torch.allclose(out, layer.weight.data, atol=0.2) elif layer.quantization_status == QuantizationStatus.CALIBRATION: - out = maybe_calibrate_or_quantize( - layer, dummy_tensor, "input", quantization_args + # init zp/scales + initialize_module_for_quantization(layer, quantization_scheme) + # init weight observers; update weight scales/zp + set_module_for_calibration(layer) + # init input observers, update input scales/zp + calibrate_activations( + module=layer, + value=dummy_tensor, + base_name="input", + quantization_args=quantization_args, ) + # call quant/dequant on inputs + out = forward_quantize(layer, dummy_tensor, "input", quantization_args) assert torch.allclose(out, dummy_tensor, atol=0.2) assert layer.input_observer._num_observed_tokens == dummy_tensor.shape[0] elif layer.quantization_status == QuantizationStatus.FROZEN: - # scale and zero points are empty -- cannot quantize - with pytest.raises(Exception): - out = maybe_calibrate_or_quantize( - layer, layer.weight.data, "input", quantization_args - ) + # init weight observers + initialize_module_for_quantization(layer, quantization_scheme) + # init weight observers; update weight scales/zp + set_module_for_calibration(layer) + # remove weight observers and any input observers + freeze_module_quantization(layer) + # call quant/dequant on weights + out = forward_quantize(layer, layer.weight.data, "weight", quantization_args) + assert torch.allclose(out, layer.weight.data, atol=0.2) @pytest.mark.parametrize( diff --git a/tests/test_quantization/lifecycle/test_frozen.py b/tests/test_quantization/lifecycle/test_frozen.py index 056c6089..dddff117 100644 --- a/tests/test_quantization/lifecycle/test_frozen.py +++ b/tests/test_quantization/lifecycle/test_frozen.py @@ -13,6 +13,9 @@ # limitations under the License. +from compressed_tensors.quantization.lifecycle.calibration import ( + set_module_for_calibration, +) from compressed_tensors.quantization.lifecycle.frozen import freeze_module_quantization from compressed_tensors.quantization.lifecycle.initialize import ( initialize_module_for_quantization, @@ -34,9 +37,9 @@ def test_set_module_for_calibration(create_quantization_scheme): initialize_module_for_quantization(layer, quantization_scheme) layer.quantization_status = QuantizationStatus("calibration") + set_module_for_calibration(layer) # should have both input and weight observer after initalizing - assert hasattr(layer, "input_observer") assert hasattr(layer, "weight_observer") # observers should get deleted after freezing diff --git a/tests/test_quantization/lifecycle/test_lifecycle.py b/tests/test_quantization/lifecycle/test_lifecycle.py index 970e4de9..41c56789 100644 --- a/tests/test_quantization/lifecycle/test_lifecycle.py +++ b/tests/test_quantization/lifecycle/test_lifecycle.py @@ -59,15 +59,12 @@ def test_lifecyle(create_quantization_scheme): expected_layer_keys.remove(key) assert len(expected_layer_keys) == 0 - # should have both input and weight observer after initalizing - assert hasattr(layer, "input_observer") - assert hasattr(layer, "weight_observer") - assert hasattr(layer, "quantization_scheme") assert hasattr(layer, "quantization_status") assert layer.quantization_status == QuantizationStatus.INITIALIZED set_module_for_calibration(layer) + assert hasattr(layer, "weight_observer") assert layer.quantization_status == QuantizationStatus.CALIBRATION # do a calibration step From 232e4944b84798bd05fddc18a7752ae2b5d460da Mon Sep 17 00:00:00 2001 From: Alexandre Marques <alexandre@neuralmagic.com> Date: Fri, 18 Oct 2024 17:00:11 -0400 Subject: [PATCH 08/12] Fix device allocation for MSE observer (#190) * make sure all tensors are in the same device * fix initial assignment * Fix testing --- src/compressed_tensors/quantization/observers/mse.py | 6 +++--- tests/test_quantization/test_observers/test_mse.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/compressed_tensors/quantization/observers/mse.py b/src/compressed_tensors/quantization/observers/mse.py index ff122a78..24d80584 100644 --- a/src/compressed_tensors/quantization/observers/mse.py +++ b/src/compressed_tensors/quantization/observers/mse.py @@ -70,9 +70,9 @@ def calculate_mse_min_max( absolute_min_val = torch.amin(observed, dim=reduce_dims, keepdims=True) absolute_max_val = torch.amax(observed, dim=reduce_dims, keepdims=True) - best = torch.full(absolute_min_val.shape, float("inf")) - min_val = torch.ones(absolute_min_val.shape) - max_val = torch.zeros(absolute_max_val.shape) + best = torch.full_like(absolute_min_val, torch.finfo(absolute_min_val.dtype).max) + min_val = torch.ones_like(absolute_min_val) + max_val = torch.zeros_like(absolute_max_val) for i in range(int(self.maxshrink * self.grid)): p = 1 - i / self.grid shrinked_min_val = p * absolute_min_val diff --git a/tests/test_quantization/test_observers/test_mse.py b/tests/test_quantization/test_observers/test_mse.py index 76a317f4..659fc573 100644 --- a/tests/test_quantization/test_observers/test_mse.py +++ b/tests/test_quantization/test_observers/test_mse.py @@ -27,7 +27,7 @@ ], ) def test_mse_observer(symmetric, expected_scale, expected_zero_point): - tensor = torch.tensor([1, 1, 1, 1, 1]) + tensor = torch.tensor([1., 1., 1., 1., 1.]) num_bits = 8 weights = QuantizationArgs(num_bits=num_bits, symmetric=symmetric, observer="mse") From d3216bc8d27e90da19b9c05913895fd83c0ac4a3 Mon Sep 17 00:00:00 2001 From: dhuangnm <74931910+dhuangnm@users.noreply.github.com> Date: Wed, 23 Oct 2024 10:38:28 -0400 Subject: [PATCH 09/12] drop 3.8 and add 3.12 to testing (#196) Co-authored-by: dhuangnm <dhuang@MacBook-Pro-2.local> --- .github/workflows/trigger-all.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/trigger-all.yml b/.github/workflows/trigger-all.yml index 02427999..1e53ff30 100644 --- a/.github/workflows/trigger-all.yml +++ b/.github/workflows/trigger-all.yml @@ -35,6 +35,6 @@ jobs: test_configs: '[{"python":"3.11.4","label":"ubuntu-22.04","timeout":"40"}, {"python":"3.10.12","label":"ubuntu-20.04","timeout":"40"}, {"python":"3.9.17","label":"k8s-a100-solo","timeout":"40"}, - {"python":"3.8.17","label":"k8s-a100-duo","timeout":"40"}]' + {"python":"3.12.6","label":"k8s-a100-duo","timeout":"40"}]' secrets: inherit From d3dea3ffd3f81d9840e99dddf242bdde6f22d6af Mon Sep 17 00:00:00 2001 From: Kyle Sayers <kylesayrs@gmail.com> Date: Wed, 23 Oct 2024 15:18:32 -0400 Subject: [PATCH 10/12] fix test which required accelerate, apply style (#194) --- src/compressed_tensors/quantization/lifecycle/frozen.py | 1 - src/compressed_tensors/quantization/observers/helpers.py | 2 +- src/compressed_tensors/quantization/observers/mse.py | 4 +++- tests/test_quantization/lifecycle/test_apply.py | 8 +------- tests/test_quantization/test_observers/test_mse.py | 2 +- 5 files changed, 6 insertions(+), 11 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/frozen.py b/src/compressed_tensors/quantization/lifecycle/frozen.py index e4723431..4a65482c 100644 --- a/src/compressed_tensors/quantization/lifecycle/frozen.py +++ b/src/compressed_tensors/quantization/lifecycle/frozen.py @@ -14,7 +14,6 @@ from compressed_tensors.quantization.quant_config import QuantizationStatus -from compressed_tensors.quantization.utils import is_kv_cache_quant_scheme from torch.nn import Module diff --git a/src/compressed_tensors/quantization/observers/helpers.py b/src/compressed_tensors/quantization/observers/helpers.py index 875a05b3..ec474303 100644 --- a/src/compressed_tensors/quantization/observers/helpers.py +++ b/src/compressed_tensors/quantization/observers/helpers.py @@ -13,7 +13,7 @@ # limitations under the License. from collections import Counter -from typing import Optional, Tuple +from typing import Tuple import torch from compressed_tensors.quantization.quant_args import ( diff --git a/src/compressed_tensors/quantization/observers/mse.py b/src/compressed_tensors/quantization/observers/mse.py index 24d80584..739e921f 100644 --- a/src/compressed_tensors/quantization/observers/mse.py +++ b/src/compressed_tensors/quantization/observers/mse.py @@ -70,7 +70,9 @@ def calculate_mse_min_max( absolute_min_val = torch.amin(observed, dim=reduce_dims, keepdims=True) absolute_max_val = torch.amax(observed, dim=reduce_dims, keepdims=True) - best = torch.full_like(absolute_min_val, torch.finfo(absolute_min_val.dtype).max) + best = torch.full_like( + absolute_min_val, torch.finfo(absolute_min_val.dtype).max + ) min_val = torch.ones_like(absolute_min_val) max_val = torch.zeros_like(absolute_max_val) for i in range(int(self.maxshrink * self.grid)): diff --git a/tests/test_quantization/lifecycle/test_apply.py b/tests/test_quantization/lifecycle/test_apply.py index 5f0bd093..dcb980f2 100644 --- a/tests/test_quantization/lifecycle/test_apply.py +++ b/tests/test_quantization/lifecycle/test_apply.py @@ -236,14 +236,8 @@ def get_sample_tinyllama_quant_config(status: str = "frozen"): def test_apply_quantization_status(caplog, ignore, should_raise_warning): import logging - from transformers import AutoModelForCausalLM - # load a dense, unquantized tiny llama model - model_name = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T" - model = AutoModelForCausalLM.from_pretrained( - model_name, device_map="cpu", torch_dtype="auto" - ) - + model = get_tinyllama_model() quantization_config_dict = { "quant_method": "sparseml", "format": "pack-quantized", diff --git a/tests/test_quantization/test_observers/test_mse.py b/tests/test_quantization/test_observers/test_mse.py index 659fc573..098551a2 100644 --- a/tests/test_quantization/test_observers/test_mse.py +++ b/tests/test_quantization/test_observers/test_mse.py @@ -27,7 +27,7 @@ ], ) def test_mse_observer(symmetric, expected_scale, expected_zero_point): - tensor = torch.tensor([1., 1., 1., 1., 1.]) + tensor = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0]) num_bits = 8 weights = QuantizationArgs(num_bits=num_bits, symmetric=symmetric, observer="mse") From 07abbf3b9d9a7ec1f497a192241030181c480d66 Mon Sep 17 00:00:00 2001 From: Kyle Sayers <kylesayrs@gmail.com> Date: Wed, 23 Oct 2024 15:25:05 -0400 Subject: [PATCH 11/12] [Bugfix] Move observer and g_idx until after module in onloaded (#195) --- .../quantization/lifecycle/calibration.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/calibration.py b/src/compressed_tensors/quantization/lifecycle/calibration.py index 66dc35a0..c9e51813 100644 --- a/src/compressed_tensors/quantization/lifecycle/calibration.py +++ b/src/compressed_tensors/quantization/lifecycle/calibration.py @@ -64,14 +64,12 @@ def set_module_for_calibration(module: Module, quantize_weights_upfront: bool = quantization_args=module.quantization_scheme.weights, ) - observer = module.weight_observer - - g_idx = getattr(module, "weight_g_idx", None) - offloaded = is_module_offloaded(module) if offloaded: module._hf_hook.pre_forward(module) + observer = module.weight_observer + g_idx = getattr(module, "weight_g_idx", None) scale, zero_point = observer(module.weight, g_idx=g_idx) update_parameter_data(module, scale, "weight_scale") update_parameter_data(module, zero_point, "weight_zero_point") From 13b5c0ba049163d6f7d310867c4ea283a116bcce Mon Sep 17 00:00:00 2001 From: Rahul Tuli <rahul@neuralmagic.com> Date: Thu, 24 Oct 2024 08:29:24 -0400 Subject: [PATCH 12/12] Add sparsity structure enum (#197) --- src/compressed_tensors/config/base.py | 62 ++++++++++++++++++++++++++- tests/test_configs/__init__.py | 13 ++++++ tests/test_configs/test_base.py | 57 ++++++++++++++++++++++++ 3 files changed, 130 insertions(+), 2 deletions(-) create mode 100644 tests/test_configs/__init__.py create mode 100644 tests/test_configs/test_base.py diff --git a/src/compressed_tensors/config/base.py b/src/compressed_tensors/config/base.py index ccc3e649..79a4fcdd 100644 --- a/src/compressed_tensors/config/base.py +++ b/src/compressed_tensors/config/base.py @@ -12,16 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -from enum import Enum +from enum import Enum, unique from typing import List, Optional from compressed_tensors.registry import RegistryMixin from pydantic import BaseModel -__all__ = ["SparsityCompressionConfig", "CompressionFormat"] +__all__ = ["SparsityCompressionConfig", "CompressionFormat", "SparsityStructure"] +@unique class CompressionFormat(Enum): dense = "dense" sparse_bitmask = "sparse-bitmask" @@ -32,6 +33,63 @@ class CompressionFormat(Enum): marlin_24 = "marlin-24" +@unique +class SparsityStructure(Enum): + """ + An enumeration to represent different sparsity structures. + + Attributes + ---------- + TWO_FOUR : str + Represents a 2:4 sparsity structure. + ZERO_ZERO : str + Represents a 0:0 sparsity structure. + UNSTRUCTURED : str + Represents an unstructured sparsity structure. + + Examples + -------- + >>> SparsityStructure('2:4') + <SparsityStructure.TWO_FOUR: '2:4'> + + >>> SparsityStructure('unstructured') + <SparsityStructure.UNSTRUCTURED: 'unstructured'> + + >>> SparsityStructure('2:4') == SparsityStructure.TWO_FOUR + True + + >>> SparsityStructure('UNSTRUCTURED') == SparsityStructure.UNSTRUCTURED + True + + >>> SparsityStructure(None) == SparsityStructure.UNSTRUCTURED + True + + >>> SparsityStructure('invalid') + Traceback (most recent call last): + ... + ValueError: invalid is not a valid SparsityStructure + """ + + TWO_FOUR = "2:4" + UNSTRUCTURED = "unstructured" + ZERO_ZERO = "0:0" + + def __new__(cls, value): + obj = object.__new__(cls) + obj._value_ = value.lower() if value is not None else value + return obj + + @classmethod + def _missing_(cls, value): + # Handle None and case-insensitive values + if value is None: + return cls.UNSTRUCTURED + for member in cls: + if member.value == value.lower(): + return member + raise ValueError(f"{value} is not a valid {cls.__name__}") + + class SparsityCompressionConfig(RegistryMixin, BaseModel): """ Base data class for storing sparsity compression parameters diff --git a/tests/test_configs/__init__.py b/tests/test_configs/__init__.py new file mode 100644 index 00000000..0c44f887 --- /dev/null +++ b/tests/test_configs/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/test_configs/test_base.py b/tests/test_configs/test_base.py new file mode 100644 index 00000000..5334ef38 --- /dev/null +++ b/tests/test_configs/test_base.py @@ -0,0 +1,57 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from compressed_tensors.config import SparsityStructure + + +def test_sparsity_structure_valid_cases(): + assert ( + SparsityStructure("2:4") == SparsityStructure.TWO_FOUR + ), "Failed to match '2:4' with TWO_FOUR" + assert ( + SparsityStructure("unstructured") == SparsityStructure.UNSTRUCTURED + ), "Failed to match 'unstructured' with UNSTRUCTURED" + assert ( + SparsityStructure("UNSTRUCTURED") == SparsityStructure.UNSTRUCTURED + ), "Failed to match 'UNSTRUCTURED' with UNSTRUCTURED" + assert ( + SparsityStructure(None) == SparsityStructure.UNSTRUCTURED + ), "Failed to match None with UNSTRUCTURED" + + +def test_sparsity_structure_invalid_case(): + with pytest.raises(ValueError, match="invalid is not a valid SparsityStructure"): + SparsityStructure("invalid") + + +def test_sparsity_structure_case_insensitivity(): + assert ( + SparsityStructure("2:4") == SparsityStructure.TWO_FOUR + ), "Failed to match '2:4' with TWO_FOUR" + assert ( + SparsityStructure("2:4".upper()) == SparsityStructure.TWO_FOUR + ), "Failed to match '2:4'.upper() with TWO_FOUR" + assert ( + SparsityStructure("unstructured".upper()) == SparsityStructure.UNSTRUCTURED + ), "Failed to match 'unstructured'.upper() with UNSTRUCTURED" + assert ( + SparsityStructure("UNSTRUCTURED".lower()) == SparsityStructure.UNSTRUCTURED + ), "Failed to match 'UNSTRUCTURED'.lower() with UNSTRUCTURED" + + +def test_sparsity_structure_default_case(): + assert ( + SparsityStructure(None) == SparsityStructure.UNSTRUCTURED + ), "Failed to match None with UNSTRUCTURED"