From 05b68aa8c3675099753aa1c4bafc46a27881ab47 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 19 Dec 2024 16:34:45 +0000 Subject: [PATCH] remove unnecessary arg --- .../test_quantization/test_utils/test_helpers.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/tests/test_quantization/test_utils/test_helpers.py b/tests/test_quantization/test_utils/test_helpers.py index d89a9c7f..b106ee2d 100644 --- a/tests/test_quantization/test_utils/test_helpers.py +++ b/tests/test_quantization/test_utils/test_helpers.py @@ -19,10 +19,9 @@ @pytest.mark.parametrize( - "dim,keepdims,strategy,exp_shape", + "keepdims,strategy,exp_shape", [ ( - tuple(), False, QuantizationStrategy.TENSOR, torch.Size( @@ -31,10 +30,9 @@ ] ), ), - (tuple(), True, QuantizationStrategy.CHANNEL, torch.Size([1, 1])), - (tuple(), True, QuantizationStrategy.GROUP, torch.Size([1, 1])), + (True, QuantizationStrategy.CHANNEL, torch.Size([1, 1])), + (True, QuantizationStrategy.GROUP, torch.Size([1, 1])), ( - tuple(), False, QuantizationStrategy.BLOCK, torch.Size( @@ -43,13 +41,13 @@ ] ), ), - (tuple(), True, QuantizationStrategy.TOKEN, torch.Size([1, 1])), + (True, QuantizationStrategy.TOKEN, torch.Size([1, 1])), ], ) -def test_calculate_qparams(dim, keepdims, strategy, exp_shape): +def test_calculate_qparams(keepdims, strategy, exp_shape): value = torch.randn(14, 5) - min_val = torch.amin(value, dim=dim, keepdims=keepdims) - max_val = torch.amax(value, dim=dim, keepdims=keepdims) + min_val = torch.amin(value, dim=tuple(), keepdims=keepdims) + max_val = torch.amax(value, dim=tuple(), keepdims=keepdims) if strategy == QuantizationStrategy.GROUP: args = QuantizationArgs(strategy=strategy, group_size=2)