Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
kylesayrs committed Dec 17, 2024
1 parent 14795ad commit bb83891
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 2 deletions.
3 changes: 1 addition & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,6 @@ def update_scale_zp(module: torch.nn.Module, base_name: str, value: torch.Tensor
min_val = torch.amin(value, dim=dim, keepdims=True)
max_val = torch.amax(value, dim=dim, keepdims=True)
scale, zp = calculate_qparams(min_val, max_val, args)
scale = scale.reshape((1, 1))
zp = zp.reshape((1, 1))
update_parameter_data(module, scale, f"{base_name}_scale")
update_parameter_data(module, zp, f"{base_name}_zero_point")

Expand Down Expand Up @@ -129,6 +127,7 @@ def update_scale_zp(module: torch.nn.Module, base_name: str, value: torch.Tensor

# per tensor quantization just calls calculate_qparams directly
min_val, max_val = torch.aminmax(value)
breakpoint()
scale, zp = calculate_qparams(min_val, max_val, args)
update_parameter_data(module, scale, f"{base_name}_scale")
update_parameter_data(module, zp, f"{base_name}_zero_point")
Expand Down
38 changes: 38 additions & 0 deletions tests/test_quantization/test_utils/test_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import torch
import pytest

from compressed_tensors.quantization.utils import calculate_qparams
from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy


_IN_DIMS = 5
_OUT_DIMS = 14
_GROUP_SIZE = 2

@pytest.mark.parametrize(
"dim,keepdims,strategy,exp_shape",
[
(tuple(), False, QuantizationStrategy.TENSOR, torch.Size([1,])),
(0, True, QuantizationStrategy.CHANNEL, torch.Size([_OUT_DIMS, 1])),
(tuple(), True, QuantizationStrategy.GROUP, torch.Size([_OUT_DIMS // _GROUP_SIZE , 1])),
(tuple(), False, QuantizationStrategy.BLOCK, torch.Size([1, ])),
(tuple(), True, QuantizationStrategy.TOKEN, torch.Size([1, 1])),
],
)
def test_calculate_qparams(dim, keepdims, strategy, exp_shape):
value = torch.randn(_OUT_DIMS, _IN_DIMS)
min_val = torch.amin(value, dim=dim, keepdims=keepdims)
max_val = torch.amax(value, dim=dim, keepdims=keepdims)

if strategy == QuantizationStrategy.GROUP:
args = QuantizationArgs(strategy=strategy, group_size=_GROUP_SIZE)
scale, zp = calculate_qparams(min_val, max_val, args)
assert scale.shape == exp_shape
assert zp.shape == exp_shape

else:
args = QuantizationArgs(strategy=strategy)

scale, zp = calculate_qparams(min_val, max_val, args)
assert scale.shape == exp_shape
assert zp.shape == exp_shape

0 comments on commit bb83891

Please sign in to comment.