diff --git a/src/compressed_tensors/compressors/sparse_quantized_compressors/marlin_24.py b/src/compressed_tensors/compressors/sparse_quantized_compressors/marlin_24.py index 7954b6b8..4c544588 100644 --- a/src/compressed_tensors/compressors/sparse_quantized_compressors/marlin_24.py +++ b/src/compressed_tensors/compressors/sparse_quantized_compressors/marlin_24.py @@ -238,7 +238,7 @@ def pack_scales_24(scales, quantization_args, w_shape): _, scale_perm_2_4, scale_perm_single_2_4 = get_permutations_24(num_bits) if ( - quantization_args.strategy is QuantizationStrategy.GROUP + quantization_args.strategy == QuantizationStrategy.GROUP and quantization_args.group_size < size_k ): scales = scales.reshape((-1, len(scale_perm_2_4)))[:, scale_perm_2_4]