Skip to content

Commit

Permalink
add 2:4 sparse only support, add test cases, add torch.comile workaround
Browse files Browse the repository at this point in the history
  • Loading branch information
dsikka committed Dec 8, 2024
1 parent 4820ebe commit 2c32ce0
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 21 deletions.
27 changes: 26 additions & 1 deletion tests/quantization/test_compressed_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def test_compressed_tensors_kv_cache(vllm_runner):
("nm-testing/Meta-Llama-3-8B-Instruct-FP8-Dynamic-IA-Per-Tensor-Weight-testing",
"tensor", "token")
])
def test_compressed_tensors_2of4(vllm_runner, args_2of4):
def test_compressed_tensors_2of4_quant(vllm_runner, args_2of4):
model, weight_strategy, input_strategy = args_2of4
with vllm_runner(model) as llm:
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
Expand All @@ -239,6 +239,31 @@ def test_compressed_tensors_2of4(vllm_runner, args_2of4):
assert sparsity_map.get("Linear").format == "dense"
assert sparsity_map.get("Linear").sparsity_structure == "2:4"

output = llm.generate_greedy("Hello my name is", max_tokens=20)
assert output


@pytest.mark.parametrize(
"args_2of4",
[("nm-testing/TinyLlama-1.1B-Chat-v1.0-2of4-Sparse-Dense-Compressor")])
def test_compressed_tensors_2of4_sparse(vllm_runner, args_2of4):
model = args_2of4
with vllm_runner(model) as llm:
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
layer = model.model.layers[0]

qkv_proj = layer.self_attn.qkv_proj
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
assert isinstance(qkv_proj.scheme, CompressedTensors24)

assert qkv_proj.scheme.weight_quant is None
assert qkv_proj.scheme.input_quant is None
assert not qkv_proj.scheme.quantized
assert qkv_proj.quant_method.quantization_config.sparsity_scheme_map
sparsity_map = qkv_proj.quant_method.quantization_config.sparsity_scheme_map # noqa: E501
assert sparsity_map.get("Linear").format == "dense"
assert sparsity_map.get("Linear").sparsity_structure == "2:4"

output = llm.generate_greedy("Hello my name is", max_tokens=20)
print(output)
assert output
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,8 @@ def get_scheme(
weight_quant = None
input_quant = None

# For models with sparsity, assumes that the sparse layers are also
# quantized for cutlass 2:4 support
sparsity_scheme: Optional[
SparsityCompressionConfig] = self.sparsity_scheme_map.get(
matched_target)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,17 @@ def create_weights(self, layer: torch.nn.Module, input_size: int,

layer.register_parameter("input_scale", input_scale)

else:
# for sparse-only, pass in 1 for weight/input scales
weight_scale = torch.nn.Parameter(data=torch.ones(
1, dtype=torch.float32),
requires_grad=False)
input_scale = torch.nn.Parameter(data=torch.ones(
1, dtype=torch.float32),
requires_grad=False)
layer.register_parameter("input_scale", input_scale)
layer.register_parameter("weight_scale", weight_scale)

layer.register_parameter("weight", weight)

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
Expand All @@ -95,12 +106,22 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
:param layer: The layer with the weights to be processed
"""
if (self.weight_quant and self.weight_quant.strategy
== QuantizationStrategy.TENSOR.value):
layer.weight_scale = torch.nn.Parameter(convert_to_channelwise(
weight_scale=layer.weight_scale,
logical_widths=layer.logical_widths),
requires_grad=False)
# torch.compile workaround
if hasattr(layer, "input_scale"):
layer.input_scale = torch.nn.Parameter(layer.input_scale.data,
requires_grad=False)

if self.weight_quant:
if self.weight_quant.strategy == QuantizationStrategy.TENSOR.value:
layer.weight_scale = torch.nn.Parameter(convert_to_channelwise(
weight_scale=layer.weight_scale,
logical_widths=layer.logical_widths),
requires_grad=False)
else:
# torch.compile workaround
layer.weight_scale = torch.nn.Parameter(
layer.weight_scale.data, requires_grad=False)

w_compressed, meta = ops.cutlass_compress_entry(layer.weight.data)
layer.weight = torch.nn.Parameter(w_compressed, requires_grad=False)
layer.meta = torch.nn.Parameter(meta, requires_grad=False)
Expand All @@ -120,21 +141,27 @@ def apply_weights(self,
:param bias: The bias to be added to the output tensor
:return: The output tensor of the layer
"""
scale = None
if hasattr(layer, "input_scale"):
scale = layer.input_scale
if self.quantized:
scale = None
if hasattr(layer, "input_scale"):
scale = layer.input_scale

if self.weights_dtype == torch.int8:
ops_output = ops.scaled_int8_quant(x, scale=scale)
q_input = ops_output[0]
input_scale = ops_output[1]
else:
assert self.weights_dtype == torch.float8_e4m3fn
if scale is not None:
q_input, input_scale = ops.scaled_fp8_quant(x, scale=scale)
else:
q_input, input_scale = ops.scaled_fp8_quant(
x, use_per_token_if_dynamic=True)

if self.weights_dtype == torch.int8:
ops_output = ops.scaled_int8_quant(x, scale=scale)
q_input = ops_output[0]
input_scale = ops_output[1]
else:
assert self.weights_dtype == torch.float8_e4m3fn
if scale is not None:
q_input, input_scale = ops.scaled_fp8_quant(x, scale=scale)
else:
q_input, input_scale = ops.scaled_fp8_quant(
x, use_per_token_if_dynamic=True)
# Not quantized, nothing to do with the input_scales, use as is
input_scale = layer.input_scale
q_input = x

out = ops.cutlass_scaled_sparse_mm(a=layer.weight,
e=layer.meta,
Expand All @@ -143,7 +170,6 @@ def apply_weights(self,
scale_b=input_scale,
out_dtype=self.output_dtype,
bias=bias)

assert out.is_contiguous()
return out

Expand Down

0 comments on commit 2c32ce0

Please sign in to comment.