Skip to content

Commit

Permalink
Address more comments
Browse files Browse the repository at this point in the history
Add Table in docstring
Add test for compressor inference
  • Loading branch information
rahul-tuli committed Jan 21, 2025
1 parent dc38c30 commit 2c51f2d
Show file tree
Hide file tree
Showing 3 changed files with 193 additions and 6 deletions.
9 changes: 5 additions & 4 deletions examples/sparse_2of4_quantization_fp8/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ oneshot(
)
```

3. **Save the Compressed Model**
### Saving the Compressed Model

The compressed model and tokenizer are saved to the output directory:

Expand All @@ -106,15 +106,16 @@ Output Directories:
- Without FP8: `Meta-Llama-3-8B-Instruct-2of4-sparse`
- With FP8: `Meta-Llama-3-8B-Instruct-2of4-W8A8-FP8-Dynamic-Per-Token`

Save Model on disk without sparse_compression:
#### Saving Without Sparse Compression

To save the model on disk without sparse compression:

```python
model.save_pretrained(save_dir, save_compressed=True, no_sparse_compression=True)
tokenizer.save_pretrained(save_dir)
```

Note: This only affects how the model is saved on disk, and not the actual
pruning/quantization run.
> **Note:** This will compress the model using the quantization compressor; however, instead of using the optimal sparsity compressor, the dense sparsity compressor will be used. This affects only how the model is saved on disk and does not change the actual pruning/quantization process.
### Validation

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,27 @@ def infer_quantization_format(
sparsity_structure: Optional[str] = None,
) -> str:
"""
Infers a quantization format based on model state and compression args
Infers the quantization format for a model based on its state and provided
compression arguments.
The following table outlines the possible quantization and sparsity formats
along with their corresponding compressor formats:
+---------------+----------+----------------------+---------------------+
| Quantization | Sparsity | Quant Compressor | Sparsity Compressor |
| | | Format | Format |
+---------------+----------+----------------------+---------------------+
| W8A8 - int | None | int_quantized | Dense |
| W8A8 - float | None | float_quantized | Dense |
| W4A16 - int | None | pack_quantized | Dense |
| W8A16 - int | None | pack_quantized | Dense |
| W8A16 - float | None | naive_quantized | Dense |
| W8A8 - int | 2:4 | int_quantized | Sparse24 |
| W8A8 - float | 2:4 | float_quantized | Sparse24 |
| W4A16 - int | 2:4 | marlin_24 | Dense |
| W8A16 - int | 2:4 | marlin_24 | Dense |
| W8A16 - float | 2:4 | naive_quantized | Dense |
+---------------+----------+----------------------+---------------------+
:param model: model to check for quantization, if the model is not quantized no
quantization format is returned
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,13 @@
from compressed_tensors import QUANTIZATION_CONFIG_NAME, CompressionFormat
from compressed_tensors.compressors import ModelCompressor
from compressed_tensors.config import BitmaskConfig, DenseSparsityConfig
from compressed_tensors.quantization import QuantizationStatus
from compressed_tensors.quantization import (
QuantizationConfig,
QuantizationStatus,
quantize,
)
from compressed_tensors.utils import get_offloaded_device, update_prefix_dict
from torch import nn
from transformers import AutoConfig, AutoModelForCausalLM
from transformers.utils.quantization_config import CompressedTensorsConfig

Expand All @@ -21,6 +26,7 @@
SparsityConfigMetadata,
)
from llmcompressor.transformers.sparsification.compressed_tensors_utils import (
get_model_compressor,
modify_save_pretrained,
patch_tied_tensors_bug,
)
Expand Down Expand Up @@ -535,3 +541,163 @@ def test_no_sparse_compression_flag(tmp_path):
assert sparsity_config
assert sparsity_config["format"] == "dense"
shutil.rmtree(tmp_path)


class DummyLinearModel(nn.Module):
"""
A dummy linear model for testing purposes, simulating a quantized linear layer.
"""

def __init__(self, weights, weight_scale=None, weight_zero_point=None):
super().__init__()
out_features, in_features = weights.shape

# Linear layer without bias
self.linear = nn.Linear(in_features, out_features, bias=False)
self.linear.weight = nn.Parameter(weights, requires_grad=False)

# Attach scale and zero-point if provided
if weight_scale is not None:
self.linear.weight_scale = nn.Parameter(
torch.tensor(weight_scale), requires_grad=False
)
if weight_zero_point is not None:
self.linear.weight_zero_point = nn.Parameter(
torch.tensor(weight_zero_point), requires_grad=False
)

def forward(self, x):
return self.linear(x)


def _create_quantization_config(
w_bits=8,
w_type="int",
w_strategy="tensor",
quantize_activations=False,
a_bits=8,
a_type="int",
a_strategy="tensor",
):
"""
Create a quantization configuration for testing.
"""
config_dict = {
"global_compression_ratio": 1.0,
"quant_method": "compressed-tensors",
"config_groups": {
"group_0": {
"targets": ["Linear"],
"weights": {
"num_bits": w_bits,
"strategy": w_strategy,
"symmetric": True,
"type": w_type,
},
}
},
}

if quantize_activations:
config_dict["config_groups"]["group_0"]["input_activations"] = {
"num_bits": a_bits,
"strategy": a_strategy,
"symmetric": True,
"type": a_type,
}

return QuantizationConfig.model_validate(config_dict)


def _quantization_config_from_string(config_str, q_type):
"""
Parse quantization config from string and type.
"""
w_bits = int(config_str[1])
a_bits = int(config_str[3:])
quantize_activations = a_bits < 16

return _create_quantization_config(
w_bits=w_bits,
w_type=q_type,
w_strategy="channel",
quantize_activations=quantize_activations,
a_bits=a_bits,
a_type=q_type,
a_strategy="channel",
)


def _make_24_sparse(tensor):
"""
Apply 2:4 sparsity pattern to the given tensor.
"""
reshaped_tensor = tensor.view(tensor.size(0), -1, 4)
mask = torch.zeros_like(reshaped_tensor, dtype=torch.bool)
mask[..., :2] = True
sparsified_tensor = torch.where(
mask, reshaped_tensor, torch.tensor(0.0, dtype=tensor.dtype)
)
return sparsified_tensor.view_as(tensor)


@pytest.mark.parametrize(
"quant_style, quant_type, is_24, expected_quant_compressor, "
"expected_sparsity_compressor",
[
("W8A8", "int", False, "int-quantized", "dense"),
("W4A16", "int", False, "pack-quantized", "dense"),
("W8A16", "int", False, "pack-quantized", "dense"),
("W8A8", "int", True, "int-quantized", "sparse-24-bitmask"),
("W4A16", "int", True, "marlin-24", "dense"),
("W8A16", "int", True, "marlin-24", "dense"),
("W8A8", "float", False, "float-quantized", "dense"),
("W8A16", "float", False, "naive-quantized", "dense"),
("W8A8", "float", True, "float-quantized", "sparse-24-bitmask"),
("W8A16", "float", True, "naive-quantized", "dense"),
],
)
def test_correct_compressor_inferred(
quant_style,
quant_type,
is_24,
expected_quant_compressor,
expected_sparsity_compressor,
):
"""
Test if the correct compressor is inferred based on
quantization and sparsity configurations.
"""
weights = torch.rand(10, 4)
if is_24:
weights = _make_24_sparse(weights)

quantization_config = _quantization_config_from_string(quant_style, quant_type)
quantization_args = quantization_config.config_groups["group_0"].weights

scale = (
torch.ones((weights.shape[0], 1))
if quantization_args.strategy == "channel"
else torch.tensor([1.0])
)
zero_point = torch.zeros_like(scale)

quantized_weights = quantize(
weights, scale=scale, zero_point=zero_point, args=quantization_args
)

model = DummyLinearModel(quantized_weights, scale, zero_point)
model.linear.quantization_scheme = quantization_config.config_groups["group_0"]
model.linear.quantization_status = QuantizationStatus.FROZEN

compressor = get_model_compressor(model)

assert compressor.quantization_config.format == expected_quant_compressor

if expected_sparsity_compressor == "dense":
assert (
compressor.sparsity_config is None
or compressor.sparsity_config.format == expected_sparsity_compressor
)
else:
assert compressor.sparsity_config.format == expected_sparsity_compressor

0 comments on commit 2c51f2d

Please sign in to comment.