-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
305904c
commit 3a6ccc8
Showing
7 changed files
with
213 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,4 +15,5 @@ | |
|
||
from .base import * | ||
from .dense import * | ||
from .sparse_24 import * | ||
from .sparse_bitmask import * |
92 changes: 92 additions & 0 deletions
92
src/compressed_tensors/compressors/sparse_compressors/sparse_24.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
from typing import Dict | ||
|
||
from compressed_tensors.compressors.base import BaseCompressor | ||
from compressed_tensors.compressors.sparse_compressors.base import BaseSparseCompressor | ||
from compressed_tensors.config import CompressionFormat, SparsityStructure | ||
from compressed_tensors.utils import ( | ||
merge_names, | ||
sparse_semi_structured_from_dense_cutlass, | ||
sparse_semi_structured_to_dense_cutlass, | ||
tensor_follows_mask_structure, | ||
) | ||
from torch import Tensor | ||
|
||
|
||
@BaseCompressor.register(name=CompressionFormat.sparse_24.value) | ||
class Sparse24Compressor(BaseSparseCompressor): | ||
""" | ||
Compresses a with 2:4 sparsity structure for inference | ||
with sparse 2:4 kernels for float/float16/bfloat16. | ||
https://github.com/pytorch/pytorch/blob/78cf8df4a019e919e8eac5f5d048d8842d4fc692/torch/sparse/semi_structured.py | ||
""" | ||
|
||
COMPRESSION_PARAM_NAMES = ["sparse_24_packed_weight", "meta"] | ||
|
||
@staticmethod | ||
def validate_sparsity_structure(name: str, weight: Tensor) -> bool: | ||
""" | ||
Checks if a tensor fits the required 2:4 sparsity structure | ||
:param name: name of the tensor to check | ||
:param weight: tensor to check for sparsity structure | ||
:return: True if all rows match the 2:4 sparsity structure, raises | ||
ValueError otherwise | ||
""" | ||
|
||
if not tensor_follows_mask_structure( | ||
weight, mask=SparsityStructure.TWO_FOUR.value | ||
): | ||
raise ValueError( | ||
"Sparse24Compressor is only compatible with weights that have " | ||
f"a 2:4 sparsity structure. Found segments in {name} " | ||
"that do not match the expected structure." | ||
) | ||
|
||
return True | ||
|
||
def compress_weight(self, name: str, value: Tensor) -> Dict[str, Tensor]: | ||
""" | ||
Compresses a given with 2:4 sparsity structure. | ||
:param name: name of the tensor in state dict of uncompressed model | ||
:param value: 2:4 sparse tensor to compress | ||
:return: dictionary containing the compressed weight and associated | ||
metadata | ||
""" | ||
weight_suffix = ".weight" | ||
if not name.endswith(weight_suffix): | ||
return {} | ||
|
||
prefix = name[: -len(weight_suffix)] | ||
self.validate_sparsity_structure(name=prefix, weight=value) | ||
sparse_24_packed_weight, meta = sparse_semi_structured_from_dense_cutlass( | ||
dense=value | ||
) | ||
return { | ||
merge_names(name, "sparse_24_packed_weight"): sparse_24_packed_weight.cpu(), | ||
merge_names(name, "meta"): meta.cpu(), | ||
} | ||
|
||
def decompress_weight(self, weight_data): | ||
assert ( | ||
"sparse_24_packed_weight" in weight_data | ||
), "sparse_24_packed_weight not found in weight_data" | ||
assert "meta" in weight_data, "meta not found in weight_data" | ||
|
||
return sparse_semi_structured_to_dense_cutlass( | ||
sparse=weight_data["sparse_24_packed_weight"], | ||
meta_reordered=weight_data["meta"], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,4 +15,5 @@ | |
# flake8: noqa | ||
from .base import * | ||
from .dense import * | ||
from .sparse_24 import * | ||
from .sparse_bitmask import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from typing import Optional | ||
|
||
from compressed_tensors.config import ( | ||
CompressionFormat, | ||
SparsityCompressionConfig, | ||
SparsityStructure, | ||
) | ||
|
||
|
||
__all__ = ["Sparse24Config"] | ||
|
||
|
||
@SparsityCompressionConfig.register(name=CompressionFormat.sparse_24.value) | ||
class Sparse24Config(SparsityCompressionConfig): | ||
""" | ||
Configuration for storing a sparse model using 2:4 compression | ||
:param global_sparsity: average sparsity of the entire model | ||
:param sparsity_structure: structure of the sparsity, "2:4" | ||
""" | ||
|
||
format: str = CompressionFormat.sparse_24.value | ||
global_sparsity: Optional[float] = 0.0 | ||
sparsity_structure: Optional[str] = SparsityStructure.TWO_FOUR.value |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import pytest | ||
import torch | ||
from compressed_tensors.utils.semi_structured_conversions import ( | ||
sparse_semi_structured_from_dense_cutlass, | ||
sparse_semi_structured_to_dense_cutlass, | ||
) | ||
|
||
|
||
def supported_dtypes(): | ||
return [torch.int8, torch.float16, torch.bfloat16, torch.float8_e4m3fn] | ||
|
||
|
||
def get_random_mat(M, K, dtype): | ||
rand_tensor_dtype = dtype | ||
if dtype in [torch.int8, torch.float8_e4m3fn]: | ||
rand_tensor_dtype = torch.float16 | ||
mat = torch.rand(M, K, dtype=rand_tensor_dtype).cuda() | ||
mat = mat.masked_fill_(mat == 0, 1) | ||
return mat.to(dtype) | ||
|
||
|
||
def generate_pruned_semi_structured_mat(M, K, dtype): | ||
mask = torch.Tensor([0, 0, 1, 1]).tile((M, K // 4)).bool() | ||
rand_tensor_dtype = dtype | ||
if dtype in [torch.int8, torch.float8_e4m3fn]: | ||
rand_tensor_dtype = torch.float16 | ||
mat = torch.rand(M, K, dtype=rand_tensor_dtype) | ||
mat = mat.masked_fill_(mat == 0, 1) | ||
if dtype == torch.float8_e4m3fn: | ||
# some float8_e4m3fn operations are not supported on CPU | ||
mat = mat.cuda() | ||
mask = mask.cuda() | ||
mat = mat * mask | ||
return mat.to(dtype) | ||
|
||
|
||
@pytest.mark.parametrize("dtype", supported_dtypes()) | ||
def test_inverse_property_from_dense_then_to_dense(dtype): | ||
M, K = 1024, 1024 | ||
dense_matrix = generate_pruned_semi_structured_mat(M, K, dtype) | ||
compressed_matrix, meta = sparse_semi_structured_from_dense_cutlass(dense_matrix) | ||
result = sparse_semi_structured_to_dense_cutlass(compressed_matrix, meta) | ||
|
||
assert ( | ||
dense_matrix.dtype == result.dtype | ||
), f"Dtype Mis-match: {dense_matrix.dtype} and {result.dtype}" | ||
assert ( | ||
dense_matrix.shape == result.shape | ||
), f"Shape Mis-match: {dense_matrix.shape} and {result.shape}" | ||
assert torch.equal( | ||
dense_matrix, result | ||
), f"Failed for dtype: {dense_matrix.dtype} and input: {dense_matrix}" |