Skip to content

Commit

Permalink
Convert to a dataclass
Browse files Browse the repository at this point in the history
  • Loading branch information
rahul-tuli committed Jan 8, 2025
1 parent e987872 commit 0af8349
Showing 1 changed file with 27 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass
from typing import Dict, List, Tuple, Union

import torch
Expand Down Expand Up @@ -53,30 +54,25 @@ def compress_weight(self, name, value):
return bitmask_dict

def decompress_weight(self, weight_data):
data = Sparse24BitMaskTensor(**weight_data)
data = Sparse24BitMaskTensor.from_compressed_data(**weight_data)
decompressed = data.decompress()
return decompressed


@dataclass
class Sparse24BitMaskTensor:
"""
Owns compressions and decompression for a single 2:4 sparse
bitmask compressed tensor.
:param shape: shape of dense tensor
:compressed: 2d tensor of non-zero values
:bitmask: 2d bitmask of non-zero values
:param compressed: 2d tensor of non-zero values
:param bitmask: 2d bitmask of non-zero values
"""

def __init__(
self,
shape: Union[torch.Size, List],
compressed: Tensor,
bitmask: Tensor,
):
self.shape = list(shape)
self.compressed = compressed
self.bitmask = bitmask
shape: List[int]
compressed: Tensor
bitmask: Tensor

@staticmethod
def from_dense(
Expand All @@ -87,7 +83,7 @@ def from_dense(
:param tensor: dense tensor to compress
:return: instantiated compressed tensor
"""
shape = tensor.shape
shape = list(tensor.shape)
compressed, bitmask = sparse24_bitmask_compress(
tensor.cpu(), sparsity_structure=sparsity_structure
)
Expand All @@ -96,26 +92,40 @@ def from_dense(
compressed=compressed,
bitmask=bitmask,
)

@staticmethod
def from_compressed_data(
shape: Union[List[int], Tensor], compressed: Tensor, bitmask: Tensor
) -> "Sparse24BitMaskTensor":
"""
:param shape: shape of the dense tensor (can be a list or a tensor)
:param compressed: 2d tensor of non-zero values
:param bitmask: 2d bitmask of non-zero values
:return: instantiated Sparse24BitMaskTensor
"""
if isinstance(shape, Tensor):
shape = shape.tolist()
return Sparse24BitMaskTensor(shape=shape, compressed=compressed, bitmask=bitmask)

def decompress(self) -> Tensor:
"""
:return: reconstructed dense tensor
"""
return sparse24_bitmask_decompress(self.compressed, self.bitmask, self.shape)

def curr_memory_size_bytes(self):
def curr_memory_size_bytes(self) -> int:
"""
:return: size in bytes required to store compressed tensor on disk
"""

def sizeof_tensor(a):
def sizeof_tensor(a: Tensor) -> int:
return a.element_size() * a.nelement()

return sizeof_tensor(self.compressed) + sizeof_tensor(self.bitmask)

def dict(self, name_prefix: str, device: str = "cpu") -> Dict[str, Tensor]:
"""
:name_prefix: name of original tensor to store compressed weight as
:param name_prefix: name of original tensor to store compressed weight as
:return: dict of compressed data for the stored weight
"""
if name_prefix.endswith(".weight"):
Expand All @@ -128,7 +138,7 @@ def dict(self, name_prefix: str, device: str = "cpu") -> Dict[str, Tensor]:
merge_names(name_prefix, "bitmask"): self.bitmask.to(device),
}

def __repr__(self):
def __repr__(self) -> str:
return f"BitMaskTensor(shape={self.shape}, compressed=True)"


Expand Down

0 comments on commit 0af8349

Please sign in to comment.