diff --git a/.github/workflows/trigger-all.yml b/.github/workflows/trigger-all.yml index 02427999..1e53ff30 100644 --- a/.github/workflows/trigger-all.yml +++ b/.github/workflows/trigger-all.yml @@ -35,6 +35,6 @@ jobs: test_configs: '[{"python":"3.11.4","label":"ubuntu-22.04","timeout":"40"}, {"python":"3.10.12","label":"ubuntu-20.04","timeout":"40"}, {"python":"3.9.17","label":"k8s-a100-solo","timeout":"40"}, - {"python":"3.8.17","label":"k8s-a100-duo","timeout":"40"}]' + {"python":"3.12.6","label":"k8s-a100-duo","timeout":"40"}]' secrets: inherit diff --git a/src/compressed_tensors/config/base.py b/src/compressed_tensors/config/base.py index ccc3e649..79a4fcdd 100644 --- a/src/compressed_tensors/config/base.py +++ b/src/compressed_tensors/config/base.py @@ -12,16 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -from enum import Enum +from enum import Enum, unique from typing import List, Optional from compressed_tensors.registry import RegistryMixin from pydantic import BaseModel -__all__ = ["SparsityCompressionConfig", "CompressionFormat"] +__all__ = ["SparsityCompressionConfig", "CompressionFormat", "SparsityStructure"] +@unique class CompressionFormat(Enum): dense = "dense" sparse_bitmask = "sparse-bitmask" @@ -32,6 +33,63 @@ class CompressionFormat(Enum): marlin_24 = "marlin-24" +@unique +class SparsityStructure(Enum): + """ + An enumeration to represent different sparsity structures. + + Attributes + ---------- + TWO_FOUR : str + Represents a 2:4 sparsity structure. + ZERO_ZERO : str + Represents a 0:0 sparsity structure. + UNSTRUCTURED : str + Represents an unstructured sparsity structure. + + Examples + -------- + >>> SparsityStructure('2:4') + + + >>> SparsityStructure('unstructured') + + + >>> SparsityStructure('2:4') == SparsityStructure.TWO_FOUR + True + + >>> SparsityStructure('UNSTRUCTURED') == SparsityStructure.UNSTRUCTURED + True + + >>> SparsityStructure(None) == SparsityStructure.UNSTRUCTURED + True + + >>> SparsityStructure('invalid') + Traceback (most recent call last): + ... + ValueError: invalid is not a valid SparsityStructure + """ + + TWO_FOUR = "2:4" + UNSTRUCTURED = "unstructured" + ZERO_ZERO = "0:0" + + def __new__(cls, value): + obj = object.__new__(cls) + obj._value_ = value.lower() if value is not None else value + return obj + + @classmethod + def _missing_(cls, value): + # Handle None and case-insensitive values + if value is None: + return cls.UNSTRUCTURED + for member in cls: + if member.value == value.lower(): + return member + raise ValueError(f"{value} is not a valid {cls.__name__}") + + class SparsityCompressionConfig(RegistryMixin, BaseModel): """ Base data class for storing sparsity compression parameters diff --git a/src/compressed_tensors/quantization/__init__.py b/src/compressed_tensors/quantization/__init__.py index 848a4458..9fde69a3 100644 --- a/src/compressed_tensors/quantization/__init__.py +++ b/src/compressed_tensors/quantization/__init__.py @@ -19,4 +19,3 @@ from .quant_config import * from .quant_scheme import * from .lifecycle import * -from .cache import QuantizedKVParameterCache diff --git a/src/compressed_tensors/quantization/cache.py b/src/compressed_tensors/quantization/cache.py deleted file mode 100644 index cc33a48a..00000000 --- a/src/compressed_tensors/quantization/cache.py +++ /dev/null @@ -1,201 +0,0 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from enum import Enum -from typing import Any, Dict, List, Optional, Tuple - -from compressed_tensors.quantization.observers import Observer -from compressed_tensors.quantization.quant_args import QuantizationArgs -from torch import Tensor -from transformers import DynamicCache as HFDyanmicCache - - -class KVCacheScaleType(Enum): - KEY = "k_scale" - VALUE = "v_scale" - - -class QuantizedKVParameterCache(HFDyanmicCache): - - """ - Quantized KV cache used in the forward call based on HF's dynamic cache. - Quantization strategy (tensor, group, channel) set from Quantization arg's strategy - Singleton, so that the same cache gets reused in all forward call of self_attn. - Each time forward is called, .update() is called, and ._quantize(), ._dequantize() - gets called appropriately. - The size of tensor is - `[batch_size, num_heads, seq_len - residual_length, head_dim]`. - - - Triggered by adding kv_cache_scheme in the recipe. - - Example: - - ```python3 - recipe = ''' - quant_stage: - quant_modifiers: - QuantizationModifier: - kv_cache_scheme: - num_bits: 8 - type: float - strategy: tensor - dynamic: false - symmetric: true - ''' - - """ - - _instance = None - _initialized = False - - def __new__(cls, *args, **kwargs): - """Singleton""" - if cls._instance is None: - cls._instance = super(QuantizedKVParameterCache, cls).__new__(cls) - return cls._instance - - def __init__(self, quantization_args: QuantizationArgs): - if not self._initialized: - super().__init__() - - self.quantization_args = quantization_args - - self.k_observers: List[Observer] = [] - self.v_observers: List[Observer] = [] - - # each index corresponds to layer_idx of the attention layer - self.k_scales: List[Tensor] = [] - self.v_scales: List[Tensor] = [] - - self.k_zps: List[Tensor] = [] - self.v_zps: List[Tensor] = [] - - self._initialized = True - - def update( - self, - key_states: Tensor, - value_states: Tensor, - layer_idx: int, - cache_kwargs: Optional[Dict[str, Any]] = None, - ) -> Tuple[Tensor, Tensor]: - """ - Get the k_scale and v_scale and output the - fakequant-ed key_states and value_states - """ - - if len(self.k_observers) <= layer_idx: - k_observer = self.quantization_args.get_observer() - v_observer = self.quantization_args.get_observer() - - self.k_observers.append(k_observer) - self.v_observers.append(v_observer) - - q_key_states = self._quantize( - key_states.contiguous(), KVCacheScaleType.KEY, layer_idx - ) - q_value_states = self._quantize( - value_states.contiguous(), KVCacheScaleType.VALUE, layer_idx - ) - - qdq_key_states = self._dequantize(q_key_states, KVCacheScaleType.KEY, layer_idx) - qdq_value_states = self._dequantize( - q_value_states, KVCacheScaleType.VALUE, layer_idx - ) - - keys_to_return, values_to_return = qdq_key_states, qdq_value_states - - return keys_to_return, values_to_return - - def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: - """ - Returns the sequence length of the cached states. - A layer index can be optionally passed. - """ - if len(self.key_cache) <= layer_idx: - return 0 - # since we cannot get the seq_length of each layer directly and - # rely on `_seen_tokens` which is updated every "layer_idx" == 0, - # this is a hack to get the actual seq_length for the given layer_idx - # this part of code otherwise fails when used to - # verify attn_weight shape in some models - return self._seen_tokens if layer_idx == 0 else self._seen_tokens - 1 - - def reset_states(self): - """reset the kv states (used in calibration)""" - self.key_cache: List[Tensor] = [] - self.value_cache: List[Tensor] = [] - # Used in `generate` to keep tally of how many tokens the cache has seen - self._seen_tokens = 0 - self._quantized_key_cache: List[Tensor] = [] - self._quantized_value_cache: List[Tensor] = [] - - def reset(self): - """ - Reset the instantiation, create new instance on init - """ - QuantizedKVParameterCache._instance = None - QuantizedKVParameterCache._initialized = False - - def _quantize(self, tensor, kv_type, layer_idx): - """Quantizes a key/value using a defined quantization method.""" - from compressed_tensors.quantization.lifecycle.forward import quantize - - if kv_type == KVCacheScaleType.KEY: # key type - observer = self.k_observers[layer_idx] - scales = self.k_scales - zps = self.k_zps - else: - assert kv_type == KVCacheScaleType.VALUE - observer = self.v_observers[layer_idx] - scales = self.v_scales - zps = self.v_zps - - scale, zp = observer(tensor) - if len(scales) <= layer_idx: - scales.append(scale) - zps.append(zp) - else: - scales[layer_idx] = scale - zps[layer_idx] = scale - - q_tensor = quantize( - x=tensor, - scale=scale, - zero_point=zp, - args=self.quantization_args, - ) - return q_tensor - - def _dequantize(self, qtensor, kv_type, layer_idx): - """Dequantizes back the tensor that was quantized by `self._quantize()`""" - from compressed_tensors.quantization.lifecycle.forward import dequantize - - if kv_type == KVCacheScaleType.KEY: - scale = self.k_scales[layer_idx] - zp = self.k_zps[layer_idx] - else: - assert kv_type == KVCacheScaleType.VALUE - scale = self.v_scales[layer_idx] - zp = self.v_zps[layer_idx] - - qdq_tensor = dequantize( - x_q=qtensor, - scale=scale, - zero_point=zp, - args=self.quantization_args, - ) - return qdq_tensor diff --git a/src/compressed_tensors/quantization/lifecycle/__init__.py b/src/compressed_tensors/quantization/lifecycle/__init__.py index 98bc6630..6acab255 100644 --- a/src/compressed_tensors/quantization/lifecycle/__init__.py +++ b/src/compressed_tensors/quantization/lifecycle/__init__.py @@ -15,9 +15,7 @@ # flake8: noqa # isort: skip_file -from .calibration import * from .forward import * -from .frozen import * from .initialize import * from .compressed import * from .apply import * diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index 09281528..7c498787 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -22,13 +22,9 @@ import torch from compressed_tensors.config import CompressionFormat -from compressed_tensors.quantization.lifecycle.calibration import ( - set_module_for_calibration, -) from compressed_tensors.quantization.lifecycle.compressed import ( compress_quantized_weights, ) -from compressed_tensors.quantization.lifecycle.frozen import freeze_module_quantization from compressed_tensors.quantization.lifecycle.initialize import ( initialize_module_for_quantization, ) @@ -233,6 +229,7 @@ def apply_quantization_status(model: Module, status: QuantizationStatus): :param model: model to apply quantization to :param status: status to update the module to """ + current_status = infer_quantization_status(model) if status >= QuantizationStatus.INITIALIZED > current_status: @@ -243,18 +240,6 @@ def apply_quantization_status(model: Module, status: QuantizationStatus): ) ) - if current_status < status >= QuantizationStatus.CALIBRATION > current_status: - # only quantize weights up front when our end goal state is calibration, - # weight quantization parameters are already loaded for frozen/compressed - quantize_weights_upfront = status == QuantizationStatus.CALIBRATION - model.apply( - lambda module: set_module_for_calibration( - module, quantize_weights_upfront=quantize_weights_upfront - ) - ) - if current_status < status >= QuantizationStatus.FROZEN > current_status: - model.apply(freeze_module_quantization) - if current_status < status >= QuantizationStatus.COMPRESSED > current_status: model.apply(compress_quantized_weights) diff --git a/src/compressed_tensors/quantization/lifecycle/calibration.py b/src/compressed_tensors/quantization/lifecycle/calibration.py deleted file mode 100644 index c67844fa..00000000 --- a/src/compressed_tensors/quantization/lifecycle/calibration.py +++ /dev/null @@ -1,70 +0,0 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import logging - -from compressed_tensors.quantization.quant_config import QuantizationStatus -from compressed_tensors.utils import has_offloaded_params, update_parameter_data -from torch.nn import Module - - -__all__ = [ - "set_module_for_calibration", -] - - -_LOGGER = logging.getLogger(__name__) - - -def set_module_for_calibration(module: Module, quantize_weights_upfront: bool = True): - """ - marks a layer as ready for calibration which activates observers - to update scales and zero points on each forward pass - - apply to full model with `model.apply(set_module_for_calibration)` - - :param module: module to set for calibration - :param quantize_weights_upfront: whether to automatically - run weight quantization at the start of calibration - """ - if not getattr(module, "quantization_scheme", None): - # no quantization scheme nothing to do - return - status = getattr(module, "quantization_status", None) - if not status or status != QuantizationStatus.INITIALIZED: - _LOGGER.warning( - f"Attempting set module with status {status} to calibration mode. " - f"but status is not {QuantizationStatus.INITIALIZED} - you may " - "be calibrating an uninitialized module which may fail or attempting " - "to re-calibrate a frozen module" - ) - - if quantize_weights_upfront and module.quantization_scheme.weights is not None: - # set weight scale and zero_point up front, calibration data doesn't affect it - observer = module.weight_observer - g_idx = getattr(module, "weight_g_idx", None) - - offloaded = has_offloaded_params(module) - if offloaded: - module._hf_hook.pre_forward(module) - - scale, zero_point = observer(module.weight, g_idx=g_idx) - update_parameter_data(module, scale, "weight_scale") - update_parameter_data(module, zero_point, "weight_zero_point") - - if offloaded: - module._hf_hook.post_forward(module, None) - - module.quantization_status = QuantizationStatus.CALIBRATION diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index d685c0c0..dcab122a 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -14,14 +14,9 @@ from functools import wraps from math import ceil -from typing import Callable, Optional +from typing import Optional import torch -from compressed_tensors.quantization.cache import QuantizedKVParameterCache -from compressed_tensors.quantization.observers.helpers import ( - calculate_range, - compute_dynamic_scales_and_zp, -) from compressed_tensors.quantization.quant_args import ( QuantizationArgs, QuantizationStrategy, @@ -29,7 +24,11 @@ ) from compressed_tensors.quantization.quant_config import QuantizationStatus from compressed_tensors.quantization.quant_scheme import QuantizationScheme -from compressed_tensors.utils import safe_permute, update_parameter_data +from compressed_tensors.quantization.utils import ( + calculate_range, + compute_dynamic_scales_and_zp, +) +from compressed_tensors.utils import safe_permute from torch.nn import Module @@ -38,7 +37,7 @@ "dequantize", "fake_quantize", "wrap_module_forward_quantized", - "maybe_calibrate_or_quantize", + "forward_quantize", ] @@ -275,15 +274,13 @@ def wrapped_forward(self, *args, **kwargs): compressed = module.quantization_status == QuantizationStatus.COMPRESSED if scheme.input_activations is not None: - # calibrate and (fake) quantize input activations when applicable - input_ = maybe_calibrate_or_quantize( - module, input_, "input", scheme.input_activations - ) + # prehook should calibrate activations before forward call + input_ = forward_quantize(module, input_, "input", scheme.input_activations) if scheme.weights is not None and not compressed: # calibrate and (fake) quantize weights when applicable unquantized_weight = self.weight.data.clone() - self.weight.data = maybe_calibrate_or_quantize( + self.weight.data = forward_quantize( module, self.weight, "weight", scheme.weights ) @@ -291,64 +288,23 @@ def wrapped_forward(self, *args, **kwargs): output = forward_func_orig.__get__(module, module.__class__)( input_, *args[1:], **kwargs ) - if scheme.output_activations is not None: - - # calibrate and (fake) quantize output activations when applicable - # kv_cache scales updated on model self_attn forward call in - # wrap_module_forward_quantized_attn - output = maybe_calibrate_or_quantize( - module, output, "output", scheme.output_activations - ) # restore back to unquantized_value if scheme.weights is not None and not compressed: self.weight.data = unquantized_weight - return output - - # bind wrapped forward to module class so reference to `self` is correct - bound_wrapped_forward = wrapped_forward.__get__(module, module.__class__) - # set forward to wrapped forward - setattr(module, "forward", bound_wrapped_forward) - - -def wrap_module_forward_quantized_attn(module: Module, scheme: QuantizationScheme): - # expects a module already initialized and injected with the parameters in - # initialize_module_for_quantization - if hasattr(module.forward, "__func__"): - forward_func_orig = module.forward.__func__ - else: - forward_func_orig = module.forward.func - - @wraps(forward_func_orig) # ensures docstring, names, etc are propagated - def wrapped_forward(self, *args, **kwargs): - - # kv cache stored under weights - if module.quantization_status == QuantizationStatus.CALIBRATION: - quantization_args: QuantizationArgs = scheme.output_activations - past_key_value: QuantizedKVParameterCache = quantization_args.get_kv_cache() - kwargs["past_key_value"] = past_key_value - - # QuantizedKVParameterCache used for obtaining k_scale, v_scale only, - # does not store quantized_key_states and quantized_value_state - kwargs["use_cache"] = False - - attn_forward: Callable = forward_func_orig.__get__(module, module.__class__) - - past_key_value.reset_states() - - rtn = attn_forward(*args, **kwargs) - - update_parameter_data( - module, past_key_value.k_scales[module.layer_idx], "k_scale" - ) - update_parameter_data( - module, past_key_value.v_scales[module.layer_idx], "v_scale" + if scheme.output_activations is not None: + # forward-hook should calibrate/forward_quantize + if ( + module.quantization_status == QuantizationStatus.CALIBRATION + and not scheme.output_activations.dynamic + ): + return output + + output = forward_quantize( + module, output, "output", scheme.output_activations ) - - return rtn - - return forward_func_orig.__get__(module, module.__class__)(*args, **kwargs) + return output # bind wrapped forward to module class so reference to `self` is correct bound_wrapped_forward = wrapped_forward.__get__(module, module.__class__) @@ -356,12 +312,9 @@ def wrapped_forward(self, *args, **kwargs): setattr(module, "forward", bound_wrapped_forward) -def maybe_calibrate_or_quantize( +def forward_quantize( module: Module, value: torch.Tensor, base_name: str, args: "QuantizationArgs" ) -> torch.Tensor: - # don't run quantization if we haven't entered calibration mode - if module.quantization_status == QuantizationStatus.INITIALIZED: - return value # in compressed mode, the weight is already compressed and quantized so we don't # need to run fake quantization @@ -379,29 +332,13 @@ def maybe_calibrate_or_quantize( g_idx = getattr(module, "weight_g_idx", None) if args.dynamic: - # dynamic quantization - no need to invoke observer + # dynamic quantization - determine the scale/zp on the fly scale, zero_point = compute_dynamic_scales_and_zp(value=value, args=args) else: - # static quantization - get previous scale and zero point from layer + # static quantization - get scale and zero point from layer scale = getattr(module, f"{base_name}_scale") zero_point = getattr(module, f"{base_name}_zero_point", None) - if ( - module.quantization_status == QuantizationStatus.CALIBRATION - and base_name != "weight" - ): - # calibration mode - get new quant params from observer - observer = getattr(module, f"{base_name}_observer") - - updated_scale, updated_zero_point = observer(value, g_idx=g_idx) - - # update scale and zero point - update_parameter_data(module, updated_scale, f"{base_name}_scale") - update_parameter_data(module, updated_zero_point, f"{base_name}_zero_point") - - scale = updated_scale - zero_point = updated_zero_point - return fake_quantize( x=value, scale=scale, diff --git a/src/compressed_tensors/quantization/lifecycle/frozen.py b/src/compressed_tensors/quantization/lifecycle/frozen.py deleted file mode 100644 index 66356cb7..00000000 --- a/src/compressed_tensors/quantization/lifecycle/frozen.py +++ /dev/null @@ -1,55 +0,0 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from compressed_tensors.quantization.quant_config import QuantizationStatus -from compressed_tensors.quantization.utils import is_kv_cache_quant_scheme -from torch.nn import Module - - -__all__ = [ - "freeze_module_quantization", -] - - -def freeze_module_quantization(module: Module): - """ - deletes observers so static quantization is completed. - - apply to full model with `model.apply(freeze_module_quantization)` - - :param module: module to freeze quantization for - """ - scheme = getattr(module, "quantization_scheme", None) - if not scheme: - # no quantization scheme nothing to do - return - - if module.quantization_status == QuantizationStatus.FROZEN: - # nothing to do, already frozen - return - - # delete observers from module if not dynamic - if scheme.input_activations and not scheme.input_activations.dynamic: - delattr(module, "input_observer") - if scheme.weights and not scheme.weights.dynamic: - delattr(module, "weight_observer") - if ( - scheme.output_activations - and not is_kv_cache_quant_scheme(scheme) - and not scheme.output_activations.dynamic - ): - delattr(module, "output_observer") - - module.quantization_status = QuantizationStatus.FROZEN diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index 27b6a803..2a1efccb 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -14,13 +14,12 @@ import logging +from enum import Enum from typing import Optional import torch -from compressed_tensors.quantization.cache import KVCacheScaleType from compressed_tensors.quantization.lifecycle.forward import ( wrap_module_forward_quantized, - wrap_module_forward_quantized_attn, ) from compressed_tensors.quantization.quant_args import ( ActivationOrdering, @@ -36,12 +35,19 @@ __all__ = [ "initialize_module_for_quantization", + "is_attention_module", + "KVCacheScaleType", ] _LOGGER = logging.getLogger(__name__) +class KVCacheScaleType(Enum): + KEY = "k_scale" + VALUE = "v_scale" + + def initialize_module_for_quantization( module: Module, scheme: Optional[QuantizationScheme] = None, @@ -66,15 +72,13 @@ def initialize_module_for_quantization( return if is_attention_module(module): - # wrap forward call of module to perform # quantized actions based on calltime status - wrap_module_forward_quantized_attn(module, scheme) _initialize_attn_scales(module) else: if scheme.input_activations is not None: - _initialize_scale_zero_point_observer( + _initialize_scale_zero_point( module, "input", scheme.input_activations, @@ -85,7 +89,7 @@ def initialize_module_for_quantization( weight_shape = None if isinstance(module, torch.nn.Linear): weight_shape = module.weight.shape - _initialize_scale_zero_point_observer( + _initialize_scale_zero_point( module, "weight", scheme.weights, @@ -101,7 +105,7 @@ def initialize_module_for_quantization( if scheme.output_activations is not None: if not is_kv_cache_quant_scheme(scheme): - _initialize_scale_zero_point_observer( + _initialize_scale_zero_point( module, "output", scheme.output_activations ) @@ -146,21 +150,21 @@ def initialize_module_for_quantization( module._hf_hook.weights_map = new_prefix_dict -def _initialize_scale_zero_point_observer( +def is_attention_module(module: Module): + return "attention" in module.__class__.__name__.lower() and ( + hasattr(module, "k_proj") + or hasattr(module, "v_proj") + or hasattr(module, "qkv_proj") + ) + + +def _initialize_scale_zero_point( module: Module, base_name: str, quantization_args: QuantizationArgs, weight_shape: Optional[torch.Size] = None, force_zero_point: bool = True, ): - - # initialize observer module and attach as submodule - observer = quantization_args.get_observer() - # no need to register an observer for dynamic quantization - if observer: - module.register_module(f"{base_name}_observer", observer) - - # no need to register a scale and zero point for a dynamic quantization if quantization_args.dynamic: return @@ -209,14 +213,6 @@ def _initialize_scale_zero_point_observer( register_offload_parameter(module, f"{base_name}_g_idx", init_g_idx) -def is_attention_module(module: Module): - return "attention" in module.__class__.__name__.lower() and ( - hasattr(module, "k_proj") - or hasattr(module, "v_proj") - or hasattr(module, "qkv_proj") - ) - - def _initialize_attn_scales(module: Module) -> None: """Initlaize k_scale, v_scale for self_attn""" diff --git a/src/compressed_tensors/quantization/observers/__init__.py b/src/compressed_tensors/quantization/observers/__init__.py deleted file mode 100644 index 05b6b367..00000000 --- a/src/compressed_tensors/quantization/observers/__init__.py +++ /dev/null @@ -1,21 +0,0 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# flake8: noqa -# isort: skip_file - -from .helpers import * -from .base import * -from .min_max import * -from .mse import * diff --git a/src/compressed_tensors/quantization/observers/base.py b/src/compressed_tensors/quantization/observers/base.py deleted file mode 100644 index d9d646f4..00000000 --- a/src/compressed_tensors/quantization/observers/base.py +++ /dev/null @@ -1,213 +0,0 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -from math import ceil -from typing import Any, Iterable, Optional, Tuple, Union - -import torch -from compressed_tensors.quantization.quant_args import ( - QuantizationArgs, - QuantizationStrategy, -) -from compressed_tensors.registry.registry import RegistryMixin -from compressed_tensors.utils import safe_permute -from torch import FloatTensor, IntTensor, Tensor -from torch.nn import Module - - -_LOGGER = logging.getLogger(__name__) - - -__all__ = ["Observer"] - - -class Observer(Module, RegistryMixin): - """ - Base Observer class to be subclassed for specific implementation. - Subclasses should override `calculate_qparams` to return a scale, zero_point - pair - """ - - def __init__(self, quantization_args: QuantizationArgs): - self.quantization_args: QuantizationArgs = quantization_args - super().__init__() - self._scale = None - self._zero_point = None - self._num_observed_tokens = None - - @torch.no_grad() - def forward( - self, observed: Tensor, g_idx: Optional[Tensor] = None - ) -> Tuple[FloatTensor, IntTensor]: - """ - maps directly to get_qparams - :param observed: optional observed tensor from which to calculate - quantization parameters - :param g_idx: optional mapping from column index to group index - :return: tuple of scale and zero point based on last observed value - """ - self.record_observed_tokens(observed) - return self.get_qparams(observed=observed, g_idx=g_idx) - - def calculate_qparams( - self, - observed: Tensor, - reduce_dims: Optional[Tuple[int]] = None, - ) -> Tuple[FloatTensor, IntTensor]: - """ - :param observed: observed tensor to calculate quantization parameters for - :param reduce_dims: optional tuple of dimensions to reduce along, - returned scale and zero point will be shaped (1,) along the - reduced dimensions - :return: tuple of scale and zero point derived from the observed tensor - """ - raise NotImplementedError(f"{self.__class__} must implement calculate_qparams") - - def post_calculate_qparams(self) -> None: - """ - Run any logic specific to its observers after running calculate_qparams - """ - ... - - def get_qparams( - self, - observed: Optional[Tensor] = None, - g_idx: Optional[Tensor] = None, - ) -> Tuple[FloatTensor, IntTensor]: - """ - Convenience function to wrap overwritten calculate_qparams - adds support to make observed tensor optional and support for tracking latest - calculated scale and zero point - - :param observed: optional observed tensor to calculate quantization parameters - from - :param g_idx: optional mapping from column index to group index - :return: tuple of scale and zero point based on last observed value - """ - if observed is not None: - group_size = self.quantization_args.group_size - - if self.quantization_args.strategy == QuantizationStrategy.TENSOR: - - # re-calculate scale and zero point, update the stored value - self._scale, self._zero_point = self.calculate_qparams(observed) - - elif self.quantization_args.strategy == QuantizationStrategy.GROUP: - rows = observed.shape[0] - columns = observed.shape[1] - num_groups = int(ceil(columns / group_size)) - self._scale = torch.empty( - (rows, num_groups), dtype=observed.dtype, device=observed.device - ) - zp_dtype = self.quantization_args.pytorch_dtype() - self._zero_point = torch.empty( - (rows, num_groups), dtype=zp_dtype, device=observed.device - ) - - # support column-order (default) quantization as well as other orderings - # such as activation ordering. Below checks if g_idx has initialized - is_column_order = g_idx is None or -1 in g_idx - if is_column_order: - group_sizes = torch.full((num_groups,), group_size, dtype=torch.int) - else: - group_indices, group_sizes = torch.unique(g_idx, return_counts=True) - group_sizes = group_sizes[torch.argsort(group_indices)] - - perm = torch.argsort(g_idx) - observed = safe_permute(observed, perm, dim=1) - - # TODO: experiment with vectorizing for loop for performance - end = 0 - for group_index, group_count in enumerate(group_sizes): - start = end - end = start + group_count - scale, zero_point = self.get_qparams_along_dim( - observed[:, start:end], - 0, - tensor_id=group_index, - ) - - self._scale[:, group_index] = scale.squeeze(1) - self._zero_point[:, group_index] = zero_point.squeeze(1) - - elif self.quantization_args.strategy == QuantizationStrategy.CHANNEL: - # assume observed is transposed, because its the output, hence use dim 0 - self._scale, self._zero_point = self.get_qparams_along_dim(observed, 0) - - elif self.quantization_args.strategy == QuantizationStrategy.TOKEN: - # use dim 1, assume the obsersed.shape = [batch, token, hidden] - # should be batch, token - self._scale, self._zero_point = self.get_qparams_along_dim( - observed, - dim={0, 1}, - ) - - return self._scale, self._zero_point - - def get_qparams_along_dim( - self, - observed, - dim: Union[int, Iterable[int]], - tensor_id: Optional[Any] = None, - ): - if isinstance(dim, int): - dim = [dim] - dim = set(dim) - - reduce_dims = tuple(idx for idx in range(observed.ndim) if idx not in dim) - return self.calculate_qparams( - observed, reduce_dims=reduce_dims, tensor_id=tensor_id - ) - - def record_observed_tokens(self, batch_tensor: Tensor): - """ - Counts the number of tokens observed during the - forward passes. The count is aggregated in the - _num_observed_tokens attribute of the class. - - Note: The batch_tensor is expected to have two dimensions - (batch_size * sequence_length, num_features). This is the - general shape expected by the forward pass of the expert - layers in a MOE model. If the input tensor does not have - two dimensions, the _num_observed_tokens attribute will be set - to None. - """ - if not isinstance(batch_tensor, Tensor): - raise ValueError(f"Expected value to be a tensor, got {type(batch_tensor)}") - - if batch_tensor.ndim != 2: - _LOGGER.debug( - "The input tensor is expected to have two dimensions " - "(batch_size * sequence_length, num_features). " - f"The input tensor has {batch_tensor.ndim} dimensions." - ) - return - - if self._num_observed_tokens is None: - # initialize the count - self._num_observed_tokens = 0 - - # batch_tensor (batch_size * sequence_length, num_features) - # observed_tokens (batch_size * sequence_length) - observed_tokens, _ = batch_tensor.shape - self._num_observed_tokens += observed_tokens - - def reset(self): - """ - Reset the state of the observer - """ - self._num_observed_tokens = None - self._scale = None - self._zero_point = None diff --git a/src/compressed_tensors/quantization/observers/helpers.py b/src/compressed_tensors/quantization/observers/helpers.py deleted file mode 100644 index ec474303..00000000 --- a/src/compressed_tensors/quantization/observers/helpers.py +++ /dev/null @@ -1,149 +0,0 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from collections import Counter -from typing import Tuple - -import torch -from compressed_tensors.quantization.quant_args import ( - FP8_DTYPE, - QuantizationArgs, - QuantizationStrategy, - QuantizationType, -) -from torch import FloatTensor, IntTensor, Tensor - - -__all__ = [ - "calculate_qparams", - "get_observer_token_count", - "calculate_range", - "compute_dynamic_scales_and_zp", -] - - -def compute_dynamic_scales_and_zp(value: Tensor, args: QuantizationArgs): - """ - Returns the computed scales and zero points for dynamic activation - qunatization. - - :param value: tensor to calculate quantization parameters for - :param args: quantization args - :param reduce_dims: optional tuple of dimensions to reduce along, - returned scale and zero point will be shaped (1,) along the - reduced dimensions - :return: tuple of scale and zero point derived from the observed tensor - """ - if args.strategy == QuantizationStrategy.TOKEN: - dim = {1, 2} - reduce_dims = tuple(idx for idx in range(value.ndim) if idx not in dim) - elif args.strategy == QuantizationStrategy.TENSOR: - reduce_dims = None - else: - raise ValueError( - f"One of {QuantizationStrategy.TOKEN} or {QuantizationStrategy.TENSOR} ", - "must be used for dynamic quantization", - ) - - if not reduce_dims: - min_val, max_val = torch.aminmax(value) - else: - min_val = torch.amin(value, dim=reduce_dims, keepdims=True) - max_val = torch.amax(value, dim=reduce_dims, keepdims=True) - - return calculate_qparams(min_val, max_val, args) - - -def get_observer_token_count(module: torch.nn.Module) -> Counter: - """ - Parse the module and return the number of tokens observed by - each module's observer. - - :param module: module to parse - :return: counter with the number of tokens observed by each observer - """ - token_counts = Counter() - for name, module in module.named_modules(): - if name.endswith(".input_observer"): - token_counts[ - name.replace(".input_observer", "") - ] = module._num_observed_tokens - return token_counts - - -def calculate_qparams( - min_vals: Tensor, max_vals: Tensor, quantization_args: QuantizationArgs -) -> Tuple[FloatTensor, IntTensor]: - """ - :param min_vals: tensor of min value(s) to calculate scale(s) and zero point(s) - from - :param max_vals: tensor of max value(s) to calculate scale(s) and zero point(s) - from - :param quantization_args: settings to quantization - :return: tuple of the calculated scale(s) and zero point(s) - """ - min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) - max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) - device = min_vals.device - - bit_min, bit_max = calculate_range(quantization_args, device) - bit_range = bit_max - bit_min - zp_dtype = quantization_args.pytorch_dtype() - - if quantization_args.symmetric: - max_val_pos = torch.max(torch.abs(min_vals), torch.abs(max_vals)) - scales = max_val_pos / (float(bit_range) / 2) - scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps) - zero_points = torch.zeros(scales.shape, device=device, dtype=min_vals.dtype) - else: - scales = (max_vals - min_vals) / float(bit_range) - scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps) - zero_points = bit_min - (min_vals / scales) - zero_points = torch.clamp(zero_points, bit_min, bit_max) - - # match zero-points to quantized type - zero_points = zero_points.to(zp_dtype) - - if scales.ndim == 0: - scales = scales.reshape(1) - zero_points = zero_points.reshape(1) - - return scales, zero_points - - -def calculate_range(quantization_args: QuantizationArgs, device: str) -> Tuple: - """ - Calculated the effective quantization range for the given Quantization Args - - :param quantization_args: quantization args to get range of - :param device: device to store the range to - :return: tuple endpoints for the given quantization range - """ - if quantization_args.type == QuantizationType.INT: - bit_range = 2**quantization_args.num_bits - q_max = torch.tensor(bit_range / 2 - 1, device=device) - q_min = torch.tensor(-bit_range / 2, device=device) - elif quantization_args.type == QuantizationType.FLOAT: - if quantization_args.num_bits != 8: - raise ValueError( - "Floating point quantization is only supported for 8 bits," - f"got {quantization_args.num_bits}" - ) - fp_range_info = torch.finfo(FP8_DTYPE) - q_max = torch.tensor(fp_range_info.max, device=device) - q_min = torch.tensor(fp_range_info.min, device=device) - else: - raise ValueError(f"Invalid quantization type {quantization_args.type}") - - return q_min, q_max diff --git a/src/compressed_tensors/quantization/observers/min_max.py b/src/compressed_tensors/quantization/observers/min_max.py deleted file mode 100644 index a5b12906..00000000 --- a/src/compressed_tensors/quantization/observers/min_max.py +++ /dev/null @@ -1,104 +0,0 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Any, Optional, Tuple - -import torch -from compressed_tensors.quantization.observers.base import Observer -from compressed_tensors.quantization.observers.helpers import calculate_qparams -from compressed_tensors.quantization.quant_args import QuantizationArgs -from torch import FloatTensor, IntTensor, Tensor - - -__all__ = ["MovingAverageMinMaxObserver"] - - -@Observer.register("minmax") -class MovingAverageMinMaxObserver(Observer): - """ - Implements a dynamic quantization observer that sets the scale and - zero point based on a moving average of the overall min and max observed values - """ - - def __init__( - self, quantization_args: QuantizationArgs, averaging_constant: float = 0.01 - ): - super().__init__(quantization_args=quantization_args) - - self.min_val = {} - self.max_val = {} - self.averaging_constant = averaging_constant - - def calculate_qparams( - self, - observed: Tensor, - reduce_dims: Optional[Tuple[int]] = None, - tensor_id: Optional[Any] = None, - ) -> Tuple[FloatTensor, IntTensor]: - """ - Updates the observed min and max using a moving average smoothed by the - averaging_constant - - :param observed: observed tensor to calculate quantization parameters for - :param reduce_dims: optional tuple of dimensions to reduce along, - returned scale and zero point will be shaped (1,) along the - reduced dimensions - :param tensor_id: Optional id if different ranges of observed tensors are - passed, useful for sharding tensors by group_size - :return: tuple of scale and zero point derived from the observed tensor - """ - tensor_id = tensor_id or "default" - - if not reduce_dims: - min_val, max_val = torch.aminmax(observed) - else: - min_val = torch.amin(observed, dim=reduce_dims, keepdims=True) - max_val = torch.amax(observed, dim=reduce_dims, keepdims=True) - - running_min_val = self.min_val.get(tensor_id, None) - running_max_val = self.max_val.get(tensor_id, None) - - if running_min_val is None or running_max_val is None: - updated_min_val = min_val - updated_max_val = max_val - else: - updated_min_val = running_min_val + self.averaging_constant * ( - min_val - running_min_val - ) - updated_max_val = running_max_val + self.averaging_constant * ( - max_val - running_max_val - ) - - self.min_val[tensor_id] = updated_min_val - self.max_val[tensor_id] = updated_max_val - - return calculate_qparams( - updated_min_val, updated_max_val, self.quantization_args - ) - - def get_qparams_along_dim( - self, observed, dim: int, tensor_id: Optional[Any] = None - ): - reduce_dims = tuple(idx for idx in range(observed.ndim) if idx != dim) - return self.calculate_qparams( - observed, reduce_dims=reduce_dims, tensor_id=tensor_id - ) - - def reset(self): - """ - Reset the state of the observer, including min and maximum values - """ - super().reset() - self.min_val = {} - self.max_val = {} diff --git a/src/compressed_tensors/quantization/observers/mse.py b/src/compressed_tensors/quantization/observers/mse.py deleted file mode 100644 index ff122a78..00000000 --- a/src/compressed_tensors/quantization/observers/mse.py +++ /dev/null @@ -1,162 +0,0 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Any, Optional, Tuple - -import torch -from compressed_tensors.quantization.observers.base import Observer -from compressed_tensors.quantization.observers.helpers import calculate_qparams -from compressed_tensors.quantization.quant_args import QuantizationArgs -from torch import FloatTensor, IntTensor, Tensor - - -__all__ = ["MovingAverageMSEObserver"] - - -@Observer.register("mse") -class MovingAverageMSEObserver(Observer): - """ - Implements a dynamic quantization observer that sets the scale and - zero point based on a moving average of the mse-clipped min and max observed values - """ - - def __init__( - self, - quantization_args: QuantizationArgs, - averaging_constant: float = 0.01, - grid: float = 100.0, - maxshrink: float = 0.80, - norm: float = 2.4, - ): - super().__init__(quantization_args=quantization_args) - - self.min_val = {} - self.max_val = {} - self.averaging_constant = averaging_constant - self.grid = grid - self.maxshrink = maxshrink - self.norm = norm - - def calculate_mse_min_max( - self, - observed: Tensor, - reduce_dims: Optional[Tuple[int]] = None, - ): - """ - Computes the mse-clipped min and max values of the observed tensor by - optimizing for quantization error - - :param observed: observed tensor to calculate quantization parameters for - :param reduce_dims: optional tuple of dimensions to reduce along, - returned values will be shaped (1,) along the reduced dimensions - :return: tuple of min and max values derived from the observed tensor - """ - from compressed_tensors.quantization.lifecycle import fake_quantize - - if not reduce_dims: - absolute_min_val, absolute_max_val = torch.aminmax(observed) - else: - absolute_min_val = torch.amin(observed, dim=reduce_dims, keepdims=True) - absolute_max_val = torch.amax(observed, dim=reduce_dims, keepdims=True) - - best = torch.full(absolute_min_val.shape, float("inf")) - min_val = torch.ones(absolute_min_val.shape) - max_val = torch.zeros(absolute_max_val.shape) - for i in range(int(self.maxshrink * self.grid)): - p = 1 - i / self.grid - shrinked_min_val = p * absolute_min_val - shrinked_max_val = p * absolute_max_val - - candidate_scales, candidate_zero_points = calculate_qparams( - shrinked_min_val, shrinked_max_val, self.quantization_args - ) - q = fake_quantize( - observed, - candidate_scales, - candidate_zero_points, - self.quantization_args, - ) - - q -= observed - q.abs_() - q.pow_(self.norm) - if not reduce_dims: - err = torch.sum(q) - else: - err = torch.sum(q, reduce_dims, keepdims=True) - - tmp = err < best - if torch.any(tmp): - best[tmp] = err[tmp] - min_val[tmp] = shrinked_min_val[tmp] - max_val[tmp] = shrinked_max_val[tmp] - return min_val, max_val - - def calculate_qparams( - self, - observed: Tensor, - reduce_dims: Optional[Tuple[int]] = None, - tensor_id: Optional[Any] = None, - ) -> Tuple[FloatTensor, IntTensor]: - """ - Updates the mse-clipped min and max values of the observed tensor using - a moving average smoothed by the averaging_constant - - :param observed: observed tensor to calculate quantization parameters for - :param reduce_dims: optional tuple of dimensions to reduce along, - returned scale and zero point will be shaped (1,) along the - reduced dimensions - :param tensor_id: Optional id if different ranges of observed tensors are - passed, useful for sharding tensors by group_size - :return: tuple of scale and zero point derived from the observed tensor - """ - min_val, max_val = self.calculate_mse_min_max(observed, reduce_dims) - - running_min_val = self.min_val.get(tensor_id, None) - running_max_val = self.max_val.get(tensor_id, None) - - if running_min_val is None or running_max_val is None: - updated_min_val = min_val - updated_max_val = max_val - else: - updated_min_val = running_min_val + self.averaging_constant * ( - min_val - running_min_val - ) - updated_max_val = running_max_val + self.averaging_constant * ( - max_val - running_max_val - ) - - tensor_id = tensor_id or "default" - self.min_val[tensor_id] = updated_min_val - self.max_val[tensor_id] = updated_max_val - - return calculate_qparams( - updated_min_val, updated_max_val, self.quantization_args - ) - - def get_qparams_along_dim( - self, observed, dim: int, tensor_id: Optional[Any] = None - ): - reduce_dims = tuple(idx for idx in range(observed.ndim) if idx != dim) - return self.calculate_qparams( - observed, reduce_dims=reduce_dims, tensor_id=tensor_id - ) - - def reset(self): - """ - Reset the state of the observer, including min and maximum values - """ - super().reset() - self.min_val = {} - self.max_val = {} diff --git a/src/compressed_tensors/quantization/quant_args.py b/src/compressed_tensors/quantization/quant_args.py index c2fc0b6a..4619d581 100644 --- a/src/compressed_tensors/quantization/quant_args.py +++ b/src/compressed_tensors/quantization/quant_args.py @@ -114,20 +114,7 @@ def get_observer(self): """ :return: torch quantization FakeQuantize built based on these QuantizationArgs """ - from compressed_tensors.quantization.observers.base import Observer - - # No observer required for the dynamic case - if self.dynamic: - self.observer = None - return self.observer - - return Observer.load_from_registry(self.observer, quantization_args=self) - - def get_kv_cache(self): - """Get the singleton KV Cache""" - from compressed_tensors.quantization.cache import QuantizedKVParameterCache - - return QuantizedKVParameterCache(self) + return self.observer @field_validator("type", mode="before") def validate_type(cls, value) -> QuantizationType: @@ -210,6 +197,7 @@ def validate_model_after(model: "QuantizationArgs") -> Dict[str, Any]: "activation ordering" ) + # infer observer w.r.t. dynamic if dynamic: if strategy not in ( QuantizationStrategy.TOKEN, @@ -221,18 +209,19 @@ def validate_model_after(model: "QuantizationArgs") -> Dict[str, Any]: "quantization", ) if observer is not None: - warnings.warn( - "No observer is used for dynamic quantization, setting to None" - ) - model.observer = None + if observer != "memoryless": # avoid annoying users with old configs + warnings.warn( + "No observer is used for dynamic quantization, setting to None" + ) + observer = None - # if we have not set an observer and we - # are running static quantization, use minmax - if not observer and not dynamic: - model.observer = "minmax" + elif observer is None: + # default to minmax for non-dynamic cases + observer = "minmax" # write back modified values model.strategy = strategy + model.observer = observer return model def pytorch_dtype(self) -> torch.dtype: diff --git a/src/compressed_tensors/quantization/utils/helpers.py b/src/compressed_tensors/quantization/utils/helpers.py index 8ebde09b..9f65ee33 100644 --- a/src/compressed_tensors/quantization/utils/helpers.py +++ b/src/compressed_tensors/quantization/utils/helpers.py @@ -16,9 +16,14 @@ from typing import Generator, List, Optional, Tuple import torch -from compressed_tensors.quantization.observers.base import Observer -from compressed_tensors.quantization.quant_args import QuantizationArgs +from compressed_tensors.quantization.quant_args import ( + FP8_DTYPE, + QuantizationArgs, + QuantizationStrategy, + QuantizationType, +) from compressed_tensors.quantization.quant_scheme import QuantizationScheme +from torch import FloatTensor, IntTensor, Tensor from torch.nn import Module from tqdm import tqdm @@ -36,6 +41,9 @@ "is_kv_cache_quant_scheme", "iter_named_leaf_modules", "iter_named_quantizable_modules", + "compute_dynamic_scales_and_zp", + "calculate_range", + "calculate_qparams", ] # target the self_attn layer @@ -45,6 +53,105 @@ _LOGGER: logging.Logger = logging.getLogger(__name__) +def calculate_qparams( + min_vals: Tensor, max_vals: Tensor, quantization_args: QuantizationArgs +) -> Tuple[FloatTensor, IntTensor]: + """ + :param min_vals: tensor of min value(s) to calculate scale(s) and zero point(s) + from + :param max_vals: tensor of max value(s) to calculate scale(s) and zero point(s) + from + :param quantization_args: settings to quantization + :return: tuple of the calculated scale(s) and zero point(s) + """ + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + device = min_vals.device + + bit_min, bit_max = calculate_range(quantization_args, device) + bit_range = bit_max - bit_min + zp_dtype = quantization_args.pytorch_dtype() + + if quantization_args.symmetric: + max_val_pos = torch.max(torch.abs(min_vals), torch.abs(max_vals)) + scales = max_val_pos / (float(bit_range) / 2) + scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps) + zero_points = torch.zeros(scales.shape, device=device, dtype=min_vals.dtype) + else: + scales = (max_vals - min_vals) / float(bit_range) + scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps) + zero_points = bit_min - (min_vals / scales) + zero_points = torch.clamp(zero_points, bit_min, bit_max) + + # match zero-points to quantized type + zero_points = zero_points.to(zp_dtype) + + if scales.ndim == 0: + scales = scales.reshape(1) + zero_points = zero_points.reshape(1) + + return scales, zero_points + + +def compute_dynamic_scales_and_zp(value: Tensor, args: QuantizationArgs): + """ + Returns the computed scales and zero points for dynamic activation + qunatization. + + :param value: tensor to calculate quantization parameters for + :param args: quantization args + :param reduce_dims: optional tuple of dimensions to reduce along, + returned scale and zero point will be shaped (1,) along the + reduced dimensions + :return: tuple of scale and zero point derived from the observed tensor + """ + if args.strategy == QuantizationStrategy.TOKEN: + dim = {1, 2} + reduce_dims = tuple(idx for idx in range(value.ndim) if idx not in dim) + elif args.strategy == QuantizationStrategy.TENSOR: + reduce_dims = None + else: + raise ValueError( + f"One of {QuantizationStrategy.TOKEN} or {QuantizationStrategy.TENSOR} ", + "must be used for dynamic quantization", + ) + + if not reduce_dims: + min_val, max_val = torch.aminmax(value) + else: + min_val = torch.amin(value, dim=reduce_dims, keepdims=True) + max_val = torch.amax(value, dim=reduce_dims, keepdims=True) + + return calculate_qparams(min_val, max_val, args) + + +def calculate_range(quantization_args: QuantizationArgs, device: str) -> Tuple: + """ + Calculated the effective quantization range for the given Quantization Args + + :param quantization_args: quantization args to get range of + :param device: device to store the range to + :return: tuple endpoints for the given quantization range + """ + if quantization_args.type == QuantizationType.INT: + bit_range = 2**quantization_args.num_bits + q_max = torch.tensor(bit_range / 2 - 1, device=device) + q_min = torch.tensor(-bit_range / 2, device=device) + elif quantization_args.type == QuantizationType.FLOAT: + if quantization_args.num_bits != 8: + raise ValueError( + "Floating point quantization is only supported for 8 bits," + f"got {quantization_args.num_bits}" + ) + fp_range_info = torch.finfo(FP8_DTYPE) + q_max = torch.tensor(fp_range_info.max, device=device) + q_min = torch.tensor(fp_range_info.min, device=device) + else: + raise ValueError(f"Invalid quantization type {quantization_args.type}") + + return q_min, q_max + + def infer_quantization_status(model: Module) -> Optional["QuantizationStatus"]: # noqa """ Checks the quantization status of a model. Assumes all modules in the model have @@ -118,12 +225,17 @@ def iter_named_leaf_modules(model: Module) -> Generator[Tuple[str, Module], None """ for name, submodule in model.named_modules(): children = list(submodule.children()) - if len(children) == 0 and not isinstance(submodule, Observer): + # TODO: verify if an observer would ever be attached in this case/remove check + if len(children) == 0 and "observer" in name: yield name, submodule else: + if len(children) > 0: + named_children, children = zip(*list(submodule.named_children())) has_non_observer_children = False - for child in children: - if not isinstance(child, Observer): + for i in range(len(children)): + child_name = named_children[i] + + if "observer" not in child_name: has_non_observer_children = True if not has_non_observer_children: @@ -144,14 +256,19 @@ def iter_named_quantizable_modules( :returns: generator tuple of (name, submodule) """ for name, submodule in model.named_modules(): + # TODO: verify if an observer would ever be attached in this case/remove check if include_children: children = list(submodule.children()) - if len(children) == 0 and not isinstance(submodule, Observer): + if len(children) == 0 and "observer" not in name: yield name, submodule else: + if len(children) > 0: + named_children, children = zip(*list(submodule.named_children())) has_non_observer_children = False - for child in children: - if not isinstance(child, Observer): + for i in range(len(children)): + child_name = named_children[i] + + if "observer" not in child_name: has_non_observer_children = True if not has_non_observer_children: diff --git a/src/compressed_tensors/registry/registry.py b/src/compressed_tensors/registry/registry.py index d8d8bc6d..76026313 100644 --- a/src/compressed_tensors/registry/registry.py +++ b/src/compressed_tensors/registry/registry.py @@ -258,7 +258,7 @@ def get_from_registry( retrieved_value = _import_and_get_value_from_module(module_path, value_name) else: # look up name in alias registry - name = _ALIAS_REGISTRY[parent_class].get(name) + name = _ALIAS_REGISTRY[parent_class].get(name, name) # look up name in registry retrieved_value = _REGISTRY[parent_class].get(name) if retrieved_value is None: diff --git a/src/compressed_tensors/version.py b/src/compressed_tensors/version.py index 1d241270..d0ba0363 100644 --- a/src/compressed_tensors/version.py +++ b/src/compressed_tensors/version.py @@ -17,7 +17,7 @@ """ -version_base = "0.7.0" +version_base = "0.7.1" is_release = True # change to True to set the generated version as a release version diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..a1c1d861 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,136 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from math import ceil + +import pytest +import torch +from compressed_tensors.quantization.utils import calculate_qparams +from compressed_tensors.utils.offload import update_parameter_data + + +def _get_dim(dim: int, value: torch.Tensor): + if isinstance(dim, int): + dim = [dim] + dim = set(dim) + + reduce_dims = tuple(idx for idx in range(value.ndim) if idx not in dim) + return reduce_dims + + +@pytest.fixture +def mock_per_token_calibration(): + def update_scale_zp(module: torch.nn.Module, base_name: str, value: torch.Tensor): + quantization_scheme = getattr(module, "quantization_scheme", None) + if not quantization_scheme: + # no quantization scheme nothing to do + return + + arg_name = "weights" if base_name == "weight" else f"{base_name}_activations" + args = getattr(quantization_scheme, arg_name, None) + + dim = _get_dim({0, 1}, value) + min_val = torch.amin(value, dim=dim, keepdims=True) + max_val = torch.amax(value, dim=dim, keepdims=True) + scale, zp = calculate_qparams(min_val, max_val, args) + scale = scale.reshape((1, 1)) + zp = zp.reshape((1, 1)) + update_parameter_data(module, scale, f"{base_name}_scale") + update_parameter_data(module, zp, f"{base_name}_zero_point") + + return update_scale_zp + + +@pytest.fixture +def mock_per_group_calibration(): + def update_scale_zp( + module: torch.nn.Module, base_name: str, value: torch.Tensor, group_size: int + ): + quantization_scheme = getattr(module, "quantization_scheme", None) + if not quantization_scheme: + # no quantization scheme nothing to do + return + + arg_name = "weights" if base_name == "weight" else f"{base_name}_activations" + args = getattr(quantization_scheme, arg_name, None) + + rows = value.shape[0] + columns = value.shape[1] + num_groups = int(ceil(columns / group_size)) + + scale = torch.zeros((rows, num_groups), dtype=value.dtype, device=value.device) + zp_dtype = args.pytorch_dtype() + zp = torch.zeros((rows, num_groups), dtype=zp_dtype, device=value.device) + + group_sizes = torch.full((num_groups,), group_size, dtype=torch.int) + end = 0 + for group_index, group_count in enumerate(group_sizes): + start = end + end = start + group_count + dim = _get_dim( + 0, + value[:, start:end], + ) + min_val = torch.amin(value, dim=dim, keepdims=True) + max_val = torch.amax(value, dim=dim, keepdims=True) + scale_out, zp_out = calculate_qparams(min_val, max_val, args) + + scale[:, group_index] = scale_out.squeeze(1) + zp[:, group_index] = zp_out.squeeze(1) + + update_parameter_data(module, scale, f"{base_name}_scale") + update_parameter_data(module, zp, f"{base_name}_zero_point") + + return update_scale_zp + + +@pytest.fixture +def mock_per_channel_calibration(): + def update_scale_zp(module: torch.nn.Module, base_name: str, value: torch.Tensor): + quantization_scheme = getattr(module, "quantization_scheme", None) + if not quantization_scheme: + # no quantization scheme nothing to do + return + + arg_name = "weights" if base_name == "weight" else f"{base_name}_activations" + + args = getattr(quantization_scheme, arg_name, None) + dim = _get_dim(0, value) + min_val = torch.amin(value, dim=dim, keepdims=True) + max_val = torch.amax(value, dim=dim, keepdims=True) + scale, zp = calculate_qparams(min_val, max_val, args) + update_parameter_data(module, scale, f"{base_name}_scale") + update_parameter_data(module, zp, f"{base_name}_zero_point") + + return update_scale_zp + + +@pytest.fixture +def mock_per_tensor_calibration(): + def update_scale_zp(module: torch.nn.Module, base_name: str, value: torch.Tensor): + quantization_scheme = getattr(module, "quantization_scheme", None) + if not quantization_scheme: + # no quantization scheme nothing to do + return + + arg_name = "weights" if base_name == "weight" else f"{base_name}_activations" + args = getattr(quantization_scheme, arg_name, None) + + # per tensor quantization just calls calculate_qparams directly + min_val, max_val = torch.aminmax(value) + scale, zp = calculate_qparams(min_val, max_val, args) + update_parameter_data(module, scale, f"{base_name}_scale") + update_parameter_data(module, zp, f"{base_name}_zero_point") + + return update_scale_zp diff --git a/tests/test_compressors/quantized_compressors/test_fp8_quant.py b/tests/test_compressors/quantized_compressors/test_fp8_quant.py index ba8a5df3..c131e216 100644 --- a/tests/test_compressors/quantized_compressors/test_fp8_quant.py +++ b/tests/test_compressors/quantized_compressors/test_fp8_quant.py @@ -25,7 +25,6 @@ QuantizationStatus, QuantizationStrategy, apply_quantization_config, - apply_quantization_status, ) from compressed_tensors.quantization.lifecycle.forward import fake_quantize from safetensors.torch import save_file @@ -114,7 +113,13 @@ def test_quant_format(strategy, group_size, sc, zp): # Note that group quantization is not supported ], ) -def test_reload_match(strategy, group_size, tmp_path): +def test_reload_match( + mock_per_group_calibration, + mock_per_channel_calibration, + strategy, + group_size, + tmp_path, +): model = Sequential( OrderedDict( [ @@ -124,11 +129,15 @@ def test_reload_match(strategy, group_size, tmp_path): ) quant_config = get_dummy_quant_config(strategy=strategy, group_size=group_size) apply_quantization_config(model, quant_config) - apply_quantization_status(model, QuantizationStatus.CALIBRATION) - - for _ in range(16): - inputs = torch.rand((512, 512)) - _ = model(inputs) + model.dummy.quantization_status = QuantizationStatus.CALIBRATION + if strategy == QuantizationStrategy.GROUP: + mock_per_group_calibration( + model.dummy, base_name="weight", value=model.dummy.weight, group_size=128 + ) + if strategy == QuantizationStrategy.CHANNEL: + mock_per_channel_calibration( + model.dummy, base_name="weight", value=model.dummy.weight + ) compressor = FloatQuantizationCompressor(config=quant_config) quantized_modules_to_args = { diff --git a/tests/test_compressors/quantized_compressors/test_pack_quant.py b/tests/test_compressors/quantized_compressors/test_pack_quant.py index 496e8304..fde57c4b 100644 --- a/tests/test_compressors/quantized_compressors/test_pack_quant.py +++ b/tests/test_compressors/quantized_compressors/test_pack_quant.py @@ -30,7 +30,6 @@ QuantizationScheme, QuantizationStatus, apply_quantization_config, - apply_quantization_status, ) from compressed_tensors.quantization.lifecycle.forward import fake_quantize from compressed_tensors.quantization.quant_args import ActivationOrdering @@ -205,7 +204,7 @@ def test_reload_match(tmp_path, num_bits): None, ], ) -def test_actorder_reload_match(actorder, tmp_path): +def test_actorder_reload_match(actorder, tmp_path, mock_per_group_calibration): model = Sequential(OrderedDict([("dummy", Linear(512, 1024, bias=None))])) group_size = 128 quant_config = get_dummy_quant_config( @@ -214,12 +213,10 @@ def test_actorder_reload_match(actorder, tmp_path): apply_quantization_config(model, quant_config) # run calibration - apply_quantization_status(model, QuantizationStatus.CALIBRATION) - for _ in range(16): - inputs = torch.rand((512, 512)) - _ = model(inputs) - apply_quantization_status(model, QuantizationStatus.FROZEN) - + model.quantization_status = QuantizationStatus.CALIBRATION + mock_per_group_calibration( + model.dummy, base_name="weight", value=model.dummy.weight, group_size=group_size + ) # apply gptq if actorder == ActivationOrdering.GROUP: init_g_idx = make_dummy_g_idx(512, group_size) diff --git a/tests/test_compressors/sparse_quantized_compressors/test_marlin_24.py b/tests/test_compressors/sparse_quantized_compressors/test_marlin_24.py index f12accb8..590b839c 100644 --- a/tests/test_compressors/sparse_quantized_compressors/test_marlin_24.py +++ b/tests/test_compressors/sparse_quantized_compressors/test_marlin_24.py @@ -29,7 +29,6 @@ QuantizationStatus, QuantizationStrategy, apply_quantization_config, - apply_quantization_status, ) from compressed_tensors.utils import mask_creator, merge_names from torch.nn.modules import Linear, Sequential @@ -54,7 +53,13 @@ def test_marlin_registered(): "strategy", [QuantizationStrategy.GROUP, QuantizationStrategy.CHANNEL] ) @pytest.mark.parametrize("layer_shape", [(512, 128), (1024, 1024), (128, 256)]) -def test_marlin24_format(num_bits, strategy, layer_shape): +def test_marlin24_format( + mock_per_group_calibration, + mock_per_channel_calibration, + num_bits, + strategy, + layer_shape, +): QUANT_NAME = "quant" NOT_QUANT_NAME = "not_quant" model = Sequential( @@ -70,11 +75,17 @@ def test_marlin24_format(num_bits, strategy, layer_shape): model.quant.weight.data *= mask apply_quantization_config(model, config) - apply_quantization_status(model, QuantizationStatus.CALIBRATION) + model.quantization_status = QuantizationStatus.CALIBRATION # runs observer to get scale and zero point - input = torch.rand((64, layer_shape[0])) - _ = model(input) + if strategy == QuantizationStrategy.GROUP: + mock_per_group_calibration( + model.quant, base_name="weight", value=model.quant.weight, group_size=128 + ) + if strategy == QuantizationStrategy.CHANNEL: + mock_per_channel_calibration( + model.quant, base_name="weight", value=model.quant.weight + ) state_dict = model.state_dict() assert len(state_dict) == 4 diff --git a/tests/test_quantization/test_observers/__init__.py b/tests/test_configs/__init__.py similarity index 100% rename from tests/test_quantization/test_observers/__init__.py rename to tests/test_configs/__init__.py diff --git a/tests/test_configs/test_base.py b/tests/test_configs/test_base.py new file mode 100644 index 00000000..5334ef38 --- /dev/null +++ b/tests/test_configs/test_base.py @@ -0,0 +1,57 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from compressed_tensors.config import SparsityStructure + + +def test_sparsity_structure_valid_cases(): + assert ( + SparsityStructure("2:4") == SparsityStructure.TWO_FOUR + ), "Failed to match '2:4' with TWO_FOUR" + assert ( + SparsityStructure("unstructured") == SparsityStructure.UNSTRUCTURED + ), "Failed to match 'unstructured' with UNSTRUCTURED" + assert ( + SparsityStructure("UNSTRUCTURED") == SparsityStructure.UNSTRUCTURED + ), "Failed to match 'UNSTRUCTURED' with UNSTRUCTURED" + assert ( + SparsityStructure(None) == SparsityStructure.UNSTRUCTURED + ), "Failed to match None with UNSTRUCTURED" + + +def test_sparsity_structure_invalid_case(): + with pytest.raises(ValueError, match="invalid is not a valid SparsityStructure"): + SparsityStructure("invalid") + + +def test_sparsity_structure_case_insensitivity(): + assert ( + SparsityStructure("2:4") == SparsityStructure.TWO_FOUR + ), "Failed to match '2:4' with TWO_FOUR" + assert ( + SparsityStructure("2:4".upper()) == SparsityStructure.TWO_FOUR + ), "Failed to match '2:4'.upper() with TWO_FOUR" + assert ( + SparsityStructure("unstructured".upper()) == SparsityStructure.UNSTRUCTURED + ), "Failed to match 'unstructured'.upper() with UNSTRUCTURED" + assert ( + SparsityStructure("UNSTRUCTURED".lower()) == SparsityStructure.UNSTRUCTURED + ), "Failed to match 'UNSTRUCTURED'.lower() with UNSTRUCTURED" + + +def test_sparsity_structure_default_case(): + assert ( + SparsityStructure(None) == SparsityStructure.UNSTRUCTURED + ), "Failed to match None with UNSTRUCTURED" diff --git a/tests/test_quantization/lifecycle/conftest.py b/tests/test_quantization/lifecycle/conftest.py index 97bf8b0c..49b5eda3 100644 --- a/tests/test_quantization/lifecycle/conftest.py +++ b/tests/test_quantization/lifecycle/conftest.py @@ -15,7 +15,9 @@ from typing import List, Optional import pytest +import torch from compressed_tensors.quantization.quant_args import QuantizationArgs +from compressed_tensors.quantization.quant_config import QuantizationStatus from compressed_tensors.quantization.quant_scheme import QuantizationScheme @@ -35,3 +37,11 @@ def quantization_scheme( ) return quantization_scheme + + +@pytest.fixture +def mock_frozen(): + def update_status(model: torch.nn.Module): + model.quantization_status = QuantizationStatus.FROZEN + + return update_status diff --git a/tests/test_quantization/lifecycle/test_apply.py b/tests/test_quantization/lifecycle/test_apply.py index 511cfdf7..7268ca27 100644 --- a/tests/test_quantization/lifecycle/test_apply.py +++ b/tests/test_quantization/lifecycle/test_apply.py @@ -22,7 +22,6 @@ DEFAULT_QUANTIZATION_METHOD, QuantizationConfig, QuantizationStatus, - freeze_module_quantization, ) from compressed_tensors.quantization.lifecycle import ( apply_quantization_config, @@ -33,7 +32,7 @@ from transformers import AutoModelForCausalLM -def test_target_prioritization(): +def test_target_prioritization(mock_frozen): # tests that the config_groups are applied in the correct order # of priority, where exact layer name > regex > module name config = { @@ -69,7 +68,7 @@ def test_target_prioritization(): config = QuantizationConfig(**config) config.quantization_status = QuantizationStatus.CALIBRATION apply_quantization_config(model, config) - model.apply(freeze_module_quantization) + mock_frozen(model) for name, module in iter_named_leaf_modules(model): if name == "model.layers.0.mlp.down_proj": @@ -149,7 +148,6 @@ def test_serialize_config_tinyllama(): assert serialized_config.config_groups["group_0"].input_activations is None assert serialized_config.config_groups["group_1"].targets == ["Linear"] assert serialized_config.config_groups["group_1"].input_activations is not None - assert serialized_config.quantization_status == QuantizationStatus.FROZEN assert serialized_config.format == CompressionFormat.dense.value assert serialized_config.quant_method == DEFAULT_QUANTIZATION_METHOD assert serialized_config.ignore == ["model.layers.1.mlp.down_proj"] @@ -238,14 +236,8 @@ def get_sample_tinyllama_quant_config(status: str = "frozen"): def test_apply_quantization_status(caplog, ignore, should_raise_warning): import logging - from transformers import AutoModelForCausalLM - # load a dense, unquantized tiny llama model - model_name = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T" - model = AutoModelForCausalLM.from_pretrained( - model_name, device_map="cpu", torch_dtype="auto" - ) - + model = get_tinyllama_model() quantization_config_dict = { "quant_method": "sparseml", "format": "pack-quantized", diff --git a/tests/test_quantization/lifecycle/test_dynamic_lifecycle.py b/tests/test_quantization/lifecycle/test_dynamic_lifecycle.py index 45d49370..dd700637 100644 --- a/tests/test_quantization/lifecycle/test_dynamic_lifecycle.py +++ b/tests/test_quantization/lifecycle/test_dynamic_lifecycle.py @@ -14,15 +14,13 @@ import torch -from compressed_tensors.quantization.lifecycle import ( - apply_quantization_config, - freeze_module_quantization, -) +from compressed_tensors.quantization.lifecycle import apply_quantization_config from compressed_tensors.quantization.quant_config import QuantizationConfig from transformers import AutoModelForCausalLM def test_apply_tinyllama_dynamic_activations(): + # NOTE: should not calibrate dynamic quant quant_config = get_sample_dynamic_tinyllama_quant_config() model = get_tinyllama_model() @@ -38,8 +36,6 @@ def test_apply_tinyllama_dynamic_activations(): # verify forward works w/ dynamic during calibration model(torch.zeros((1, 1), dtype=int), torch.zeros((1, 1), dtype=int)) - # freeze and test that only weight observers are deleted - model.apply(freeze_module_quantization) _test_linears_dynamic_quantization_status(model, quant_config, frozen=True) # verify forward works w/ dynamic after freeze model(torch.zeros((1, 1), dtype=int), torch.zeros((1, 1), dtype=int)) @@ -78,7 +74,6 @@ def _test_layer_dynamic_quantization_status( # check weights always have scale/zp and observer only if not frozen assert hasattr(module, "weight_scale") == weights assert hasattr(module, "weight_zero_point") == weights - assert hasattr(module, "weight_observer") == (weights and not frozen) def get_tinyllama_model(): diff --git a/tests/test_quantization/lifecycle/test_forward.py b/tests/test_quantization/lifecycle/test_forward.py index b9ee67ff..542cd8b9 100644 --- a/tests/test_quantization/lifecycle/test_forward.py +++ b/tests/test_quantization/lifecycle/test_forward.py @@ -17,10 +17,9 @@ import torch from compressed_tensors.quantization.lifecycle.forward import ( dequantize, - maybe_calibrate_or_quantize, + forward_quantize, quantize, wrap_module_forward_quantized, - wrap_module_forward_quantized_attn, ) from compressed_tensors.quantization.lifecycle.initialize import ( initialize_module_for_quantization, @@ -55,10 +54,10 @@ def test_wrap_module_forward_quantized(create_quantization_scheme): assert not func_forward == layer.forward.__func__ -@pytest.mark.parametrize( - "quantization_status", ["initialized", "calibration", "frozen"] -) -def test_maybe_calibrate_or_quantize(create_quantization_scheme, quantization_status): +@pytest.mark.parametrize("quantization_status", ["initialized", "calibration"]) +def test_forward_quantize( + mock_per_tensor_calibration, create_quantization_scheme, quantization_status +): num_bits = 8 quantization_scheme = create_quantization_scheme( targets=["*"], @@ -72,26 +71,24 @@ def test_maybe_calibrate_or_quantize(create_quantization_scheme, quantization_st dummy_tensor = torch.randn(8, 4) # (num_tokens, num_features) layer.quantization_status = QuantizationStatus(quantization_status) - initialize_module_for_quantization(layer, quantization_scheme) - # only calibration updates the scale and zero-point if layer.quantization_status == QuantizationStatus.INITIALIZED: - out = maybe_calibrate_or_quantize( - layer, dummy_tensor, "input", quantization_args - ) - assert torch.allclose(out, dummy_tensor) + # Init zp and scales + initialize_module_for_quantization(layer, quantization_scheme) + # mock weight calibration + mock_per_tensor_calibration(layer, "weight", value=layer.weight.data) + # call quant/dequant on weights + out = forward_quantize(layer, layer.weight, "weight", quantization_args) + assert torch.allclose(out, layer.weight.data, atol=0.2) elif layer.quantization_status == QuantizationStatus.CALIBRATION: - out = maybe_calibrate_or_quantize( - layer, dummy_tensor, "input", quantization_args - ) + # init zp/scales + initialize_module_for_quantization(layer, quantization_scheme) + # run weight and input calibration + mock_per_tensor_calibration(layer, "weight", value=layer.weight.data) + mock_per_tensor_calibration(layer, "input", value=dummy_tensor) + # call quant/dequant on inputs + out = forward_quantize(layer, dummy_tensor, "input", quantization_args) assert torch.allclose(out, dummy_tensor, atol=0.2) - assert layer.input_observer._num_observed_tokens == dummy_tensor.shape[0] - elif layer.quantization_status == QuantizationStatus.FROZEN: - # scale and zero points are empty -- cannot quantize - with pytest.raises(Exception): - out = maybe_calibrate_or_quantize( - layer, layer.weight.data, "input", quantization_args - ) @pytest.mark.parametrize( @@ -206,21 +203,3 @@ def test_dequantize(num_bits, type, strategy, group_size, scale, zero_point, g_i dtype=None, g_idx=g_idx, ) - - -def test_wrap_module_forward_quantized_attn(create_quantization_scheme): - num_bits = 8 - quantization_scheme = create_quantization_scheme( - targets=["self_attn"], - weights=QuantizationArgs(num_bits=num_bits, symmetric=True), - input_activations=QuantizationArgs(num_bits=num_bits, symmetric=False), - ) - - mock_attn_layer = Linear(4, 4) - - attn_forward = mock_attn_layer.forward.__func__ - - # check that the forward call is overwritten - wrap_module_forward_quantized_attn(mock_attn_layer, quantization_scheme) - - assert not attn_forward == mock_attn_layer.forward.__func__ diff --git a/tests/test_quantization/lifecycle/test_frozen.py b/tests/test_quantization/lifecycle/test_frozen.py deleted file mode 100644 index 056c6089..00000000 --- a/tests/test_quantization/lifecycle/test_frozen.py +++ /dev/null @@ -1,47 +0,0 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from compressed_tensors.quantization.lifecycle.frozen import freeze_module_quantization -from compressed_tensors.quantization.lifecycle.initialize import ( - initialize_module_for_quantization, -) -from compressed_tensors.quantization.quant_args import QuantizationArgs -from compressed_tensors.quantization.quant_config import QuantizationStatus -from torch.nn import Linear - - -def test_set_module_for_calibration(create_quantization_scheme): - num_bits = 8 - quantization_scheme = create_quantization_scheme( - targets=["*"], - weights=QuantizationArgs(num_bits=num_bits, symmetric=True), - input_activations=QuantizationArgs(num_bits=num_bits, symmetric=False), - ) - - layer = Linear(4, 4) - - initialize_module_for_quantization(layer, quantization_scheme) - layer.quantization_status = QuantizationStatus("calibration") - - # should have both input and weight observer after initalizing - assert hasattr(layer, "input_observer") - assert hasattr(layer, "weight_observer") - - # observers should get deleted after freezing - freeze_module_quantization(layer) - assert not hasattr(layer, "input_observer") - assert not hasattr(layer, "weight_observer") - - assert layer.quantization_status == QuantizationStatus("frozen") diff --git a/tests/test_quantization/lifecycle/test_kv_cache.py b/tests/test_quantization/lifecycle/test_kv_cache.py deleted file mode 100644 index b72f5265..00000000 --- a/tests/test_quantization/lifecycle/test_kv_cache.py +++ /dev/null @@ -1,78 +0,0 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import pytest -import torch -from compressed_tensors.quantization import ( - QuantizationConfig, - QuantizationStatus, - apply_quantization_config, - freeze_module_quantization, -) -from transformers import AutoModelForCausalLM - - -config = { - "quant_method": "compressed-tensors", - "format": "fakequant", - "kv_cache_scheme": { - "num_bits": 8, - "type": "int", - "symmetric": True, - "strategy": "tensor", - }, - "config_groups": { - "group_1": { - "weights": { - "num_bits": 4, - "type": "int", - "symmetric": True, - "strategy": "tensor", - }, - "targets": ["Linear"], - }, - }, -} - - -@pytest.mark.parametrize("config", [config]) -def test_kv_cache_quantization(config): - - sample = { - name: torch.ones((1, 32)).long() - for name in ["input_ids", "attention_mask", "labels"] - } - model = AutoModelForCausalLM.from_pretrained( - "HuggingFaceM4/tiny-random-LlamaForCausalLM", - torch_dtype="auto", - ) - model.eval() - - config = QuantizationConfig(**config) - config.quantization_status = QuantizationStatus.CALIBRATION - apply_quantization_config(model, config) - - with torch.no_grad(): - _ = model(**sample) - - model.apply(freeze_module_quantization) - - reloaded_config = QuantizationConfig.from_pretrained(model) - - assert ( - config.kv_cache_scheme.model_dump().keys() - == reloaded_config.kv_cache_scheme.model_dump().keys() - ) - assert list(config.kv_cache_scheme.model_dump().values()) == list( - reloaded_config.kv_cache_scheme.model_dump().values() - ) diff --git a/tests/test_quantization/lifecycle/test_lifecycle.py b/tests/test_quantization/lifecycle/test_lifecycle.py index 970e4de9..8f3e2dd0 100644 --- a/tests/test_quantization/lifecycle/test_lifecycle.py +++ b/tests/test_quantization/lifecycle/test_lifecycle.py @@ -15,10 +15,6 @@ from copy import deepcopy import torch -from compressed_tensors.quantization.lifecycle.calibration import ( - set_module_for_calibration, -) -from compressed_tensors.quantization.lifecycle.frozen import freeze_module_quantization from compressed_tensors.quantization.lifecycle.initialize import ( initialize_module_for_quantization, ) @@ -27,7 +23,7 @@ from torch.nn import Linear -def test_lifecyle(create_quantization_scheme): +def test_lifecyle(mock_per_tensor_calibration, create_quantization_scheme): num_bits = 8 quantization_scheme = create_quantization_scheme( @@ -59,18 +55,10 @@ def test_lifecyle(create_quantization_scheme): expected_layer_keys.remove(key) assert len(expected_layer_keys) == 0 - # should have both input and weight observer after initalizing - assert hasattr(layer, "input_observer") - assert hasattr(layer, "weight_observer") - assert hasattr(layer, "quantization_scheme") assert hasattr(layer, "quantization_status") assert layer.quantization_status == QuantizationStatus.INITIALIZED - set_module_for_calibration(layer) - assert layer.quantization_status == QuantizationStatus.CALIBRATION - - # do a calibration step assert torch.numel(layer.input_zero_point.data) == 1 assert torch.numel(layer.input_scale) == 1 assert torch.numel(layer.weight_scale) == 1 @@ -78,7 +66,10 @@ def test_lifecyle(create_quantization_scheme): random_input = torch.randn(4, 4) random_input[0][0] = 42 # skew distribution to force non-zero zp - layer(random_input) + + # do a calibration step + mock_per_tensor_calibration(layer, "weight", value=layer.weight) + mock_per_tensor_calibration(layer, "input", value=random_input) # zero-points and scale should be updated after forward pass assert torch.numel(layer.input_zero_point.data) > 0 @@ -99,7 +90,9 @@ def test_lifecyle(create_quantization_scheme): for _ in range(10): random_input = torch.randn(4, 4) random_input[0][0] = 42 # skew distribution to force non-zero zp - layer(random_input) + + mock_per_tensor_calibration(layer, "weight", value=layer.weight) + mock_per_tensor_calibration(layer, "input", value=random_input) assert initialized_layer_input_zero_point != 0 assert initialized_layer_input_scale != layer.input_scale @@ -113,9 +106,6 @@ def test_lifecyle(create_quantization_scheme): layer_before_freeze_input_scale = deepcopy(layer.input_scale) layer_before_freeze_weight_scale = deepcopy(layer.weight_scale) - # Freeze, no update after any forward pass - freeze_module_quantization(layer) - for _ in range(10): layer(torch.randn(4, 4)) assert layer_before_freeze_input_zero_point == layer.input_zero_point diff --git a/tests/test_quantization/test_cache.py b/tests/test_quantization/test_cache.py deleted file mode 100644 index 941af70f..00000000 --- a/tests/test_quantization/test_cache.py +++ /dev/null @@ -1,116 +0,0 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -from compressed_tensors.quantization.cache import QuantizedKVParameterCache -from compressed_tensors.quantization.quant_args import QuantizationArgs - - -def test_is_quantized_cache_singleton(): - """ - Check if quantized_cache is a singleton, used for - passing in QuantizedKVParameterCache to the forward call of - the model's self_attn - """ - - args = QuantizationArgs() - cache: QuantizedKVParameterCache = args.get_kv_cache() - observer = args.get_observer() - - tensor = torch.tensor([1, 2, 3]) - cache.k_scales.append(tensor) - cache.k_observers.append(observer) - - same_cache = args.get_kv_cache() - - assert len(cache.k_scales) == len(same_cache.k_scales) - assert torch.equal(cache.k_scales[0], same_cache.k_scales[0]) - - assert cache.k_observers == same_cache.k_observers - assert hex(id(cache.k_observers[0])) == hex(id(same_cache.k_observers[0])) - - cache.reset() - - -def test_update(): - - nbits = 8 - args = QuantizationArgs(nbits=nbits, symmetric=True) - cache: QuantizedKVParameterCache = args.get_kv_cache() - - max_key_states_val = 1.0 - max_value_states_val = 2.0 - key_states = torch.cat( - (max_key_states_val * torch.ones(1, 2, 2), torch.ones(1, 2, 2)), dim=0 - ) - value_states = torch.cat( - (max_value_states_val * torch.ones(1, 2, 2), torch.ones(1, 2, 2)), dim=0 - ) - layer_idx = 0 - - cache.update(key_states, value_states, layer_idx) - denom = (2 ** (nbits) - 1) / 2 - expected_k_scale = torch.tensor([max_key_states_val / denom]) - expected_v_scale = torch.tensor([max_value_states_val / denom]) - - assert cache.k_scales[0] == expected_k_scale - assert cache.v_scales[0] == expected_v_scale - - # new attn layer - layer_idx = 1 - cache.update(key_states, value_states, layer_idx) - - assert len(cache.k_scales) == 2 - assert len(cache.v_scales) == 2 - - assert len(cache.k_observers) == 2 - assert len(cache.v_observers) == 2 - - cache.reset() - - -def test_cache_reset(): - nbits = 8 - args = QuantizationArgs(nbits=nbits, symmetric=True) - cache: QuantizedKVParameterCache = args.get_kv_cache() - - max_key_states_val = 1.0 - max_value_states_val = 2.0 - key_states = torch.cat( - (max_key_states_val * torch.ones(1, 2, 2), torch.ones(1, 2, 2)), dim=0 - ) - value_states = torch.cat( - (max_value_states_val * torch.ones(1, 2, 2), torch.ones(1, 2, 2)), dim=0 - ) - layer_idx = 0 - - cache.update(key_states, value_states, layer_idx) - assert len(cache.k_scales) == 1 - assert len(cache.v_scales) == 1 - - assert len(cache.k_observers) == 1 - assert len(cache.v_observers) == 1 - - cache.reset() - - # new instance, different memory addr - different_cache: QuantizedKVParameterCache = args.get_kv_cache() - - assert len(different_cache.k_scales) == 0 - assert len(different_cache.v_scales) == 0 - - assert len(different_cache.k_observers) == 0 - assert len(different_cache.v_observers) == 0 - - assert hex(id(cache)) != hex(id(different_cache)) diff --git a/tests/test_quantization/test_configs/test_bit_depths.py b/tests/test_quantization/test_configs/test_bit_depths.py index c3bf39c9..4fa63306 100644 --- a/tests/test_quantization/test_configs/test_bit_depths.py +++ b/tests/test_quantization/test_configs/test_bit_depths.py @@ -26,7 +26,6 @@ def create_config(bit_depth, quant_type, input_symmetry, weight_symmetry): - print(quant_type) weights = QuantizationArgs( num_bits=bit_depth, type=quant_type, symmetric=weight_symmetry ) @@ -53,7 +52,9 @@ def create_config(bit_depth, quant_type, input_symmetry, weight_symmetry): @pytest.mark.parametrize("quant_type", ["int"]) @pytest.mark.parametrize("input_symmetry", [True, False, None]) @pytest.mark.parametrize("weight_symmetry", [True, False]) -def test_bit_depths(bit_depth, quant_type, input_symmetry, weight_symmetry): +def test_bit_depths( + mock_per_tensor_calibration, bit_depth, quant_type, input_symmetry, weight_symmetry +): model = Linear(64, 64) quant_config = create_config(bit_depth, quant_type, input_symmetry, weight_symmetry) apply_quantization_config(model, quant_config) @@ -62,8 +63,17 @@ def test_bit_depths(bit_depth, quant_type, input_symmetry, weight_symmetry): max = int(2**bit_depth / 2) - 1 inputs = torch.randn(32, 64) - model(inputs) + model.apply( + lambda module: mock_per_tensor_calibration( + module, base_name="weight", value=model.weight + ) + ) if input_symmetry is not None: + model.apply( + lambda module: mock_per_tensor_calibration( + module, base_name="input", value=inputs + ) + ) assert model.input_zero_point >= min assert model.input_zero_point <= max @@ -105,7 +115,9 @@ def test_bit_depths(bit_depth, quant_type, input_symmetry, weight_symmetry): @pytest.mark.parametrize("quant_type", ["float"]) @pytest.mark.parametrize("input_symmetry", [True, False, None]) @pytest.mark.parametrize("weight_symmetry", [True, False]) -def test_fp8(bit_depth, quant_type, input_symmetry, weight_symmetry): +def test_fp8( + mock_per_tensor_calibration, bit_depth, quant_type, input_symmetry, weight_symmetry +): model = Linear(64, 64) quant_config = create_config(bit_depth, quant_type, input_symmetry, weight_symmetry) apply_quantization_config(model, quant_config) @@ -115,10 +127,19 @@ def test_fp8(bit_depth, quant_type, input_symmetry, weight_symmetry): max = dtype_info.max inputs = torch.randn(32, 64) - model(inputs) + model.apply( + lambda module: mock_per_tensor_calibration( + module, base_name="weight", value=model.weight + ) + ) assert model.weight_zero_point.dtype == torch.float8_e4m3fn model.weight_zero_point.data = model.weight_zero_point.to(model.weight.dtype) if input_symmetry is not None: + model.apply( + lambda module: mock_per_tensor_calibration( + module, base_name="input", value=inputs + ) + ) assert model.input_zero_point.dtype == torch.float8_e4m3fn model.input_zero_point.data = model.input_zero_point.to(model.weight.dtype) assert model.input_zero_point >= min diff --git a/tests/test_quantization/test_configs/test_strategies.py b/tests/test_quantization/test_configs/test_strategies.py index a6f424ce..94201463 100644 --- a/tests/test_quantization/test_configs/test_strategies.py +++ b/tests/test_quantization/test_configs/test_strategies.py @@ -53,7 +53,9 @@ def create_config( @pytest.mark.parametrize("input_symmetry", [None]) @pytest.mark.parametrize("weight_symmetry", [True, False]) @pytest.mark.parametrize("model_shape", [(64, 128), (300, 200), (400, 400)]) -def test_channelwise(input_symmetry, weight_symmetry, model_shape): +def test_channelwise( + mock_per_channel_calibration, input_symmetry, weight_symmetry, model_shape +): model = Linear(model_shape[0], model_shape[1]) quant_config = create_config( input_symmetry, weight_symmetry, w_strategy=QuantizationStrategy.CHANNEL @@ -61,7 +63,9 @@ def test_channelwise(input_symmetry, weight_symmetry, model_shape): apply_quantization_config(model, quant_config) inputs = torch.randn(32, model_shape[0]) - model(inputs) + mock_per_channel_calibration(model, base_name="weight", value=model.weight) + if input_symmetry is not None: + mock_per_channel_calibration(model, base_name="input", value=inputs) assert list(model.weight_scale.shape) == [model_shape[1], 1] assert list(model.weight_zero_point.shape) == [model_shape[1], 1] @@ -72,7 +76,9 @@ def test_channelwise(input_symmetry, weight_symmetry, model_shape): @pytest.mark.parametrize("weight_symmetry", [True, False]) @pytest.mark.parametrize("model_shape", [(128, 256), (256, 512), (512, 1024)]) @pytest.mark.parametrize("group_size", [32, 128]) -def test_group(input_symmetry, weight_symmetry, model_shape, group_size): +def test_group( + mock_per_group_calibration, input_symmetry, weight_symmetry, model_shape, group_size +): model = Linear(model_shape[0], model_shape[1]) quant_config = create_config( input_symmetry, @@ -83,7 +89,13 @@ def test_group(input_symmetry, weight_symmetry, model_shape, group_size): apply_quantization_config(model, quant_config) inputs = torch.randn(128, model_shape[0]) - model(inputs) + mock_per_group_calibration( + model, base_name="weight", value=model.weight, group_size=group_size + ) + if input_symmetry is not None: + mock_per_group_calibration( + model, base_name="input", value=inputs, group_size=group_size + ) assert list(model.weight_scale.shape) == [ model_shape[1], @@ -99,7 +111,13 @@ def test_group(input_symmetry, weight_symmetry, model_shape, group_size): @pytest.mark.parametrize("input_symmetry", [True, False]) @pytest.mark.parametrize("weight_symmetry", [True, False]) @pytest.mark.parametrize("input_shape", [(32, 256), (300, 200), (400, 400)]) -def test_token(input_symmetry, weight_symmetry, input_shape): +def test_token( + mock_per_channel_calibration, + mock_per_token_calibration, + input_symmetry, + weight_symmetry, + input_shape, +): model = Linear(input_shape[1], 256) quant_config = create_config( input_symmetry, @@ -110,7 +128,8 @@ def test_token(input_symmetry, weight_symmetry, input_shape): apply_quantization_config(model, quant_config) inputs = torch.randn(input_shape) - model(inputs) + mock_per_channel_calibration(model, base_name="weight", value=model.weight) + mock_per_token_calibration(model, base_name="input", value=inputs) assert list(model.input_scale.shape) == [1, 1] assert list(model.input_zero_point.shape) == [1, 1] diff --git a/tests/test_quantization/test_observers/test_helpers.py b/tests/test_quantization/test_observers/test_helpers.py deleted file mode 100644 index 0b976a25..00000000 --- a/tests/test_quantization/test_observers/test_helpers.py +++ /dev/null @@ -1,91 +0,0 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from compressed_tensors.quantization import ( - QuantizationConfig, - apply_quantization_config, -) -from compressed_tensors.quantization.observers.helpers import get_observer_token_count -from transformers import AutoModelForCausalLM, AutoTokenizer - - -def test_get_observer_token_count(): - model = AutoModelForCausalLM.from_pretrained("Isotonic/TinyMixtral-4x248M-MoE") - tokenizer = AutoTokenizer.from_pretrained("Isotonic/TinyMixtral-4x248M-MoE") - model.eval() - config = QuantizationConfig( - format="fakequant", - quantization_status="calibration", - config_groups={ - "group_1": { - "input_activations": { - "num_bits": 8, - "type": "int", - "symmetric": False, - "strategy": "tensor", - }, - "targets": ["Linear"], - }, - }, - ) - apply_quantization_config(model, config) - - # start calibration - calib_list = [ - "I am a string that", - "is used for calibration so", - "that your model is", - "quantized properly.", - ] - - total_num_tokens_observed = 0 - for calib_sample in calib_list: - calib_tensor = tokenizer(calib_sample, return_tensors="pt") - _ = model(**calib_tensor) - total_num_tokens_observed += len(calib_tensor.input_ids.flatten()) - - counter = get_observer_token_count(model) - - # filter out the None values - # (tokens, in the appropriate format, that were not observed by the model) - counter = {k: v for k, v in counter.items() if v is not None} - - # iterate over all the layers in the model where the token count in the proper - # format is has been observed - for i in range(model.config.num_hidden_layers): - # fetch the tokens observed by the router - tokens_observed_by_router = counter.pop( - f"model.layers.{i}.block_sparse_moe.gate" - ) - assert tokens_observed_by_router == total_num_tokens_observed - - # fetch the sum of tokens observed by all the experts - sum_tokens_observed_by_experts = 0 - keys_for_this_layer = [ - k - for k in counter.keys() - if f"model.layers.{i}.block_sparse_moe.experts" in k - ] - for key in keys_for_this_layer: - sum_tokens_observed_by_experts += counter.pop(key) - - # each Mixtral expert is comprised of 3 linear layers, - # so we need to multiply by 3 - assert ( - sum_tokens_observed_by_experts - == total_num_tokens_observed * model.config.num_experts_per_tok * 3 - ) - - # there are no more information in the counter - assert len(counter) == 0 diff --git a/tests/test_quantization/test_observers/test_min_max.py b/tests/test_quantization/test_observers/test_min_max.py deleted file mode 100644 index b9333332..00000000 --- a/tests/test_quantization/test_observers/test_min_max.py +++ /dev/null @@ -1,113 +0,0 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import pytest -import torch -from compressed_tensors.quantization.quant_args import QuantizationArgs - - -def make_dummy_g_idx(columns: int, group_size: int) -> torch.Tensor: - perm = torch.randperm(columns) - return torch.tensor([index // group_size for index in range(columns)])[perm] - - -@pytest.mark.parametrize( - "symmetric,expected_scale,expected_zero_point", - [ - (True, 0.0078, 0), - (False, 0.0039, -128), - ], -) -def test_min_max_observer(symmetric, expected_scale, expected_zero_point): - tensor = torch.tensor([1, 1, 1, 1, 1]) - num_bits = 8 - weights = QuantizationArgs(num_bits=num_bits, symmetric=symmetric) - - observer = weights.get_observer() - scale, zero_point = observer(tensor) - - assert round(scale.item(), 4) == expected_scale - assert round(zero_point.item(), 4) == expected_zero_point - - -def test_min_max_observer_symmetric_scale_range(): - tensor = torch.rand(4, 4) - tensor *= 127 - - num_bits = 8 - weights = QuantizationArgs(num_bits=num_bits, symmetric=True) - - observer = weights.get_observer() - scale, zero_point = observer(tensor) - - # if symmetric, max symmetric_range = abs(-128) / 255 - assert round(scale.item(), 4) <= 1.0039 - assert round(zero_point.item(), 4) == 0 - - -def test_min_max_observer_value_update(): - inp = torch.tensor([1, 1, 1, 1, 1]) - inp_update_max = torch.tensor([127, 1, 1, 1, 1]) - inp_update_min = torch.tensor([-128, 1, 1, 1, 1]) - - delta = 1e-6 - - # update the min, max twice total - tensors = [ - inp, - inp, - inp_update_max, # update max - inp, - inp_update_min, # update min - ] - - tensor = inp - num_bits = 8 - weights = QuantizationArgs(num_bits=num_bits, symmetric=True) - - observer = weights.get_observer() - curr_max = 1 - curr_min = 1 - for i, tensor in enumerate(tensors): - observer(tensor) - curr_max = max(observer.max_val.get("default"), curr_max) - curr_min = min(observer.min_val.get("default"), curr_max) - - if i < 2: - assert curr_max == 1 - assert curr_min == 1 - elif i < 4: - assert abs(curr_max - 2.2600) < delta - assert curr_min == 1 - else: - assert abs(curr_max - 2.2600) < delta - assert abs(curr_min - (-0.2900)) < delta - - -def test_g_idx(): - group_size = 2 - input_shape = (128, 512) - tensor = torch.rand(input_shape) - weights = QuantizationArgs(num_bits=8, group_size=group_size) - g_idx = make_dummy_g_idx(tensor.shape[1], group_size) - - observer = weights.get_observer() - scale_g_idx, zero_point_g_idx = observer(tensor, g_idx=g_idx) - - observer.reset() - scale, zero_point = observer(tensor[:, torch.argsort(g_idx)]) - - assert scale_g_idx == pytest.approx(scale) - assert zero_point_g_idx == pytest.approx(zero_point) diff --git a/tests/test_quantization/test_observers/test_mse.py b/tests/test_quantization/test_observers/test_mse.py deleted file mode 100644 index 76a317f4..00000000 --- a/tests/test_quantization/test_observers/test_mse.py +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import pytest -import torch -from compressed_tensors.quantization.observers import MovingAverageMSEObserver -from compressed_tensors.quantization.quant_args import QuantizationArgs - - -@pytest.mark.parametrize( - "symmetric,expected_scale,expected_zero_point", - [ - (True, 0.0078, 0), - (False, 0.0039, -128), - ], -) -def test_mse_observer(symmetric, expected_scale, expected_zero_point): - tensor = torch.tensor([1, 1, 1, 1, 1]) - num_bits = 8 - weights = QuantizationArgs(num_bits=num_bits, symmetric=symmetric, observer="mse") - - observer = weights.get_observer() - scale, zero_point = observer(tensor) - - assert isinstance(observer, MovingAverageMSEObserver) - assert round(scale.item(), 4) == expected_scale - assert round(zero_point.item(), 4) == expected_zero_point - - -def test_mse_observer_symmetric_scale_range(): - tensor = torch.rand(4, 4) - tensor *= 127 - - num_bits = 8 - weights = QuantizationArgs(num_bits=num_bits, symmetric=True) - - observer = weights.get_observer() - scale, zero_point = observer(tensor) - - # if symmetric, max symmetric_range = abs(-128) / 255 - assert round(scale.item(), 4) <= 1.0039 - assert round(zero_point.item(), 4) == 0