Skip to content

Commit

Permalink
Merge branch 'main' of github.com:neuralmagic/compressed-tensors
Browse files Browse the repository at this point in the history
  • Loading branch information
horheynm committed Oct 30, 2024
2 parents 065fd90 + 13b5c0b commit 8b89c15
Show file tree
Hide file tree
Showing 25 changed files with 350 additions and 206 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/trigger-all.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,6 @@ jobs:
test_configs: '[{"python":"3.11.4","label":"ubuntu-22.04","timeout":"40"},
{"python":"3.10.12","label":"ubuntu-20.04","timeout":"40"},
{"python":"3.9.17","label":"k8s-a100-solo","timeout":"40"},
{"python":"3.8.17","label":"k8s-a100-duo","timeout":"40"}]'
{"python":"3.12.6","label":"k8s-a100-duo","timeout":"40"}]'

secrets: inherit
Original file line number Diff line number Diff line change
Expand Up @@ -242,10 +242,6 @@ def __init__(
self.sparsity_compressor = None
self.quantization_compressor = None

if sparsity_config and sparsity_config.format == CompressionFormat.dense.value:
# ignore dense sparsity config
self.sparsity_config = None

if sparsity_config is not None:
self.sparsity_compressor = BaseCompressor.load_from_registry(
sparsity_config.format, config=sparsity_config
Expand Down
62 changes: 60 additions & 2 deletions src/compressed_tensors/config/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from enum import Enum
from enum import Enum, unique
from typing import List, Optional

from compressed_tensors.registry import RegistryMixin
from pydantic import BaseModel


__all__ = ["SparsityCompressionConfig", "CompressionFormat"]
__all__ = ["SparsityCompressionConfig", "CompressionFormat", "SparsityStructure"]


@unique
class CompressionFormat(Enum):
dense = "dense"
sparse_bitmask = "sparse-bitmask"
Expand All @@ -32,6 +33,63 @@ class CompressionFormat(Enum):
marlin_24 = "marlin-24"


@unique
class SparsityStructure(Enum):
"""
An enumeration to represent different sparsity structures.
Attributes
----------
TWO_FOUR : str
Represents a 2:4 sparsity structure.
ZERO_ZERO : str
Represents a 0:0 sparsity structure.
UNSTRUCTURED : str
Represents an unstructured sparsity structure.
Examples
--------
>>> SparsityStructure('2:4')
<SparsityStructure.TWO_FOUR: '2:4'>
>>> SparsityStructure('unstructured')
<SparsityStructure.UNSTRUCTURED: 'unstructured'>
>>> SparsityStructure('2:4') == SparsityStructure.TWO_FOUR
True
>>> SparsityStructure('UNSTRUCTURED') == SparsityStructure.UNSTRUCTURED
True
>>> SparsityStructure(None) == SparsityStructure.UNSTRUCTURED
True
>>> SparsityStructure('invalid')
Traceback (most recent call last):
...
ValueError: invalid is not a valid SparsityStructure
"""

TWO_FOUR = "2:4"
UNSTRUCTURED = "unstructured"
ZERO_ZERO = "0:0"

def __new__(cls, value):
obj = object.__new__(cls)
obj._value_ = value.lower() if value is not None else value
return obj

@classmethod
def _missing_(cls, value):
# Handle None and case-insensitive values
if value is None:
return cls.UNSTRUCTURED
for member in cls:
if member.value == value.lower():
return member
raise ValueError(f"{value} is not a valid {cls.__name__}")


class SparsityCompressionConfig(RegistryMixin, BaseModel):
"""
Base data class for storing sparsity compression parameters
Expand Down
1 change: 0 additions & 1 deletion src/compressed_tensors/quantization/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ class KVCacheScaleType(Enum):


class QuantizedKVParameterCache(HFDyanmicCache):

"""
Quantized KV cache used in the forward call based on HF's dynamic cache.
Quantization strategy (tensor, group, channel) set from Quantization arg's strategy
Expand Down
8 changes: 6 additions & 2 deletions src/compressed_tensors/quantization/lifecycle/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,8 @@ def load_pretrained_quantization(model: Module, model_name_or_path: str):


def apply_quantization_config(
model: Module, config: QuantizationConfig, run_compressed: bool = False
) -> Dict:
model: Module, config: Union[QuantizationConfig, None], run_compressed: bool = False
) -> OrderedDict:
"""
Initializes the model for quantization in-place based on the given config
Expand All @@ -117,6 +117,10 @@ def apply_quantization_config(
:param run_compressed: Whether the model will be run in compressed mode or
decompressed fully on load
"""
# Workaround for when HF Quantizer passes None, see PR #180
if config is None:
return OrderedDict()

# remove reference to the original `config`
# argument. This function can mutate it, and we'd
# like to keep the original `config` as it is.
Expand Down
14 changes: 12 additions & 2 deletions src/compressed_tensors/quantization/lifecycle/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,23 @@ def set_module_for_calibration(module: Module, quantize_weights_upfront: bool =

if quantize_weights_upfront and module.quantization_scheme.weights is not None:
# set weight scale and zero_point up front, calibration data doesn't affect it
observer = module.weight_observer
g_idx = getattr(module, "weight_g_idx", None)
if not hasattr(module, "weight_observer"):
from compressed_tensors.quantization.lifecycle.initialize import (
initialize_observers,
)

initialize_observers(
module=module,
base_name="weight",
quantization_args=module.quantization_scheme.weights,
)

offloaded = is_module_offloaded(module)
if offloaded:
module._hf_hook.pre_forward(module)

observer = module.weight_observer
g_idx = getattr(module, "weight_g_idx", None)
scale, zero_point = observer(module.weight, g_idx=g_idx)
update_parameter_data(module, scale, "weight_scale")
update_parameter_data(module, zero_point, "weight_zero_point")
Expand Down
93 changes: 63 additions & 30 deletions src/compressed_tensors/quantization/lifecycle/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@

import torch
from compressed_tensors.quantization.cache import QuantizedKVParameterCache
from compressed_tensors.quantization.observers.helpers import calculate_range
from compressed_tensors.quantization.observers.helpers import (
calculate_range,
compute_dynamic_scales_and_zp,
)
from compressed_tensors.quantization.quant_args import (
QuantizationArgs,
QuantizationStrategy,
Expand All @@ -35,7 +38,8 @@
"dequantize",
"fake_quantize",
"wrap_module_forward_quantized",
"maybe_calibrate_or_quantize",
"forward_quantize",
"calibrate_activations",
]


Expand Down Expand Up @@ -273,14 +277,24 @@ def wrapped_forward(self, *args, **kwargs):

if scheme.input_activations is not None:
# calibrate and (fake) quantize input activations when applicable
input_ = maybe_calibrate_or_quantize(
module, input_, "input", scheme.input_activations
)
# NOTE: will be moved out of compressed-tensors
if (
module.quantization_status == QuantizationStatus.CALIBRATION
and not scheme.input_activations.dynamic
):
calibrate_activations(
module=module,
value=input_,
base_name="input",
quantization_args=scheme.input_activations,
)

input_ = forward_quantize(module, input_, "input", scheme.input_activations)

if scheme.weights is not None and not compressed:
# calibrate and (fake) quantize weights when applicable
unquantized_weight = self.weight.data.clone()
self.weight.data = maybe_calibrate_or_quantize(
self.weight.data = forward_quantize(
module, self.weight, "weight", scheme.weights
)

Expand All @@ -293,7 +307,19 @@ def wrapped_forward(self, *args, **kwargs):
# calibrate and (fake) quantize output activations when applicable
# kv_cache scales updated on model self_attn forward call in
# wrap_module_forward_quantized_attn
output = maybe_calibrate_or_quantize(

if (
module.quantization_status == QuantizationStatus.CALIBRATION
and not scheme.output_activations.dynamic
):
calibrate_activations(
module=module,
value=output,
base_name="output",
quantization_args=scheme.ouput_activations,
)

output = forward_quantize(
module, output, "output", scheme.output_activations
)

Expand Down Expand Up @@ -353,12 +379,36 @@ def wrapped_forward(self, *args, **kwargs):
setattr(module, "forward", bound_wrapped_forward)


def maybe_calibrate_or_quantize(
def calibrate_activations(
module: Module,
value: torch.Tensor,
base_name: str,
quantization_args: QuantizationArgs,
):
# If empty tensor, can't update zp/scale
# Case for MoEs
if value.numel() == 0:
return
# calibration mode - get new quant params from observer
if not hasattr(module, f"{base_name}_observer"):
from compressed_tensors.quantization.lifecycle import initialize_observers

initialize_observers(
module=module, base_name=base_name, quantization_args=quantization_args
)

observer = getattr(module, f"{base_name}_observer")

updated_scale, updated_zero_point = observer(value)

# update scale and zero point
update_parameter_data(module, updated_scale, f"{base_name}_scale")
update_parameter_data(module, updated_zero_point, f"{base_name}_zero_point")


def forward_quantize(
module: Module, value: torch.Tensor, base_name: str, args: "QuantizationArgs"
) -> torch.Tensor:
# don't run quantization if we haven't entered calibration mode
if module.quantization_status == QuantizationStatus.INITIALIZED:
return value

# in compressed mode, the weight is already compressed and quantized so we don't
# need to run fake quantization
Expand All @@ -376,30 +426,13 @@ def maybe_calibrate_or_quantize(
g_idx = getattr(module, "weight_g_idx", None)

if args.dynamic:
# dynamic quantization - get scale and zero point directly from observer
observer = getattr(module, f"{base_name}_observer")
scale, zero_point = observer(value, g_idx=g_idx)
# dynamic quantization - no need to invoke observer
scale, zero_point = compute_dynamic_scales_and_zp(value=value, args=args)
else:
# static quantization - get previous scale and zero point from layer
scale = getattr(module, f"{base_name}_scale")
zero_point = getattr(module, f"{base_name}_zero_point", None)

if (
module.quantization_status == QuantizationStatus.CALIBRATION
and base_name != "weight"
):
# calibration mode - get new quant params from observer
observer = getattr(module, f"{base_name}_observer")

updated_scale, updated_zero_point = observer(value, g_idx=g_idx)

# update scale and zero point
update_parameter_data(module, updated_scale, f"{base_name}_scale")
update_parameter_data(module, updated_zero_point, f"{base_name}_zero_point")

scale = updated_scale
zero_point = updated_zero_point

return fake_quantize(
x=value,
scale=scale,
Expand Down
11 changes: 3 additions & 8 deletions src/compressed_tensors/quantization/lifecycle/frozen.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@


from compressed_tensors.quantization.quant_config import QuantizationStatus
from compressed_tensors.quantization.utils import is_kv_cache_quant_scheme
from torch.nn import Module


Expand All @@ -41,15 +40,11 @@ def freeze_module_quantization(module: Module):
return

# delete observers from module if not dynamic
if scheme.input_activations and not scheme.input_activations.dynamic:
if hasattr(module, "input_observer") and not scheme.input_activations.dynamic:
delattr(module, "input_observer")
if scheme.weights and not scheme.weights.dynamic:
if hasattr(module, "weight_observer") and not scheme.weights.dynamic:
delattr(module, "weight_observer")
if (
scheme.output_activations
and not is_kv_cache_quant_scheme(scheme)
and not scheme.output_activations.dynamic
):
if hasattr(module, "output_observer") and not scheme.output_activations.dynamic:
delattr(module, "output_observer")

module.quantization_status = QuantizationStatus.FROZEN
47 changes: 0 additions & 47 deletions src/compressed_tensors/quantization/lifecycle/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,62 +16,15 @@
Miscelaneous helpers for the quantization lifecycle
"""

from typing import Optional

import torch
from torch.nn import Module


__all__ = [
"update_layer_weight_quant_params",
"enable_quantization",
"disable_quantization",
]


def update_layer_weight_quant_params(
layer: Module,
weight: Optional[torch.Tensor] = None,
g_idx: Optional[torch.Tensor] = None,
reset_obs: bool = False,
):
"""
Update quantization parameters on layer
:param layer: input layer
:param weight: weight to update quant params with, defaults to layer weight
:param g_idx: optional mapping from column index to group index
:param reset_obs: reset the observer before calculating quant params,
defaults to False
"""
attached_weight = getattr(layer, "weight", None)

if weight is None:
weight = attached_weight
scale = getattr(layer, "weight_scale", None)
zero_point = getattr(layer, "weight_zero_point", None)
if g_idx is None:
g_idx = getattr(layer, "weight_g_idx", None)
observer = getattr(layer, "weight_observer", None)

if weight is None or observer is None or scale is None or zero_point is None:
# scale, zp, or observer not calibratable or weight not available
return

if reset_obs:
observer.reset()

if attached_weight is not None:
weight = weight.to(attached_weight.dtype)

updated_scale, updated_zero_point = observer(weight)

# update scale and zero point
device = next(layer.parameters()).device
scale.data = updated_scale.to(device)
zero_point.data = updated_zero_point.to(device)


def enable_quantization(module: Module):
module.quantization_enabled = True

Expand Down
Loading

0 comments on commit 8b89c15

Please sign in to comment.