diff --git a/tests/test_quantization/test_utils/test_helpers.py b/tests/test_quantization/test_utils/test_helpers.py index d89a9c7f..b106ee2d 100644 --- a/tests/test_quantization/test_utils/test_helpers.py +++ b/tests/test_quantization/test_utils/test_helpers.py @@ -19,10 +19,9 @@ @pytest.mark.parametrize( - "dim,keepdims,strategy,exp_shape", + "keepdims,strategy,exp_shape", [ ( - tuple(), False, QuantizationStrategy.TENSOR, torch.Size( @@ -31,10 +30,9 @@ ] ), ), - (tuple(), True, QuantizationStrategy.CHANNEL, torch.Size([1, 1])), - (tuple(), True, QuantizationStrategy.GROUP, torch.Size([1, 1])), + (True, QuantizationStrategy.CHANNEL, torch.Size([1, 1])), + (True, QuantizationStrategy.GROUP, torch.Size([1, 1])), ( - tuple(), False, QuantizationStrategy.BLOCK, torch.Size( @@ -43,13 +41,13 @@ ] ), ), - (tuple(), True, QuantizationStrategy.TOKEN, torch.Size([1, 1])), + (True, QuantizationStrategy.TOKEN, torch.Size([1, 1])), ], ) -def test_calculate_qparams(dim, keepdims, strategy, exp_shape): +def test_calculate_qparams(keepdims, strategy, exp_shape): value = torch.randn(14, 5) - min_val = torch.amin(value, dim=dim, keepdims=keepdims) - max_val = torch.amax(value, dim=dim, keepdims=keepdims) + min_val = torch.amin(value, dim=tuple(), keepdims=keepdims) + max_val = torch.amax(value, dim=tuple(), keepdims=keepdims) if strategy == QuantizationStrategy.GROUP: args = QuantizationArgs(strategy=strategy, group_size=2)