diff --git a/src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py b/src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py index 63896846..e51433c2 100644 --- a/src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py +++ b/src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py @@ -20,7 +20,7 @@ from compressed_tensors.compressors.sparse_compressors.base import BaseSparseCompressor from compressed_tensors.config import CompressionFormat, SparsityStructure from compressed_tensors.quantization import FP8_DTYPE -from compressed_tensors.utils import merge_names, pack_into_bitmasks, unpack_bitmasks +from compressed_tensors.utils import merge_names, pack_bitmasks, unpack_bitmasks from torch import Tensor @@ -173,7 +173,7 @@ def sparse24_bitmask_compress( num_rows, num_cols = tensor.shape compressed_values = values.reshape(num_rows, num_cols // 2) - bitmasks_packed = pack_into_bitmasks(bytemasks) + bitmasks_packed = pack_bitmasks(bytemasks) return compressed_values, bitmasks_packed diff --git a/src/compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py b/src/compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py index 9c2e10ae..7c2023cf 100644 --- a/src/compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py +++ b/src/compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py @@ -19,7 +19,7 @@ from compressed_tensors.compressors.sparse_compressors.base import BaseSparseCompressor from compressed_tensors.config import CompressionFormat from compressed_tensors.quantization import FP8_DTYPE -from compressed_tensors.utils import merge_names, pack_into_bitmasks, unpack_bitmasks +from compressed_tensors.utils import merge_names, pack_bitmasks, unpack_bitmasks from torch import Tensor @@ -139,7 +139,7 @@ def bitmask_compress(tensor: Tensor) -> Tuple[Tensor, Tensor, Tensor]: values = values.view(FP8_DTYPE) else: values = tensor[bytemasks] - bitmasks_packed = pack_into_bitmasks(bytemasks) + bitmasks_packed = pack_bitmasks(bytemasks) return values, bitmasks_packed, row_offsets diff --git a/src/compressed_tensors/utils/helpers.py b/src/compressed_tensors/utils/helpers.py index 3e970484..9bef7d47 100644 --- a/src/compressed_tensors/utils/helpers.py +++ b/src/compressed_tensors/utils/helpers.py @@ -32,7 +32,7 @@ "Aliasable", "combine_shards", "shard_tensor", - "pack_into_bitmasks", + "pack_bitmasks", "unpack_bitmasks", ] @@ -286,7 +286,7 @@ def combine_shards(shards, dim=0): return combined -def pack_into_bitmasks(bytemasks: torch.Tensor) -> torch.Tensor: +def pack_bitmasks(bytemasks: torch.Tensor) -> torch.Tensor: """ Converts a bytemask tensor to a bitmask tensor to reduce memory. Shape RxC will be compressed to R x ceil(C/8)