diff --git a/tests/test_quantization/test_utils/test_helpers.py b/tests/test_quantization/test_utils/test_helpers.py index b07d5e36..d89a9c7f 100644 --- a/tests/test_quantization/test_utils/test_helpers.py +++ b/tests/test_quantization/test_utils/test_helpers.py @@ -1,38 +1,60 @@ -import torch -import pytest +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. -from compressed_tensors.quantization.utils import calculate_qparams +import pytest +import torch from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy +from compressed_tensors.quantization.utils import calculate_qparams -_IN_DIMS = 5 -_OUT_DIMS = 14 -_GROUP_SIZE = 2 - @pytest.mark.parametrize( "dim,keepdims,strategy,exp_shape", [ - (tuple(), False, QuantizationStrategy.TENSOR, torch.Size([1,])), - (0, True, QuantizationStrategy.CHANNEL, torch.Size([_OUT_DIMS, 1])), - (tuple(), True, QuantizationStrategy.GROUP, torch.Size([_OUT_DIMS // _GROUP_SIZE , 1])), - (tuple(), False, QuantizationStrategy.BLOCK, torch.Size([1, ])), + ( + tuple(), + False, + QuantizationStrategy.TENSOR, + torch.Size( + [ + 1, + ] + ), + ), + (tuple(), True, QuantizationStrategy.CHANNEL, torch.Size([1, 1])), + (tuple(), True, QuantizationStrategy.GROUP, torch.Size([1, 1])), + ( + tuple(), + False, + QuantizationStrategy.BLOCK, + torch.Size( + [ + 1, + ] + ), + ), (tuple(), True, QuantizationStrategy.TOKEN, torch.Size([1, 1])), ], ) def test_calculate_qparams(dim, keepdims, strategy, exp_shape): - value = torch.randn(_OUT_DIMS, _IN_DIMS) + value = torch.randn(14, 5) min_val = torch.amin(value, dim=dim, keepdims=keepdims) max_val = torch.amax(value, dim=dim, keepdims=keepdims) if strategy == QuantizationStrategy.GROUP: - args = QuantizationArgs(strategy=strategy, group_size=_GROUP_SIZE) - scale, zp = calculate_qparams(min_val, max_val, args) - assert scale.shape == exp_shape - assert zp.shape == exp_shape - + args = QuantizationArgs(strategy=strategy, group_size=2) else: args = QuantizationArgs(strategy=strategy) - scale, zp = calculate_qparams(min_val, max_val, args) assert scale.shape == exp_shape assert zp.shape == exp_shape