Skip to content

Commit

Permalink
Add: Sparse24Compressor
Browse files Browse the repository at this point in the history
  • Loading branch information
rahul-tuli committed Oct 2, 2024
1 parent fc4b23c commit 68ca6c3
Show file tree
Hide file tree
Showing 6 changed files with 124 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@

from .base import *
from .dense import *
from .sparse_24 import *
from .sparse_bitmask import *
86 changes: 86 additions & 0 deletions src/compressed_tensors/compressors/sparse_compressors/sparse_24.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from typing import Dict

from compressed_tensors.compressors import Compressor
from compressed_tensors.config import CompressionFormat
from compressed_tensors.utils import (
merge_names,
sparse_semi_structured_from_dense_cutlass,
sparse_semi_structured_to_dense_cutlass,
tensor_follows_mask_structure,
)
from torch import Tensor


@Compressor.register(name=CompressionFormat.sparse_24.value)
class Sparse24Compressor(Compressor):
"""
Compresses a with 2:4 sparsity structure for inference
with sparse 2:4 kernels for float/float16/bfloat16.
https://github.com/pytorch/pytorch/blob/78cf8df4a019e919e8eac5f5d048d8842d4fc692/torch/sparse/semi_structured.py
"""

COMPRESSION_PARAM_NAMES = ["weight_packed", "meta"]

@staticmethod
def validate_sparsity_structure(name: str, weight: Tensor) -> bool:
"""
Checks if a tensor fits the required 2:4 sparsity structure
:param name: name of the tensor to check
:param weight: tensor to check for sparsity structure
:return: True if all rows match the 2:4 sparsity structure, raises
ValueError otherwise
"""

if not tensor_follows_mask_structure(weight, mask="2:4"):
raise ValueError(
"Sparse24Compressor is only compatible with weights that have "
f"a 2:4 sparsity structure. Found segments in {name} "
"that do not match the expected structure."
)

return True

def compress_weight(self, name: str, value: Tensor) -> Dict[str, Tensor]:
"""
Compresses a given with 2:4 sparsity structure.
:param name: name of the tensor in state dict of uncompressed model
:param value: 2:4 sparse tensor to compress
:return: dictionary containing the compressed weight and associated
metadata
"""
weight_suffix = ".weight"
if not name.endswith(weight_suffix):
return {}

prefix = name[: -len(weight_suffix)]
self.validate_sparsity_structure(name=prefix, weight=value)
weight_packed, meta = sparse_semi_structured_from_dense_cutlass(dense=value)
return {
merge_names(prefix, "weight_packed"): weight_packed.cpu(),
merge_names(prefix, "meta"): meta.cpu(),
}

def decompress_weight(self, weight_data):
assert "weight_packed" in weight_data, "weight_packed not found in weight_data"
assert "meta" in weight_data, "meta not found in weight_data"

return sparse_semi_structured_to_dense_cutlass(
sparse=weight_data["weight_packed"], meta_reordered=weight_data["meta"]
)
1 change: 1 addition & 0 deletions src/compressed_tensors/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@
# flake8: noqa
from .base import *
from .dense import *
from .sparse_24 import *
from .sparse_bitmask import *
1 change: 1 addition & 0 deletions src/compressed_tensors/config/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
class CompressionFormat(Enum):
dense = "dense"
sparse_bitmask = "sparse-bitmask"
sparse_24 = "sparse-24"
int_quantized = "int-quantized"
float_quantized = "float-quantized"
naive_quantized = "naive-quantized"
Expand Down
33 changes: 33 additions & 0 deletions src/compressed_tensors/config/sparse_24.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional

from compressed_tensors.config import CompressionFormat, SparsityCompressionConfig


__all__ = ["Sparse24Config"]


@SparsityCompressionConfig.register(name=CompressionFormat.sparse_24.value)
class Sparse24Config(SparsityCompressionConfig):
"""
Configuration for storing a sparse model using 2:4 compression
:param global_sparsity: average sparsity of the entire model
:param sparsity_structure: structure of the sparsity, "2:4"
"""

format: str = CompressionFormat.sparse_24.value
global_sparsity: Optional[float] = 0.0
sparsity_structure: Optional[str] = "2:4"
2 changes: 2 additions & 0 deletions src/compressed_tensors/utils/semi_structured_conversions.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def _calculate_meta_reordering_scatter_offsets(m, meta_ncols, meta_dtype, device
# This function converts dense matrix into sparse semi-structured
# representation, producing "compressed" matrix, in the layout used by
# CUTLASS backend, and corresponding metadata matrix.
# Copied from https://github.com/pytorch/pytorch/blob/78cf8df4a019e919e8eac5f5d048d8842d4fc692/torch/sparse/_semi_structured_conversions.py#L47
def sparse_semi_structured_from_dense_cutlass(dense):
if dense.dim() != 2:
raise RuntimeError(
Expand Down Expand Up @@ -213,6 +214,7 @@ def sparse_semi_structured_from_dense_cutlass(dense):
# reconstructs dense matrix from a pair of "compressed" matrix, given
# in the layout used by CUTLASS backend, and accompanying metadata
# matrix.
# Copied from https://github.com/pytorch/pytorch/blob/78cf8df4a019e919e8eac5f5d048d8842d4fc692/torch/sparse/_semi_structured_conversions.py#L180
def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered):
if sparse.dim() != 2:
raise RuntimeError(
Expand Down

0 comments on commit 68ca6c3

Please sign in to comment.