From 2b790565310833d630452b88c530004f142b82b2 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Thu, 31 Oct 2024 10:15:04 -0400 Subject: [PATCH] Observer Restructure: Remove Observers, `calibration`, and applying `frozen` steps from lifecycle (#189) * temporary workaround * separate out calibration from forward pass * fix missing import * fix tests * update all other tests * clean * update * clean-up * fix test case * remove calibration and init observer steps * update * update * clean-up/fix * cleanup * cleanup * remove cache * clean-up * remove frozen * more clean-up * remove observer, cache, and frozen state * update more test cases * fix bit_depth test * fix more tests * clean-up remaining tests * clean-up * dont skip * more clean-up * fix --- .../quantization/__init__.py | 1 - src/compressed_tensors/quantization/cache.py | 200 ---------------- .../quantization/lifecycle/__init__.py | 2 - .../quantization/lifecycle/apply.py | 17 +- .../quantization/lifecycle/calibration.py | 80 ------- .../quantization/lifecycle/forward.py | 120 ++-------- .../quantization/lifecycle/frozen.py | 50 ---- .../quantization/lifecycle/initialize.py | 39 ++-- .../quantization/observers/__init__.py | 21 -- .../quantization/observers/base.py | 213 ------------------ .../quantization/observers/helpers.py | 149 ------------ .../quantization/observers/min_max.py | 104 --------- .../quantization/observers/mse.py | 164 -------------- .../quantization/quant_args.py | 9 +- .../quantization/utils/helpers.py | 135 ++++++++++- tests/conftest.py | 141 ++++++++++++ .../quantized_compressors/test_fp8_quant.py | 22 +- .../quantized_compressors/test_pack_quant.py | 12 +- .../test_marlin_24.py | 20 +- tests/test_quantization/lifecycle/conftest.py | 10 + .../test_quantization/lifecycle/test_apply.py | 6 +- .../lifecycle/test_dynamic_lifecycle.py | 8 +- .../lifecycle/test_forward.py | 59 +---- .../lifecycle/test_frozen.py | 50 ---- .../lifecycle/test_kv_cache.py | 78 ------- .../lifecycle/test_lifecycle.py | 24 +- tests/test_quantization/test_cache.py | 116 ---------- .../test_configs/test_bit_depths.py | 31 ++- .../test_configs/test_strategies.py | 31 ++- .../test_observers/__init__.py | 13 -- .../test_observers/test_helpers.py | 91 -------- .../test_observers/test_min_max.py | 113 ---------- .../test_observers/test_mse.py | 54 ----- 33 files changed, 420 insertions(+), 1763 deletions(-) delete mode 100644 src/compressed_tensors/quantization/cache.py delete mode 100644 src/compressed_tensors/quantization/lifecycle/calibration.py delete mode 100644 src/compressed_tensors/quantization/lifecycle/frozen.py delete mode 100644 src/compressed_tensors/quantization/observers/__init__.py delete mode 100644 src/compressed_tensors/quantization/observers/base.py delete mode 100644 src/compressed_tensors/quantization/observers/helpers.py delete mode 100644 src/compressed_tensors/quantization/observers/min_max.py delete mode 100644 src/compressed_tensors/quantization/observers/mse.py create mode 100644 tests/conftest.py delete mode 100644 tests/test_quantization/lifecycle/test_frozen.py delete mode 100644 tests/test_quantization/lifecycle/test_kv_cache.py delete mode 100644 tests/test_quantization/test_cache.py delete mode 100644 tests/test_quantization/test_observers/__init__.py delete mode 100644 tests/test_quantization/test_observers/test_helpers.py delete mode 100644 tests/test_quantization/test_observers/test_min_max.py delete mode 100644 tests/test_quantization/test_observers/test_mse.py diff --git a/src/compressed_tensors/quantization/__init__.py b/src/compressed_tensors/quantization/__init__.py index 848a4458..9fde69a3 100644 --- a/src/compressed_tensors/quantization/__init__.py +++ b/src/compressed_tensors/quantization/__init__.py @@ -19,4 +19,3 @@ from .quant_config import * from .quant_scheme import * from .lifecycle import * -from .cache import QuantizedKVParameterCache diff --git a/src/compressed_tensors/quantization/cache.py b/src/compressed_tensors/quantization/cache.py deleted file mode 100644 index 312f1c9d..00000000 --- a/src/compressed_tensors/quantization/cache.py +++ /dev/null @@ -1,200 +0,0 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from enum import Enum -from typing import Any, Dict, List, Optional, Tuple - -from compressed_tensors.quantization.observers import Observer -from compressed_tensors.quantization.quant_args import QuantizationArgs -from torch import Tensor -from transformers import DynamicCache as HFDyanmicCache - - -class KVCacheScaleType(Enum): - KEY = "k_scale" - VALUE = "v_scale" - - -class QuantizedKVParameterCache(HFDyanmicCache): - """ - Quantized KV cache used in the forward call based on HF's dynamic cache. - Quantization strategy (tensor, group, channel) set from Quantization arg's strategy - Singleton, so that the same cache gets reused in all forward call of self_attn. - Each time forward is called, .update() is called, and ._quantize(), ._dequantize() - gets called appropriately. - The size of tensor is - `[batch_size, num_heads, seq_len - residual_length, head_dim]`. - - - Triggered by adding kv_cache_scheme in the recipe. - - Example: - - ```python3 - recipe = ''' - quant_stage: - quant_modifiers: - QuantizationModifier: - kv_cache_scheme: - num_bits: 8 - type: float - strategy: tensor - dynamic: false - symmetric: true - ''' - - """ - - _instance = None - _initialized = False - - def __new__(cls, *args, **kwargs): - """Singleton""" - if cls._instance is None: - cls._instance = super(QuantizedKVParameterCache, cls).__new__(cls) - return cls._instance - - def __init__(self, quantization_args: QuantizationArgs): - if not self._initialized: - super().__init__() - - self.quantization_args = quantization_args - - self.k_observers: List[Observer] = [] - self.v_observers: List[Observer] = [] - - # each index corresponds to layer_idx of the attention layer - self.k_scales: List[Tensor] = [] - self.v_scales: List[Tensor] = [] - - self.k_zps: List[Tensor] = [] - self.v_zps: List[Tensor] = [] - - self._initialized = True - - def update( - self, - key_states: Tensor, - value_states: Tensor, - layer_idx: int, - cache_kwargs: Optional[Dict[str, Any]] = None, - ) -> Tuple[Tensor, Tensor]: - """ - Get the k_scale and v_scale and output the - fakequant-ed key_states and value_states - """ - - if len(self.k_observers) <= layer_idx: - k_observer = self.quantization_args.get_observer() - v_observer = self.quantization_args.get_observer() - - self.k_observers.append(k_observer) - self.v_observers.append(v_observer) - - q_key_states = self._quantize( - key_states.contiguous(), KVCacheScaleType.KEY, layer_idx - ) - q_value_states = self._quantize( - value_states.contiguous(), KVCacheScaleType.VALUE, layer_idx - ) - - qdq_key_states = self._dequantize(q_key_states, KVCacheScaleType.KEY, layer_idx) - qdq_value_states = self._dequantize( - q_value_states, KVCacheScaleType.VALUE, layer_idx - ) - - keys_to_return, values_to_return = qdq_key_states, qdq_value_states - - return keys_to_return, values_to_return - - def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: - """ - Returns the sequence length of the cached states. - A layer index can be optionally passed. - """ - if len(self.key_cache) <= layer_idx: - return 0 - # since we cannot get the seq_length of each layer directly and - # rely on `_seen_tokens` which is updated every "layer_idx" == 0, - # this is a hack to get the actual seq_length for the given layer_idx - # this part of code otherwise fails when used to - # verify attn_weight shape in some models - return self._seen_tokens if layer_idx == 0 else self._seen_tokens - 1 - - def reset_states(self): - """reset the kv states (used in calibration)""" - self.key_cache: List[Tensor] = [] - self.value_cache: List[Tensor] = [] - # Used in `generate` to keep tally of how many tokens the cache has seen - self._seen_tokens = 0 - self._quantized_key_cache: List[Tensor] = [] - self._quantized_value_cache: List[Tensor] = [] - - def reset(self): - """ - Reset the instantiation, create new instance on init - """ - QuantizedKVParameterCache._instance = None - QuantizedKVParameterCache._initialized = False - - def _quantize(self, tensor, kv_type, layer_idx): - """Quantizes a key/value using a defined quantization method.""" - from compressed_tensors.quantization.lifecycle.forward import quantize - - if kv_type == KVCacheScaleType.KEY: # key type - observer = self.k_observers[layer_idx] - scales = self.k_scales - zps = self.k_zps - else: - assert kv_type == KVCacheScaleType.VALUE - observer = self.v_observers[layer_idx] - scales = self.v_scales - zps = self.v_zps - - scale, zp = observer(tensor) - if len(scales) <= layer_idx: - scales.append(scale) - zps.append(zp) - else: - scales[layer_idx] = scale - zps[layer_idx] = scale - - q_tensor = quantize( - x=tensor, - scale=scale, - zero_point=zp, - args=self.quantization_args, - ) - return q_tensor - - def _dequantize(self, qtensor, kv_type, layer_idx): - """Dequantizes back the tensor that was quantized by `self._quantize()`""" - from compressed_tensors.quantization.lifecycle.forward import dequantize - - if kv_type == KVCacheScaleType.KEY: - scale = self.k_scales[layer_idx] - zp = self.k_zps[layer_idx] - else: - assert kv_type == KVCacheScaleType.VALUE - scale = self.v_scales[layer_idx] - zp = self.v_zps[layer_idx] - - qdq_tensor = dequantize( - x_q=qtensor, - scale=scale, - zero_point=zp, - args=self.quantization_args, - ) - return qdq_tensor diff --git a/src/compressed_tensors/quantization/lifecycle/__init__.py b/src/compressed_tensors/quantization/lifecycle/__init__.py index 98bc6630..6acab255 100644 --- a/src/compressed_tensors/quantization/lifecycle/__init__.py +++ b/src/compressed_tensors/quantization/lifecycle/__init__.py @@ -15,9 +15,7 @@ # flake8: noqa # isort: skip_file -from .calibration import * from .forward import * -from .frozen import * from .initialize import * from .compressed import * from .apply import * diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index 09281528..7c498787 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -22,13 +22,9 @@ import torch from compressed_tensors.config import CompressionFormat -from compressed_tensors.quantization.lifecycle.calibration import ( - set_module_for_calibration, -) from compressed_tensors.quantization.lifecycle.compressed import ( compress_quantized_weights, ) -from compressed_tensors.quantization.lifecycle.frozen import freeze_module_quantization from compressed_tensors.quantization.lifecycle.initialize import ( initialize_module_for_quantization, ) @@ -233,6 +229,7 @@ def apply_quantization_status(model: Module, status: QuantizationStatus): :param model: model to apply quantization to :param status: status to update the module to """ + current_status = infer_quantization_status(model) if status >= QuantizationStatus.INITIALIZED > current_status: @@ -243,18 +240,6 @@ def apply_quantization_status(model: Module, status: QuantizationStatus): ) ) - if current_status < status >= QuantizationStatus.CALIBRATION > current_status: - # only quantize weights up front when our end goal state is calibration, - # weight quantization parameters are already loaded for frozen/compressed - quantize_weights_upfront = status == QuantizationStatus.CALIBRATION - model.apply( - lambda module: set_module_for_calibration( - module, quantize_weights_upfront=quantize_weights_upfront - ) - ) - if current_status < status >= QuantizationStatus.FROZEN > current_status: - model.apply(freeze_module_quantization) - if current_status < status >= QuantizationStatus.COMPRESSED > current_status: model.apply(compress_quantized_weights) diff --git a/src/compressed_tensors/quantization/lifecycle/calibration.py b/src/compressed_tensors/quantization/lifecycle/calibration.py deleted file mode 100644 index c9e51813..00000000 --- a/src/compressed_tensors/quantization/lifecycle/calibration.py +++ /dev/null @@ -1,80 +0,0 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import logging - -from compressed_tensors.quantization.quant_config import QuantizationStatus -from compressed_tensors.utils import is_module_offloaded, update_parameter_data -from torch.nn import Module - - -__all__ = [ - "set_module_for_calibration", -] - - -_LOGGER = logging.getLogger(__name__) - - -def set_module_for_calibration(module: Module, quantize_weights_upfront: bool = True): - """ - marks a layer as ready for calibration which activates observers - to update scales and zero points on each forward pass - - apply to full model with `model.apply(set_module_for_calibration)` - - :param module: module to set for calibration - :param quantize_weights_upfront: whether to automatically - run weight quantization at the start of calibration - """ - if not getattr(module, "quantization_scheme", None): - # no quantization scheme nothing to do - return - status = getattr(module, "quantization_status", None) - if not status or status != QuantizationStatus.INITIALIZED: - _LOGGER.warning( - f"Attempting set module with status {status} to calibration mode. " - f"but status is not {QuantizationStatus.INITIALIZED} - you may " - "be calibrating an uninitialized module which may fail or attempting " - "to re-calibrate a frozen module" - ) - - if quantize_weights_upfront and module.quantization_scheme.weights is not None: - # set weight scale and zero_point up front, calibration data doesn't affect it - if not hasattr(module, "weight_observer"): - from compressed_tensors.quantization.lifecycle.initialize import ( - initialize_observers, - ) - - initialize_observers( - module=module, - base_name="weight", - quantization_args=module.quantization_scheme.weights, - ) - - offloaded = is_module_offloaded(module) - if offloaded: - module._hf_hook.pre_forward(module) - - observer = module.weight_observer - g_idx = getattr(module, "weight_g_idx", None) - scale, zero_point = observer(module.weight, g_idx=g_idx) - update_parameter_data(module, scale, "weight_scale") - update_parameter_data(module, zero_point, "weight_zero_point") - - if offloaded: - module._hf_hook.post_forward(module, None) - - module.quantization_status = QuantizationStatus.CALIBRATION diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index eae641f8..19a22a39 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -17,11 +17,6 @@ from typing import Callable, Optional import torch -from compressed_tensors.quantization.cache import QuantizedKVParameterCache -from compressed_tensors.quantization.observers.helpers import ( - calculate_range, - compute_dynamic_scales_and_zp, -) from compressed_tensors.quantization.quant_args import ( QuantizationArgs, QuantizationStrategy, @@ -29,6 +24,10 @@ ) from compressed_tensors.quantization.quant_config import QuantizationStatus from compressed_tensors.quantization.quant_scheme import QuantizationScheme +from compressed_tensors.quantization.utils import ( + calculate_range, + compute_dynamic_scales_and_zp, +) from compressed_tensors.utils import safe_permute, update_parameter_data from torch.nn import Module @@ -39,7 +38,6 @@ "fake_quantize", "wrap_module_forward_quantized", "forward_quantize", - "calibrate_activations", ] @@ -276,19 +274,7 @@ def wrapped_forward(self, *args, **kwargs): compressed = module.quantization_status == QuantizationStatus.COMPRESSED if scheme.input_activations is not None: - # calibrate and (fake) quantize input activations when applicable - # NOTE: will be moved out of compressed-tensors - if ( - module.quantization_status == QuantizationStatus.CALIBRATION - and not scheme.input_activations.dynamic - ): - calibrate_activations( - module=module, - value=input_, - base_name="input", - quantization_args=scheme.input_activations, - ) - + # prehook should calibrate activations before forward call input_ = forward_quantize(module, input_, "input", scheme.input_activations) if scheme.weights is not None and not compressed: @@ -302,31 +288,22 @@ def wrapped_forward(self, *args, **kwargs): output = forward_func_orig.__get__(module, module.__class__)( input_, *args[1:], **kwargs ) - if scheme.output_activations is not None: - # calibrate and (fake) quantize output activations when applicable - # kv_cache scales updated on model self_attn forward call in - # wrap_module_forward_quantized_attn + # restore back to unquantized_value + if scheme.weights is not None and not compressed: + self.weight.data = unquantized_weight + if scheme.output_activations is not None: + # forward-hook should calibrate/forward_quantize if ( module.quantization_status == QuantizationStatus.CALIBRATION and not scheme.output_activations.dynamic ): - calibrate_activations( - module=module, - value=output, - base_name="output", - quantization_args=scheme.ouput_activations, - ) + return output output = forward_quantize( module, output, "output", scheme.output_activations ) - - # restore back to unquantized_value - if scheme.weights is not None and not compressed: - self.weight.data = unquantized_weight - return output # bind wrapped forward to module class so reference to `self` is correct @@ -335,77 +312,6 @@ def wrapped_forward(self, *args, **kwargs): setattr(module, "forward", bound_wrapped_forward) -def wrap_module_forward_quantized_attn(module: Module, scheme: QuantizationScheme): - # expects a module already initialized and injected with the parameters in - # initialize_module_for_quantization - if hasattr(module.forward, "__func__"): - forward_func_orig = module.forward.__func__ - else: - forward_func_orig = module.forward.func - - @wraps(forward_func_orig) # ensures docstring, names, etc are propagated - def wrapped_forward(self, *args, **kwargs): - - # kv cache stored under weights - if module.quantization_status == QuantizationStatus.CALIBRATION: - quantization_args: QuantizationArgs = scheme.output_activations - past_key_value: QuantizedKVParameterCache = quantization_args.get_kv_cache() - kwargs["past_key_value"] = past_key_value - - # QuantizedKVParameterCache used for obtaining k_scale, v_scale only, - # does not store quantized_key_states and quantized_value_state - kwargs["use_cache"] = False - - attn_forward: Callable = forward_func_orig.__get__(module, module.__class__) - - past_key_value.reset_states() - - rtn = attn_forward(*args, **kwargs) - - update_parameter_data( - module, past_key_value.k_scales[module.layer_idx], "k_scale" - ) - update_parameter_data( - module, past_key_value.v_scales[module.layer_idx], "v_scale" - ) - - return rtn - - return forward_func_orig.__get__(module, module.__class__)(*args, **kwargs) - - # bind wrapped forward to module class so reference to `self` is correct - bound_wrapped_forward = wrapped_forward.__get__(module, module.__class__) - # set forward to wrapped forward - setattr(module, "forward", bound_wrapped_forward) - - -def calibrate_activations( - module: Module, - value: torch.Tensor, - base_name: str, - quantization_args: QuantizationArgs, -): - # If empty tensor, can't update zp/scale - # Case for MoEs - if value.numel() == 0: - return - # calibration mode - get new quant params from observer - if not hasattr(module, f"{base_name}_observer"): - from compressed_tensors.quantization.lifecycle import initialize_observers - - initialize_observers( - module=module, base_name=base_name, quantization_args=quantization_args - ) - - observer = getattr(module, f"{base_name}_observer") - - updated_scale, updated_zero_point = observer(value) - - # update scale and zero point - update_parameter_data(module, updated_scale, f"{base_name}_scale") - update_parameter_data(module, updated_zero_point, f"{base_name}_zero_point") - - def forward_quantize( module: Module, value: torch.Tensor, base_name: str, args: "QuantizationArgs" ) -> torch.Tensor: @@ -426,10 +332,10 @@ def forward_quantize( g_idx = getattr(module, "weight_g_idx", None) if args.dynamic: - # dynamic quantization - no need to invoke observer + # dynamic quantization - determine the scale/zp on the fly scale, zero_point = compute_dynamic_scales_and_zp(value=value, args=args) else: - # static quantization - get previous scale and zero point from layer + # static quantization - get scale and zero point from layer scale = getattr(module, f"{base_name}_scale") zero_point = getattr(module, f"{base_name}_zero_point", None) diff --git a/src/compressed_tensors/quantization/lifecycle/frozen.py b/src/compressed_tensors/quantization/lifecycle/frozen.py deleted file mode 100644 index 4a65482c..00000000 --- a/src/compressed_tensors/quantization/lifecycle/frozen.py +++ /dev/null @@ -1,50 +0,0 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from compressed_tensors.quantization.quant_config import QuantizationStatus -from torch.nn import Module - - -__all__ = [ - "freeze_module_quantization", -] - - -def freeze_module_quantization(module: Module): - """ - deletes observers so static quantization is completed. - - apply to full model with `model.apply(freeze_module_quantization)` - - :param module: module to freeze quantization for - """ - scheme = getattr(module, "quantization_scheme", None) - if not scheme: - # no quantization scheme nothing to do - return - - if module.quantization_status == QuantizationStatus.FROZEN: - # nothing to do, already frozen - return - - # delete observers from module if not dynamic - if hasattr(module, "input_observer") and not scheme.input_activations.dynamic: - delattr(module, "input_observer") - if hasattr(module, "weight_observer") and not scheme.weights.dynamic: - delattr(module, "weight_observer") - if hasattr(module, "output_observer") and not scheme.output_activations.dynamic: - delattr(module, "output_observer") - - module.quantization_status = QuantizationStatus.FROZEN diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index 68157cb1..eb4d6b18 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -14,13 +14,12 @@ import logging +from enum import Enum from typing import Optional import torch -from compressed_tensors.quantization.cache import KVCacheScaleType from compressed_tensors.quantization.lifecycle.forward import ( wrap_module_forward_quantized, - wrap_module_forward_quantized_attn, ) from compressed_tensors.quantization.quant_args import ( ActivationOrdering, @@ -34,12 +33,21 @@ from torch.nn import Module, Parameter -__all__ = ["initialize_module_for_quantization", "initialize_observers"] +__all__ = [ + "initialize_module_for_quantization", + "is_attention_module", + "KVCacheScaleType", +] _LOGGER = logging.getLogger(__name__) +class KVCacheScaleType(Enum): + KEY = "k_scale" + VALUE = "v_scale" + + def initialize_module_for_quantization( module: Module, scheme: Optional[QuantizationScheme] = None, @@ -64,9 +72,7 @@ def initialize_module_for_quantization( return if is_attention_module(module): - # wrap forward call of module to perform # quantized actions based on calltime status - wrap_module_forward_quantized_attn(module, scheme) _initialize_attn_scales(module) else: @@ -107,6 +113,7 @@ def initialize_module_for_quantization( module.quantization_status = QuantizationStatus.INITIALIZED offloaded = False + # What is this doing/why isn't this in the attn case? if is_module_offloaded(module): try: from accelerate.hooks import add_hook_to_module, remove_hook_from_module @@ -144,14 +151,12 @@ def initialize_module_for_quantization( module._hf_hook.weights_map = new_prefix_dict -def initialize_observers( - module: Module, - base_name: str, - quantization_args: QuantizationArgs, -): - # initialize observer module and attach as submodule - observer = quantization_args.get_observer() - module.register_module(f"{base_name}_observer", observer) +def is_attention_module(module: Module): + return "attention" in module.__class__.__name__.lower() and ( + hasattr(module, "k_proj") + or hasattr(module, "v_proj") + or hasattr(module, "qkv_proj") + ) def _initialize_scale_zero_point( @@ -209,14 +214,6 @@ def _initialize_scale_zero_point( module.register_parameter(f"{base_name}_g_idx", init_g_idx) -def is_attention_module(module: Module): - return "attention" in module.__class__.__name__.lower() and ( - hasattr(module, "k_proj") - or hasattr(module, "v_proj") - or hasattr(module, "qkv_proj") - ) - - def _initialize_attn_scales(module: Module) -> None: """Initlaize k_scale, v_scale for self_attn""" diff --git a/src/compressed_tensors/quantization/observers/__init__.py b/src/compressed_tensors/quantization/observers/__init__.py deleted file mode 100644 index 05b6b367..00000000 --- a/src/compressed_tensors/quantization/observers/__init__.py +++ /dev/null @@ -1,21 +0,0 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# flake8: noqa -# isort: skip_file - -from .helpers import * -from .base import * -from .min_max import * -from .mse import * diff --git a/src/compressed_tensors/quantization/observers/base.py b/src/compressed_tensors/quantization/observers/base.py deleted file mode 100644 index d9d646f4..00000000 --- a/src/compressed_tensors/quantization/observers/base.py +++ /dev/null @@ -1,213 +0,0 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -from math import ceil -from typing import Any, Iterable, Optional, Tuple, Union - -import torch -from compressed_tensors.quantization.quant_args import ( - QuantizationArgs, - QuantizationStrategy, -) -from compressed_tensors.registry.registry import RegistryMixin -from compressed_tensors.utils import safe_permute -from torch import FloatTensor, IntTensor, Tensor -from torch.nn import Module - - -_LOGGER = logging.getLogger(__name__) - - -__all__ = ["Observer"] - - -class Observer(Module, RegistryMixin): - """ - Base Observer class to be subclassed for specific implementation. - Subclasses should override `calculate_qparams` to return a scale, zero_point - pair - """ - - def __init__(self, quantization_args: QuantizationArgs): - self.quantization_args: QuantizationArgs = quantization_args - super().__init__() - self._scale = None - self._zero_point = None - self._num_observed_tokens = None - - @torch.no_grad() - def forward( - self, observed: Tensor, g_idx: Optional[Tensor] = None - ) -> Tuple[FloatTensor, IntTensor]: - """ - maps directly to get_qparams - :param observed: optional observed tensor from which to calculate - quantization parameters - :param g_idx: optional mapping from column index to group index - :return: tuple of scale and zero point based on last observed value - """ - self.record_observed_tokens(observed) - return self.get_qparams(observed=observed, g_idx=g_idx) - - def calculate_qparams( - self, - observed: Tensor, - reduce_dims: Optional[Tuple[int]] = None, - ) -> Tuple[FloatTensor, IntTensor]: - """ - :param observed: observed tensor to calculate quantization parameters for - :param reduce_dims: optional tuple of dimensions to reduce along, - returned scale and zero point will be shaped (1,) along the - reduced dimensions - :return: tuple of scale and zero point derived from the observed tensor - """ - raise NotImplementedError(f"{self.__class__} must implement calculate_qparams") - - def post_calculate_qparams(self) -> None: - """ - Run any logic specific to its observers after running calculate_qparams - """ - ... - - def get_qparams( - self, - observed: Optional[Tensor] = None, - g_idx: Optional[Tensor] = None, - ) -> Tuple[FloatTensor, IntTensor]: - """ - Convenience function to wrap overwritten calculate_qparams - adds support to make observed tensor optional and support for tracking latest - calculated scale and zero point - - :param observed: optional observed tensor to calculate quantization parameters - from - :param g_idx: optional mapping from column index to group index - :return: tuple of scale and zero point based on last observed value - """ - if observed is not None: - group_size = self.quantization_args.group_size - - if self.quantization_args.strategy == QuantizationStrategy.TENSOR: - - # re-calculate scale and zero point, update the stored value - self._scale, self._zero_point = self.calculate_qparams(observed) - - elif self.quantization_args.strategy == QuantizationStrategy.GROUP: - rows = observed.shape[0] - columns = observed.shape[1] - num_groups = int(ceil(columns / group_size)) - self._scale = torch.empty( - (rows, num_groups), dtype=observed.dtype, device=observed.device - ) - zp_dtype = self.quantization_args.pytorch_dtype() - self._zero_point = torch.empty( - (rows, num_groups), dtype=zp_dtype, device=observed.device - ) - - # support column-order (default) quantization as well as other orderings - # such as activation ordering. Below checks if g_idx has initialized - is_column_order = g_idx is None or -1 in g_idx - if is_column_order: - group_sizes = torch.full((num_groups,), group_size, dtype=torch.int) - else: - group_indices, group_sizes = torch.unique(g_idx, return_counts=True) - group_sizes = group_sizes[torch.argsort(group_indices)] - - perm = torch.argsort(g_idx) - observed = safe_permute(observed, perm, dim=1) - - # TODO: experiment with vectorizing for loop for performance - end = 0 - for group_index, group_count in enumerate(group_sizes): - start = end - end = start + group_count - scale, zero_point = self.get_qparams_along_dim( - observed[:, start:end], - 0, - tensor_id=group_index, - ) - - self._scale[:, group_index] = scale.squeeze(1) - self._zero_point[:, group_index] = zero_point.squeeze(1) - - elif self.quantization_args.strategy == QuantizationStrategy.CHANNEL: - # assume observed is transposed, because its the output, hence use dim 0 - self._scale, self._zero_point = self.get_qparams_along_dim(observed, 0) - - elif self.quantization_args.strategy == QuantizationStrategy.TOKEN: - # use dim 1, assume the obsersed.shape = [batch, token, hidden] - # should be batch, token - self._scale, self._zero_point = self.get_qparams_along_dim( - observed, - dim={0, 1}, - ) - - return self._scale, self._zero_point - - def get_qparams_along_dim( - self, - observed, - dim: Union[int, Iterable[int]], - tensor_id: Optional[Any] = None, - ): - if isinstance(dim, int): - dim = [dim] - dim = set(dim) - - reduce_dims = tuple(idx for idx in range(observed.ndim) if idx not in dim) - return self.calculate_qparams( - observed, reduce_dims=reduce_dims, tensor_id=tensor_id - ) - - def record_observed_tokens(self, batch_tensor: Tensor): - """ - Counts the number of tokens observed during the - forward passes. The count is aggregated in the - _num_observed_tokens attribute of the class. - - Note: The batch_tensor is expected to have two dimensions - (batch_size * sequence_length, num_features). This is the - general shape expected by the forward pass of the expert - layers in a MOE model. If the input tensor does not have - two dimensions, the _num_observed_tokens attribute will be set - to None. - """ - if not isinstance(batch_tensor, Tensor): - raise ValueError(f"Expected value to be a tensor, got {type(batch_tensor)}") - - if batch_tensor.ndim != 2: - _LOGGER.debug( - "The input tensor is expected to have two dimensions " - "(batch_size * sequence_length, num_features). " - f"The input tensor has {batch_tensor.ndim} dimensions." - ) - return - - if self._num_observed_tokens is None: - # initialize the count - self._num_observed_tokens = 0 - - # batch_tensor (batch_size * sequence_length, num_features) - # observed_tokens (batch_size * sequence_length) - observed_tokens, _ = batch_tensor.shape - self._num_observed_tokens += observed_tokens - - def reset(self): - """ - Reset the state of the observer - """ - self._num_observed_tokens = None - self._scale = None - self._zero_point = None diff --git a/src/compressed_tensors/quantization/observers/helpers.py b/src/compressed_tensors/quantization/observers/helpers.py deleted file mode 100644 index ec474303..00000000 --- a/src/compressed_tensors/quantization/observers/helpers.py +++ /dev/null @@ -1,149 +0,0 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from collections import Counter -from typing import Tuple - -import torch -from compressed_tensors.quantization.quant_args import ( - FP8_DTYPE, - QuantizationArgs, - QuantizationStrategy, - QuantizationType, -) -from torch import FloatTensor, IntTensor, Tensor - - -__all__ = [ - "calculate_qparams", - "get_observer_token_count", - "calculate_range", - "compute_dynamic_scales_and_zp", -] - - -def compute_dynamic_scales_and_zp(value: Tensor, args: QuantizationArgs): - """ - Returns the computed scales and zero points for dynamic activation - qunatization. - - :param value: tensor to calculate quantization parameters for - :param args: quantization args - :param reduce_dims: optional tuple of dimensions to reduce along, - returned scale and zero point will be shaped (1,) along the - reduced dimensions - :return: tuple of scale and zero point derived from the observed tensor - """ - if args.strategy == QuantizationStrategy.TOKEN: - dim = {1, 2} - reduce_dims = tuple(idx for idx in range(value.ndim) if idx not in dim) - elif args.strategy == QuantizationStrategy.TENSOR: - reduce_dims = None - else: - raise ValueError( - f"One of {QuantizationStrategy.TOKEN} or {QuantizationStrategy.TENSOR} ", - "must be used for dynamic quantization", - ) - - if not reduce_dims: - min_val, max_val = torch.aminmax(value) - else: - min_val = torch.amin(value, dim=reduce_dims, keepdims=True) - max_val = torch.amax(value, dim=reduce_dims, keepdims=True) - - return calculate_qparams(min_val, max_val, args) - - -def get_observer_token_count(module: torch.nn.Module) -> Counter: - """ - Parse the module and return the number of tokens observed by - each module's observer. - - :param module: module to parse - :return: counter with the number of tokens observed by each observer - """ - token_counts = Counter() - for name, module in module.named_modules(): - if name.endswith(".input_observer"): - token_counts[ - name.replace(".input_observer", "") - ] = module._num_observed_tokens - return token_counts - - -def calculate_qparams( - min_vals: Tensor, max_vals: Tensor, quantization_args: QuantizationArgs -) -> Tuple[FloatTensor, IntTensor]: - """ - :param min_vals: tensor of min value(s) to calculate scale(s) and zero point(s) - from - :param max_vals: tensor of max value(s) to calculate scale(s) and zero point(s) - from - :param quantization_args: settings to quantization - :return: tuple of the calculated scale(s) and zero point(s) - """ - min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) - max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) - device = min_vals.device - - bit_min, bit_max = calculate_range(quantization_args, device) - bit_range = bit_max - bit_min - zp_dtype = quantization_args.pytorch_dtype() - - if quantization_args.symmetric: - max_val_pos = torch.max(torch.abs(min_vals), torch.abs(max_vals)) - scales = max_val_pos / (float(bit_range) / 2) - scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps) - zero_points = torch.zeros(scales.shape, device=device, dtype=min_vals.dtype) - else: - scales = (max_vals - min_vals) / float(bit_range) - scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps) - zero_points = bit_min - (min_vals / scales) - zero_points = torch.clamp(zero_points, bit_min, bit_max) - - # match zero-points to quantized type - zero_points = zero_points.to(zp_dtype) - - if scales.ndim == 0: - scales = scales.reshape(1) - zero_points = zero_points.reshape(1) - - return scales, zero_points - - -def calculate_range(quantization_args: QuantizationArgs, device: str) -> Tuple: - """ - Calculated the effective quantization range for the given Quantization Args - - :param quantization_args: quantization args to get range of - :param device: device to store the range to - :return: tuple endpoints for the given quantization range - """ - if quantization_args.type == QuantizationType.INT: - bit_range = 2**quantization_args.num_bits - q_max = torch.tensor(bit_range / 2 - 1, device=device) - q_min = torch.tensor(-bit_range / 2, device=device) - elif quantization_args.type == QuantizationType.FLOAT: - if quantization_args.num_bits != 8: - raise ValueError( - "Floating point quantization is only supported for 8 bits," - f"got {quantization_args.num_bits}" - ) - fp_range_info = torch.finfo(FP8_DTYPE) - q_max = torch.tensor(fp_range_info.max, device=device) - q_min = torch.tensor(fp_range_info.min, device=device) - else: - raise ValueError(f"Invalid quantization type {quantization_args.type}") - - return q_min, q_max diff --git a/src/compressed_tensors/quantization/observers/min_max.py b/src/compressed_tensors/quantization/observers/min_max.py deleted file mode 100644 index a5b12906..00000000 --- a/src/compressed_tensors/quantization/observers/min_max.py +++ /dev/null @@ -1,104 +0,0 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Any, Optional, Tuple - -import torch -from compressed_tensors.quantization.observers.base import Observer -from compressed_tensors.quantization.observers.helpers import calculate_qparams -from compressed_tensors.quantization.quant_args import QuantizationArgs -from torch import FloatTensor, IntTensor, Tensor - - -__all__ = ["MovingAverageMinMaxObserver"] - - -@Observer.register("minmax") -class MovingAverageMinMaxObserver(Observer): - """ - Implements a dynamic quantization observer that sets the scale and - zero point based on a moving average of the overall min and max observed values - """ - - def __init__( - self, quantization_args: QuantizationArgs, averaging_constant: float = 0.01 - ): - super().__init__(quantization_args=quantization_args) - - self.min_val = {} - self.max_val = {} - self.averaging_constant = averaging_constant - - def calculate_qparams( - self, - observed: Tensor, - reduce_dims: Optional[Tuple[int]] = None, - tensor_id: Optional[Any] = None, - ) -> Tuple[FloatTensor, IntTensor]: - """ - Updates the observed min and max using a moving average smoothed by the - averaging_constant - - :param observed: observed tensor to calculate quantization parameters for - :param reduce_dims: optional tuple of dimensions to reduce along, - returned scale and zero point will be shaped (1,) along the - reduced dimensions - :param tensor_id: Optional id if different ranges of observed tensors are - passed, useful for sharding tensors by group_size - :return: tuple of scale and zero point derived from the observed tensor - """ - tensor_id = tensor_id or "default" - - if not reduce_dims: - min_val, max_val = torch.aminmax(observed) - else: - min_val = torch.amin(observed, dim=reduce_dims, keepdims=True) - max_val = torch.amax(observed, dim=reduce_dims, keepdims=True) - - running_min_val = self.min_val.get(tensor_id, None) - running_max_val = self.max_val.get(tensor_id, None) - - if running_min_val is None or running_max_val is None: - updated_min_val = min_val - updated_max_val = max_val - else: - updated_min_val = running_min_val + self.averaging_constant * ( - min_val - running_min_val - ) - updated_max_val = running_max_val + self.averaging_constant * ( - max_val - running_max_val - ) - - self.min_val[tensor_id] = updated_min_val - self.max_val[tensor_id] = updated_max_val - - return calculate_qparams( - updated_min_val, updated_max_val, self.quantization_args - ) - - def get_qparams_along_dim( - self, observed, dim: int, tensor_id: Optional[Any] = None - ): - reduce_dims = tuple(idx for idx in range(observed.ndim) if idx != dim) - return self.calculate_qparams( - observed, reduce_dims=reduce_dims, tensor_id=tensor_id - ) - - def reset(self): - """ - Reset the state of the observer, including min and maximum values - """ - super().reset() - self.min_val = {} - self.max_val = {} diff --git a/src/compressed_tensors/quantization/observers/mse.py b/src/compressed_tensors/quantization/observers/mse.py deleted file mode 100644 index 739e921f..00000000 --- a/src/compressed_tensors/quantization/observers/mse.py +++ /dev/null @@ -1,164 +0,0 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Any, Optional, Tuple - -import torch -from compressed_tensors.quantization.observers.base import Observer -from compressed_tensors.quantization.observers.helpers import calculate_qparams -from compressed_tensors.quantization.quant_args import QuantizationArgs -from torch import FloatTensor, IntTensor, Tensor - - -__all__ = ["MovingAverageMSEObserver"] - - -@Observer.register("mse") -class MovingAverageMSEObserver(Observer): - """ - Implements a dynamic quantization observer that sets the scale and - zero point based on a moving average of the mse-clipped min and max observed values - """ - - def __init__( - self, - quantization_args: QuantizationArgs, - averaging_constant: float = 0.01, - grid: float = 100.0, - maxshrink: float = 0.80, - norm: float = 2.4, - ): - super().__init__(quantization_args=quantization_args) - - self.min_val = {} - self.max_val = {} - self.averaging_constant = averaging_constant - self.grid = grid - self.maxshrink = maxshrink - self.norm = norm - - def calculate_mse_min_max( - self, - observed: Tensor, - reduce_dims: Optional[Tuple[int]] = None, - ): - """ - Computes the mse-clipped min and max values of the observed tensor by - optimizing for quantization error - - :param observed: observed tensor to calculate quantization parameters for - :param reduce_dims: optional tuple of dimensions to reduce along, - returned values will be shaped (1,) along the reduced dimensions - :return: tuple of min and max values derived from the observed tensor - """ - from compressed_tensors.quantization.lifecycle import fake_quantize - - if not reduce_dims: - absolute_min_val, absolute_max_val = torch.aminmax(observed) - else: - absolute_min_val = torch.amin(observed, dim=reduce_dims, keepdims=True) - absolute_max_val = torch.amax(observed, dim=reduce_dims, keepdims=True) - - best = torch.full_like( - absolute_min_val, torch.finfo(absolute_min_val.dtype).max - ) - min_val = torch.ones_like(absolute_min_val) - max_val = torch.zeros_like(absolute_max_val) - for i in range(int(self.maxshrink * self.grid)): - p = 1 - i / self.grid - shrinked_min_val = p * absolute_min_val - shrinked_max_val = p * absolute_max_val - - candidate_scales, candidate_zero_points = calculate_qparams( - shrinked_min_val, shrinked_max_val, self.quantization_args - ) - q = fake_quantize( - observed, - candidate_scales, - candidate_zero_points, - self.quantization_args, - ) - - q -= observed - q.abs_() - q.pow_(self.norm) - if not reduce_dims: - err = torch.sum(q) - else: - err = torch.sum(q, reduce_dims, keepdims=True) - - tmp = err < best - if torch.any(tmp): - best[tmp] = err[tmp] - min_val[tmp] = shrinked_min_val[tmp] - max_val[tmp] = shrinked_max_val[tmp] - return min_val, max_val - - def calculate_qparams( - self, - observed: Tensor, - reduce_dims: Optional[Tuple[int]] = None, - tensor_id: Optional[Any] = None, - ) -> Tuple[FloatTensor, IntTensor]: - """ - Updates the mse-clipped min and max values of the observed tensor using - a moving average smoothed by the averaging_constant - - :param observed: observed tensor to calculate quantization parameters for - :param reduce_dims: optional tuple of dimensions to reduce along, - returned scale and zero point will be shaped (1,) along the - reduced dimensions - :param tensor_id: Optional id if different ranges of observed tensors are - passed, useful for sharding tensors by group_size - :return: tuple of scale and zero point derived from the observed tensor - """ - min_val, max_val = self.calculate_mse_min_max(observed, reduce_dims) - - running_min_val = self.min_val.get(tensor_id, None) - running_max_val = self.max_val.get(tensor_id, None) - - if running_min_val is None or running_max_val is None: - updated_min_val = min_val - updated_max_val = max_val - else: - updated_min_val = running_min_val + self.averaging_constant * ( - min_val - running_min_val - ) - updated_max_val = running_max_val + self.averaging_constant * ( - max_val - running_max_val - ) - - tensor_id = tensor_id or "default" - self.min_val[tensor_id] = updated_min_val - self.max_val[tensor_id] = updated_max_val - - return calculate_qparams( - updated_min_val, updated_max_val, self.quantization_args - ) - - def get_qparams_along_dim( - self, observed, dim: int, tensor_id: Optional[Any] = None - ): - reduce_dims = tuple(idx for idx in range(observed.ndim) if idx != dim) - return self.calculate_qparams( - observed, reduce_dims=reduce_dims, tensor_id=tensor_id - ) - - def reset(self): - """ - Reset the state of the observer, including min and maximum values - """ - super().reset() - self.min_val = {} - self.max_val = {} diff --git a/src/compressed_tensors/quantization/quant_args.py b/src/compressed_tensors/quantization/quant_args.py index c2fc0b6a..3259976c 100644 --- a/src/compressed_tensors/quantization/quant_args.py +++ b/src/compressed_tensors/quantization/quant_args.py @@ -114,20 +114,13 @@ def get_observer(self): """ :return: torch quantization FakeQuantize built based on these QuantizationArgs """ - from compressed_tensors.quantization.observers.base import Observer # No observer required for the dynamic case if self.dynamic: self.observer = None return self.observer - return Observer.load_from_registry(self.observer, quantization_args=self) - - def get_kv_cache(self): - """Get the singleton KV Cache""" - from compressed_tensors.quantization.cache import QuantizedKVParameterCache - - return QuantizedKVParameterCache(self) + return self.observer @field_validator("type", mode="before") def validate_type(cls, value) -> QuantizationType: diff --git a/src/compressed_tensors/quantization/utils/helpers.py b/src/compressed_tensors/quantization/utils/helpers.py index 8ebde09b..6f05524a 100644 --- a/src/compressed_tensors/quantization/utils/helpers.py +++ b/src/compressed_tensors/quantization/utils/helpers.py @@ -16,9 +16,14 @@ from typing import Generator, List, Optional, Tuple import torch -from compressed_tensors.quantization.observers.base import Observer -from compressed_tensors.quantization.quant_args import QuantizationArgs +from compressed_tensors.quantization.quant_args import ( + FP8_DTYPE, + QuantizationArgs, + QuantizationStrategy, + QuantizationType, +) from compressed_tensors.quantization.quant_scheme import QuantizationScheme +from torch import FloatTensor, IntTensor, Tensor from torch.nn import Module from tqdm import tqdm @@ -36,6 +41,9 @@ "is_kv_cache_quant_scheme", "iter_named_leaf_modules", "iter_named_quantizable_modules", + "compute_dynamic_scales_and_zp", + "calculate_range", + "calculate_qparams", ] # target the self_attn layer @@ -45,6 +53,105 @@ _LOGGER: logging.Logger = logging.getLogger(__name__) +def calculate_qparams( + min_vals: Tensor, max_vals: Tensor, quantization_args: QuantizationArgs +) -> Tuple[FloatTensor, IntTensor]: + """ + :param min_vals: tensor of min value(s) to calculate scale(s) and zero point(s) + from + :param max_vals: tensor of max value(s) to calculate scale(s) and zero point(s) + from + :param quantization_args: settings to quantization + :return: tuple of the calculated scale(s) and zero point(s) + """ + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + device = min_vals.device + + bit_min, bit_max = calculate_range(quantization_args, device) + bit_range = bit_max - bit_min + zp_dtype = quantization_args.pytorch_dtype() + + if quantization_args.symmetric: + max_val_pos = torch.max(torch.abs(min_vals), torch.abs(max_vals)) + scales = max_val_pos / (float(bit_range) / 2) + scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps) + zero_points = torch.zeros(scales.shape, device=device, dtype=min_vals.dtype) + else: + scales = (max_vals - min_vals) / float(bit_range) + scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps) + zero_points = bit_min - (min_vals / scales) + zero_points = torch.clamp(zero_points, bit_min, bit_max) + + # match zero-points to quantized type + zero_points = zero_points.to(zp_dtype) + + if scales.ndim == 0: + scales = scales.reshape(1) + zero_points = zero_points.reshape(1) + + return scales, zero_points + + +def compute_dynamic_scales_and_zp(value: Tensor, args: QuantizationArgs): + """ + Returns the computed scales and zero points for dynamic activation + qunatization. + + :param value: tensor to calculate quantization parameters for + :param args: quantization args + :param reduce_dims: optional tuple of dimensions to reduce along, + returned scale and zero point will be shaped (1,) along the + reduced dimensions + :return: tuple of scale and zero point derived from the observed tensor + """ + if args.strategy == QuantizationStrategy.TOKEN: + dim = {1, 2} + reduce_dims = tuple(idx for idx in range(value.ndim) if idx not in dim) + elif args.strategy == QuantizationStrategy.TENSOR: + reduce_dims = None + else: + raise ValueError( + f"One of {QuantizationStrategy.TOKEN} or {QuantizationStrategy.TENSOR} ", + "must be used for dynamic quantization", + ) + + if not reduce_dims: + min_val, max_val = torch.aminmax(value) + else: + min_val = torch.amin(value, dim=reduce_dims, keepdims=True) + max_val = torch.amax(value, dim=reduce_dims, keepdims=True) + + return calculate_qparams(min_val, max_val, args) + + +def calculate_range(quantization_args: QuantizationArgs, device: str) -> Tuple: + """ + Calculated the effective quantization range for the given Quantization Args + + :param quantization_args: quantization args to get range of + :param device: device to store the range to + :return: tuple endpoints for the given quantization range + """ + if quantization_args.type == QuantizationType.INT: + bit_range = 2**quantization_args.num_bits + q_max = torch.tensor(bit_range / 2 - 1, device=device) + q_min = torch.tensor(-bit_range / 2, device=device) + elif quantization_args.type == QuantizationType.FLOAT: + if quantization_args.num_bits != 8: + raise ValueError( + "Floating point quantization is only supported for 8 bits," + f"got {quantization_args.num_bits}" + ) + fp_range_info = torch.finfo(FP8_DTYPE) + q_max = torch.tensor(fp_range_info.max, device=device) + q_min = torch.tensor(fp_range_info.min, device=device) + else: + raise ValueError(f"Invalid quantization type {quantization_args.type}") + + return q_min, q_max + + def infer_quantization_status(model: Module) -> Optional["QuantizationStatus"]: # noqa """ Checks the quantization status of a model. Assumes all modules in the model have @@ -118,12 +225,18 @@ def iter_named_leaf_modules(model: Module) -> Generator[Tuple[str, Module], None """ for name, submodule in model.named_modules(): children = list(submodule.children()) - if len(children) == 0 and not isinstance(submodule, Observer): + # TODO: verify if an observer would ever be attached in this case/remove check + if len(children) == 0 and "observer" in name: yield name, submodule else: + if len(children) > 0: + named_children, children = zip(*list(submodule.named_children())) has_non_observer_children = False - for child in children: - if not isinstance(child, Observer): + for i in range(len(children)): + child = children[i] + child_name = named_children[i] + + if "observer" not in child_name: has_non_observer_children = True if not has_non_observer_children: @@ -144,14 +257,20 @@ def iter_named_quantizable_modules( :returns: generator tuple of (name, submodule) """ for name, submodule in model.named_modules(): + # TODO: verify if an observer would ever be attached in this case/remove check if include_children: children = list(submodule.children()) - if len(children) == 0 and not isinstance(submodule, Observer): + if len(children) == 0 and "observer" not in name: yield name, submodule else: + if len(children) > 0: + named_children, children = zip(*list(submodule.named_children())) has_non_observer_children = False - for child in children: - if not isinstance(child, Observer): + for i in range(len(children)): + child_name = named_children[i] + child = children[i] + + if "observer" not in child_name: has_non_observer_children = True if not has_non_observer_children: diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..97790458 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,141 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from math import ceil +from typing import Any, Iterable, Optional, Union + +import pytest +import torch +from compressed_tensors.quantization.quant_args import ( + QuantizationArgs, + QuantizationStrategy, +) +from compressed_tensors.quantization.utils import calculate_qparams +from compressed_tensors.utils.offload import update_parameter_data + + +def _get_dim(dim: int, value: torch.Tensor): + if isinstance(dim, int): + dim = [dim] + dim = set(dim) + + reduce_dims = tuple(idx for idx in range(value.ndim) if idx not in dim) + return reduce_dims + + +@pytest.fixture +def mock_per_token_calibration(): + def update_scale_zp(module: torch.nn.Module, base_name: str, value: torch.Tensor): + quantization_scheme = getattr(module, "quantization_scheme", None) + if not quantization_scheme: + # no quantization scheme nothing to do + return + + arg_name = "weights" if base_name == "weight" else f"{base_name}_activations" + args = getattr(quantization_scheme, arg_name, None) + + dim = _get_dim({0, 1}, value) + min_val = torch.amin(value, dim=dim, keepdims=True) + max_val = torch.amax(value, dim=dim, keepdims=True) + scale, zp = calculate_qparams(min_val, max_val, args) + scale = scale.reshape((1, 1)) + zp = zp.reshape((1, 1)) + update_parameter_data(module, scale, f"{base_name}_scale") + update_parameter_data(module, zp, f"{base_name}_zero_point") + + return update_scale_zp + + +@pytest.fixture +def mock_per_group_calibration(): + def update_scale_zp( + module: torch.nn.Module, base_name: str, value: torch.Tensor, group_size: int + ): + quantization_scheme = getattr(module, "quantization_scheme", None) + if not quantization_scheme: + # no quantization scheme nothing to do + return + + arg_name = "weights" if base_name == "weight" else f"{base_name}_activations" + args = getattr(quantization_scheme, arg_name, None) + + rows = value.shape[0] + columns = value.shape[1] + num_groups = int(ceil(columns / group_size)) + + scale = torch.zeros((rows, num_groups), dtype=value.dtype, device=value.device) + zp_dtype = args.pytorch_dtype() + zp = torch.zeros((rows, num_groups), dtype=zp_dtype, device=value.device) + + group_sizes = torch.full((num_groups,), group_size, dtype=torch.int) + end = 0 + for group_index, group_count in enumerate(group_sizes): + start = end + end = start + group_count + dim = _get_dim( + 0, + value[:, start:end], + ) + min_val = torch.amin(value, dim=dim, keepdims=True) + max_val = torch.amax(value, dim=dim, keepdims=True) + scale_out, zp_out = calculate_qparams(min_val, max_val, args) + + scale[:, group_index] = scale_out.squeeze(1) + zp[:, group_index] = zp_out.squeeze(1) + + update_parameter_data(module, scale, f"{base_name}_scale") + update_parameter_data(module, zp, f"{base_name}_zero_point") + + return update_scale_zp + + +@pytest.fixture +def mock_per_channel_calibration(): + def update_scale_zp(module: torch.nn.Module, base_name: str, value: torch.Tensor): + quantization_scheme = getattr(module, "quantization_scheme", None) + if not quantization_scheme: + # no quantization scheme nothing to do + return + + arg_name = "weights" if base_name == "weight" else f"{base_name}_activations" + + args = getattr(quantization_scheme, arg_name, None) + dim = _get_dim(0, value) + min_val = torch.amin(value, dim=dim, keepdims=True) + max_val = torch.amax(value, dim=dim, keepdims=True) + scale, zp = calculate_qparams(min_val, max_val, args) + update_parameter_data(module, scale, f"{base_name}_scale") + update_parameter_data(module, zp, f"{base_name}_zero_point") + + return update_scale_zp + + +@pytest.fixture +def mock_per_tensor_calibration(): + def update_scale_zp(module: torch.nn.Module, base_name: str, value: torch.Tensor): + quantization_scheme = getattr(module, "quantization_scheme", None) + if not quantization_scheme: + # no quantization scheme nothing to do + return + + arg_name = "weights" if base_name == "weight" else f"{base_name}_activations" + args = getattr(quantization_scheme, arg_name, None) + + # per tensor quantization just calls calculate_qparams directly + min_val, max_val = torch.aminmax(value) + scale, zp = calculate_qparams(min_val, max_val, args) + update_parameter_data(module, scale, f"{base_name}_scale") + update_parameter_data(module, zp, f"{base_name}_zero_point") + + return update_scale_zp diff --git a/tests/test_compressors/quantized_compressors/test_fp8_quant.py b/tests/test_compressors/quantized_compressors/test_fp8_quant.py index ba8a5df3..4a253711 100644 --- a/tests/test_compressors/quantized_compressors/test_fp8_quant.py +++ b/tests/test_compressors/quantized_compressors/test_fp8_quant.py @@ -114,7 +114,13 @@ def test_quant_format(strategy, group_size, sc, zp): # Note that group quantization is not supported ], ) -def test_reload_match(strategy, group_size, tmp_path): +def test_reload_match( + mock_per_group_calibration, + mock_per_channel_calibration, + strategy, + group_size, + tmp_path, +): model = Sequential( OrderedDict( [ @@ -124,11 +130,15 @@ def test_reload_match(strategy, group_size, tmp_path): ) quant_config = get_dummy_quant_config(strategy=strategy, group_size=group_size) apply_quantization_config(model, quant_config) - apply_quantization_status(model, QuantizationStatus.CALIBRATION) - - for _ in range(16): - inputs = torch.rand((512, 512)) - _ = model(inputs) + model.dummy.quantization_status = QuantizationStatus.CALIBRATION + if strategy == QuantizationStrategy.GROUP: + mock_per_group_calibration( + model.dummy, base_name="weight", value=model.dummy.weight, group_size=128 + ) + if strategy == QuantizationStrategy.CHANNEL: + mock_per_channel_calibration( + model.dummy, base_name="weight", value=model.dummy.weight + ) compressor = FloatQuantizationCompressor(config=quant_config) quantized_modules_to_args = { diff --git a/tests/test_compressors/quantized_compressors/test_pack_quant.py b/tests/test_compressors/quantized_compressors/test_pack_quant.py index 496e8304..834b482c 100644 --- a/tests/test_compressors/quantized_compressors/test_pack_quant.py +++ b/tests/test_compressors/quantized_compressors/test_pack_quant.py @@ -205,7 +205,7 @@ def test_reload_match(tmp_path, num_bits): None, ], ) -def test_actorder_reload_match(actorder, tmp_path): +def test_actorder_reload_match(actorder, tmp_path, mock_per_group_calibration): model = Sequential(OrderedDict([("dummy", Linear(512, 1024, bias=None))])) group_size = 128 quant_config = get_dummy_quant_config( @@ -214,12 +214,10 @@ def test_actorder_reload_match(actorder, tmp_path): apply_quantization_config(model, quant_config) # run calibration - apply_quantization_status(model, QuantizationStatus.CALIBRATION) - for _ in range(16): - inputs = torch.rand((512, 512)) - _ = model(inputs) - apply_quantization_status(model, QuantizationStatus.FROZEN) - + model.quantization_status = QuantizationStatus.CALIBRATION + mock_per_group_calibration( + model.dummy, base_name="weight", value=model.dummy.weight, group_size=group_size + ) # apply gptq if actorder == ActivationOrdering.GROUP: init_g_idx = make_dummy_g_idx(512, group_size) diff --git a/tests/test_compressors/sparse_quantized_compressors/test_marlin_24.py b/tests/test_compressors/sparse_quantized_compressors/test_marlin_24.py index f12accb8..c9fc678b 100644 --- a/tests/test_compressors/sparse_quantized_compressors/test_marlin_24.py +++ b/tests/test_compressors/sparse_quantized_compressors/test_marlin_24.py @@ -54,7 +54,13 @@ def test_marlin_registered(): "strategy", [QuantizationStrategy.GROUP, QuantizationStrategy.CHANNEL] ) @pytest.mark.parametrize("layer_shape", [(512, 128), (1024, 1024), (128, 256)]) -def test_marlin24_format(num_bits, strategy, layer_shape): +def test_marlin24_format( + mock_per_group_calibration, + mock_per_channel_calibration, + num_bits, + strategy, + layer_shape, +): QUANT_NAME = "quant" NOT_QUANT_NAME = "not_quant" model = Sequential( @@ -70,11 +76,17 @@ def test_marlin24_format(num_bits, strategy, layer_shape): model.quant.weight.data *= mask apply_quantization_config(model, config) - apply_quantization_status(model, QuantizationStatus.CALIBRATION) + model.quantization_status = QuantizationStatus.CALIBRATION # runs observer to get scale and zero point - input = torch.rand((64, layer_shape[0])) - _ = model(input) + if strategy == QuantizationStrategy.GROUP: + mock_per_group_calibration( + model.quant, base_name="weight", value=model.quant.weight, group_size=128 + ) + if strategy == QuantizationStrategy.CHANNEL: + mock_per_channel_calibration( + model.quant, base_name="weight", value=model.quant.weight + ) state_dict = model.state_dict() assert len(state_dict) == 4 diff --git a/tests/test_quantization/lifecycle/conftest.py b/tests/test_quantization/lifecycle/conftest.py index 97bf8b0c..49b5eda3 100644 --- a/tests/test_quantization/lifecycle/conftest.py +++ b/tests/test_quantization/lifecycle/conftest.py @@ -15,7 +15,9 @@ from typing import List, Optional import pytest +import torch from compressed_tensors.quantization.quant_args import QuantizationArgs +from compressed_tensors.quantization.quant_config import QuantizationStatus from compressed_tensors.quantization.quant_scheme import QuantizationScheme @@ -35,3 +37,11 @@ def quantization_scheme( ) return quantization_scheme + + +@pytest.fixture +def mock_frozen(): + def update_status(model: torch.nn.Module): + model.quantization_status = QuantizationStatus.FROZEN + + return update_status diff --git a/tests/test_quantization/lifecycle/test_apply.py b/tests/test_quantization/lifecycle/test_apply.py index dcb980f2..4e9839b9 100644 --- a/tests/test_quantization/lifecycle/test_apply.py +++ b/tests/test_quantization/lifecycle/test_apply.py @@ -22,7 +22,6 @@ DEFAULT_QUANTIZATION_METHOD, QuantizationConfig, QuantizationStatus, - freeze_module_quantization, ) from compressed_tensors.quantization.lifecycle import ( apply_quantization_config, @@ -32,7 +31,7 @@ from transformers import AutoModelForCausalLM -def test_target_prioritization(): +def test_target_prioritization(mock_frozen): # tests that the config_groups are applied in the correct order # of priority, where exact layer name > regex > module name config = { @@ -68,7 +67,7 @@ def test_target_prioritization(): config = QuantizationConfig(**config) config.quantization_status = QuantizationStatus.CALIBRATION apply_quantization_config(model, config) - model.apply(freeze_module_quantization) + mock_frozen(model) for name, module in iter_named_leaf_modules(model): if name == "model.layers.0.mlp.down_proj": @@ -148,7 +147,6 @@ def test_serialize_config_tinyllama(): assert serialized_config.config_groups["group_0"].input_activations is None assert serialized_config.config_groups["group_1"].targets == ["Linear"] assert serialized_config.config_groups["group_1"].input_activations is not None - assert serialized_config.quantization_status == QuantizationStatus.FROZEN assert serialized_config.format == CompressionFormat.dense.value assert serialized_config.quant_method == DEFAULT_QUANTIZATION_METHOD assert serialized_config.ignore == ["model.layers.1.mlp.down_proj"] diff --git a/tests/test_quantization/lifecycle/test_dynamic_lifecycle.py b/tests/test_quantization/lifecycle/test_dynamic_lifecycle.py index 1f88626e..dd700637 100644 --- a/tests/test_quantization/lifecycle/test_dynamic_lifecycle.py +++ b/tests/test_quantization/lifecycle/test_dynamic_lifecycle.py @@ -14,10 +14,7 @@ import torch -from compressed_tensors.quantization.lifecycle import ( - apply_quantization_config, - freeze_module_quantization, -) +from compressed_tensors.quantization.lifecycle import apply_quantization_config from compressed_tensors.quantization.quant_config import QuantizationConfig from transformers import AutoModelForCausalLM @@ -39,8 +36,6 @@ def test_apply_tinyllama_dynamic_activations(): # verify forward works w/ dynamic during calibration model(torch.zeros((1, 1), dtype=int), torch.zeros((1, 1), dtype=int)) - # freeze and test that only weight observers are deleted - model.apply(freeze_module_quantization) _test_linears_dynamic_quantization_status(model, quant_config, frozen=True) # verify forward works w/ dynamic after freeze model(torch.zeros((1, 1), dtype=int), torch.zeros((1, 1), dtype=int)) @@ -79,7 +74,6 @@ def _test_layer_dynamic_quantization_status( # check weights always have scale/zp and observer only if not frozen assert hasattr(module, "weight_scale") == weights assert hasattr(module, "weight_zero_point") == weights - assert hasattr(module, "weight_observer") == (weights and not frozen) def get_tinyllama_model(): diff --git a/tests/test_quantization/lifecycle/test_forward.py b/tests/test_quantization/lifecycle/test_forward.py index 0730c991..542cd8b9 100644 --- a/tests/test_quantization/lifecycle/test_forward.py +++ b/tests/test_quantization/lifecycle/test_forward.py @@ -15,18 +15,12 @@ import pytest import torch -from compressed_tensors.quantization.lifecycle.calibration import ( - set_module_for_calibration, -) from compressed_tensors.quantization.lifecycle.forward import ( - calibrate_activations, dequantize, forward_quantize, quantize, wrap_module_forward_quantized, - wrap_module_forward_quantized_attn, ) -from compressed_tensors.quantization.lifecycle.frozen import freeze_module_quantization from compressed_tensors.quantization.lifecycle.initialize import ( initialize_module_for_quantization, ) @@ -60,10 +54,10 @@ def test_wrap_module_forward_quantized(create_quantization_scheme): assert not func_forward == layer.forward.__func__ -@pytest.mark.parametrize( - "quantization_status", ["initialized", "calibration", "frozen"] -) -def test_forward_quantize(create_quantization_scheme, quantization_status): +@pytest.mark.parametrize("quantization_status", ["initialized", "calibration"]) +def test_forward_quantize( + mock_per_tensor_calibration, create_quantization_scheme, quantization_status +): num_bits = 8 quantization_scheme = create_quantization_scheme( targets=["*"], @@ -81,37 +75,20 @@ def test_forward_quantize(create_quantization_scheme, quantization_status): if layer.quantization_status == QuantizationStatus.INITIALIZED: # Init zp and scales initialize_module_for_quantization(layer, quantization_scheme) - # init weight observers; update weight scales/zp - set_module_for_calibration(layer) + # mock weight calibration + mock_per_tensor_calibration(layer, "weight", value=layer.weight.data) # call quant/dequant on weights out = forward_quantize(layer, layer.weight, "weight", quantization_args) assert torch.allclose(out, layer.weight.data, atol=0.2) elif layer.quantization_status == QuantizationStatus.CALIBRATION: # init zp/scales initialize_module_for_quantization(layer, quantization_scheme) - # init weight observers; update weight scales/zp - set_module_for_calibration(layer) - # init input observers, update input scales/zp - calibrate_activations( - module=layer, - value=dummy_tensor, - base_name="input", - quantization_args=quantization_args, - ) + # run weight and input calibration + mock_per_tensor_calibration(layer, "weight", value=layer.weight.data) + mock_per_tensor_calibration(layer, "input", value=dummy_tensor) # call quant/dequant on inputs out = forward_quantize(layer, dummy_tensor, "input", quantization_args) assert torch.allclose(out, dummy_tensor, atol=0.2) - assert layer.input_observer._num_observed_tokens == dummy_tensor.shape[0] - elif layer.quantization_status == QuantizationStatus.FROZEN: - # init weight observers - initialize_module_for_quantization(layer, quantization_scheme) - # init weight observers; update weight scales/zp - set_module_for_calibration(layer) - # remove weight observers and any input observers - freeze_module_quantization(layer) - # call quant/dequant on weights - out = forward_quantize(layer, layer.weight.data, "weight", quantization_args) - assert torch.allclose(out, layer.weight.data, atol=0.2) @pytest.mark.parametrize( @@ -226,21 +203,3 @@ def test_dequantize(num_bits, type, strategy, group_size, scale, zero_point, g_i dtype=None, g_idx=g_idx, ) - - -def test_wrap_module_forward_quantized_attn(create_quantization_scheme): - num_bits = 8 - quantization_scheme = create_quantization_scheme( - targets=["self_attn"], - weights=QuantizationArgs(num_bits=num_bits, symmetric=True), - input_activations=QuantizationArgs(num_bits=num_bits, symmetric=False), - ) - - mock_attn_layer = Linear(4, 4) - - attn_forward = mock_attn_layer.forward.__func__ - - # check that the forward call is overwritten - wrap_module_forward_quantized_attn(mock_attn_layer, quantization_scheme) - - assert not attn_forward == mock_attn_layer.forward.__func__ diff --git a/tests/test_quantization/lifecycle/test_frozen.py b/tests/test_quantization/lifecycle/test_frozen.py deleted file mode 100644 index dddff117..00000000 --- a/tests/test_quantization/lifecycle/test_frozen.py +++ /dev/null @@ -1,50 +0,0 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from compressed_tensors.quantization.lifecycle.calibration import ( - set_module_for_calibration, -) -from compressed_tensors.quantization.lifecycle.frozen import freeze_module_quantization -from compressed_tensors.quantization.lifecycle.initialize import ( - initialize_module_for_quantization, -) -from compressed_tensors.quantization.quant_args import QuantizationArgs -from compressed_tensors.quantization.quant_config import QuantizationStatus -from torch.nn import Linear - - -def test_set_module_for_calibration(create_quantization_scheme): - num_bits = 8 - quantization_scheme = create_quantization_scheme( - targets=["*"], - weights=QuantizationArgs(num_bits=num_bits, symmetric=True), - input_activations=QuantizationArgs(num_bits=num_bits, symmetric=False), - ) - - layer = Linear(4, 4) - - initialize_module_for_quantization(layer, quantization_scheme) - layer.quantization_status = QuantizationStatus("calibration") - set_module_for_calibration(layer) - - # should have both input and weight observer after initalizing - assert hasattr(layer, "weight_observer") - - # observers should get deleted after freezing - freeze_module_quantization(layer) - assert not hasattr(layer, "input_observer") - assert not hasattr(layer, "weight_observer") - - assert layer.quantization_status == QuantizationStatus("frozen") diff --git a/tests/test_quantization/lifecycle/test_kv_cache.py b/tests/test_quantization/lifecycle/test_kv_cache.py deleted file mode 100644 index b72f5265..00000000 --- a/tests/test_quantization/lifecycle/test_kv_cache.py +++ /dev/null @@ -1,78 +0,0 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import pytest -import torch -from compressed_tensors.quantization import ( - QuantizationConfig, - QuantizationStatus, - apply_quantization_config, - freeze_module_quantization, -) -from transformers import AutoModelForCausalLM - - -config = { - "quant_method": "compressed-tensors", - "format": "fakequant", - "kv_cache_scheme": { - "num_bits": 8, - "type": "int", - "symmetric": True, - "strategy": "tensor", - }, - "config_groups": { - "group_1": { - "weights": { - "num_bits": 4, - "type": "int", - "symmetric": True, - "strategy": "tensor", - }, - "targets": ["Linear"], - }, - }, -} - - -@pytest.mark.parametrize("config", [config]) -def test_kv_cache_quantization(config): - - sample = { - name: torch.ones((1, 32)).long() - for name in ["input_ids", "attention_mask", "labels"] - } - model = AutoModelForCausalLM.from_pretrained( - "HuggingFaceM4/tiny-random-LlamaForCausalLM", - torch_dtype="auto", - ) - model.eval() - - config = QuantizationConfig(**config) - config.quantization_status = QuantizationStatus.CALIBRATION - apply_quantization_config(model, config) - - with torch.no_grad(): - _ = model(**sample) - - model.apply(freeze_module_quantization) - - reloaded_config = QuantizationConfig.from_pretrained(model) - - assert ( - config.kv_cache_scheme.model_dump().keys() - == reloaded_config.kv_cache_scheme.model_dump().keys() - ) - assert list(config.kv_cache_scheme.model_dump().values()) == list( - reloaded_config.kv_cache_scheme.model_dump().values() - ) diff --git a/tests/test_quantization/lifecycle/test_lifecycle.py b/tests/test_quantization/lifecycle/test_lifecycle.py index 41c56789..cabfee76 100644 --- a/tests/test_quantization/lifecycle/test_lifecycle.py +++ b/tests/test_quantization/lifecycle/test_lifecycle.py @@ -14,11 +14,8 @@ from copy import deepcopy +import pytest import torch -from compressed_tensors.quantization.lifecycle.calibration import ( - set_module_for_calibration, -) -from compressed_tensors.quantization.lifecycle.frozen import freeze_module_quantization from compressed_tensors.quantization.lifecycle.initialize import ( initialize_module_for_quantization, ) @@ -27,7 +24,7 @@ from torch.nn import Linear -def test_lifecyle(create_quantization_scheme): +def test_lifecyle(mock_per_tensor_calibration, create_quantization_scheme): num_bits = 8 quantization_scheme = create_quantization_scheme( @@ -63,11 +60,6 @@ def test_lifecyle(create_quantization_scheme): assert hasattr(layer, "quantization_status") assert layer.quantization_status == QuantizationStatus.INITIALIZED - set_module_for_calibration(layer) - assert hasattr(layer, "weight_observer") - assert layer.quantization_status == QuantizationStatus.CALIBRATION - - # do a calibration step assert torch.numel(layer.input_zero_point.data) == 1 assert torch.numel(layer.input_scale) == 1 assert torch.numel(layer.weight_scale) == 1 @@ -75,7 +67,10 @@ def test_lifecyle(create_quantization_scheme): random_input = torch.randn(4, 4) random_input[0][0] = 42 # skew distribution to force non-zero zp - layer(random_input) + + # do a calibration step + mock_per_tensor_calibration(layer, "weight", value=layer.weight) + mock_per_tensor_calibration(layer, "input", value=random_input) # zero-points and scale should be updated after forward pass assert torch.numel(layer.input_zero_point.data) > 0 @@ -96,7 +91,9 @@ def test_lifecyle(create_quantization_scheme): for _ in range(10): random_input = torch.randn(4, 4) random_input[0][0] = 42 # skew distribution to force non-zero zp - layer(random_input) + + mock_per_tensor_calibration(layer, "weight", value=layer.weight) + mock_per_tensor_calibration(layer, "input", value=random_input) assert initialized_layer_input_zero_point != 0 assert initialized_layer_input_scale != layer.input_scale @@ -110,9 +107,6 @@ def test_lifecyle(create_quantization_scheme): layer_before_freeze_input_scale = deepcopy(layer.input_scale) layer_before_freeze_weight_scale = deepcopy(layer.weight_scale) - # Freeze, no update after any forward pass - freeze_module_quantization(layer) - for _ in range(10): layer(torch.randn(4, 4)) assert layer_before_freeze_input_zero_point == layer.input_zero_point diff --git a/tests/test_quantization/test_cache.py b/tests/test_quantization/test_cache.py deleted file mode 100644 index 941af70f..00000000 --- a/tests/test_quantization/test_cache.py +++ /dev/null @@ -1,116 +0,0 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -from compressed_tensors.quantization.cache import QuantizedKVParameterCache -from compressed_tensors.quantization.quant_args import QuantizationArgs - - -def test_is_quantized_cache_singleton(): - """ - Check if quantized_cache is a singleton, used for - passing in QuantizedKVParameterCache to the forward call of - the model's self_attn - """ - - args = QuantizationArgs() - cache: QuantizedKVParameterCache = args.get_kv_cache() - observer = args.get_observer() - - tensor = torch.tensor([1, 2, 3]) - cache.k_scales.append(tensor) - cache.k_observers.append(observer) - - same_cache = args.get_kv_cache() - - assert len(cache.k_scales) == len(same_cache.k_scales) - assert torch.equal(cache.k_scales[0], same_cache.k_scales[0]) - - assert cache.k_observers == same_cache.k_observers - assert hex(id(cache.k_observers[0])) == hex(id(same_cache.k_observers[0])) - - cache.reset() - - -def test_update(): - - nbits = 8 - args = QuantizationArgs(nbits=nbits, symmetric=True) - cache: QuantizedKVParameterCache = args.get_kv_cache() - - max_key_states_val = 1.0 - max_value_states_val = 2.0 - key_states = torch.cat( - (max_key_states_val * torch.ones(1, 2, 2), torch.ones(1, 2, 2)), dim=0 - ) - value_states = torch.cat( - (max_value_states_val * torch.ones(1, 2, 2), torch.ones(1, 2, 2)), dim=0 - ) - layer_idx = 0 - - cache.update(key_states, value_states, layer_idx) - denom = (2 ** (nbits) - 1) / 2 - expected_k_scale = torch.tensor([max_key_states_val / denom]) - expected_v_scale = torch.tensor([max_value_states_val / denom]) - - assert cache.k_scales[0] == expected_k_scale - assert cache.v_scales[0] == expected_v_scale - - # new attn layer - layer_idx = 1 - cache.update(key_states, value_states, layer_idx) - - assert len(cache.k_scales) == 2 - assert len(cache.v_scales) == 2 - - assert len(cache.k_observers) == 2 - assert len(cache.v_observers) == 2 - - cache.reset() - - -def test_cache_reset(): - nbits = 8 - args = QuantizationArgs(nbits=nbits, symmetric=True) - cache: QuantizedKVParameterCache = args.get_kv_cache() - - max_key_states_val = 1.0 - max_value_states_val = 2.0 - key_states = torch.cat( - (max_key_states_val * torch.ones(1, 2, 2), torch.ones(1, 2, 2)), dim=0 - ) - value_states = torch.cat( - (max_value_states_val * torch.ones(1, 2, 2), torch.ones(1, 2, 2)), dim=0 - ) - layer_idx = 0 - - cache.update(key_states, value_states, layer_idx) - assert len(cache.k_scales) == 1 - assert len(cache.v_scales) == 1 - - assert len(cache.k_observers) == 1 - assert len(cache.v_observers) == 1 - - cache.reset() - - # new instance, different memory addr - different_cache: QuantizedKVParameterCache = args.get_kv_cache() - - assert len(different_cache.k_scales) == 0 - assert len(different_cache.v_scales) == 0 - - assert len(different_cache.k_observers) == 0 - assert len(different_cache.v_observers) == 0 - - assert hex(id(cache)) != hex(id(different_cache)) diff --git a/tests/test_quantization/test_configs/test_bit_depths.py b/tests/test_quantization/test_configs/test_bit_depths.py index c3bf39c9..4fa63306 100644 --- a/tests/test_quantization/test_configs/test_bit_depths.py +++ b/tests/test_quantization/test_configs/test_bit_depths.py @@ -26,7 +26,6 @@ def create_config(bit_depth, quant_type, input_symmetry, weight_symmetry): - print(quant_type) weights = QuantizationArgs( num_bits=bit_depth, type=quant_type, symmetric=weight_symmetry ) @@ -53,7 +52,9 @@ def create_config(bit_depth, quant_type, input_symmetry, weight_symmetry): @pytest.mark.parametrize("quant_type", ["int"]) @pytest.mark.parametrize("input_symmetry", [True, False, None]) @pytest.mark.parametrize("weight_symmetry", [True, False]) -def test_bit_depths(bit_depth, quant_type, input_symmetry, weight_symmetry): +def test_bit_depths( + mock_per_tensor_calibration, bit_depth, quant_type, input_symmetry, weight_symmetry +): model = Linear(64, 64) quant_config = create_config(bit_depth, quant_type, input_symmetry, weight_symmetry) apply_quantization_config(model, quant_config) @@ -62,8 +63,17 @@ def test_bit_depths(bit_depth, quant_type, input_symmetry, weight_symmetry): max = int(2**bit_depth / 2) - 1 inputs = torch.randn(32, 64) - model(inputs) + model.apply( + lambda module: mock_per_tensor_calibration( + module, base_name="weight", value=model.weight + ) + ) if input_symmetry is not None: + model.apply( + lambda module: mock_per_tensor_calibration( + module, base_name="input", value=inputs + ) + ) assert model.input_zero_point >= min assert model.input_zero_point <= max @@ -105,7 +115,9 @@ def test_bit_depths(bit_depth, quant_type, input_symmetry, weight_symmetry): @pytest.mark.parametrize("quant_type", ["float"]) @pytest.mark.parametrize("input_symmetry", [True, False, None]) @pytest.mark.parametrize("weight_symmetry", [True, False]) -def test_fp8(bit_depth, quant_type, input_symmetry, weight_symmetry): +def test_fp8( + mock_per_tensor_calibration, bit_depth, quant_type, input_symmetry, weight_symmetry +): model = Linear(64, 64) quant_config = create_config(bit_depth, quant_type, input_symmetry, weight_symmetry) apply_quantization_config(model, quant_config) @@ -115,10 +127,19 @@ def test_fp8(bit_depth, quant_type, input_symmetry, weight_symmetry): max = dtype_info.max inputs = torch.randn(32, 64) - model(inputs) + model.apply( + lambda module: mock_per_tensor_calibration( + module, base_name="weight", value=model.weight + ) + ) assert model.weight_zero_point.dtype == torch.float8_e4m3fn model.weight_zero_point.data = model.weight_zero_point.to(model.weight.dtype) if input_symmetry is not None: + model.apply( + lambda module: mock_per_tensor_calibration( + module, base_name="input", value=inputs + ) + ) assert model.input_zero_point.dtype == torch.float8_e4m3fn model.input_zero_point.data = model.input_zero_point.to(model.weight.dtype) assert model.input_zero_point >= min diff --git a/tests/test_quantization/test_configs/test_strategies.py b/tests/test_quantization/test_configs/test_strategies.py index a6f424ce..94201463 100644 --- a/tests/test_quantization/test_configs/test_strategies.py +++ b/tests/test_quantization/test_configs/test_strategies.py @@ -53,7 +53,9 @@ def create_config( @pytest.mark.parametrize("input_symmetry", [None]) @pytest.mark.parametrize("weight_symmetry", [True, False]) @pytest.mark.parametrize("model_shape", [(64, 128), (300, 200), (400, 400)]) -def test_channelwise(input_symmetry, weight_symmetry, model_shape): +def test_channelwise( + mock_per_channel_calibration, input_symmetry, weight_symmetry, model_shape +): model = Linear(model_shape[0], model_shape[1]) quant_config = create_config( input_symmetry, weight_symmetry, w_strategy=QuantizationStrategy.CHANNEL @@ -61,7 +63,9 @@ def test_channelwise(input_symmetry, weight_symmetry, model_shape): apply_quantization_config(model, quant_config) inputs = torch.randn(32, model_shape[0]) - model(inputs) + mock_per_channel_calibration(model, base_name="weight", value=model.weight) + if input_symmetry is not None: + mock_per_channel_calibration(model, base_name="input", value=inputs) assert list(model.weight_scale.shape) == [model_shape[1], 1] assert list(model.weight_zero_point.shape) == [model_shape[1], 1] @@ -72,7 +76,9 @@ def test_channelwise(input_symmetry, weight_symmetry, model_shape): @pytest.mark.parametrize("weight_symmetry", [True, False]) @pytest.mark.parametrize("model_shape", [(128, 256), (256, 512), (512, 1024)]) @pytest.mark.parametrize("group_size", [32, 128]) -def test_group(input_symmetry, weight_symmetry, model_shape, group_size): +def test_group( + mock_per_group_calibration, input_symmetry, weight_symmetry, model_shape, group_size +): model = Linear(model_shape[0], model_shape[1]) quant_config = create_config( input_symmetry, @@ -83,7 +89,13 @@ def test_group(input_symmetry, weight_symmetry, model_shape, group_size): apply_quantization_config(model, quant_config) inputs = torch.randn(128, model_shape[0]) - model(inputs) + mock_per_group_calibration( + model, base_name="weight", value=model.weight, group_size=group_size + ) + if input_symmetry is not None: + mock_per_group_calibration( + model, base_name="input", value=inputs, group_size=group_size + ) assert list(model.weight_scale.shape) == [ model_shape[1], @@ -99,7 +111,13 @@ def test_group(input_symmetry, weight_symmetry, model_shape, group_size): @pytest.mark.parametrize("input_symmetry", [True, False]) @pytest.mark.parametrize("weight_symmetry", [True, False]) @pytest.mark.parametrize("input_shape", [(32, 256), (300, 200), (400, 400)]) -def test_token(input_symmetry, weight_symmetry, input_shape): +def test_token( + mock_per_channel_calibration, + mock_per_token_calibration, + input_symmetry, + weight_symmetry, + input_shape, +): model = Linear(input_shape[1], 256) quant_config = create_config( input_symmetry, @@ -110,7 +128,8 @@ def test_token(input_symmetry, weight_symmetry, input_shape): apply_quantization_config(model, quant_config) inputs = torch.randn(input_shape) - model(inputs) + mock_per_channel_calibration(model, base_name="weight", value=model.weight) + mock_per_token_calibration(model, base_name="input", value=inputs) assert list(model.input_scale.shape) == [1, 1] assert list(model.input_zero_point.shape) == [1, 1] diff --git a/tests/test_quantization/test_observers/__init__.py b/tests/test_quantization/test_observers/__init__.py deleted file mode 100644 index 0c44f887..00000000 --- a/tests/test_quantization/test_observers/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/tests/test_quantization/test_observers/test_helpers.py b/tests/test_quantization/test_observers/test_helpers.py deleted file mode 100644 index 0b976a25..00000000 --- a/tests/test_quantization/test_observers/test_helpers.py +++ /dev/null @@ -1,91 +0,0 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from compressed_tensors.quantization import ( - QuantizationConfig, - apply_quantization_config, -) -from compressed_tensors.quantization.observers.helpers import get_observer_token_count -from transformers import AutoModelForCausalLM, AutoTokenizer - - -def test_get_observer_token_count(): - model = AutoModelForCausalLM.from_pretrained("Isotonic/TinyMixtral-4x248M-MoE") - tokenizer = AutoTokenizer.from_pretrained("Isotonic/TinyMixtral-4x248M-MoE") - model.eval() - config = QuantizationConfig( - format="fakequant", - quantization_status="calibration", - config_groups={ - "group_1": { - "input_activations": { - "num_bits": 8, - "type": "int", - "symmetric": False, - "strategy": "tensor", - }, - "targets": ["Linear"], - }, - }, - ) - apply_quantization_config(model, config) - - # start calibration - calib_list = [ - "I am a string that", - "is used for calibration so", - "that your model is", - "quantized properly.", - ] - - total_num_tokens_observed = 0 - for calib_sample in calib_list: - calib_tensor = tokenizer(calib_sample, return_tensors="pt") - _ = model(**calib_tensor) - total_num_tokens_observed += len(calib_tensor.input_ids.flatten()) - - counter = get_observer_token_count(model) - - # filter out the None values - # (tokens, in the appropriate format, that were not observed by the model) - counter = {k: v for k, v in counter.items() if v is not None} - - # iterate over all the layers in the model where the token count in the proper - # format is has been observed - for i in range(model.config.num_hidden_layers): - # fetch the tokens observed by the router - tokens_observed_by_router = counter.pop( - f"model.layers.{i}.block_sparse_moe.gate" - ) - assert tokens_observed_by_router == total_num_tokens_observed - - # fetch the sum of tokens observed by all the experts - sum_tokens_observed_by_experts = 0 - keys_for_this_layer = [ - k - for k in counter.keys() - if f"model.layers.{i}.block_sparse_moe.experts" in k - ] - for key in keys_for_this_layer: - sum_tokens_observed_by_experts += counter.pop(key) - - # each Mixtral expert is comprised of 3 linear layers, - # so we need to multiply by 3 - assert ( - sum_tokens_observed_by_experts - == total_num_tokens_observed * model.config.num_experts_per_tok * 3 - ) - - # there are no more information in the counter - assert len(counter) == 0 diff --git a/tests/test_quantization/test_observers/test_min_max.py b/tests/test_quantization/test_observers/test_min_max.py deleted file mode 100644 index b9333332..00000000 --- a/tests/test_quantization/test_observers/test_min_max.py +++ /dev/null @@ -1,113 +0,0 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import pytest -import torch -from compressed_tensors.quantization.quant_args import QuantizationArgs - - -def make_dummy_g_idx(columns: int, group_size: int) -> torch.Tensor: - perm = torch.randperm(columns) - return torch.tensor([index // group_size for index in range(columns)])[perm] - - -@pytest.mark.parametrize( - "symmetric,expected_scale,expected_zero_point", - [ - (True, 0.0078, 0), - (False, 0.0039, -128), - ], -) -def test_min_max_observer(symmetric, expected_scale, expected_zero_point): - tensor = torch.tensor([1, 1, 1, 1, 1]) - num_bits = 8 - weights = QuantizationArgs(num_bits=num_bits, symmetric=symmetric) - - observer = weights.get_observer() - scale, zero_point = observer(tensor) - - assert round(scale.item(), 4) == expected_scale - assert round(zero_point.item(), 4) == expected_zero_point - - -def test_min_max_observer_symmetric_scale_range(): - tensor = torch.rand(4, 4) - tensor *= 127 - - num_bits = 8 - weights = QuantizationArgs(num_bits=num_bits, symmetric=True) - - observer = weights.get_observer() - scale, zero_point = observer(tensor) - - # if symmetric, max symmetric_range = abs(-128) / 255 - assert round(scale.item(), 4) <= 1.0039 - assert round(zero_point.item(), 4) == 0 - - -def test_min_max_observer_value_update(): - inp = torch.tensor([1, 1, 1, 1, 1]) - inp_update_max = torch.tensor([127, 1, 1, 1, 1]) - inp_update_min = torch.tensor([-128, 1, 1, 1, 1]) - - delta = 1e-6 - - # update the min, max twice total - tensors = [ - inp, - inp, - inp_update_max, # update max - inp, - inp_update_min, # update min - ] - - tensor = inp - num_bits = 8 - weights = QuantizationArgs(num_bits=num_bits, symmetric=True) - - observer = weights.get_observer() - curr_max = 1 - curr_min = 1 - for i, tensor in enumerate(tensors): - observer(tensor) - curr_max = max(observer.max_val.get("default"), curr_max) - curr_min = min(observer.min_val.get("default"), curr_max) - - if i < 2: - assert curr_max == 1 - assert curr_min == 1 - elif i < 4: - assert abs(curr_max - 2.2600) < delta - assert curr_min == 1 - else: - assert abs(curr_max - 2.2600) < delta - assert abs(curr_min - (-0.2900)) < delta - - -def test_g_idx(): - group_size = 2 - input_shape = (128, 512) - tensor = torch.rand(input_shape) - weights = QuantizationArgs(num_bits=8, group_size=group_size) - g_idx = make_dummy_g_idx(tensor.shape[1], group_size) - - observer = weights.get_observer() - scale_g_idx, zero_point_g_idx = observer(tensor, g_idx=g_idx) - - observer.reset() - scale, zero_point = observer(tensor[:, torch.argsort(g_idx)]) - - assert scale_g_idx == pytest.approx(scale) - assert zero_point_g_idx == pytest.approx(zero_point) diff --git a/tests/test_quantization/test_observers/test_mse.py b/tests/test_quantization/test_observers/test_mse.py deleted file mode 100644 index 098551a2..00000000 --- a/tests/test_quantization/test_observers/test_mse.py +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import pytest -import torch -from compressed_tensors.quantization.observers import MovingAverageMSEObserver -from compressed_tensors.quantization.quant_args import QuantizationArgs - - -@pytest.mark.parametrize( - "symmetric,expected_scale,expected_zero_point", - [ - (True, 0.0078, 0), - (False, 0.0039, -128), - ], -) -def test_mse_observer(symmetric, expected_scale, expected_zero_point): - tensor = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0]) - num_bits = 8 - weights = QuantizationArgs(num_bits=num_bits, symmetric=symmetric, observer="mse") - - observer = weights.get_observer() - scale, zero_point = observer(tensor) - - assert isinstance(observer, MovingAverageMSEObserver) - assert round(scale.item(), 4) == expected_scale - assert round(zero_point.item(), 4) == expected_zero_point - - -def test_mse_observer_symmetric_scale_range(): - tensor = torch.rand(4, 4) - tensor *= 127 - - num_bits = 8 - weights = QuantizationArgs(num_bits=num_bits, symmetric=True) - - observer = weights.get_observer() - scale, zero_point = observer(tensor) - - # if symmetric, max symmetric_range = abs(-128) / 255 - assert round(scale.item(), 4) <= 1.0039 - assert round(zero_point.item(), 4) == 0