diff --git a/tests/quantization/test_compressed_tensors.py b/tests/quantization/test_compressed_tensors.py index d7deedb2dc49e..ea6d019872654 100644 --- a/tests/quantization/test_compressed_tensors.py +++ b/tests/quantization/test_compressed_tensors.py @@ -179,9 +179,17 @@ def test_compressed_tensors_kv_cache(vllm_runner): output = llm.generate_greedy("Hello world!", max_tokens=20) assert output -@pytest.mark.parametrize( - "args_2of4", - [("nm-testing/Meta-Llama-3-8B-Instruct-FP8-Dynamic-2of4-testing", "channel", "token")]) + +@pytest.mark.parametrize("args_2of4", [ + ("nm-testing/Meta-Llama-3-8B-Instruct-FP8-Dynamic-2of4-testing", "channel", + "token"), + ("nm-testing/Meta-Llama-3-8B-Instruct-FP8-Static-Per-Tensor-testing", + "channel", "tensor"), + ("nm-testing/Meta-Llama-3-8B-Instruct-FP8-Static-testing", "tensor", + "tensor"), + ("nm-testing/Meta-Llama-3-8B-Instruct-FP8-Dynamic-IA-Per-Tensor-Weight-testing", + "tensor", "token") +]) def test_compressed_tensors_2of4(vllm_runner, args_2of4): model, weight_strategy, input_strategy = args_2of4 with vllm_runner(model) as llm: @@ -202,4 +210,4 @@ def test_compressed_tensors_2of4(vllm_runner, args_2of4): output = llm.generate_greedy("Hello my name is", max_tokens=20) print(output) - assert output \ No newline at end of file + assert output diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index b8337fd250a8e..da34b9b9aa68c 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -388,13 +388,12 @@ def get_scheme( if self.supports_cutlass_24(weight_quant=weight_quant, input_quant=input_quant, sparsity_scheme=sparsity_scheme): - # Have a valid sparsity scheme + # Have a valid sparsity scheme # Validate layer is supported by Cutlass 2:4 Kernel - scheme = CompressedTensors24( - quantized=weight_quant is not None or input_quant is not None, - weight_quant=weight_quant, - input_quant=input_quant - ) + scheme = CompressedTensors24(quantized=weight_quant is not None + or input_quant is not None, + weight_quant=weight_quant, + input_quant=input_quant) else: # Find the quant_scheme scheme = self._get_scheme_from_parts( diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py index 133e8a236f722..5fee61d340f7c 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py @@ -4,20 +4,20 @@ from compressed_tensors.quantization import QuantizationType, QuantizationStrategy from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme) -from vllm.model_executor.parameter import ModelWeightParameter, ChannelQuantScaleParameter, PerTensorScaleParameter +from vllm.model_executor.parameter import ModelWeightParameter, ChannelQuantScaleParameter, PerTensorScaleParameter, BasevLLMParameter from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + convert_to_channelwise) __all__ = ["CompressedTensors24"] class CompressedTensors24(CompressedTensorsScheme): - def __init__( - self, - quantized: bool = False, - weight_quant=None, - input_quant=None - ): + def __init__(self, + quantized: bool = False, + weight_quant=None, + input_quant=None): self.quantized = quantized self.weight_quant = weight_quant self.input_quant = input_quant @@ -33,6 +33,7 @@ def create_weights(self, layer: torch.nn.Module, input_size: int, **kwargs): self.output_dtype = params_dtype + layer.logical_widths = output_partition_sizes weights_dtype: torch.dtype = self._get_params_dtype(params_dtype) # parameter to store uncompressed weight @@ -51,7 +52,7 @@ def create_weights(self, layer: torch.nn.Module, input_size: int, output_dim=0, weight_loader=weight_loader) else: - assert self.weight_quant.strategy == QuantizationStrategy.TOKEN.value + assert self.weight_quant.strategy == QuantizationStrategy.TENSOR.value weight_scale = PerTensorScaleParameter(data=torch.empty( len(output_partition_sizes), dtype=torch.float32), weight_loader=weight_loader) @@ -62,9 +63,9 @@ def create_weights(self, layer: torch.nn.Module, input_size: int, if not self.input_quant.dynamic: # register input quant scale assert self.input_quant.strategy == QuantizationStrategy.TENSOR.value - input_scale = PerTensorScaleParameter(data=torch.empty( - len(output_partition_sizes), dtype=torch.float32), - weight_loader=weight_loader) + input_scale = BasevLLMParameter(data=torch.empty( + 1, dtype=torch.float32), + weight_loader=weight_loader) layer.register_parameter("input_scale", input_scale) @@ -81,6 +82,11 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: :param layer: The layer with the weights to be processed """ + if self.weight_quant.strategy == QuantizationStrategy.TENSOR.value: + layer.weight_scale = torch.nn.Parameter(convert_to_channelwise( + weight_scale=layer.weight_scale, + logical_widths=layer.logical_widths), + requires_grad=False) w_compressed, meta = ops.cutlass_compress_entry(layer.weight.data) layer.w_compressed = torch.nn.Parameter(w_compressed, requires_grad=False)