From 8acd9b84302e6d1e0555fbd5cd441a5f7a3f4871 Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Fri, 20 Dec 2024 15:56:38 +0000 Subject: [PATCH] Adds: Fully shardable Sparse24BitMaskCompressor Adds: Sharding test --- .../sparse_compressors/__init__.py | 1 + .../sparse_compressors/sparse_24_bitmask.py | 226 ++++++++++++++++++ .../sparse_compressors/sparse_bitmask.py | 41 +--- src/compressed_tensors/config/__init__.py | 1 + src/compressed_tensors/config/base.py | 1 + .../config/sparse_24_bitmask.py | 40 ++++ src/compressed_tensors/utils/helpers.py | 112 ++++++++- .../test_sparse_24_bitmask.py | 142 +++++++++++ 8 files changed, 524 insertions(+), 40 deletions(-) create mode 100644 src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py create mode 100644 src/compressed_tensors/config/sparse_24_bitmask.py create mode 100644 tests/test_compressors/sparse_compressors/test_sparse_24_bitmask.py diff --git a/src/compressed_tensors/compressors/sparse_compressors/__init__.py b/src/compressed_tensors/compressors/sparse_compressors/__init__.py index de4fd887..871079ac 100644 --- a/src/compressed_tensors/compressors/sparse_compressors/__init__.py +++ b/src/compressed_tensors/compressors/sparse_compressors/__init__.py @@ -15,4 +15,5 @@ from .base import * from .dense import * +from .sparse_24_bitmask import * from .sparse_bitmask import * diff --git a/src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py b/src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py new file mode 100644 index 00000000..54ea7200 --- /dev/null +++ b/src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py @@ -0,0 +1,226 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, List, Tuple, Union + +import torch +from compressed_tensors.compressors.base import BaseCompressor +from compressed_tensors.compressors.sparse_compressors.base import BaseSparseCompressor +from compressed_tensors.config import CompressionFormat, SparsityStructure +from compressed_tensors.quantization import FP8_DTYPE +from compressed_tensors.utils import merge_names, pack_into_bitmasks, unpack_bitmasks +from torch import Tensor + + +__all__ = [ + "Sparse24BitMaskCompressor", + "Sparse24BitMaskTensor", + "sparse24_bitmask_compress", + "sparse24_bitmask_decompress", + "get_24_bytemasks", +] + + +@BaseCompressor.register(name=CompressionFormat.sparse_24_bitmask.value) +class Sparse24BitMaskCompressor(BaseSparseCompressor): + """ + Compression for sparse models using bitmasks. Non-zero weights are stored in a 2d + values tensor, with their locations stored in a 2d bitmask + """ + + COMPRESSION_PARAM_NAMES = [ + "shape", + "compressed", + "bitmask", + ] + + def compress_weight(self, name, value): + bitmask_tensor = Sparse24BitMaskTensor.from_dense( + value, self.config.sparsity_structure + ) + bitmask_dict = bitmask_tensor.dict(name_prefix=name, device="cpu") + return bitmask_dict + + def decompress_weight(self, weight_data): + data = Sparse24BitMaskTensor(**weight_data) + decompressed = data.decompress() + return decompressed + + +class Sparse24BitMaskTensor: + """ + Owns compressions and decompression for a single 2:4 sparse + bitmask compressed tensor. + + :param shape: shape of dense tensor + :compressed: 2d tensor of non-zero values + :bitmask: 2d bitmask of non-zero values + """ + + def __init__( + self, + shape: Union[torch.Size, List], + compressed: Tensor, + bitmask: Tensor, + ): + self.shape = list(shape) + self.compressed = compressed + self.bitmask = bitmask + + @staticmethod + def from_dense( + tensor: Tensor, + sparsity_structure: Union[SparsityStructure, str] = SparsityStructure.TWO_FOUR, + ) -> "Sparse24BitMaskTensor": + """ + :param tensor: dense tensor to compress + :return: instantiated compressed tensor + """ + shape = tensor.shape + compressed, bitmask = sparse24_bitmask_compress( + tensor.cpu(), sparsity_structure=sparsity_structure + ) + return Sparse24BitMaskTensor( + shape=shape, + compressed=compressed, + bitmask=bitmask, + ) + + def decompress(self) -> Tensor: + """ + :return: reconstructed dense tensor + """ + return sparse24_bitmask_decompress(self.compressed, self.bitmask, self.shape) + + def curr_memory_size_bytes(self): + """ + :return: size in bytes required to store compressed tensor on disk + """ + + def sizeof_tensor(a): + return a.element_size() * a.nelement() + + return sizeof_tensor(self.compressed) + sizeof_tensor(self.bitmask) + + def dict(self, name_prefix: str, device: str = "cpu") -> Dict[str, Tensor]: + """ + :name_prefix: name of original tensor to store compressed weight as + :return: dict of compressed data for the stored weight + """ + if name_prefix.endswith(".weight"): + name_prefix = name_prefix[: -len(".weight")] + return { + merge_names(name_prefix, "shape"): torch.tensor( + self.shape, device=device + ).reshape(-1, 1), + merge_names(name_prefix, "compressed"): self.compressed.to(device), + merge_names(name_prefix, "bitmask"): self.bitmask.to(device), + } + + def __repr__(self): + return f"BitMaskTensor(shape={self.shape}, compressed=True)" + + +def sparse24_bitmask_compress( + tensor: Tensor, + sparsity_structure: Union[SparsityStructure, str] = SparsityStructure.TWO_FOUR, +) -> Tuple[Tensor, Tensor, Tensor]: + """ + Compresses a dense tensor using bitmask compression + + :param tensor: dense 2D tensor to compress + :param sparsity_structure: structure of sparsity in the tensor, defaults + to unstructured, can also be set to `2:4` + :return: tuple of compressed data representing tensor + """ + assert len(tensor.shape) == 2, "Only 2D tensors are supported" + assert ( + SparsityStructure(sparsity_structure) == SparsityStructure.TWO_FOUR + ), "Only 2:4 sparsity is supported" + + bytemasks = get_24_bytemasks(tensor=tensor) + + if tensor.dtype == FP8_DTYPE: + # acces raw bytes of the tensor + tensor_view = tensor.view(torch.int8) + values = tensor_view[bytemasks] + values = values.view(FP8_DTYPE) + else: + values = tensor[bytemasks] + + num_rows, num_cols = tensor.shape + compressed_values = values.reshape(num_rows, num_cols // 2) + bitmasks_packed = pack_into_bitmasks(bytemasks) + return compressed_values, bitmasks_packed + + +def sparse24_bitmask_decompress( + values: Tensor, bitmasks: Tensor, original_shape: torch.Size +) -> Tensor: + """ + Reconstructs a dense tensor from a compressed one + + :param values: 1d tensor of non-zero values + :param bitmasks: 2d int8 tensor flagging locations of non-zero values in the + tensors original shape + :param original_shape: shape of the dense tensor + :return: decompressed dense tensor + """ + bytemasks_unpacked = unpack_bitmasks(bitmasks, original_shape) + + decompressed_tensor = torch.zeros(original_shape, dtype=values.dtype) + decompressed_tensor = decompressed_tensor.to(values.device) + values = values.flatten() + if decompressed_tensor.dtype == FP8_DTYPE: + decompressed_tensor[bytemasks_unpacked] = values + decompressed_tensor = decompressed_tensor.cuda() + else: + decompressed_tensor[bytemasks_unpacked] = values + return decompressed_tensor + + +def get_24_bytemasks(tensor): + """ + Generate a 2:4 sparsity mask for the given tensor. + + This function creates a mask where exactly 2 out of every 4 elements are + preserved based on their magnitudes. The preserved elements are the ones + with the highest absolute values in each group of 4 elements. + + :param tensor: The input tensor for which the 2:4 sparsity mask is to be created. + The tensor can be of any shape but its total number of elements + must be a multiple of 4. + :return: A boolean tensor of the same shape as the input tensor, where `True` + indicates the preserved elements and `False` indicates the pruned elements. + :raises ValueError: If the total number of elements in the tensor is not a + multiple of 4. + """ + original_dtype = tensor.dtype + if tensor.dtype == FP8_DTYPE: + tensor = tensor.view(torch.int8) + original_shape = tensor.shape + num_elements = tensor.numel() + + if num_elements % 4 != 0: + raise ValueError("Tensor size must be a multiple of 4 for TWO_FOUR sparsity") + + reshaped_tensor = tensor.view(-1, 4) + abs_tensor = reshaped_tensor.abs() + topk_indices = abs_tensor.topk(2, dim=1).indices + mask = torch.zeros_like(reshaped_tensor, dtype=torch.bool) + mask.scatter_(1, topk_indices, True) + mask = mask.view(original_shape) + tensor = tensor.view(original_dtype) + + return mask diff --git a/src/compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py b/src/compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py index 0434499d..9c2e10ae 100644 --- a/src/compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py +++ b/src/compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py @@ -14,13 +14,12 @@ from typing import Dict, List, Tuple, Union -import numpy import torch from compressed_tensors.compressors.base import BaseCompressor from compressed_tensors.compressors.sparse_compressors.base import BaseSparseCompressor from compressed_tensors.config import CompressionFormat from compressed_tensors.quantization import FP8_DTYPE -from compressed_tensors.utils import merge_names +from compressed_tensors.utils import merge_names, pack_into_bitmasks, unpack_bitmasks from torch import Tensor @@ -29,8 +28,6 @@ "BitmaskTensor", "bitmask_compress", "bitmask_decompress", - "pack_bitmasks", - "unpack_bitmasks", ] @@ -142,7 +139,7 @@ def bitmask_compress(tensor: Tensor) -> Tuple[Tensor, Tensor, Tensor]: values = values.view(FP8_DTYPE) else: values = tensor[bytemasks] - bitmasks_packed = pack_bitmasks(bytemasks) + bitmasks_packed = pack_into_bitmasks(bytemasks) return values, bitmasks_packed, row_offsets @@ -164,37 +161,3 @@ def bitmask_decompress( decompressed_tensor[bytemasks_unpacked] = values return decompressed_tensor - - -def pack_bitmasks(bytemasks: Tensor) -> Tensor: - """ - Converts a bytemask tensor to a bitmask tensor to reduce memory. Shape RxC will be - compressed to R x ceil(C/8) - :param bytemasks: mask tensor where each byte corresponds to a weight - :return: mask tensor where each bit corresounds to a weight - """ - packed_bits_numpy = numpy.packbits(bytemasks.numpy(), axis=-1, bitorder="little") - packed_bits_torch = torch.from_numpy(packed_bits_numpy) - - return packed_bits_torch - - -def unpack_bitmasks(packed_bitmasks: Tensor, original_shape: torch.Size) -> Tensor: - """ - Converts a bitmask tensor back to a bytemask tensor for use during decompression - - :param packed_bitmasks: mask tensor where each bit corresponds to a weight - :param original_shape: dense shape to decompress to - :return: boolean mask of weights in the original dense shape - """ - # Unpack the bits - unpacked_bits = numpy.unpackbits( - packed_bitmasks.numpy(), axis=-1, count=original_shape[-1], bitorder="little" - ) - - # Reshape to match the original shape - unpacked_bitmasks_torch = torch.from_numpy( - unpacked_bits.reshape(original_shape).astype(bool) - ) - - return unpacked_bitmasks_torch diff --git a/src/compressed_tensors/config/__init__.py b/src/compressed_tensors/config/__init__.py index ff83f5af..582b8a9e 100644 --- a/src/compressed_tensors/config/__init__.py +++ b/src/compressed_tensors/config/__init__.py @@ -15,4 +15,5 @@ # flake8: noqa from .base import * from .dense import * +from .sparse_24_bitmask import * from .sparse_bitmask import * diff --git a/src/compressed_tensors/config/base.py b/src/compressed_tensors/config/base.py index 79a4fcdd..9ca6f2cf 100644 --- a/src/compressed_tensors/config/base.py +++ b/src/compressed_tensors/config/base.py @@ -26,6 +26,7 @@ class CompressionFormat(Enum): dense = "dense" sparse_bitmask = "sparse-bitmask" + sparse_24_bitmask = "sparse-24-bitmask" int_quantized = "int-quantized" float_quantized = "float-quantized" naive_quantized = "naive-quantized" diff --git a/src/compressed_tensors/config/sparse_24_bitmask.py b/src/compressed_tensors/config/sparse_24_bitmask.py new file mode 100644 index 00000000..7aae2dbe --- /dev/null +++ b/src/compressed_tensors/config/sparse_24_bitmask.py @@ -0,0 +1,40 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +from compressed_tensors.config import ( + CompressionFormat, + SparsityCompressionConfig, + SparsityStructure, +) + + +__all__ = ["Sparse24BitMaskConfig"] + + +@SparsityCompressionConfig.register(name=CompressionFormat.sparse_24_bitmask.value) +class Sparse24BitMaskConfig(SparsityCompressionConfig): + """ + Configuration for storing a 24 sparse model using + bytemask compression + + :param global_sparsity: average sparsity of the entire model + :param sparsity_structure: structure of the sparsity, should always be + "2:4" for this compression format + """ + + format: str = CompressionFormat.sparse_24_bitmask.value + global_sparsity: Optional[float] = 0.0 + sparsity_structure: Optional[str] = SparsityStructure.TWO_FOUR.value diff --git a/src/compressed_tensors/utils/helpers.py b/src/compressed_tensors/utils/helpers.py index 910436eb..39066e8b 100644 --- a/src/compressed_tensors/utils/helpers.py +++ b/src/compressed_tensors/utils/helpers.py @@ -14,8 +14,9 @@ import warnings from functools import wraps -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Dict, List, Optional +import numpy import torch from transformers import AutoConfig @@ -29,6 +30,10 @@ "getattr_chain", "deprecated", "Aliasable", + "combine_shards", + "shard_tensor", + "pack_into_bitmasks", + "unpack_bitmasks", ] FSDP_WRAPPER_NAME = "_fsdp_wrapped_module" @@ -214,3 +219,108 @@ def __eq__(self, other): def __hash__(self): canonical_value = self.aliases.get(self.value, self.value) return hash(canonical_value) + + +def shard_tensor( + tensor: torch.Tensor, shard_sizes: List[int], dim: int = 0 +) -> List[torch.Tensor]: + """ + Shards a tensor into a list of tensors along a given dimension. + + raises: ValueError: If the sum of shard_sizes does not match the + size of the tensor along the given dimension. + + :param tensor: The input tensor to shard. + :param shard_sizes : List of sizes for each shard along the specified dimension. + :param dim : The dimension along which to shard the tensor. + :returns: A list of tensors sharded along the specified dimension. + """ + if sum(shard_sizes) != tensor.size(dim): + raise ValueError( + "Sum of shard_sizes must equal the size of the tensor " + "along the specified dimension." + ) + + shards = [] + start_idx = 0 + + for size in shard_sizes: + end_idx = start_idx + size + shard = tensor.narrow(dim, start_idx, size) + shards.append(shard) + start_idx = end_idx + + return shards + + +def combine_shards(shards, dim=0): + """ + Combine decompressed shards along a given dimension without using torch.cat + for unsupported dtypes like float8_e4m3fn. + + :param shards: List of decompressed shard tensors. + :param dim: Dimension to combine along (default: 0). + :return: Combined decompressed tensor. + """ + try: + # Attempt regular concatenation + return torch.cat(shards, dim=dim) + except RuntimeError as e: + # Handle unsupported concatenation + if all(shard.dtype == torch.float8_e4m3fn for shard in shards): + total_shape = list(shards[0].shape) + total_shape[dim] = sum(shard.shape[dim] for shard in shards) + combined = torch.zeros( + total_shape, dtype=shards[0].dtype, device=shards[0].device + ) + + shard_offset = 0 + for shard in shards: + shard_size = shard.shape[dim] + combined.narrow(dim, shard_offset, shard_size).copy_(shard) + shard_offset += shard_size + + return combined + else: + # Re-raise unexpected errors + raise e + + +def pack_into_bitmasks(bytemasks: torch.Tensor) -> torch.Tensor: + """ + Converts a bytemask tensor to a bitmask tensor to reduce memory. Shape RxC will be + compressed to R x ceil(C/8) + + :param bytemasks: mask tensor where each byte corresponds to a weight + :return: mask tensor where each bit corresounds to a weight + """ + packed_bits_numpy = numpy.packbits(bytemasks.numpy(), axis=-1, bitorder="little") + packed_bits_torch = torch.from_numpy(packed_bits_numpy) + + return packed_bits_torch + + +def unpack_bitmasks( + packed_bitmasks: torch.Tensor, original_shape: torch.Size +) -> torch.Tensor: + """ + Converts a bitmask tensor back to a bytemask tensor for use during decompression + + :param packed_bitmasks: mask tensor where each bit corresponds to a weight + :param original_shape: dense shape to decompress to + :return: boolean mask of weights in the original dense shape + """ + # Unpack the bits + unpacked_bits = numpy.unpackbits( + packed_bitmasks.cpu().numpy(), + axis=-1, + count=original_shape[-1], + bitorder="little", + ) + + # Reshape to match the original shape + unpacked_bitmasks_torch = torch.from_numpy( + unpacked_bits.reshape(original_shape).astype(bool) + ) + + return unpacked_bitmasks_torch diff --git a/tests/test_compressors/sparse_compressors/test_sparse_24_bitmask.py b/tests/test_compressors/sparse_compressors/test_sparse_24_bitmask.py new file mode 100644 index 00000000..4e1b5c47 --- /dev/null +++ b/tests/test_compressors/sparse_compressors/test_sparse_24_bitmask.py @@ -0,0 +1,142 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest +import torch +from compressed_tensors import Sparse24BitMaskTensor +from compressed_tensors.quantization import FP8_DTYPE +from compressed_tensors.utils import combine_shards, shard_tensor +from tests.testing_utils import generate_pruned_semi_structured_mat + + +@pytest.fixture +def dense_matrix_fixture(): + def _generate_dense_matrix(M, K, dtype): + return generate_pruned_semi_structured_mat(M, K, dtype) + + return _generate_dense_matrix + + +@pytest.fixture +def shard_validation(): + def _validate_shard_shapes(sharded_values, sharded_bitmask, expected_shapes): + for shard_values, shard_bitmask, expected_shape in zip( + sharded_values, sharded_bitmask, expected_shapes + ): + assert ( + shard_values.shape == expected_shape["compressed"] + ), f"Shape mismatch: {shard_values.shape} != {expected_shape['compressed']}" + assert ( + shard_bitmask.shape == expected_shape["bitmask"] + ), f"Shape mismatch: {shard_bitmask.shape} != {expected_shape['bitmask']}" + + return _validate_shard_shapes + + +@pytest.mark.parametrize("dtype", [FP8_DTYPE]) +def test_bitmask_compress_decompress_fp8(dense_matrix_fixture, dtype): + M, K = 1024, 1024 + dense_matrix = dense_matrix_fixture(M, K, dtype) + + bitmask_tensor = Sparse24BitMaskTensor.from_dense( + dense_matrix, sparsity_structure="2:4" + ) + decompressed_tensor = bitmask_tensor.decompress() + + dense_matrix = dense_matrix.to(decompressed_tensor.device) + + assert dense_matrix.dtype == decompressed_tensor.dtype, "Dtype mismatch" + assert dense_matrix.shape == decompressed_tensor.shape, "Shape mismatch" + assert torch.equal(dense_matrix, decompressed_tensor), "Decompression failed" + + +@pytest.mark.parametrize( + "dtype, M, K, shard_sizes, shard_dim, expected_shapes", + [ + ( + FP8_DTYPE, + 2560, + 2048, + [2048, 256, 256], + 0, + [ + {"compressed": (2048, 1024), "bitmask": (2048, 2048 // 8)}, + {"compressed": (256, 1024), "bitmask": (256, 2048 // 8)}, + {"compressed": (256, 1024), "bitmask": (256, 2048 // 8)}, + ], + ), + ( + FP8_DTYPE, + 2048, + 2048, + [1024, 1024], + 1, + [ + {"compressed": (2048, 512), "bitmask": (2048, 2048 // 8 // 2)}, + {"compressed": (2048, 512), "bitmask": (2048, 2048 // 8 // 2)}, + ], + ), + ], +) +def test_bitmask_compress_decompress_sharded( + dense_matrix_fixture, + shard_validation, + dtype, + M, + K, + shard_sizes, + shard_dim, + expected_shapes, +): + dense_matrix = dense_matrix_fixture(M, K, dtype) + + bitmask_tensor = Sparse24BitMaskTensor.from_dense(dense_matrix) + compressed_values = bitmask_tensor.compressed + compressed_bitmask = bitmask_tensor.bitmask + + if shard_dim == 1: + compressed_shard_sizes = [size // 2 for size in shard_sizes] + bitmask_shard_sizes = [size // 8 for size in shard_sizes] + else: + compressed_shard_sizes = shard_sizes + bitmask_shard_sizes = shard_sizes + + sharded_compressed_values = shard_tensor( + compressed_values, compressed_shard_sizes, dim=shard_dim + ) + sharded_compressed_bitmask = shard_tensor( + compressed_bitmask, bitmask_shard_sizes, dim=shard_dim + ) + + shard_validation( + sharded_compressed_values, sharded_compressed_bitmask, expected_shapes + ) + + decompressed_shards = [ + Sparse24BitMaskTensor( + shape=(expected_shape["bitmask"][0], expected_shape["bitmask"][1] * 8), + compressed=shard_values, + bitmask=shard_bitmask, + ).decompress() + for shard_values, shard_bitmask, expected_shape in zip( + sharded_compressed_values, sharded_compressed_bitmask, expected_shapes + ) + ] + + decompressed_combined = combine_shards(decompressed_shards, dim=shard_dim) + + assert dense_matrix.dtype == decompressed_combined.dtype, "Dtype mismatch" + assert dense_matrix.shape == decompressed_combined.shape, "Shape mismatch" + assert torch.equal(dense_matrix, decompressed_combined), "Decompression failed"