Skip to content

Commit

Permalink
review comments from @dsikka
Browse files Browse the repository at this point in the history
  • Loading branch information
rahul-tuli committed Jan 8, 2025
1 parent 8acd9b8 commit e987872
Showing 1 changed file with 22 additions and 22 deletions.
44 changes: 22 additions & 22 deletions src/compressed_tensors/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,35 +255,35 @@ def shard_tensor(

def combine_shards(shards, dim=0):
"""
Combine decompressed shards along a given dimension without using torch.cat
for unsupported dtypes like float8_e4m3fn.
Combine decompressed shards along a given dimension using `narrow`.
:param shards: List of decompressed shard tensors.
:param dim: Dimension to combine along (default: 0).
:return: Combined decompressed tensor.
"""
try:
# Attempt regular concatenation
return torch.cat(shards, dim=dim)
except RuntimeError as e:
# Handle unsupported concatenation
if all(shard.dtype == torch.float8_e4m3fn for shard in shards):
total_shape = list(shards[0].shape)
total_shape[dim] = sum(shard.shape[dim] for shard in shards)
combined = torch.zeros(
total_shape, dtype=shards[0].dtype, device=shards[0].device
)
if not shards:
raise ValueError("The list of shards is empty.")

shard_offset = 0
for shard in shards:
shard_size = shard.shape[dim]
combined.narrow(dim, shard_offset, shard_size).copy_(shard)
shard_offset += shard_size
# Assert that all shards have the same dtype
shard_dtypes = {shard.dtype for shard in shards}
if len(shard_dtypes) > 1:
raise ValueError("All shards must have the same dtype.")

return combined
else:
# Re-raise unexpected errors
raise e
# Determine the total shape of the combined tensor
total_shape = list(shards[0].shape)
total_shape[dim] = sum(shard.shape[dim] for shard in shards)

# Create the combined tensor
combined = torch.zeros(total_shape, dtype=shards[0].dtype, device=shards[0].device)

# Fill the combined tensor using narrow
shard_offset = 0
for shard in shards:
shard_size = shard.shape[dim]
combined.narrow(dim, shard_offset, shard_size).copy_(shard)
shard_offset += shard_size

return combined


def pack_into_bitmasks(bytemasks: torch.Tensor) -> torch.Tensor:
Expand Down

0 comments on commit e987872

Please sign in to comment.