Skip to content

Commit

Permalink
Add: Sparse24_compressor + tests
Browse files Browse the repository at this point in the history
  • Loading branch information
rahul-tuli committed Nov 27, 2024
1 parent 305904c commit 3a6ccc8
Show file tree
Hide file tree
Showing 7 changed files with 213 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@

from .base import *
from .dense import *
from .sparse_24 import *
from .sparse_bitmask import *
92 changes: 92 additions & 0 deletions src/compressed_tensors/compressors/sparse_compressors/sparse_24.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from typing import Dict

from compressed_tensors.compressors.base import BaseCompressor
from compressed_tensors.compressors.sparse_compressors.base import BaseSparseCompressor
from compressed_tensors.config import CompressionFormat, SparsityStructure
from compressed_tensors.utils import (
merge_names,
sparse_semi_structured_from_dense_cutlass,
sparse_semi_structured_to_dense_cutlass,
tensor_follows_mask_structure,
)
from torch import Tensor


@BaseCompressor.register(name=CompressionFormat.sparse_24.value)
class Sparse24Compressor(BaseSparseCompressor):
"""
Compresses a with 2:4 sparsity structure for inference
with sparse 2:4 kernels for float/float16/bfloat16.
https://github.com/pytorch/pytorch/blob/78cf8df4a019e919e8eac5f5d048d8842d4fc692/torch/sparse/semi_structured.py
"""

COMPRESSION_PARAM_NAMES = ["sparse_24_packed_weight", "meta"]

@staticmethod
def validate_sparsity_structure(name: str, weight: Tensor) -> bool:
"""
Checks if a tensor fits the required 2:4 sparsity structure
:param name: name of the tensor to check
:param weight: tensor to check for sparsity structure
:return: True if all rows match the 2:4 sparsity structure, raises
ValueError otherwise
"""

if not tensor_follows_mask_structure(
weight, mask=SparsityStructure.TWO_FOUR.value
):
raise ValueError(
"Sparse24Compressor is only compatible with weights that have "
f"a 2:4 sparsity structure. Found segments in {name} "
"that do not match the expected structure."
)

return True

def compress_weight(self, name: str, value: Tensor) -> Dict[str, Tensor]:
"""
Compresses a given with 2:4 sparsity structure.
:param name: name of the tensor in state dict of uncompressed model
:param value: 2:4 sparse tensor to compress
:return: dictionary containing the compressed weight and associated
metadata
"""
weight_suffix = ".weight"
if not name.endswith(weight_suffix):
return {}

prefix = name[: -len(weight_suffix)]
self.validate_sparsity_structure(name=prefix, weight=value)
sparse_24_packed_weight, meta = sparse_semi_structured_from_dense_cutlass(
dense=value
)
return {
merge_names(name, "sparse_24_packed_weight"): sparse_24_packed_weight.cpu(),
merge_names(name, "meta"): meta.cpu(),
}

def decompress_weight(self, weight_data):
assert (
"sparse_24_packed_weight" in weight_data
), "sparse_24_packed_weight not found in weight_data"
assert "meta" in weight_data, "meta not found in weight_data"

return sparse_semi_structured_to_dense_cutlass(
sparse=weight_data["sparse_24_packed_weight"],
meta_reordered=weight_data["meta"],
)
1 change: 1 addition & 0 deletions src/compressed_tensors/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@
# flake8: noqa
from .base import *
from .dense import *
from .sparse_24 import *
from .sparse_bitmask import *
1 change: 1 addition & 0 deletions src/compressed_tensors/config/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
class CompressionFormat(Enum):
dense = "dense"
sparse_bitmask = "sparse-bitmask"
sparse_24 = "sparse-24"
int_quantized = "int-quantized"
float_quantized = "float-quantized"
naive_quantized = "naive-quantized"
Expand Down
37 changes: 37 additions & 0 deletions src/compressed_tensors/config/sparse_24.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional

from compressed_tensors.config import (
CompressionFormat,
SparsityCompressionConfig,
SparsityStructure,
)


__all__ = ["Sparse24Config"]


@SparsityCompressionConfig.register(name=CompressionFormat.sparse_24.value)
class Sparse24Config(SparsityCompressionConfig):
"""
Configuration for storing a sparse model using 2:4 compression
:param global_sparsity: average sparsity of the entire model
:param sparsity_structure: structure of the sparsity, "2:4"
"""

