diff --git a/src/compressed_tensors/compressors/sparse_compressors/__init__.py b/src/compressed_tensors/compressors/sparse_compressors/__init__.py index de4fd887..f1b59ad3 100644 --- a/src/compressed_tensors/compressors/sparse_compressors/__init__.py +++ b/src/compressed_tensors/compressors/sparse_compressors/__init__.py @@ -15,4 +15,5 @@ from .base import * from .dense import * +from .sparse_24 import * from .sparse_bitmask import * diff --git a/src/compressed_tensors/compressors/sparse_compressors/sparse_24.py b/src/compressed_tensors/compressors/sparse_compressors/sparse_24.py new file mode 100644 index 00000000..08ebefab --- /dev/null +++ b/src/compressed_tensors/compressors/sparse_compressors/sparse_24.py @@ -0,0 +1,86 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Dict + +from compressed_tensors.compressors import Compressor +from compressed_tensors.config import CompressionFormat +from compressed_tensors.utils import ( + merge_names, + sparse_semi_structured_from_dense_cutlass, + sparse_semi_structured_to_dense_cutlass, + tensor_follows_mask_structure, +) +from torch import Tensor + + +@Compressor.register(name=CompressionFormat.sparse_24.value) +class Sparse24Compressor(Compressor): + """ + Compresses a with 2:4 sparsity structure for inference + with sparse 2:4 kernels for float/float16/bfloat16. + + https://github.com/pytorch/pytorch/blob/78cf8df4a019e919e8eac5f5d048d8842d4fc692/torch/sparse/semi_structured.py + """ + + COMPRESSION_PARAM_NAMES = ["weight_packed", "meta"] + + @staticmethod + def validate_sparsity_structure(name: str, weight: Tensor) -> bool: + """ + Checks if a tensor fits the required 2:4 sparsity structure + :param name: name of the tensor to check + :param weight: tensor to check for sparsity structure + :return: True if all rows match the 2:4 sparsity structure, raises + ValueError otherwise + """ + + if not tensor_follows_mask_structure(weight, mask="2:4"): + raise ValueError( + "Sparse24Compressor is only compatible with weights that have " + f"a 2:4 sparsity structure. Found segments in {name} " + "that do not match the expected structure." + ) + + return True + + def compress_weight(self, name: str, value: Tensor) -> Dict[str, Tensor]: + """ + Compresses a given with 2:4 sparsity structure. + + :param name: name of the tensor in state dict of uncompressed model + :param value: 2:4 sparse tensor to compress + :return: dictionary containing the compressed weight and associated + metadata + """ + weight_suffix = ".weight" + if not name.endswith(weight_suffix): + return {} + + prefix = name[: -len(weight_suffix)] + self.validate_sparsity_structure(name=prefix, weight=value) + weight_packed, meta = sparse_semi_structured_from_dense_cutlass(dense=value) + return { + merge_names(prefix, "weight_packed"): weight_packed.cpu(), + merge_names(prefix, "meta"): meta.cpu(), + } + + def decompress_weight(self, weight_data): + assert "weight_packed" in weight_data, "weight_packed not found in weight_data" + assert "meta" in weight_data, "meta not found in weight_data" + + return sparse_semi_structured_to_dense_cutlass( + sparse=weight_data["weight_packed"], meta_reordered=weight_data["meta"] + ) diff --git a/src/compressed_tensors/config/__init__.py b/src/compressed_tensors/config/__init__.py index ff83f5af..f021f284 100644 --- a/src/compressed_tensors/config/__init__.py +++ b/src/compressed_tensors/config/__init__.py @@ -15,4 +15,5 @@ # flake8: noqa from .base import * from .dense import * +from .sparse_24 import * from .sparse_bitmask import * diff --git a/src/compressed_tensors/config/base.py b/src/compressed_tensors/config/base.py index ccc3e649..65129687 100644 --- a/src/compressed_tensors/config/base.py +++ b/src/compressed_tensors/config/base.py @@ -25,6 +25,7 @@ class CompressionFormat(Enum): dense = "dense" sparse_bitmask = "sparse-bitmask" + sparse_24 = "sparse-24" int_quantized = "int-quantized" float_quantized = "float-quantized" naive_quantized = "naive-quantized" diff --git a/src/compressed_tensors/config/sparse_24.py b/src/compressed_tensors/config/sparse_24.py new file mode 100644 index 00000000..a0c477c7 --- /dev/null +++ b/src/compressed_tensors/config/sparse_24.py @@ -0,0 +1,33 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +from compressed_tensors.config import CompressionFormat, SparsityCompressionConfig + + +__all__ = ["Sparse24Config"] + + +@SparsityCompressionConfig.register(name=CompressionFormat.sparse_24.value) +class Sparse24Config(SparsityCompressionConfig): + """ + Configuration for storing a sparse model using 2:4 compression + :param global_sparsity: average sparsity of the entire model + :param sparsity_structure: structure of the sparsity, "2:4" + """ + + format: str = CompressionFormat.sparse_24.value + global_sparsity: Optional[float] = 0.0 + sparsity_structure: Optional[str] = "2:4" diff --git a/src/compressed_tensors/utils/semi_structured_conversions.py b/src/compressed_tensors/utils/semi_structured_conversions.py index ef318a48..b5a89c51 100644 --- a/src/compressed_tensors/utils/semi_structured_conversions.py +++ b/src/compressed_tensors/utils/semi_structured_conversions.py @@ -75,6 +75,7 @@ def _calculate_meta_reordering_scatter_offsets(m, meta_ncols, meta_dtype, device # This function converts dense matrix into sparse semi-structured # representation, producing "compressed" matrix, in the layout used by # CUTLASS backend, and corresponding metadata matrix. +# Copied from https://github.com/pytorch/pytorch/blob/78cf8df4a019e919e8eac5f5d048d8842d4fc692/torch/sparse/_semi_structured_conversions.py#L47 def sparse_semi_structured_from_dense_cutlass(dense): if dense.dim() != 2: raise RuntimeError( @@ -213,6 +214,7 @@ def sparse_semi_structured_from_dense_cutlass(dense): # reconstructs dense matrix from a pair of "compressed" matrix, given # in the layout used by CUTLASS backend, and accompanying metadata # matrix. +# Copied from https://github.com/pytorch/pytorch/blob/78cf8df4a019e919e8eac5f5d048d8842d4fc692/torch/sparse/_semi_structured_conversions.py#L180 def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered): if sparse.dim() != 2: raise RuntimeError(