diff --git a/tests/quantization/test_compressed_tensors.py b/tests/quantization/test_compressed_tensors.py index fd1a1d05de208..31624cbf1f6f0 100644 --- a/tests/quantization/test_compressed_tensors.py +++ b/tests/quantization/test_compressed_tensors.py @@ -221,7 +221,7 @@ def test_compressed_tensors_kv_cache(vllm_runner): ("nm-testing/Meta-Llama-3-8B-Instruct-FP8-Dynamic-IA-Per-Tensor-Weight-testing", "tensor", "token") ]) -def test_compressed_tensors_2of4(vllm_runner, args_2of4): +def test_compressed_tensors_2of4_quant(vllm_runner, args_2of4): model, weight_strategy, input_strategy = args_2of4 with vllm_runner(model) as llm: model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501 @@ -239,6 +239,31 @@ def test_compressed_tensors_2of4(vllm_runner, args_2of4): assert sparsity_map.get("Linear").format == "dense" assert sparsity_map.get("Linear").sparsity_structure == "2:4" + output = llm.generate_greedy("Hello my name is", max_tokens=20) + assert output + + +@pytest.mark.parametrize( + "args_2of4", + [("nm-testing/TinyLlama-1.1B-Chat-v1.0-2of4-Sparse-Dense-Compressor")]) +def test_compressed_tensors_2of4_sparse(vllm_runner, args_2of4): + model = args_2of4 + with vllm_runner(model) as llm: + model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501 + layer = model.model.layers[0] + + qkv_proj = layer.self_attn.qkv_proj + assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod) + assert isinstance(qkv_proj.scheme, CompressedTensors24) + + assert qkv_proj.scheme.weight_quant is None + assert qkv_proj.scheme.input_quant is None + assert not qkv_proj.scheme.quantized + assert qkv_proj.quant_method.quantization_config.sparsity_scheme_map + sparsity_map = qkv_proj.quant_method.quantization_config.sparsity_scheme_map # noqa: E501 + assert sparsity_map.get("Linear").format == "dense" + assert sparsity_map.get("Linear").sparsity_structure == "2:4" + output = llm.generate_greedy("Hello my name is", max_tokens=20) print(output) assert output diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 87d543ced7d0c..0c1fc18228f5c 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -385,6 +385,8 @@ def get_scheme( weight_quant = None input_quant = None + # For models with sparsity, assumes that the sparse layers are also + # quantized for cutlass 2:4 support sparsity_scheme: Optional[ SparsityCompressionConfig] = self.sparsity_scheme_map.get( matched_target) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py index 9da7151ff1b15..12e1b26d87081 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py @@ -82,6 +82,17 @@ def create_weights(self, layer: torch.nn.Module, input_size: int, layer.register_parameter("input_scale", input_scale) + else: + # for sparse-only, pass in 1 for weight/input scales + weight_scale = torch.nn.Parameter(data=torch.ones( + 1, dtype=torch.float32), + requires_grad=False) + input_scale = torch.nn.Parameter(data=torch.ones( + 1, dtype=torch.float32), + requires_grad=False) + layer.register_parameter("input_scale", input_scale) + layer.register_parameter("weight_scale", weight_scale) + layer.register_parameter("weight", weight) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: @@ -95,12 +106,22 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: :param layer: The layer with the weights to be processed """ - if (self.weight_quant and self.weight_quant.strategy - == QuantizationStrategy.TENSOR.value): - layer.weight_scale = torch.nn.Parameter(convert_to_channelwise( - weight_scale=layer.weight_scale, - logical_widths=layer.logical_widths), - requires_grad=False) + # torch.compile workaround + if hasattr(layer, "input_scale"): + layer.input_scale = torch.nn.Parameter(layer.input_scale.data, + requires_grad=False) + + if self.weight_quant: + if self.weight_quant.strategy == QuantizationStrategy.TENSOR.value: + layer.weight_scale = torch.nn.Parameter(convert_to_channelwise( + weight_scale=layer.weight_scale, + logical_widths=layer.logical_widths), + requires_grad=False) + else: + # torch.compile workaround + layer.weight_scale = torch.nn.Parameter( + layer.weight_scale.data, requires_grad=False) + w_compressed, meta = ops.cutlass_compress_entry(layer.weight.data) layer.weight = torch.nn.Parameter(w_compressed, requires_grad=False) layer.meta = torch.nn.Parameter(meta, requires_grad=False) @@ -120,21 +141,27 @@ def apply_weights(self, :param bias: The bias to be added to the output tensor :return: The output tensor of the layer """ - scale = None - if hasattr(layer, "input_scale"): - scale = layer.input_scale + if self.quantized: + scale = None + if hasattr(layer, "input_scale"): + scale = layer.input_scale + + if self.weights_dtype == torch.int8: + ops_output = ops.scaled_int8_quant(x, scale=scale) + q_input = ops_output[0] + input_scale = ops_output[1] + else: + assert self.weights_dtype == torch.float8_e4m3fn + if scale is not None: + q_input, input_scale = ops.scaled_fp8_quant(x, scale=scale) + else: + q_input, input_scale = ops.scaled_fp8_quant( + x, use_per_token_if_dynamic=True) - if self.weights_dtype == torch.int8: - ops_output = ops.scaled_int8_quant(x, scale=scale) - q_input = ops_output[0] - input_scale = ops_output[1] else: - assert self.weights_dtype == torch.float8_e4m3fn - if scale is not None: - q_input, input_scale = ops.scaled_fp8_quant(x, scale=scale) - else: - q_input, input_scale = ops.scaled_fp8_quant( - x, use_per_token_if_dynamic=True) + # Not quantized, nothing to do with the input_scales, use as is + input_scale = layer.input_scale + q_input = x out = ops.cutlass_scaled_sparse_mm(a=layer.weight, e=layer.meta, @@ -143,7 +170,6 @@ def apply_weights(self, scale_b=input_scale, out_dtype=self.output_dtype, bias=bias) - assert out.is_contiguous() return out