Skip to content

Commit

Permalink
Adds: Fully shardable Sparse24BitMaskCompressor
Browse files Browse the repository at this point in the history
Adds: Sharding test
  • Loading branch information
rahul-tuli committed Jan 8, 2025
1 parent 00e8419 commit 8acd9b8
Show file tree
Hide file tree
Showing 8 changed files with 524 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@

from .base import *
from .dense import *
from .sparse_24_bitmask import *
from .sparse_bitmask import *
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, List, Tuple, Union

import torch
from compressed_tensors.compressors.base import BaseCompressor
from compressed_tensors.compressors.sparse_compressors.base import BaseSparseCompressor
from compressed_tensors.config import CompressionFormat, SparsityStructure
from compressed_tensors.quantization import FP8_DTYPE
from compressed_tensors.utils import merge_names, pack_into_bitmasks, unpack_bitmasks
from torch import Tensor


__all__ = [
"Sparse24BitMaskCompressor",
"Sparse24BitMaskTensor",
"sparse24_bitmask_compress",
"sparse24_bitmask_decompress",
"get_24_bytemasks",
]


@BaseCompressor.register(name=CompressionFormat.sparse_24_bitmask.value)
class Sparse24BitMaskCompressor(BaseSparseCompressor):
"""
Compression for sparse models using bitmasks. Non-zero weights are stored in a 2d
values tensor, with their locations stored in a 2d bitmask
"""

COMPRESSION_PARAM_NAMES = [
"shape",
"compressed",
"bitmask",
]

def compress_weight(self, name, value):
bitmask_tensor = Sparse24BitMaskTensor.from_dense(
value, self.config.sparsity_structure
)
bitmask_dict = bitmask_tensor.dict(name_prefix=name, device="cpu")
return bitmask_dict

def decompress_weight(self, weight_data):
data = Sparse24BitMaskTensor(**weight_data)
decompressed = data.decompress()
return decompressed


class Sparse24BitMaskTensor:
"""
Owns compressions and decompression for a single 2:4 sparse
bitmask compressed tensor.
:param shape: shape of dense tensor
:compressed: 2d tensor of non-zero values
:bitmask: 2d bitmask of non-zero values
"""

def __init__(
self,
shape: Union[torch.Size, List],
compressed: Tensor,
bitmask: Tensor,
):
self.shape = list(shape)
self.compressed = compressed
self.bitmask = bitmask

@staticmethod
def from_dense(
tensor: Tensor,
sparsity_structure: Union[SparsityStructure, str] = SparsityStructure.TWO_FOUR,
) -> "Sparse24BitMaskTensor":
"""
:param tensor: dense tensor to compress
:return: instantiated compressed tensor
"""
shape = tensor.shape
compressed, bitmask = sparse24_bitmask_compress(
tensor.cpu(), sparsity_structure=sparsity_structure
)
return Sparse24BitMaskTensor(
shape=shape,
compressed=compressed,
bitmask=bitmask,
)

def decompress(self) -> Tensor:
"""
:return: reconstructed dense tensor
"""
return sparse24_bitmask_decompress(self.compressed, self.bitmask, self.shape)

def curr_memory_size_bytes(self):
"""
:return: size in bytes required to store compressed tensor on disk
"""

def sizeof_tensor(a):
return a.element_size() * a.nelement()

return sizeof_tensor(self.compressed) + sizeof_tensor(self.bitmask)

def dict(self, name_prefix: str, device: str = "cpu") -> Dict[str, Tensor]:
"""
:name_prefix: name of original tensor to store compressed weight as
:return: dict of compressed data for the stored weight
"""
if name_prefix.endswith(".weight"):
name_prefix = name_prefix[: -len(".weight")]
return {
merge_names(name_prefix, "shape"): torch.tensor(
self.shape, device=device
).reshape(-1, 1),
merge_names(name_prefix, "compressed"): self.compressed.to(device),
merge_names(name_prefix, "bitmask"): self.bitmask.to(device),
}

def __repr__(self):
return f"BitMaskTensor(shape={self.shape}, compressed=True)"


def sparse24_bitmask_compress(
tensor: Tensor,
sparsity_structure: Union[SparsityStructure, str] = SparsityStructure.TWO_FOUR,
) -> Tuple[Tensor, Tensor, Tensor]:
"""
Compresses a dense tensor using bitmask compression
:param tensor: dense 2D tensor to compress
:param sparsity_structure: structure of sparsity in the tensor, defaults
to unstructured, can also be set to `2:4`
:return: tuple of compressed data representing tensor
"""
assert len(tensor.shape) == 2, "Only 2D tensors are supported"
assert (
SparsityStructure(sparsity_structure) == SparsityStructure.TWO_FOUR
), "Only 2:4 sparsity is supported"

bytemasks = get_24_bytemasks(tensor=tensor)

if tensor.dtype == FP8_DTYPE:
# acces raw bytes of the tensor
tensor_view = tensor.view(torch.int8)
values = tensor_view[bytemasks]
values = values.view(FP8_DTYPE)
else:
values = tensor[bytemasks]

