-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
fc4b23c
commit 68ca6c3
Showing
6 changed files
with
124 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,4 +15,5 @@ | |
|
||
from .base import * | ||
from .dense import * | ||
from .sparse_24 import * | ||
from .sparse_bitmask import * |
86 changes: 86 additions & 0 deletions
86
src/compressed_tensors/compressors/sparse_compressors/sparse_24.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
from typing import Dict | ||
|
||
from compressed_tensors.compressors import Compressor | ||
from compressed_tensors.config import CompressionFormat | ||
from compressed_tensors.utils import ( | ||
merge_names, | ||
sparse_semi_structured_from_dense_cutlass, | ||
sparse_semi_structured_to_dense_cutlass, | ||
tensor_follows_mask_structure, | ||
) | ||
from torch import Tensor | ||
|
||
|
||
@Compressor.register(name=CompressionFormat.sparse_24.value) | ||
class Sparse24Compressor(Compressor): | ||
""" | ||
Compresses a with 2:4 sparsity structure for inference | ||
with sparse 2:4 kernels for float/float16/bfloat16. | ||
https://github.com/pytorch/pytorch/blob/78cf8df4a019e919e8eac5f5d048d8842d4fc692/torch/sparse/semi_structured.py | ||
""" | ||
|
||
COMPRESSION_PARAM_NAMES = ["weight_packed", "meta"] | ||
|
||
@staticmethod | ||
def validate_sparsity_structure(name: str, weight: Tensor) -> bool: | ||
""" | ||
Checks if a tensor fits the required 2:4 sparsity structure | ||
:param name: name of the tensor to check | ||
:param weight: tensor to check for sparsity structure | ||
:return: True if all rows match the 2:4 sparsity structure, raises | ||
ValueError otherwise | ||
""" | ||
|
||
if not tensor_follows_mask_structure(weight, mask="2:4"): | ||
raise ValueError( | ||
"Sparse24Compressor is only compatible with weights that have " | ||
f"a 2:4 sparsity structure. Found segments in {name} " | ||
"that do not match the expected structure." | ||
) | ||
|
||
return True | ||
|
||
def compress_weight(self, name: str, value: Tensor) -> Dict[str, Tensor]: | ||
""" | ||
Compresses a given with 2:4 sparsity structure. | ||
:param name: name of the tensor in state dict of uncompressed model | ||
:param value: 2:4 sparse tensor to compress | ||
:return: dictionary containing the compressed weight and associated | ||
metadata | ||
""" | ||
weight_suffix = ".weight" | ||
if not name.endswith(weight_suffix): | ||
return {} | ||
|
||
prefix = name[: -len(weight_suffix)] | ||
self.validate_sparsity_structure(name=prefix, weight=value) | ||
weight_packed, meta = sparse_semi_structured_from_dense_cutlass(dense=value) | ||
return { | ||
merge_names(prefix, "weight_packed"): weight_packed.cpu(), | ||
merge_names(prefix, "meta"): meta.cpu(), | ||
} | ||
|
||
def decompress_weight(self, weight_data): | ||
assert "weight_packed" in weight_data, "weight_packed not found in weight_data" | ||
assert "meta" in weight_data, "meta not found in weight_data" | ||
|
||
return sparse_semi_structured_to_dense_cutlass( | ||
sparse=weight_data["weight_packed"], meta_reordered=weight_data["meta"] | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,4 +15,5 @@ | |
# flake8: noqa | ||
from .base import * | ||
from .dense import * | ||
from .sparse_24 import * | ||
from .sparse_bitmask import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from typing import Optional | ||
|
||
from compressed_tensors.config import CompressionFormat, SparsityCompressionConfig | ||
|
||
|
||
__all__ = ["Sparse24Config"] | ||
|
||
|
||
@SparsityCompressionConfig.register(name=CompressionFormat.sparse_24.value) | ||
class Sparse24Config(SparsityCompressionConfig): | ||
""" | ||
Configuration for storing a sparse model using 2:4 compression | ||
:param global_sparsity: average sparsity of the entire model | ||
:param sparsity_structure: structure of the sparsity, "2:4" | ||
""" | ||
|
||
format: str = CompressionFormat.sparse_24.value | ||
global_sparsity: Optional[float] = 0.0 | ||
sparsity_structure: Optional[str] = "2:4" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters