Skip to content

Commit

Permalink
[Bugfix] Update expected shape for per token strategy (#210)
Browse files Browse the repository at this point in the history
* update expected shape for per token strategy

* add tests

* wip

* add helpers test

Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>

* remove breakpoint

Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>

* remove unnecessary arg

---------

Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
  • Loading branch information
kylesayrs authored Dec 19, 2024
1 parent 1fa514a commit 975cb22
Show file tree
Hide file tree
Showing 5 changed files with 159 additions and 15 deletions.
5 changes: 4 additions & 1 deletion src/compressed_tensors/quantization/lifecycle/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,10 @@ def _initialize_scale_zero_point(
device = get_execution_device(module)

# infer expected scale/zero point shape
expected_shape = 1 # per tensor
if quantization_args.strategy == QuantizationStrategy.TOKEN:
expected_shape = (1, 1)
else:
expected_shape = 1

if base_name == "weight" and weight_shape is not None:
if quantization_args.strategy == QuantizationStrategy.CHANNEL:
Expand Down
2 changes: 0 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,6 @@ def update_scale_zp(module: torch.nn.Module, base_name: str, value: torch.Tensor
min_val = torch.amin(value, dim=dim, keepdims=True)
max_val = torch.amax(value, dim=dim, keepdims=True)
scale, zp = calculate_qparams(min_val, max_val, args)
scale = scale.reshape((1, 1))
zp = zp.reshape((1, 1))
update_parameter_data(module, scale, f"{base_name}_scale")
update_parameter_data(module, zp, f"{base_name}_zero_point")

Expand Down
89 changes: 87 additions & 2 deletions tests/test_quantization/lifecycle/test_initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,25 @@


import pytest
from compressed_tensors.quantization import (
ActivationOrdering,
QuantizationArgs,
QuantizationScheme,
QuantizationStatus,
QuantizationStrategy,
)
from compressed_tensors.quantization.lifecycle.initialize import (
initialize_module_for_quantization,
)
from compressed_tensors.quantization.quant_args import QuantizationArgs
from compressed_tensors.quantization.quant_config import QuantizationStatus
from torch.nn import Linear


NUM_BITS = 8
Q_PARAM_NAMES = {
"input_activations": "input",
"weights": "weight",
"output_activations": "output",
}


@pytest.mark.parametrize(
Expand Down Expand Up @@ -77,3 +87,78 @@ def test_initialize_module_for_quantization(
assert hasattr(layer, "quantization_status")

assert layer.quantization_status == QuantizationStatus.INITIALIZED


@pytest.mark.parametrize(
"weights,input_activations",
[
(
QuantizationArgs(strategy="tensor"),
QuantizationArgs(strategy="tensor"),
),
(
QuantizationArgs(strategy="channel"),
None,
),
(
QuantizationArgs(strategy="group", group_size=2),
None,
),
(
QuantizationArgs(strategy="group", group_size=2, actorder="group"),
None,
),
(
QuantizationArgs(strategy="group", group_size=2, actorder="weight"),
None,
),
(
QuantizationArgs(strategy="block"),
QuantizationArgs(strategy="block"),
),
(
QuantizationArgs(strategy="token"),
QuantizationArgs(strategy="token"),
),
],
)
def test_initialize_quantization_parameters(weights, input_activations):
quantization_scheme = QuantizationScheme(
targets=["*"],
weights=weights,
input_activations=input_activations,
)
layer = Linear(7, 8)
initialize_module_for_quantization(layer, quantization_scheme)

for q_type in ("input_activations", "weights"):
args = getattr(quantization_scheme, q_type)
if args is None:
continue
q_param_name = Q_PARAM_NAMES[q_type]

# scale and zero point
if args.strategy == QuantizationStrategy.TENSOR:
expected_shape = (1,)

elif args.strategy == QuantizationStrategy.CHANNEL: # only weight
expected_shape = (layer.weight.shape[0], 1)

elif args.strategy == QuantizationStrategy.GROUP: # only weight
num_groups = layer.weight.shape[1] // args.group_size
expected_shape = (layer.weight.shape[0], max(num_groups, 1))

elif args.strategy == QuantizationStrategy.BLOCK:
expected_shape = (1,)

elif args.strategy == QuantizationStrategy.TOKEN:
expected_shape = (1, 1)

assert getattr(layer, f"{q_param_name}_scale").shape == expected_shape
assert getattr(layer, f"{q_param_name}_zero_point").shape == expected_shape

# g_idx
if args.actorder == ActivationOrdering.GROUP:
assert getattr(layer, f"{q_param_name}_g_idx").shape == (
layer.weight.shape[1],
)
20 changes: 10 additions & 10 deletions tests/test_quantization/test_configs/test_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ def test_channelwise(
if input_symmetry is not None:
mock_per_channel_calibration(model, base_name="input", value=inputs)

assert list(model.weight_scale.shape) == [model_shape[1], 1]
assert list(model.weight_zero_point.shape) == [model_shape[1], 1]
assert model.weight_scale.shape == (model_shape[1], 1)
assert model.weight_zero_point.shape == (model_shape[1], 1)


@torch.no_grad
Expand Down Expand Up @@ -97,14 +97,14 @@ def test_group(
model, base_name="input", value=inputs, group_size=group_size
)

assert list(model.weight_scale.shape) == [
assert model.weight_scale.shape == (
model_shape[1],
int(model_shape[0] / group_size),
]
assert list(model.weight_zero_point.shape) == [
)
assert model.weight_zero_point.shape == (
model_shape[1],
int(model_shape[0] / group_size),
]
)


@torch.no_grad
Expand All @@ -131,8 +131,8 @@ def test_token(
mock_per_channel_calibration(model, base_name="weight", value=model.weight)
mock_per_token_calibration(model, base_name="input", value=inputs)

assert list(model.input_scale.shape) == [1, 1]
assert list(model.input_zero_point.shape) == [1, 1]
assert model.input_scale.shape == (1, 1)
assert model.input_zero_point.shape == (1, 1)

assert list(model.weight_scale.shape) == [256, 1]
assert list(model.weight_zero_point.shape) == [256, 1]
assert model.weight_scale.shape == (256, 1)
assert model.weight_zero_point.shape == (256, 1)
58 changes: 58 additions & 0 deletions tests/test_quantization/test_utils/test_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
import torch
from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy
from compressed_tensors.quantization.utils import calculate_qparams


@pytest.mark.parametrize(
"keepdims,strategy,exp_shape",
[
(
False,
QuantizationStrategy.TENSOR,
torch.Size(
[
1,
]
),
),
(True, QuantizationStrategy.CHANNEL, torch.Size([1, 1])),
(True, QuantizationStrategy.GROUP, torch.Size([1, 1])),
(
False,
QuantizationStrategy.BLOCK,
torch.Size(
[
1,
]
),
),
(True, QuantizationStrategy.TOKEN, torch.Size([1, 1])),
],
)
def test_calculate_qparams(keepdims, strategy, exp_shape):
value = torch.randn(14, 5)
min_val = torch.amin(value, dim=tuple(), keepdims=keepdims)
max_val = torch.amax(value, dim=tuple(), keepdims=keepdims)

if strategy == QuantizationStrategy.GROUP:
args = QuantizationArgs(strategy=strategy, group_size=2)
else:
args = QuantizationArgs(strategy=strategy)
scale, zp = calculate_qparams(min_val, max_val, args)
assert scale.shape == exp_shape
assert zp.shape == exp_shape

0 comments on commit 975cb22

Please sign in to comment.