num_rows, num_cols = tensor.shape
compressed_values = values.reshape(num_rows, num_cols // 2)
bitmasks_packed = pack_into_bitmasks(bytemasks)
return compressed_values, bitmasks_packed


def sparse24_bitmask_decompress(
values: Tensor, bitmasks: Tensor, original_shape: torch.Size
) -> Tensor:
"""
Reconstructs a dense tensor from a compressed one
:param values: 1d tensor of non-zero values
:param bitmasks: 2d int8 tensor flagging locations of non-zero values in the
tensors original shape
:param original_shape: shape of the dense tensor
:return: decompressed dense tensor
"""
bytemasks_unpacked = unpack_bitmasks(bitmasks, original_shape)

decompressed_tensor = torch.zeros(original_shape, dtype=values.dtype)
decompressed_tensor = decompressed_tensor.to(values.device)
values = values.flatten()
if decompressed_tensor.dtype == FP8_DTYPE:
decompressed_tensor[bytemasks_unpacked] = values
decompressed_tensor = decompressed_tensor.cuda()
else:
decompressed_tensor[bytemasks_unpacked] = values
return decompressed_tensor


def get_24_bytemasks(tensor):
"""
Generate a 2:4 sparsity mask for the given tensor.
This function creates a mask where exactly 2 out of every 4 elements are
preserved based on their magnitudes. The preserved elements are the ones
with the highest absolute values in each group of 4 elements.
:param tensor: The input tensor for which the 2:4 sparsity mask is to be created.
The tensor can be of any shape but its total number of elements
must be a multiple of 4.
:return: A boolean tensor of the same shape as the input tensor, where `True`
indicates the preserved elements and `False` indicates the pruned elements.
:raises ValueError: If the total number of elements in the tensor is not a
multiple of 4.
"""
original_dtype = tensor.dtype
if tensor.dtype == FP8_DTYPE:
tensor = tensor.view(torch.int8)
original_shape = tensor.shape
num_elements = tensor.numel()

if num_elements % 4 != 0:
raise ValueError("Tensor size must be a multiple of 4 for TWO_FOUR sparsity")

reshaped_tensor = tensor.view(-1, 4)
abs_tensor = reshaped_tensor.abs()
topk_indices = abs_tensor.topk(2, dim=1).indices
mask = torch.zeros_like(reshaped_tensor, dtype=torch.bool)
mask.scatter_(1, topk_indices, True)
mask = mask.view(original_shape)
tensor = tensor.view(original_dtype)

return mask
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,12 @@

from typing import Dict, List, Tuple, Union

import numpy
import torch
from compressed_tensors.compressors.base import BaseCompressor
from compressed_tensors.compressors.sparse_compressors.base import BaseSparseCompressor
from compressed_tensors.config import CompressionFormat
from compressed_tensors.quantization import FP8_DTYPE
from compressed_tensors.utils import merge_names
from compressed_tensors.utils import merge_names, pack_into_bitmasks, unpack_bitmasks
from torch import Tensor


Expand All @@ -29,8 +28,6 @@
"BitmaskTensor",
"bitmask_compress",
"bitmask_decompress",
"pack_bitmasks",
"unpack_bitmasks",
]


Expand Down Expand Up @@ -142,7 +139,7 @@ def bitmask_compress(tensor: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
values = values.view(FP8_DTYPE)
else:
values = tensor[bytemasks]
bitmasks_packed = pack_bitmasks(bytemasks)
bitmasks_packed = pack_into_bitmasks(bytemasks)
return values, bitmasks_packed, row_offsets


Expand All @@ -164,37 +161,3 @@ def bitmask_decompress(
decompressed_tensor[bytemasks_unpacked] = values

return decompressed_tensor


def pack_bitmasks(bytemasks: Tensor) -> Tensor:
"""
Converts a bytemask tensor to a bitmask tensor to reduce memory. Shape RxC will be
compressed to R x ceil(C/8)
:param bytemasks: mask tensor where each byte corresponds to a weight
:return: mask tensor where each bit corresounds to a weight
"""
packed_bits_numpy = numpy.packbits(bytemasks.numpy(), axis=-1, bitorder="little")
packed_bits_torch = torch.from_numpy(packed_bits_numpy)

return packed_bits_torch


def unpack_bitmasks(packed_bitmasks: Tensor, original_shape: torch.Size) -> Tensor:
"""
Converts a bitmask tensor back to a bytemask tensor for use during decompression
:param packed_bitmasks: mask tensor where each bit corresponds to a weight
:param original_shape: dense shape to decompress to
:return: boolean mask of weights in the original dense shape
"""
# Unpack the bits
unpacked_bits = numpy.unpackbits(
packed_bitmasks.numpy(), axis=-1, count=original_shape[-1], bitorder="little"
)

# Reshape to match the original shape
unpacked_bitmasks_torch = torch.from_numpy(
unpacked_bits.reshape(original_shape).astype(bool)
)

return unpacked_bitmasks_torch
1 change: 1 addition & 0 deletions src/compressed_tensors/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@
# flake8: noqa
from .base import *
from .dense import *
from .sparse_24_bitmask import *
from .sparse_bitmask import *
1 change: 1 addition & 0 deletions src/compressed_tensors/config/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
class CompressionFormat(Enum):
dense = "dense"
sparse_bitmask = "sparse-bitmask"
sparse_24_bitmask = "sparse-24-bitmask"
int_quantized = "int-quantized"
float_quantized = "float-quantized"
naive_quantized = "naive-quantized"
Expand Down
40 changes: 40 additions & 0 deletions src/compressed_tensors/config/sparse_24_bitmask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional

from compressed_tensors.config import (
CompressionFormat,
SparsityCompressionConfig,
SparsityStructure,
)


__all__ = ["Sparse24BitMaskConfig"]


@SparsityCompressionConfig.register(name=CompressionFormat.sparse_24_bitmask.value)
class Sparse24BitMaskConfig(SparsityCompressionConfig):
"""
Configuration for storing a 24 sparse model using
bytemask compression
:param global_sparsity: average sparsity of the entire model
:param sparsity_structure: structure of the sparsity, should always be
"2:4" for this compression format
"""

format: str = CompressionFormat.sparse_24_bitmask.value
global_sparsity: Optional[float] = 0.0
sparsity_structure: Optional[str] = SparsityStructure.TWO_FOUR.value
Loading

0 comments on commit 8acd9b8

Please sign in to comment.