From 0af8349de2df27b7dca737e37e98a09af3023e4b Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Wed, 8 Jan 2025 20:33:47 +0000 Subject: [PATCH] Convert to a dataclass --- .../sparse_compressors/sparse_24_bitmask.py | 44 ++++++++++++------- 1 file changed, 27 insertions(+), 17 deletions(-) diff --git a/src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py b/src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py index 54ea7200..0eec9dc9 100644 --- a/src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py +++ b/src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from dataclasses import dataclass from typing import Dict, List, Tuple, Union import torch @@ -53,30 +54,25 @@ def compress_weight(self, name, value): return bitmask_dict def decompress_weight(self, weight_data): - data = Sparse24BitMaskTensor(**weight_data) + data = Sparse24BitMaskTensor.from_compressed_data(**weight_data) decompressed = data.decompress() return decompressed +@dataclass class Sparse24BitMaskTensor: """ Owns compressions and decompression for a single 2:4 sparse bitmask compressed tensor. :param shape: shape of dense tensor - :compressed: 2d tensor of non-zero values - :bitmask: 2d bitmask of non-zero values + :param compressed: 2d tensor of non-zero values + :param bitmask: 2d bitmask of non-zero values """ - def __init__( - self, - shape: Union[torch.Size, List], - compressed: Tensor, - bitmask: Tensor, - ): - self.shape = list(shape) - self.compressed = compressed - self.bitmask = bitmask + shape: List[int] + compressed: Tensor + bitmask: Tensor @staticmethod def from_dense( @@ -87,7 +83,7 @@ def from_dense( :param tensor: dense tensor to compress :return: instantiated compressed tensor """ - shape = tensor.shape + shape = list(tensor.shape) compressed, bitmask = sparse24_bitmask_compress( tensor.cpu(), sparsity_structure=sparsity_structure ) @@ -96,6 +92,20 @@ def from_dense( compressed=compressed, bitmask=bitmask, ) + + @staticmethod + def from_compressed_data( + shape: Union[List[int], Tensor], compressed: Tensor, bitmask: Tensor + ) -> "Sparse24BitMaskTensor": + """ + :param shape: shape of the dense tensor (can be a list or a tensor) + :param compressed: 2d tensor of non-zero values + :param bitmask: 2d bitmask of non-zero values + :return: instantiated Sparse24BitMaskTensor + """ + if isinstance(shape, Tensor): + shape = shape.tolist() + return Sparse24BitMaskTensor(shape=shape, compressed=compressed, bitmask=bitmask) def decompress(self) -> Tensor: """ @@ -103,19 +113,19 @@ def decompress(self) -> Tensor: """ return sparse24_bitmask_decompress(self.compressed, self.bitmask, self.shape) - def curr_memory_size_bytes(self): + def curr_memory_size_bytes(self) -> int: """ :return: size in bytes required to store compressed tensor on disk """ - def sizeof_tensor(a): + def sizeof_tensor(a: Tensor) -> int: return a.element_size() * a.nelement() return sizeof_tensor(self.compressed) + sizeof_tensor(self.bitmask) def dict(self, name_prefix: str, device: str = "cpu") -> Dict[str, Tensor]: """ - :name_prefix: name of original tensor to store compressed weight as + :param name_prefix: name of original tensor to store compressed weight as :return: dict of compressed data for the stored weight """ if name_prefix.endswith(".weight"): @@ -128,7 +138,7 @@ def dict(self, name_prefix: str, device: str = "cpu") -> Dict[str, Tensor]: merge_names(name_prefix, "bitmask"): self.bitmask.to(device), } - def __repr__(self): + def __repr__(self) -> str: return f"BitMaskTensor(shape={self.shape}, compressed=True)"