Skip to content

Commit

Permalink
Revert function name change
Browse files Browse the repository at this point in the history
  • Loading branch information
rahul-tuli committed Jan 8, 2025
1 parent e8f26d4 commit 0ca476f
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from compressed_tensors.compressors.sparse_compressors.base import BaseSparseCompressor
from compressed_tensors.config import CompressionFormat, SparsityStructure
from compressed_tensors.quantization import FP8_DTYPE
from compressed_tensors.utils import merge_names, pack_into_bitmasks, unpack_bitmasks
from compressed_tensors.utils import merge_names, pack_bitmasks, unpack_bitmasks
from torch import Tensor


Expand Down Expand Up @@ -173,7 +173,7 @@ def sparse24_bitmask_compress(

num_rows, num_cols = tensor.shape
compressed_values = values.reshape(num_rows, num_cols // 2)
bitmasks_packed = pack_into_bitmasks(bytemasks)
bitmasks_packed = pack_bitmasks(bytemasks)
return compressed_values, bitmasks_packed


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from compressed_tensors.compressors.sparse_compressors.base import BaseSparseCompressor
from compressed_tensors.config import CompressionFormat
from compressed_tensors.quantization import FP8_DTYPE
from compressed_tensors.utils import merge_names, pack_into_bitmasks, unpack_bitmasks
from compressed_tensors.utils import merge_names, pack_bitmasks, unpack_bitmasks
from torch import Tensor


Expand Down Expand Up @@ -139,7 +139,7 @@ def bitmask_compress(tensor: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
values = values.view(FP8_DTYPE)
else:
values = tensor[bytemasks]
bitmasks_packed = pack_into_bitmasks(bytemasks)
bitmasks_packed = pack_bitmasks(bytemasks)
return values, bitmasks_packed, row_offsets


Expand Down
4 changes: 2 additions & 2 deletions src/compressed_tensors/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
"Aliasable",
"combine_shards",
"shard_tensor",
"pack_into_bitmasks",
"pack_bitmasks",
"unpack_bitmasks",
]

Expand Down Expand Up @@ -286,7 +286,7 @@ def combine_shards(shards, dim=0):
return combined


def pack_into_bitmasks(bytemasks: torch.Tensor) -> torch.Tensor:
def pack_bitmasks(bytemasks: torch.Tensor) -> torch.Tensor:
"""
Converts a bytemask tensor to a bitmask tensor to reduce memory. Shape RxC will be
compressed to R x ceil(C/8)
Expand Down

0 comments on commit 0ca476f

Please sign in to comment.