Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compressed lifecycle implementation (INT8 only) #33

Merged
merged 17 commits into from
May 7, 2024
Merged
1 change: 1 addition & 0 deletions src/compressed_tensors/quantization/lifecycle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,5 @@
from .forward import *
from .frozen import *
from .initialize import *
from .compressed import *
from .apply import *
24 changes: 21 additions & 3 deletions src/compressed_tensors/quantization/lifecycle/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
from compressed_tensors.quantization.lifecycle.calibration import (
set_module_for_calibration,
)
from compressed_tensors.quantization.lifecycle.compressed import (
compress_quantized_weights,
)
from compressed_tensors.quantization.lifecycle.frozen import freeze_module_quantization
from compressed_tensors.quantization.lifecycle.initialize import (
initialize_module_for_quantization,
Expand Down Expand Up @@ -117,13 +120,20 @@ def apply_quantization_status(model: Module, status: QuantizationStatus):
:param model: model to apply quantization to
:param status: status to update the module to
"""
if status >= QuantizationStatus.INITIALIZED:
current_status = _infer_status(model)

if status >= QuantizationStatus.INITIALIZED > current_status:
model.apply(initialize_module_for_quantization)
if status >= QuantizationStatus.CALIBRATION:

if current_status < status >= QuantizationStatus.CALIBRATION > current_status:
Satrat marked this conversation as resolved.
Show resolved Hide resolved
model.apply(set_module_for_calibration)
if status >= QuantizationStatus.FROZEN:

if current_status < status >= QuantizationStatus.FROZEN > current_status:
model.apply(freeze_module_quantization)

if current_status < status >= QuantizationStatus.COMPRESSED > current_status:
model.apply(compress_quantized_weights)


def _find_first_name_or_class_match(
name: str,
Expand Down Expand Up @@ -151,6 +161,14 @@ def _find_first_match(value: str, targets: Iterable[str]) -> Optional[str]:
return None


def _infer_status(model: Module) -> Optional[QuantizationStatus]:
Satrat marked this conversation as resolved.
Show resolved Hide resolved
for module in model.modules():
status = getattr(module, "quantization_status", None)
if status is not None:
return status
return None


def _load_quant_args_from_state_dict(
base_name: str, module_name: str, module: Module, state_dict: Dict
):
Expand Down
65 changes: 65 additions & 0 deletions src/compressed_tensors/quantization/lifecycle/compressed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import logging

import torch
from compressed_tensors.quantization.lifecycle.forward import quantize
from compressed_tensors.quantization.quant_config import QuantizationStatus
from torch.nn import Module


__all__ = [
"compress_quantized_weights",
]


_LOGGER = logging.getLogger(__name__)


def compress_quantized_weights(module: Module):
"""
Quantizes the module weight representation to use fewer bits in memory

apply to full model with `model.apply(compress_quantized)`
bfineran marked this conversation as resolved.
Show resolved Hide resolved
bfineran marked this conversation as resolved.
Show resolved Hide resolved

:param module: module to compress to quantized representation
"""
scheme = getattr(module, "quantization_scheme", None)
if not scheme or not scheme.weights:
# no quantization scheme or weights not quantized, nothing to do
return

weight = getattr(module, "weight", None)
scale = getattr(module, "weight_scale", None)
zero_point = getattr(module, "weight_zero_point", None)

if weight is None or scale is None or zero_point is None:
# no weight, scale, or ZP, nothing to do

# TODO: Should we mark as compressed anyway here to maintain consistent
# status throughout the model?
return
Satrat marked this conversation as resolved.
Show resolved Hide resolved

module.weight.requires_grad = False # cannot use auto grad after compression
module.weight.data = quantize(
x=weight,
scale=scale,
zero_point=zero_point,
args=scheme.weights,
dtype=torch.int8,
)

module.quantization_status = QuantizationStatus.COMPRESSED
24 changes: 15 additions & 9 deletions src/compressed_tensors/quantization/lifecycle/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from functools import wraps
from typing import Optional

import torch
from compressed_tensors.quantization.quant_args import QuantizationArgs
Expand All @@ -29,17 +30,26 @@ def quantize(
x: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor,
q_min: torch.Tensor,
q_max: torch.Tensor,
args: QuantizationArgs,
dtype: Optional[torch.dtype] = None,
) -> torch.Tensor:
return torch.clamp(
bit_range = 2**args.num_bits
q_min = torch.tensor(bit_range / 2 - 1, device=x.device)
q_max = torch.tensor(-bit_range / 2, device=x.device)

quantized_value = torch.clamp(
torch.round(
x / scale + zero_point,
),
q_min,
q_max,
)

if dtype is not None:
quantized_value = quantized_value.to(dtype)

return quantized_value


@torch.no_grad()
def dequantize(
Expand All @@ -57,12 +67,8 @@ def fake_quantize(
zero_point: torch.Tensor,
args: QuantizationArgs,
) -> torch.Tensor:
bit_range = 2**args.num_bits
max_q = torch.tensor(bit_range / 2 - 1, device=x.device)
min_q = torch.tensor(-bit_range / 2, device=x.device)
Q = torch.zeros_like(x)
Q = quantize(x, scale, zero_point, min_q, max_q)
return dequantize(Q, scale, zero_point)
x_quant = quantize(x, scale, zero_point, args)
return dequantize(x_quant, scale, zero_point)


def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme):
Expand Down
23 changes: 23 additions & 0 deletions src/compressed_tensors/quantization/quant_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,33 @@ def lifecycle_order(cls) -> List["QuantizationStatus"]:
return

def __ge__(self, other):
if other is None:
return True
if not isinstance(other, self.__class__):
raise NotImplementedError
return LIFECYCLE_ORDER.index(self) >= LIFECYCLE_ORDER.index(other)

def __gt__(self, other):
if other is None:
return True
if not isinstance(other, self.__class__):
raise NotImplementedError
return LIFECYCLE_ORDER.index(self) > LIFECYCLE_ORDER.index(other)

def __lt__(self, other):
if other is None:
return False
if not isinstance(other, self.__class__):
raise NotImplementedError
return LIFECYCLE_ORDER.index(self) < LIFECYCLE_ORDER.index(other)

def __le__(self, other):
if other is None:
return False
if not isinstance(other, self.__class__):
raise NotImplementedError
return LIFECYCLE_ORDER.index(self) <= LIFECYCLE_ORDER.index(other)


LIFECYCLE_ORDER = [
QuantizationStatus.INITIALIZED,
Expand Down
42 changes: 37 additions & 5 deletions tests/quantization/lifecycle/test_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional

from compressed_tensors.quantization.lifecycle import apply_quantization_config
import torch
from compressed_tensors.quantization.lifecycle import (
apply_quantization_config,
apply_quantization_status,
)
from compressed_tensors.quantization.quant_config import (
QuantizationConfig,
QuantizationStatus,
Expand All @@ -22,7 +27,7 @@


def test_apply_quantization_config_tinyllama():
quant_config = get_sample_tinyllama_quant_config()
quant_config = get_sample_tinyllama_quant_config(status="calibration")
model = get_tinyllama_model()

# check that model is not already quantized
Expand Down Expand Up @@ -55,6 +60,23 @@ def test_apply_quantization_config_tinyllama():
assert num_embeddings == 1
assert num_rotary_embeddings == 22

# test quantization compression
# sample forward pass to fill scales, zps
model(torch.zeros((1, 1), dtype=int), torch.zeros((1, 1), dtype=int))
apply_quantization_status(model, QuantizationStatus.COMPRESSED)
for name, module in model.named_modules():
if name in quant_config.ignore:
continue
module_type = module.__class__.__name__
if module_type == "Linear":
_test_layer_quantization_status(
module,
inputs=True,
weights=True,
expected_status=QuantizationStatus.COMPRESSED,
expected_dtype=torch.int8,
)


def test_serialize_config_tinyllama():
quant_config = get_sample_tinyllama_quant_config()
Expand All @@ -81,11 +103,19 @@ def test_serialize_config_tinyllama():
assert serialized_config.global_compression_ratio < 8.0


def _test_layer_quantization_status(module, inputs: bool, weights: bool):
def _test_layer_quantization_status(
module,
inputs: bool,
weights: bool,
expected_status: Optional[QuantizationStatus] = None,
expected_dtype: Optional[torch.dtype] = None,
):
# check if quantization is applied at all (true if inputs or weights targeted)
quantized = inputs or weights
assert hasattr(module, "quantization_scheme") == quantized
assert hasattr(module, "quantization_status") == quantized
if expected_status is not None:
assert module.quantization_status is expected_status

# check inputs matches expected
assert hasattr(module, "input_scale") == inputs
Expand All @@ -94,6 +124,8 @@ def _test_layer_quantization_status(module, inputs: bool, weights: bool):
# check weights matches expected
assert hasattr(module, "weight_scale") == weights
assert hasattr(module, "weight_zero_point") == weights
if weights and expected_dtype is not None:
assert module.weight.dtype is expected_dtype


def get_tinyllama_model():
Expand All @@ -102,11 +134,11 @@ def get_tinyllama_model():
)


def get_sample_tinyllama_quant_config():
def get_sample_tinyllama_quant_config(status: str = "frozen"):
config_dict = {
"quant_method": "sparseml",
"format": "fakequant",
"quantization_status": "frozen",
"quantization_status": status,
"global_compression_ratio": None,
"config_groups": {
"group_1": {
Expand Down
Loading