diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index eb4d6b18..5aecae0d 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -174,7 +174,10 @@ def _initialize_scale_zero_point( device = get_execution_device(module) # infer expected scale/zero point shape - expected_shape = 1 # per tensor + if quantization_args.strategy == QuantizationStrategy.TOKEN: + expected_shape = (1, 1) + else: + expected_shape = 1 if base_name == "weight" and weight_shape is not None: if quantization_args.strategy == QuantizationStrategy.CHANNEL: diff --git a/tests/conftest.py b/tests/conftest.py index a1c1d861..492f7af0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -44,8 +44,6 @@ def update_scale_zp(module: torch.nn.Module, base_name: str, value: torch.Tensor min_val = torch.amin(value, dim=dim, keepdims=True) max_val = torch.amax(value, dim=dim, keepdims=True) scale, zp = calculate_qparams(min_val, max_val, args) - scale = scale.reshape((1, 1)) - zp = zp.reshape((1, 1)) update_parameter_data(module, scale, f"{base_name}_scale") update_parameter_data(module, zp, f"{base_name}_zero_point") diff --git a/tests/test_quantization/lifecycle/test_initialize.py b/tests/test_quantization/lifecycle/test_initialize.py index 987b2ae2..215f2130 100644 --- a/tests/test_quantization/lifecycle/test_initialize.py +++ b/tests/test_quantization/lifecycle/test_initialize.py @@ -14,15 +14,25 @@ import pytest +from compressed_tensors.quantization import ( + ActivationOrdering, + QuantizationArgs, + QuantizationScheme, + QuantizationStatus, + QuantizationStrategy, +) from compressed_tensors.quantization.lifecycle.initialize import ( initialize_module_for_quantization, ) -from compressed_tensors.quantization.quant_args import QuantizationArgs -from compressed_tensors.quantization.quant_config import QuantizationStatus from torch.nn import Linear NUM_BITS = 8 +Q_PARAM_NAMES = { + "input_activations": "input", + "weights": "weight", + "output_activations": "output", +} @pytest.mark.parametrize( @@ -77,3 +87,78 @@ def test_initialize_module_for_quantization( assert hasattr(layer, "quantization_status") assert layer.quantization_status == QuantizationStatus.INITIALIZED + + +@pytest.mark.parametrize( + "weights,input_activations", + [ + ( + QuantizationArgs(strategy="tensor"), + QuantizationArgs(strategy="tensor"), + ), + ( + QuantizationArgs(strategy="channel"), + None, + ), + ( + QuantizationArgs(strategy="group", group_size=2), + None, + ), + ( + QuantizationArgs(strategy="group", group_size=2, actorder="group"), + None, + ), + ( + QuantizationArgs(strategy="group", group_size=2, actorder="weight"), + None, + ), + ( + QuantizationArgs(strategy="block"), + QuantizationArgs(strategy="block"), + ), + ( + QuantizationArgs(strategy="token"), + QuantizationArgs(strategy="token"), + ), + ], +) +def test_initialize_quantization_parameters(weights, input_activations): + quantization_scheme = QuantizationScheme( + targets=["*"], + weights=weights, + input_activations=input_activations, + ) + layer = Linear(7, 8) + initialize_module_for_quantization(layer, quantization_scheme) + + for q_type in ("input_activations", "weights"): + args = getattr(quantization_scheme, q_type) + if args is None: + continue + q_param_name = Q_PARAM_NAMES[q_type] + + # scale and zero point + if args.strategy == QuantizationStrategy.TENSOR: + expected_shape = (1,) + + elif args.strategy == QuantizationStrategy.CHANNEL: # only weight + expected_shape = (layer.weight.shape[0], 1) + + elif args.strategy == QuantizationStrategy.GROUP: # only weight + num_groups = layer.weight.shape[1] // args.group_size + expected_shape = (layer.weight.shape[0], max(num_groups, 1)) + + elif args.strategy == QuantizationStrategy.BLOCK: + expected_shape = (1,) + + elif args.strategy == QuantizationStrategy.TOKEN: + expected_shape = (1, 1) + + assert getattr(layer, f"{q_param_name}_scale").shape == expected_shape + assert getattr(layer, f"{q_param_name}_zero_point").shape == expected_shape + + # g_idx + if args.actorder == ActivationOrdering.GROUP: + assert getattr(layer, f"{q_param_name}_g_idx").shape == ( + layer.weight.shape[1], + ) diff --git a/tests/test_quantization/test_configs/test_strategies.py b/tests/test_quantization/test_configs/test_strategies.py index 94201463..6605daf0 100644 --- a/tests/test_quantization/test_configs/test_strategies.py +++ b/tests/test_quantization/test_configs/test_strategies.py @@ -67,8 +67,8 @@ def test_channelwise( if input_symmetry is not None: mock_per_channel_calibration(model, base_name="input", value=inputs) - assert list(model.weight_scale.shape) == [model_shape[1], 1] - assert list(model.weight_zero_point.shape) == [model_shape[1], 1] + assert model.weight_scale.shape == (model_shape[1], 1) + assert model.weight_zero_point.shape == (model_shape[1], 1) @torch.no_grad @@ -97,14 +97,14 @@ def test_group( model, base_name="input", value=inputs, group_size=group_size ) - assert list(model.weight_scale.shape) == [ + assert model.weight_scale.shape == ( model_shape[1], int(model_shape[0] / group_size), - ] - assert list(model.weight_zero_point.shape) == [ + ) + assert model.weight_zero_point.shape == ( model_shape[1], int(model_shape[0] / group_size), - ] + ) @torch.no_grad @@ -131,8 +131,8 @@ def test_token( mock_per_channel_calibration(model, base_name="weight", value=model.weight) mock_per_token_calibration(model, base_name="input", value=inputs) - assert list(model.input_scale.shape) == [1, 1] - assert list(model.input_zero_point.shape) == [1, 1] + assert model.input_scale.shape == (1, 1) + assert model.input_zero_point.shape == (1, 1) - assert list(model.weight_scale.shape) == [256, 1] - assert list(model.weight_zero_point.shape) == [256, 1] + assert model.weight_scale.shape == (256, 1) + assert model.weight_zero_point.shape == (256, 1) diff --git a/tests/test_quantization/test_utils/test_helpers.py b/tests/test_quantization/test_utils/test_helpers.py new file mode 100644 index 00000000..b106ee2d --- /dev/null +++ b/tests/test_quantization/test_utils/test_helpers.py @@ -0,0 +1,58 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch +from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy +from compressed_tensors.quantization.utils import calculate_qparams + + +@pytest.mark.parametrize( + "keepdims,strategy,exp_shape", + [ + ( + False, + QuantizationStrategy.TENSOR, + torch.Size( + [ + 1, + ] + ), + ), + (True, QuantizationStrategy.CHANNEL, torch.Size([1, 1])), + (True, QuantizationStrategy.GROUP, torch.Size([1, 1])), + ( + False, + QuantizationStrategy.BLOCK, + torch.Size( + [ + 1, + ] + ), + ), + (True, QuantizationStrategy.TOKEN, torch.Size([1, 1])), + ], +) +def test_calculate_qparams(keepdims, strategy, exp_shape): + value = torch.randn(14, 5) + min_val = torch.amin(value, dim=tuple(), keepdims=keepdims) + max_val = torch.amax(value, dim=tuple(), keepdims=keepdims) + + if strategy == QuantizationStrategy.GROUP: + args = QuantizationArgs(strategy=strategy, group_size=2) + else: + args = QuantizationArgs(strategy=strategy) + scale, zp = calculate_qparams(min_val, max_val, args) + assert scale.shape == exp_shape + assert zp.shape == exp_shape