diff --git a/src/compressed_tensors/utils/helpers.py b/src/compressed_tensors/utils/helpers.py index 39066e8b..3e970484 100644 --- a/src/compressed_tensors/utils/helpers.py +++ b/src/compressed_tensors/utils/helpers.py @@ -255,35 +255,35 @@ def shard_tensor( def combine_shards(shards, dim=0): """ - Combine decompressed shards along a given dimension without using torch.cat - for unsupported dtypes like float8_e4m3fn. + Combine decompressed shards along a given dimension using `narrow`. :param shards: List of decompressed shard tensors. :param dim: Dimension to combine along (default: 0). :return: Combined decompressed tensor. """ - try: - # Attempt regular concatenation - return torch.cat(shards, dim=dim) - except RuntimeError as e: - # Handle unsupported concatenation - if all(shard.dtype == torch.float8_e4m3fn for shard in shards): - total_shape = list(shards[0].shape) - total_shape[dim] = sum(shard.shape[dim] for shard in shards) - combined = torch.zeros( - total_shape, dtype=shards[0].dtype, device=shards[0].device - ) + if not shards: + raise ValueError("The list of shards is empty.") - shard_offset = 0 - for shard in shards: - shard_size = shard.shape[dim] - combined.narrow(dim, shard_offset, shard_size).copy_(shard) - shard_offset += shard_size + # Assert that all shards have the same dtype + shard_dtypes = {shard.dtype for shard in shards} + if len(shard_dtypes) > 1: + raise ValueError("All shards must have the same dtype.") - return combined - else: - # Re-raise unexpected errors - raise e + # Determine the total shape of the combined tensor + total_shape = list(shards[0].shape) + total_shape[dim] = sum(shard.shape[dim] for shard in shards) + + # Create the combined tensor + combined = torch.zeros(total_shape, dtype=shards[0].dtype, device=shards[0].device) + + # Fill the combined tensor using narrow + shard_offset = 0 + for shard in shards: + shard_size = shard.shape[dim] + combined.narrow(dim, shard_offset, shard_size).copy_(shard) + shard_offset += shard_size + + return combined def pack_into_bitmasks(bytemasks: torch.Tensor) -> torch.Tensor: