From 1e4e98c9beb89fc21d04521cf341ca697ec3e597 Mon Sep 17 00:00:00 2001 From: Dipika Date: Wed, 4 Dec 2024 03:48:14 +0000 Subject: [PATCH] get compressed pway to work without using torch.cat Signed-off-by: Dipika --- .../schemes/compressed_tensors_24.py | 38 +++++++++++-------- 1 file changed, 23 insertions(+), 15 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py index e3df8317bdb99..ceb3bcc9819a9 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py @@ -175,13 +175,9 @@ def apply_weights(self, return out def _decompress_layer_weight(self, layer: torch.nn.Module) -> torch.Tensor: - sparse_24_packed_weight = layer.sparse_24_packed_weight.data meta = layer.meta.data - split_weights = None - split_meta = None - def _process_split(input_weight, input_meta): weight_data = { "sparse_24_packed_weight": input_weight, @@ -190,17 +186,29 @@ def _process_split(input_weight, input_meta): decompress = self.model_compressor.sparsity_compressor.decompress_weight(weight_data) return decompress - if isinstance(layer, (QKVParallelLinear, MergedColumnParallelLinear)): - split_weights = torch.split(sparse_24_packed_weight, layer.logical_widths) - split_meta = torch.split(meta, layer.logical_widths) - - if split_weights: - all_compress = [] - for i in range(len(split_weights)): - compress_i = _process_split(split_weights[i], split_meta[i]) - all_compress.append(compress_i) - - decompressed = torch.cat(all_compress) + + if len(layer.logical_widths) > 1: + # to store decompressed weight + weight_to_compress_p = torch.empty( + sparse_24_packed_weight.shape[0], + sparse_24_packed_weight.shape[1] * 2, + dtype=torch.float8_e4m3fn, + device="cuda" + ) + + shard_offset = 0 + for i in range(len(layer.logical_widths)): + if i > 0: + shard_offset += layer.logical_widths[i-1] + + sparse_data = sparse_24_packed_weight.narrow(0, shard_offset, layer.logical_widths[i]) + meta_data = meta.narrow(0, shard_offset, layer.logical_widths[i]) + decompress_i = _process_split(sparse_data, meta_data) + + param_data = weight_to_compress_p.narrow(0, shard_offset, layer.logical_widths[i]) + param_data.copy_(decompress_i) + + decompressed = weight_to_compress_p else: decompressed = _process_split(sparse_24_packed_weight, meta)