diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index 5955b5e7..49e7b1a9 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -172,9 +172,10 @@ def _initialize_scale_zero_point_observer( # (output_channels, 1) expected_shape = (weight_shape[0], 1) elif quantization_args.strategy == QuantizationStrategy.GROUP: + num_groups = weight_shape[1] // quantization_args.group_size expected_shape = ( weight_shape[0], - weight_shape[1] // quantization_args.group_size, + max(num_groups, 1) ) scale_dtype = module.weight.dtype