Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
kylesayrs committed Nov 19, 2024
1 parent cf99e09 commit 14795ad
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 12 deletions.
89 changes: 87 additions & 2 deletions tests/test_quantization/lifecycle/test_initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,25 @@


import pytest
from compressed_tensors.quantization import (
ActivationOrdering,
QuantizationArgs,
QuantizationScheme,
QuantizationStatus,
QuantizationStrategy,
)
from compressed_tensors.quantization.lifecycle.initialize import (
initialize_module_for_quantization,
)
from compressed_tensors.quantization.quant_args import QuantizationArgs
from compressed_tensors.quantization.quant_config import QuantizationStatus
from torch.nn import Linear


NUM_BITS = 8
Q_PARAM_NAMES = {
"input_activations": "input",
"weights": "weight",
"output_activations": "output",
}


@pytest.mark.parametrize(
Expand Down Expand Up @@ -77,3 +87,78 @@ def test_initialize_module_for_quantization(
assert hasattr(layer, "quantization_status")

assert layer.quantization_status == QuantizationStatus.INITIALIZED


@pytest.mark.parametrize(
"weights,input_activations",
[
(
QuantizationArgs(strategy="tensor"),
QuantizationArgs(strategy="tensor"),
),
(
QuantizationArgs(strategy="channel"),
None,
),
(
QuantizationArgs(strategy="group", group_size=2),
None,
),
(
QuantizationArgs(strategy="group", group_size=2, actorder="group"),
None,
),
(
QuantizationArgs(strategy="group", group_size=2, actorder="weight"),
None,
),
(
QuantizationArgs(strategy="block"),
QuantizationArgs(strategy="block"),
),
(
QuantizationArgs(strategy="token"),
QuantizationArgs(strategy="token"),
),
],
)
def test_initialize_quantization_parameters(weights, input_activations):
quantization_scheme = QuantizationScheme(
targets=["*"],
weights=weights,
input_activations=input_activations,
)
layer = Linear(7, 8)
initialize_module_for_quantization(layer, quantization_scheme)

for q_type in ("input_activations", "weights"):
args = getattr(quantization_scheme, q_type)
if args is None:
continue
q_param_name = Q_PARAM_NAMES[q_type]

# scale and zero point
if args.strategy == QuantizationStrategy.TENSOR:
expected_shape = (1,)

elif args.strategy == QuantizationStrategy.CHANNEL: # only weight
expected_shape = (layer.weight.shape[0], 1)

elif args.strategy == QuantizationStrategy.GROUP: # only weight
num_groups = layer.weight.shape[1] // args.group_size
expected_shape = (layer.weight.shape[0], max(num_groups, 1))

elif args.strategy == QuantizationStrategy.BLOCK:
expected_shape = (1,)

elif args.strategy == QuantizationStrategy.TOKEN:
expected_shape = (1, 1)

assert getattr(layer, f"{q_param_name}_scale").shape == expected_shape
assert getattr(layer, f"{q_param_name}_zero_point").shape == expected_shape

# g_idx
if args.actorder == ActivationOrdering.GROUP:
assert getattr(layer, f"{q_param_name}_g_idx").shape == (
layer.weight.shape[1],
)
20 changes: 10 additions & 10 deletions tests/test_quantization/test_configs/test_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ def test_channelwise(
if input_symmetry is not None:
mock_per_channel_calibration(model, base_name="input", value=inputs)

assert list(model.weight_scale.shape) == [model_shape[1], 1]
assert list(model.weight_zero_point.shape) == [model_shape[1], 1]
assert model.weight_scale.shape == (model_shape[1], 1)
assert model.weight_zero_point.shape == (model_shape[1], 1)


@torch.no_grad
Expand Down Expand Up @@ -97,14 +97,14 @@ def test_group(
model, base_name="input", value=inputs, group_size=group_size
)

assert list(model.weight_scale.shape) == [
assert model.weight_scale.shape == (
model_shape[1],
int(model_shape[0] / group_size),
]
assert list(model.weight_zero_point.shape) == [
)
assert model.weight_zero_point.shape == (
model_shape[1],
int(model_shape[0] / group_size),
]
)


@torch.no_grad
Expand All @@ -131,8 +131,8 @@ def test_token(
mock_per_channel_calibration(model, base_name="weight", value=model.weight)
mock_per_token_calibration(model, base_name="input", value=inputs)

assert list(model.input_scale.shape) == [1, 1]
assert list(model.input_zero_point.shape) == [1, 1]
assert model.input_scale.shape == (1, 1)
assert model.input_zero_point.shape == (1, 1)

assert list(model.weight_scale.shape) == [256, 1]
assert list(model.weight_zero_point.shape) == [256, 1]
assert model.weight_scale.shape == (256, 1)
assert model.weight_zero_point.shape == (256, 1)

0 comments on commit 14795ad

Please sign in to comment.