From 74f1aa6b215c99287a2439c22f6f71d87a2314c2 Mon Sep 17 00:00:00 2001 From: George Date: Wed, 25 Sep 2024 11:12:20 -0400 Subject: [PATCH] [KV-Cache] Make k_scale, v_scale as attributes of self_attn using HFCache (#148) * init * init * delete unnces file * pass if no seed * pre polish * clean up * post clean up, merge main * tests and get rid of iter_named_leaf_modules, use iter_named_modules * comments * pass tests * mgoin comments * frozen state for inference * only compute scale, zp, do not keep quantized_key|value_states * do calibration if there is kv_cache_scheme * pass test * comments --- .../compressors/model_compressor.py | 56 ----- .../quantization/__init__.py | 1 + src/compressed_tensors/quantization/cache.py | 201 ++++++++++++++++++ .../quantization/lifecycle/apply.py | 14 +- .../quantization/lifecycle/forward.py | 55 ++++- .../quantization/lifecycle/frozen.py | 7 +- .../quantization/lifecycle/initialize.py | 164 ++++++++------ .../quantization/quant_args.py | 6 + .../quantization/quant_config.py | 9 +- .../quantization/utils/helpers.py | 61 ++++-- .../lifecycle/test_forward.py | 19 ++ .../lifecycle/test_kv_cache.py | 26 --- tests/test_quantization/test_cache.py | 116 ++++++++++ 13 files changed, 566 insertions(+), 169 deletions(-) create mode 100644 src/compressed_tensors/quantization/cache.py create mode 100644 tests/test_quantization/test_cache.py diff --git a/src/compressed_tensors/compressors/model_compressor.py b/src/compressed_tensors/compressors/model_compressor.py index 28d1a7c3..46a2d708 100644 --- a/src/compressed_tensors/compressors/model_compressor.py +++ b/src/compressed_tensors/compressors/model_compressor.py @@ -252,62 +252,6 @@ def compress( compressed_state_dict ) - # HACK (mgoin): Post-process step for kv cache scales to take the - # k/v_proj module `output_scale` parameters, and store them in the - # parent attention module as `k_scale` and `v_scale` - # - # Example: - # Replace `model.layers.0.self_attn.k_proj.output_scale` - # with `model.layers.0.self_attn.k_scale` - if ( - self.quantization_config is not None - and self.quantization_config.kv_cache_scheme is not None - ): - # HACK (mgoin): We assume the quantized modules in question - # will be k_proj and v_proj since those are the default targets. - # We check that both of these modules have output activation - # quantization, and additionally check that q_proj doesn't. - q_proj_has_no_quant_output = 0 - k_proj_has_quant_output = 0 - v_proj_has_quant_output = 0 - for name, module in model.named_modules(): - if not hasattr(module, "quantization_scheme"): - # We still want to count non-quantized q_proj - if name.endswith(".q_proj"): - q_proj_has_no_quant_output += 1 - continue - out_act = module.quantization_scheme.output_activations - if name.endswith(".q_proj") and out_act is None: - q_proj_has_no_quant_output += 1 - elif name.endswith(".k_proj") and out_act is not None: - k_proj_has_quant_output += 1 - elif name.endswith(".v_proj") and out_act is not None: - v_proj_has_quant_output += 1 - - assert ( - q_proj_has_no_quant_output > 0 - and k_proj_has_quant_output > 0 - and v_proj_has_quant_output > 0 - ) - assert ( - q_proj_has_no_quant_output - == k_proj_has_quant_output - == v_proj_has_quant_output - ) - - # Move all .k/v_proj.output_scale parameters to .k/v_scale - working_state_dict = {} - for key in compressed_state_dict.keys(): - if key.endswith(".k_proj.output_scale"): - new_key = key.replace(".k_proj.output_scale", ".k_scale") - working_state_dict[new_key] = compressed_state_dict[key] - elif key.endswith(".v_proj.output_scale"): - new_key = key.replace(".v_proj.output_scale", ".v_scale") - working_state_dict[new_key] = compressed_state_dict[key] - else: - working_state_dict[key] = compressed_state_dict[key] - compressed_state_dict = working_state_dict - # HACK: Override the dtype_byte_size function in transformers to # support float8 types. Fix is posted upstream # https://github.com/huggingface/transformers/pull/30488 diff --git a/src/compressed_tensors/quantization/__init__.py b/src/compressed_tensors/quantization/__init__.py index 9fde69a3..848a4458 100644 --- a/src/compressed_tensors/quantization/__init__.py +++ b/src/compressed_tensors/quantization/__init__.py @@ -19,3 +19,4 @@ from .quant_config import * from .quant_scheme import * from .lifecycle import * +from .cache import QuantizedKVParameterCache diff --git a/src/compressed_tensors/quantization/cache.py b/src/compressed_tensors/quantization/cache.py new file mode 100644 index 00000000..cc33a48a --- /dev/null +++ b/src/compressed_tensors/quantization/cache.py @@ -0,0 +1,201 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from enum import Enum +from typing import Any, Dict, List, Optional, Tuple + +from compressed_tensors.quantization.observers import Observer +from compressed_tensors.quantization.quant_args import QuantizationArgs +from torch import Tensor +from transformers import DynamicCache as HFDyanmicCache + + +class KVCacheScaleType(Enum): + KEY = "k_scale" + VALUE = "v_scale" + + +class QuantizedKVParameterCache(HFDyanmicCache): + + """ + Quantized KV cache used in the forward call based on HF's dynamic cache. + Quantization strategy (tensor, group, channel) set from Quantization arg's strategy + Singleton, so that the same cache gets reused in all forward call of self_attn. + Each time forward is called, .update() is called, and ._quantize(), ._dequantize() + gets called appropriately. + The size of tensor is + `[batch_size, num_heads, seq_len - residual_length, head_dim]`. + + + Triggered by adding kv_cache_scheme in the recipe. + + Example: + + ```python3 + recipe = ''' + quant_stage: + quant_modifiers: + QuantizationModifier: + kv_cache_scheme: + num_bits: 8 + type: float + strategy: tensor + dynamic: false + symmetric: true + ''' + + """ + + _instance = None + _initialized = False + + def __new__(cls, *args, **kwargs): + """Singleton""" + if cls._instance is None: + cls._instance = super(QuantizedKVParameterCache, cls).__new__(cls) + return cls._instance + + def __init__(self, quantization_args: QuantizationArgs): + if not self._initialized: + super().__init__() + + self.quantization_args = quantization_args + + self.k_observers: List[Observer] = [] + self.v_observers: List[Observer] = [] + + # each index corresponds to layer_idx of the attention layer + self.k_scales: List[Tensor] = [] + self.v_scales: List[Tensor] = [] + + self.k_zps: List[Tensor] = [] + self.v_zps: List[Tensor] = [] + + self._initialized = True + + def update( + self, + key_states: Tensor, + value_states: Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[Tensor, Tensor]: + """ + Get the k_scale and v_scale and output the + fakequant-ed key_states and value_states + """ + + if len(self.k_observers) <= layer_idx: + k_observer = self.quantization_args.get_observer() + v_observer = self.quantization_args.get_observer() + + self.k_observers.append(k_observer) + self.v_observers.append(v_observer) + + q_key_states = self._quantize( + key_states.contiguous(), KVCacheScaleType.KEY, layer_idx + ) + q_value_states = self._quantize( + value_states.contiguous(), KVCacheScaleType.VALUE, layer_idx + ) + + qdq_key_states = self._dequantize(q_key_states, KVCacheScaleType.KEY, layer_idx) + qdq_value_states = self._dequantize( + q_value_states, KVCacheScaleType.VALUE, layer_idx + ) + + keys_to_return, values_to_return = qdq_key_states, qdq_value_states + + return keys_to_return, values_to_return + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """ + Returns the sequence length of the cached states. + A layer index can be optionally passed. + """ + if len(self.key_cache) <= layer_idx: + return 0 + # since we cannot get the seq_length of each layer directly and + # rely on `_seen_tokens` which is updated every "layer_idx" == 0, + # this is a hack to get the actual seq_length for the given layer_idx + # this part of code otherwise fails when used to + # verify attn_weight shape in some models + return self._seen_tokens if layer_idx == 0 else self._seen_tokens - 1 + + def reset_states(self): + """reset the kv states (used in calibration)""" + self.key_cache: List[Tensor] = [] + self.value_cache: List[Tensor] = [] + # Used in `generate` to keep tally of how many tokens the cache has seen + self._seen_tokens = 0 + self._quantized_key_cache: List[Tensor] = [] + self._quantized_value_cache: List[Tensor] = [] + + def reset(self): + """ + Reset the instantiation, create new instance on init + """ + QuantizedKVParameterCache._instance = None + QuantizedKVParameterCache._initialized = False + + def _quantize(self, tensor, kv_type, layer_idx): + """Quantizes a key/value using a defined quantization method.""" + from compressed_tensors.quantization.lifecycle.forward import quantize + + if kv_type == KVCacheScaleType.KEY: # key type + observer = self.k_observers[layer_idx] + scales = self.k_scales + zps = self.k_zps + else: + assert kv_type == KVCacheScaleType.VALUE + observer = self.v_observers[layer_idx] + scales = self.v_scales + zps = self.v_zps + + scale, zp = observer(tensor) + if len(scales) <= layer_idx: + scales.append(scale) + zps.append(zp) + else: + scales[layer_idx] = scale + zps[layer_idx] = scale + + q_tensor = quantize( + x=tensor, + scale=scale, + zero_point=zp, + args=self.quantization_args, + ) + return q_tensor + + def _dequantize(self, qtensor, kv_type, layer_idx): + """Dequantizes back the tensor that was quantized by `self._quantize()`""" + from compressed_tensors.quantization.lifecycle.forward import dequantize + + if kv_type == KVCacheScaleType.KEY: + scale = self.k_scales[layer_idx] + zp = self.k_zps[layer_idx] + else: + assert kv_type == KVCacheScaleType.VALUE + scale = self.v_scales[layer_idx] + zp = self.v_zps[layer_idx] + + qdq_tensor = dequantize( + x_q=qtensor, + scale=scale, + zero_point=zp, + args=self.quantization_args, + ) + return qdq_tensor diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index 7b8e240e..a66dba92 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -43,6 +43,7 @@ infer_quantization_status, is_kv_cache_quant_scheme, iter_named_leaf_modules, + iter_named_quantizable_modules, ) from compressed_tensors.utils.helpers import fix_fsdp_module_name, replace_module from compressed_tensors.utils.offload import update_parameter_data @@ -135,15 +136,23 @@ def apply_quantization_config( # list of submodules to ignore ignored_submodules = defaultdict(list) # mark appropriate layers for quantization by setting their quantization schemes - for name, submodule in iter_named_leaf_modules(model): + for name, submodule in iter_named_quantizable_modules( + model, + include_children=True, + include_attn=True, + ): # child modules and attention modules # potentially fix module name to remove FSDP wrapper prefix name = fix_fsdp_module_name(name) if matches := find_name_or_class_matches(name, submodule, config.ignore): for match in matches: ignored_submodules[match].append(name) continue # layer matches ignore list, continue + targets = find_name_or_class_matches(name, submodule, target_to_scheme) + if targets: + # mark modules to be quantized by adding + # quant scheme to the matching layers scheme = _scheme_from_targets(target_to_scheme, targets, name) if run_compressed: format = config.format @@ -200,6 +209,9 @@ def process_kv_cache_config( :param config: the QuantizationConfig :return: the QuantizationConfig with additional "kv_cache" group """ + if targets == KV_CACHE_TARGETS: + _LOGGER.info(f"KV cache targets set to default value of: {KV_CACHE_TARGETS}") + kv_cache_dict = config.kv_cache_scheme.model_dump() kv_cache_scheme = QuantizationScheme( output_activations=QuantizationArgs(**kv_cache_dict), diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index d70d9acc..4dbe4a85 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -14,9 +14,10 @@ from functools import wraps from math import ceil -from typing import Optional +from typing import Callable, Optional import torch +from compressed_tensors.quantization.cache import QuantizedKVParameterCache from compressed_tensors.quantization.observers.helpers import calculate_range from compressed_tensors.quantization.quant_args import ( QuantizationArgs, @@ -62,6 +63,7 @@ def quantize( :param g_idx: optional mapping from column index to group index :return: fake quantized tensor """ + return _process_quantization( x=x, scale=scale, @@ -165,8 +167,8 @@ def _process_quantization( x: torch.Tensor, scale: torch.Tensor, zero_point: torch.Tensor, - g_idx: Optional[torch.Tensor], args: QuantizationArgs, + g_idx: Optional[torch.Tensor] = None, dtype: Optional[torch.dtype] = None, do_quantize: bool = True, do_dequantize: bool = True, @@ -266,6 +268,7 @@ def wrapped_forward(self, *args, **kwargs): return forward_func_orig.__get__(module, module.__class__)(*args, **kwargs) input_ = args[0] + compressed = module.quantization_status == QuantizationStatus.COMPRESSED if scheme.input_activations is not None: @@ -285,9 +288,11 @@ def wrapped_forward(self, *args, **kwargs): output = forward_func_orig.__get__(module, module.__class__)( input_, *args[1:], **kwargs ) - if scheme.output_activations is not None: + # calibrate and (fake) quantize output activations when applicable + # kv_cache scales updated on model self_attn forward call in + # wrap_module_forward_quantized_attn output = maybe_calibrate_or_quantize( module, output, "output", scheme.output_activations ) @@ -304,6 +309,50 @@ def wrapped_forward(self, *args, **kwargs): setattr(module, "forward", bound_wrapped_forward) +def wrap_module_forward_quantized_attn(module: Module, scheme: QuantizationScheme): + # expects a module already initialized and injected with the parameters in + # initialize_module_for_quantization + if hasattr(module.forward, "__func__"): + forward_func_orig = module.forward.__func__ + else: + forward_func_orig = module.forward.func + + @wraps(forward_func_orig) # ensures docstring, names, etc are propagated + def wrapped_forward(self, *args, **kwargs): + + # kv cache stored under weights + if module.quantization_status == QuantizationStatus.CALIBRATION: + quantization_args: QuantizationArgs = scheme.output_activations + past_key_value: QuantizedKVParameterCache = quantization_args.get_kv_cache() + kwargs["past_key_value"] = past_key_value + + # QuantizedKVParameterCache used for obtaining k_scale, v_scale only, + # does not store quantized_key_states and quantized_value_state + kwargs["use_cache"] = False + + attn_forward: Callable = forward_func_orig.__get__(module, module.__class__) + + past_key_value.reset_states() + + rtn = attn_forward(*args, **kwargs) + + update_parameter_data( + module, past_key_value.k_scales[module.layer_idx], "k_scale" + ) + update_parameter_data( + module, past_key_value.v_scales[module.layer_idx], "v_scale" + ) + + return rtn + + return forward_func_orig.__get__(module, module.__class__)(*args, **kwargs) + + # bind wrapped forward to module class so reference to `self` is correct + bound_wrapped_forward = wrapped_forward.__get__(module, module.__class__) + # set forward to wrapped forward + setattr(module, "forward", bound_wrapped_forward) + + def maybe_calibrate_or_quantize( module: Module, value: torch.Tensor, base_name: str, args: "QuantizationArgs" ) -> torch.Tensor: diff --git a/src/compressed_tensors/quantization/lifecycle/frozen.py b/src/compressed_tensors/quantization/lifecycle/frozen.py index 652f1c3a..66356cb7 100644 --- a/src/compressed_tensors/quantization/lifecycle/frozen.py +++ b/src/compressed_tensors/quantization/lifecycle/frozen.py @@ -14,6 +14,7 @@ from compressed_tensors.quantization.quant_config import QuantizationStatus +from compressed_tensors.quantization.utils import is_kv_cache_quant_scheme from torch.nn import Module @@ -44,7 +45,11 @@ def freeze_module_quantization(module: Module): delattr(module, "input_observer") if scheme.weights and not scheme.weights.dynamic: delattr(module, "weight_observer") - if scheme.output_activations and not scheme.output_activations.dynamic: + if ( + scheme.output_activations + and not is_kv_cache_quant_scheme(scheme) + and not scheme.output_activations.dynamic + ): delattr(module, "output_observer") module.quantization_status = QuantizationStatus.FROZEN diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index 3c9a1211..78f6fd4b 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -17,8 +17,10 @@ from typing import Optional import torch +from compressed_tensors.quantization.cache import KVCacheScaleType from compressed_tensors.quantization.lifecycle.forward import ( wrap_module_forward_quantized, + wrap_module_forward_quantized_attn, ) from compressed_tensors.quantization.quant_args import ( ActivationOrdering, @@ -27,6 +29,7 @@ ) from compressed_tensors.quantization.quant_config import QuantizationStatus from compressed_tensors.quantization.quant_scheme import QuantizationScheme +from compressed_tensors.quantization.utils import is_kv_cache_quant_scheme from compressed_tensors.utils import get_execution_device, is_module_offloaded from torch.nn import Module, Parameter @@ -62,72 +65,78 @@ def initialize_module_for_quantization( # no scheme passed and layer not targeted for quantization - skip return - if scheme.input_activations is not None: - _initialize_scale_zero_point_observer( - module, "input", scheme.input_activations, force_zero_point=force_zero_point - ) - if scheme.weights is not None: - if hasattr(module, "weight"): - weight_shape = module.weight.shape - _initialize_scale_zero_point_observer( - module, - "weight", - scheme.weights, - weight_shape=weight_shape, - force_zero_point=force_zero_point, - ) - else: - _LOGGER.warning( - f"module type {type(module)} targeted for weight quantization but " - "has no attribute weight, skipping weight quantization " - f"for {type(module)}" - ) - if scheme.output_activations is not None: - _initialize_scale_zero_point_observer( - module, - "output", - scheme.output_activations, - force_zero_point=force_zero_point, - ) + if is_attention_module(module): + # wrap forward call of module to perform + # quantized actions based on calltime status + wrap_module_forward_quantized_attn(module, scheme) + _initialize_attn_scales(module) - module.quantization_scheme = scheme - module.quantization_status = QuantizationStatus.INITIALIZED + else: - offloaded = False - if is_module_offloaded(module): - try: - from accelerate.hooks import add_hook_to_module, remove_hook_from_module - from accelerate.utils import PrefixedDataset - except ModuleNotFoundError: - raise ModuleNotFoundError( - "Offloaded model detected. To use CPU offloading with " - "compressed-tensors the `accelerate` package must be installed, " - "run `pip install compressed-tensors[accelerate]`" + if scheme.input_activations is not None: + _initialize_scale_zero_point_observer( + module, "input", scheme.input_activations ) - - offloaded = True - hook = module._hf_hook - prefix_dict = module._hf_hook.weights_map - new_prefix = {} - - # recreate the prefix dict (since it is immutable) - # and add quantization parameters - for key, data in module.named_parameters(): - if key not in prefix_dict: - new_prefix[f"{prefix_dict.prefix}{key}"] = data + if scheme.weights is not None: + if hasattr(module, "weight"): + weight_shape = None + if isinstance(module, torch.nn.Linear): + weight_shape = module.weight.shape + _initialize_scale_zero_point_observer( + module, "weight", scheme.weights, weight_shape=weight_shape + ) else: - new_prefix[f"{prefix_dict.prefix}{key}"] = prefix_dict[key] - new_prefix_dict = PrefixedDataset(new_prefix, prefix_dict.prefix) - remove_hook_from_module(module) - - # wrap forward call of module to perform quantized actions based on calltime status - wrap_module_forward_quantized(module, scheme) - - if offloaded: - # we need to re-add the hook for offloading now that we've wrapped forward - add_hook_to_module(module, hook) - if prefix_dict is not None: - module._hf_hook.weights_map = new_prefix_dict + _LOGGER.warning( + f"module type {type(module)} targeted for weight quantization but " + "has no attribute weight, skipping weight quantization " + f"for {type(module)}" + ) + + if scheme.output_activations is not None: + if not is_kv_cache_quant_scheme(scheme): + _initialize_scale_zero_point_observer( + module, "output", scheme.output_activations + ) + + module.quantization_scheme = scheme + module.quantization_status = QuantizationStatus.INITIALIZED + + offloaded = False + if is_module_offloaded(module): + try: + from accelerate.hooks import add_hook_to_module, remove_hook_from_module + from accelerate.utils import PrefixedDataset + except ModuleNotFoundError: + raise ModuleNotFoundError( + "Offloaded model detected. To use CPU offloading with " + "compressed-tensors the `accelerate` package must be installed, " + "run `pip install compressed-tensors[accelerate]`" + ) + + offloaded = True + hook = module._hf_hook + prefix_dict = module._hf_hook.weights_map + new_prefix = {} + + # recreate the prefix dict (since it is immutable) + # and add quantization parameters + for key, data in module.named_parameters(): + if key not in prefix_dict: + new_prefix[f"{prefix_dict.prefix}{key}"] = data + else: + new_prefix[f"{prefix_dict.prefix}{key}"] = prefix_dict[key] + new_prefix_dict = PrefixedDataset(new_prefix, prefix_dict.prefix) + remove_hook_from_module(module) + + # wrap forward call of module to perform + # quantized actions based on calltime status + wrap_module_forward_quantized(module, scheme) + + if offloaded: + # we need to re-add the hook for offloading now that we've wrapped forward + add_hook_to_module(module, hook) + if prefix_dict is not None: + module._hf_hook.weights_map = new_prefix_dict def _initialize_scale_zero_point_observer( @@ -189,3 +198,34 @@ def _initialize_scale_zero_point_observer( requires_grad=False, ) module.register_parameter(f"{base_name}_g_idx", init_g_idx) + + +def is_attention_module(module: Module): + return "attention" in module.__class__.__name__.lower() and ( + hasattr(module, "k_proj") + or hasattr(module, "v_proj") + or hasattr(module, "qkv_proj") + ) + + +def _initialize_attn_scales(module: Module) -> None: + """Initlaize k_scale, v_scale for self_attn""" + + expected_shape = 1 # per tensor + + param = next(module.parameters()) + scale_dtype = param.dtype + device = param.device + + init_scale = Parameter( + torch.empty(expected_shape, dtype=scale_dtype, device=device), + requires_grad=False, + ) + + module.register_parameter(KVCacheScaleType.KEY.value, init_scale) + + init_scale = Parameter( + torch.empty(expected_shape, dtype=scale_dtype, device=device), + requires_grad=False, + ) + module.register_parameter(KVCacheScaleType.VALUE.value, init_scale) diff --git a/src/compressed_tensors/quantization/quant_args.py b/src/compressed_tensors/quantization/quant_args.py index b502ebd2..54805c58 100644 --- a/src/compressed_tensors/quantization/quant_args.py +++ b/src/compressed_tensors/quantization/quant_args.py @@ -122,6 +122,12 @@ def get_observer(self): return Observer.load_from_registry(self.observer, quantization_args=self) + def get_kv_cache(self): + """Get the singleton KV Cache""" + from compressed_tensors.quantization.cache import QuantizedKVParameterCache + + return QuantizedKVParameterCache(self) + @field_validator("type", mode="before") def validate_type(cls, value) -> QuantizationType: if isinstance(value, str): diff --git a/src/compressed_tensors/quantization/quant_config.py b/src/compressed_tensors/quantization/quant_config.py index 01b43910..30785554 100644 --- a/src/compressed_tensors/quantization/quant_config.py +++ b/src/compressed_tensors/quantization/quant_config.py @@ -24,7 +24,7 @@ from compressed_tensors.quantization.utils import ( calculate_compression_ratio, is_module_quantized, - iter_named_leaf_modules, + iter_named_quantizable_modules, module_type, parse_out_kv_cache_args, ) @@ -177,7 +177,9 @@ def from_pretrained( quantization_status = None ignore = {} quantization_type_names = set() - for name, submodule in iter_named_leaf_modules(model): + for name, submodule in iter_named_quantizable_modules( + model, include_children=True, include_attn=True + ): layer_type = module_type(submodule) if not is_module_quantized(submodule): if layer_type not in ignore: @@ -241,6 +243,9 @@ def from_pretrained( ) def requires_calibration_data(self): + if self.kv_cache_scheme is not None: + return True + for _, scheme in self.config_groups.items(): if scheme.input_activations is not None: if not scheme.input_activations.dynamic: diff --git a/src/compressed_tensors/quantization/utils/helpers.py b/src/compressed_tensors/quantization/utils/helpers.py index 3db7a711..8ebde09b 100644 --- a/src/compressed_tensors/quantization/utils/helpers.py +++ b/src/compressed_tensors/quantization/utils/helpers.py @@ -13,8 +13,7 @@ # limitations under the License. import logging -import re -from typing import List, Optional, Tuple +from typing import Generator, List, Optional, Tuple import torch from compressed_tensors.quantization.observers.base import Observer @@ -28,7 +27,6 @@ "infer_quantization_status", "is_module_quantized", "is_model_quantized", - "iter_named_leaf_modules", "module_type", "calculate_compression_ratio", "get_torch_bit_depth", @@ -36,9 +34,14 @@ "parse_out_kv_cache_args", "KV_CACHE_TARGETS", "is_kv_cache_quant_scheme", + "iter_named_leaf_modules", + "iter_named_quantizable_modules", ] -KV_CACHE_TARGETS = ["re:.*k_proj", "re:.*v_proj"] +# target the self_attn layer +# QuantizedKVParameterCache is responsible for obtaining the k_scale and v_scale +KV_CACHE_TARGETS = ["re:.*self_attn$"] + _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -106,11 +109,10 @@ def module_type(module: Module) -> str: return type(module).__name__ -def iter_named_leaf_modules(model: Module) -> Tuple[str, Module]: +def iter_named_leaf_modules(model: Module) -> Generator[Tuple[str, Module], None, None]: """ Yields modules that do not have any submodules except observers. The observers themselves are not yielded - :param model: model to get leaf modules of :returns: generator tuple of (name, leaf_submodule) """ @@ -128,6 +130,37 @@ def iter_named_leaf_modules(model: Module) -> Tuple[str, Module]: yield name, submodule +def iter_named_quantizable_modules( + model: Module, include_children: bool = True, include_attn: bool = False +) -> Generator[Tuple[str, Module], None, None]: + """ + Yield name and submodule of + - leaf modules, set by include_children + - attention modyles, set by include_attn + + :param model: model to get leaf modules of + :param include_children: flag to get the leaf modules + :param inlcude_attn: flag to get the attention modules + :returns: generator tuple of (name, submodule) + """ + for name, submodule in model.named_modules(): + if include_children: + children = list(submodule.children()) + if len(children) == 0 and not isinstance(submodule, Observer): + yield name, submodule + else: + has_non_observer_children = False + for child in children: + if not isinstance(child, Observer): + has_non_observer_children = True + + if not has_non_observer_children: + yield name, submodule + if include_attn: + if name.endswith("self_attn"): + yield name, submodule + + def get_torch_bit_depth(value: torch.Tensor) -> int: """ Determine the number of bits used to represent the dtype of a tensor @@ -204,19 +237,11 @@ def is_kv_cache_quant_scheme(scheme: QuantizationScheme) -> bool: :param scheme: The QuantizationScheme to investigate :return: boolean flag """ - if len(scheme.targets) == 1: - # match on the KV_CACHE_TARGETS regex pattern - # if there is only one target - is_match_targets = any( - [re.match(pattern[3:], scheme.targets[0]) for pattern in KV_CACHE_TARGETS] - ) - else: - # match on the exact KV_CACHE_TARGETS - # if there are multiple targets - is_match_targets = set(KV_CACHE_TARGETS) == set(scheme.targets) + for target in scheme.targets: + if target in KV_CACHE_TARGETS: + return True - is_match_output_activations = scheme.output_activations is not None - return is_match_targets and is_match_output_activations + return False def parse_out_kv_cache_args( diff --git a/tests/test_quantization/lifecycle/test_forward.py b/tests/test_quantization/lifecycle/test_forward.py index a434cd6e..b9ee67ff 100644 --- a/tests/test_quantization/lifecycle/test_forward.py +++ b/tests/test_quantization/lifecycle/test_forward.py @@ -20,6 +20,7 @@ maybe_calibrate_or_quantize, quantize, wrap_module_forward_quantized, + wrap_module_forward_quantized_attn, ) from compressed_tensors.quantization.lifecycle.initialize import ( initialize_module_for_quantization, @@ -205,3 +206,21 @@ def test_dequantize(num_bits, type, strategy, group_size, scale, zero_point, g_i dtype=None, g_idx=g_idx, ) + + +def test_wrap_module_forward_quantized_attn(create_quantization_scheme): + num_bits = 8 + quantization_scheme = create_quantization_scheme( + targets=["self_attn"], + weights=QuantizationArgs(num_bits=num_bits, symmetric=True), + input_activations=QuantizationArgs(num_bits=num_bits, symmetric=False), + ) + + mock_attn_layer = Linear(4, 4) + + attn_forward = mock_attn_layer.forward.__func__ + + # check that the forward call is overwritten + wrap_module_forward_quantized_attn(mock_attn_layer, quantization_scheme) + + assert not attn_forward == mock_attn_layer.forward.__func__ diff --git a/tests/test_quantization/lifecycle/test_kv_cache.py b/tests/test_quantization/lifecycle/test_kv_cache.py index c7d92741..b72f5265 100644 --- a/tests/test_quantization/lifecycle/test_kv_cache.py +++ b/tests/test_quantization/lifecycle/test_kv_cache.py @@ -76,29 +76,3 @@ def test_kv_cache_quantization(config): assert list(config.kv_cache_scheme.model_dump().values()) == list( reloaded_config.kv_cache_scheme.model_dump().values() ) - - -@pytest.mark.parametrize("config", [config]) -def test_kv_cache_quantization_clashing_configs(config): - config["config_groups"]["group_1"]["output_activations"] = { - "num_bits": 8, - "type": "int", - "symmetric": True, - "strategy": "tensor", - } - - model = AutoModelForCausalLM.from_pretrained( - "HuggingFaceM4/tiny-random-LlamaForCausalLM", - torch_dtype="auto", - ) - model.eval() - - config = QuantizationConfig(**config) - config.quantization_status = QuantizationStatus.CALIBRATION - with pytest.raises(ValueError): - # raise ValueError, because there is a clash between the - # kv cache quantization arguments and the ordinary - # quantization arguments - # (they are both adding output activations to the - # re:.*k_proj and re:.*q_proj) - apply_quantization_config(model, config) diff --git a/tests/test_quantization/test_cache.py b/tests/test_quantization/test_cache.py new file mode 100644 index 00000000..941af70f --- /dev/null +++ b/tests/test_quantization/test_cache.py @@ -0,0 +1,116 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from compressed_tensors.quantization.cache import QuantizedKVParameterCache +from compressed_tensors.quantization.quant_args import QuantizationArgs + + +def test_is_quantized_cache_singleton(): + """ + Check if quantized_cache is a singleton, used for + passing in QuantizedKVParameterCache to the forward call of + the model's self_attn + """ + + args = QuantizationArgs() + cache: QuantizedKVParameterCache = args.get_kv_cache() + observer = args.get_observer() + + tensor = torch.tensor([1, 2, 3]) + cache.k_scales.append(tensor) + cache.k_observers.append(observer) + + same_cache = args.get_kv_cache() + + assert len(cache.k_scales) == len(same_cache.k_scales) + assert torch.equal(cache.k_scales[0], same_cache.k_scales[0]) + + assert cache.k_observers == same_cache.k_observers + assert hex(id(cache.k_observers[0])) == hex(id(same_cache.k_observers[0])) + + cache.reset() + + +def test_update(): + + nbits = 8 + args = QuantizationArgs(nbits=nbits, symmetric=True) + cache: QuantizedKVParameterCache = args.get_kv_cache() + + max_key_states_val = 1.0 + max_value_states_val = 2.0 + key_states = torch.cat( + (max_key_states_val * torch.ones(1, 2, 2), torch.ones(1, 2, 2)), dim=0 + ) + value_states = torch.cat( + (max_value_states_val * torch.ones(1, 2, 2), torch.ones(1, 2, 2)), dim=0 + ) + layer_idx = 0 + + cache.update(key_states, value_states, layer_idx) + denom = (2 ** (nbits) - 1) / 2 + expected_k_scale = torch.tensor([max_key_states_val / denom]) + expected_v_scale = torch.tensor([max_value_states_val / denom]) + + assert cache.k_scales[0] == expected_k_scale + assert cache.v_scales[0] == expected_v_scale + + # new attn layer + layer_idx = 1 + cache.update(key_states, value_states, layer_idx) + + assert len(cache.k_scales) == 2 + assert len(cache.v_scales) == 2 + + assert len(cache.k_observers) == 2 + assert len(cache.v_observers) == 2 + + cache.reset() + + +def test_cache_reset(): + nbits = 8 + args = QuantizationArgs(nbits=nbits, symmetric=True) + cache: QuantizedKVParameterCache = args.get_kv_cache() + + max_key_states_val = 1.0 + max_value_states_val = 2.0 + key_states = torch.cat( + (max_key_states_val * torch.ones(1, 2, 2), torch.ones(1, 2, 2)), dim=0 + ) + value_states = torch.cat( + (max_value_states_val * torch.ones(1, 2, 2), torch.ones(1, 2, 2)), dim=0 + ) + layer_idx = 0 + + cache.update(key_states, value_states, layer_idx) + assert len(cache.k_scales) == 1 + assert len(cache.v_scales) == 1 + + assert len(cache.k_observers) == 1 + assert len(cache.v_observers) == 1 + + cache.reset() + + # new instance, different memory addr + different_cache: QuantizedKVParameterCache = args.get_kv_cache() + + assert len(different_cache.k_scales) == 0 + assert len(different_cache.v_scales) == 0 + + assert len(different_cache.k_observers) == 0 + assert len(different_cache.v_observers) == 0 + + assert hex(id(cache)) != hex(id(different_cache))