format: str = CompressionFormat.sparse_24.value
global_sparsity: Optional[float] = 0.0
sparsity_structure: Optional[str] = SparsityStructure.TWO_FOUR.value
19 changes: 15 additions & 4 deletions src/compressed_tensors/utils/semi_structured_conversions.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def _calculate_meta_reordering_scatter_offsets(m, meta_ncols, meta_dtype, device
# This function converts dense matrix into sparse semi-structured
# representation, producing "compressed" matrix, in the layout used by
# CUTLASS backend, and corresponding metadata matrix.
# Modified from https://github.com/pytorch/pytorch/blob/78cf8df4a019e919e8eac5f5d048d8842d4fc692/torch/sparse/_semi_structured_conversions.py#L47
def sparse_semi_structured_from_dense_cutlass(dense):
if dense.dim() != 2:
raise RuntimeError(
Expand All @@ -85,7 +86,7 @@ def sparse_semi_structured_from_dense_cutlass(dense):
device = dense.device

meta_dtype = torch.int8
if dense.dtype == torch.int8:
if dense.dtype == torch.int8 or dense.dtype == torch.float8_e4m3fn:
meta_dtype = torch.int32
elif dense.dtype in [torch.half, torch.bfloat16, torch.float, torch.int32]:
meta_dtype = torch.int16
Expand Down Expand Up @@ -165,11 +166,15 @@ def sparse_semi_structured_from_dense_cutlass(dense):
idxs1 = bit2 | (bit3.to(torch.int64) << 1)

if dense.dtype != torch.float:
if dense.dtype == torch.float8_e4m3fn:
dense_4 = dense_4.view(torch.int8)
sparse0 = dense_4.gather(
-1, idxs0.unsqueeze(-1)
) # type: ignore[possibly-undefined]
sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1))
sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2)
if dense.dtype == torch.float8_e4m3fn:
sparse = sparse.view(torch.float8_e4m3fn)
else:
sparse = dense_2.gather(-1, idxs0.unsqueeze(-1) // 2).view(
m, k // 2
Expand Down Expand Up @@ -213,6 +218,7 @@ def sparse_semi_structured_from_dense_cutlass(dense):
# reconstructs dense matrix from a pair of "compressed" matrix, given
# in the layout used by CUTLASS backend, and accompanying metadata
# matrix.
# Copied from https://github.com/pytorch/pytorch/blob/78cf8df4a019e919e8eac5f5d048d8842d4fc692/torch/sparse/_semi_structured_conversions.py#L180
def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered):
if sparse.dim() != 2:
raise RuntimeError(
Expand Down Expand Up @@ -298,16 +304,21 @@ def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered):
torch.arange(0, 2 * m * k // ksparse, device=device) * 4
).view(-1, 1).repeat(1, 2).view(-1)

dense = torch.zeros((m * 2 * k,), dtype=sparse.dtype, device=device)
sparse_dtype = sparse.dtype if sparse.dtype != torch.float8_e4m3fn else torch.int8
dense = torch.zeros((m * 2 * k,), dtype=sparse_dtype, device=device)
if sparse.dtype != torch.float:
# dense.scatter_(0, dense_offsets, sparse.view(-1))
dense.scatter_(0, dense_offsets, sparse.reshape(-1))
if sparse.dtype == torch.float8_e4m3fn:
dense.scatter_(0, dense_offsets, sparse.view(torch.int8).view(-1))
else:
dense.scatter_(0, dense_offsets, sparse.reshape(-1))
else:
dense.view(torch.half).scatter_(
0, dense_offsets, sparse.view(torch.half).view(-1)
)

return dense.view(m, 2 * k)
result = dense.view(m, 2 * k)
return result.view(sparse.dtype)


def mask_creator(tensor):
Expand Down
66 changes: 66 additions & 0 deletions tests/test_utils/test_semi_structured_conversions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
import torch
from compressed_tensors.utils.semi_structured_conversions import (
sparse_semi_structured_from_dense_cutlass,
sparse_semi_structured_to_dense_cutlass,
)


def supported_dtypes():
return [torch.int8, torch.float16, torch.bfloat16, torch.float8_e4m3fn]


def get_random_mat(M, K, dtype):
rand_tensor_dtype = dtype
if dtype in [torch.int8, torch.float8_e4m3fn]:
rand_tensor_dtype = torch.float16
mat = torch.rand(M, K, dtype=rand_tensor_dtype).cuda()
mat = mat.masked_fill_(mat == 0, 1)
return mat.to(dtype)


def generate_pruned_semi_structured_mat(M, K, dtype):
mask = torch.Tensor([0, 0, 1, 1]).tile((M, K // 4)).bool()
rand_tensor_dtype = dtype
if dtype in [torch.int8, torch.float8_e4m3fn]:
rand_tensor_dtype = torch.float16
mat = torch.rand(M, K, dtype=rand_tensor_dtype)
mat = mat.masked_fill_(mat == 0, 1)
if dtype == torch.float8_e4m3fn:
# some float8_e4m3fn operations are not supported on CPU
mat = mat.cuda()
mask = mask.cuda()
mat = mat * mask
return mat.to(dtype)


@pytest.mark.parametrize("dtype", supported_dtypes())
def test_inverse_property_from_dense_then_to_dense(dtype):
M, K = 1024, 1024
dense_matrix = generate_pruned_semi_structured_mat(M, K, dtype)
compressed_matrix, meta = sparse_semi_structured_from_dense_cutlass(dense_matrix)
result = sparse_semi_structured_to_dense_cutlass(compressed_matrix, meta)

assert (
dense_matrix.dtype == result.dtype
), f"Dtype Mis-match: {dense_matrix.dtype} and {result.dtype}"
assert (
dense_matrix.shape == result.shape
), f"Shape Mis-match: {dense_matrix.shape} and {result.shape}"
assert torch.equal(
dense_matrix, result
), f"Failed for dtype: {dense_matrix.dtype} and input: {dense_matrix}"

0 comments on commit 3a6ccc8

Please sign in to comment.