From bb8389164e8e4d865829506cba567dacd5022cd2 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 17 Dec 2024 18:16:36 +0000 Subject: [PATCH] wip --- tests/conftest.py | 3 +- .../test_utils/test_helpers.py | 38 +++++++++++++++++++ 2 files changed, 39 insertions(+), 2 deletions(-) create mode 100644 tests/test_quantization/test_utils/test_helpers.py diff --git a/tests/conftest.py b/tests/conftest.py index a1c1d861..1db6308f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -44,8 +44,6 @@ def update_scale_zp(module: torch.nn.Module, base_name: str, value: torch.Tensor min_val = torch.amin(value, dim=dim, keepdims=True) max_val = torch.amax(value, dim=dim, keepdims=True) scale, zp = calculate_qparams(min_val, max_val, args) - scale = scale.reshape((1, 1)) - zp = zp.reshape((1, 1)) update_parameter_data(module, scale, f"{base_name}_scale") update_parameter_data(module, zp, f"{base_name}_zero_point") @@ -129,6 +127,7 @@ def update_scale_zp(module: torch.nn.Module, base_name: str, value: torch.Tensor # per tensor quantization just calls calculate_qparams directly min_val, max_val = torch.aminmax(value) + breakpoint() scale, zp = calculate_qparams(min_val, max_val, args) update_parameter_data(module, scale, f"{base_name}_scale") update_parameter_data(module, zp, f"{base_name}_zero_point") diff --git a/tests/test_quantization/test_utils/test_helpers.py b/tests/test_quantization/test_utils/test_helpers.py new file mode 100644 index 00000000..b07d5e36 --- /dev/null +++ b/tests/test_quantization/test_utils/test_helpers.py @@ -0,0 +1,38 @@ +import torch +import pytest + +from compressed_tensors.quantization.utils import calculate_qparams +from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy + + +_IN_DIMS = 5 +_OUT_DIMS = 14 +_GROUP_SIZE = 2 + +@pytest.mark.parametrize( + "dim,keepdims,strategy,exp_shape", + [ + (tuple(), False, QuantizationStrategy.TENSOR, torch.Size([1,])), + (0, True, QuantizationStrategy.CHANNEL, torch.Size([_OUT_DIMS, 1])), + (tuple(), True, QuantizationStrategy.GROUP, torch.Size([_OUT_DIMS // _GROUP_SIZE , 1])), + (tuple(), False, QuantizationStrategy.BLOCK, torch.Size([1, ])), + (tuple(), True, QuantizationStrategy.TOKEN, torch.Size([1, 1])), + ], +) +def test_calculate_qparams(dim, keepdims, strategy, exp_shape): + value = torch.randn(_OUT_DIMS, _IN_DIMS) + min_val = torch.amin(value, dim=dim, keepdims=keepdims) + max_val = torch.amax(value, dim=dim, keepdims=keepdims) + + if strategy == QuantizationStrategy.GROUP: + args = QuantizationArgs(strategy=strategy, group_size=_GROUP_SIZE) + scale, zp = calculate_qparams(min_val, max_val, args) + assert scale.shape == exp_shape + assert zp.shape == exp_shape + + else: + args = QuantizationArgs(strategy=strategy) + + scale, zp = calculate_qparams(min_val, max_val, args) + assert scale.shape == exp_shape + assert zp.shape == exp_shape