From 49bdcf5de6d28f697ef8db5fbf52b9ffc4929697 Mon Sep 17 00:00:00 2001 From: Dipika Date: Mon, 4 Nov 2024 23:28:59 +0000 Subject: [PATCH 1/2] patch --- .../schemes/compressed_tensors_24.py | 39 +++++++++++++++---- 1 file changed, 32 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py index b96c915dceca8..930bc75f04b47 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py @@ -92,14 +92,39 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: weight_packed_data = layer.weight_packed.data meta = layer.meta.data - weight_data = { - "weight_packed": weight_packed_data, - "meta": meta - } - #decompress = self.model_compressor.sparsity_compressor.decompress_weight(weight_data).contiguous() + qkv_sizes = [2048, 256, 256] + gate_up_sizes = [5632, 5632] + split_weights = None + split_meta = None + + def _process_split(input_weight, input_meta): + decompress = sparse_semi_structured_to_dense_cutlass(input_weight, input_meta) + return decompress + + print(self.layer_name) + if "qkv" in self.layer_name: + split_weights = torch.split(weight_packed_data, qkv_sizes) + split_meta = torch.split(meta, qkv_sizes) + elif "gate_up" in self.layer_name: + split_weights = torch.split(weight_packed_data, gate_up_sizes) + split_meta = torch.split(meta, gate_up_sizes) + + if split_weights: + all_compress = [] + for i in range(len(split_weights)): + print(split_weights[i].shape, split_meta[i].shape) + compress_i = _process_split(split_weights[i], split_meta[i]) + all_compress.append(compress_i) + + compressed = torch.cat(all_compress) + compressed = compress_to_torch_sparse_semi_structured_mat(compressed) + else: + decompress = sparse_semi_structured_to_dense_cutlass(weight_packed_data, meta) + compressed = compress_to_torch_sparse_semi_structured_mat(decompress) + + #decompress = self.model_compressor.sparsity_compressor.decompress_weight(weight_data) # Temporarily swap in to use Alex's method. Seems like the compression might be wrong? - decompress = sparse_semi_structured_to_dense_cutlass(weight_packed_data, meta) - compressed = compress_to_torch_sparse_semi_structured_mat(decompress) + layer.weight = Parameter(compressed, requires_grad=False) else: From 6133f81b239a760e2a4c0cdb0057aac0042e8147 Mon Sep 17 00:00:00 2001 From: Dipika Date: Mon, 4 Nov 2024 23:34:40 +0000 Subject: [PATCH 2/2] use our decompressor --- .../schemes/compressed_tensors_24.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py index 930bc75f04b47..2844018cc98cb 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py @@ -98,7 +98,11 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: split_meta = None def _process_split(input_weight, input_meta): - decompress = sparse_semi_structured_to_dense_cutlass(input_weight, input_meta) + weight_data = { + "weight_packed": input_weight, + "meta": input_meta + } + decompress = self.model_compressor.sparsity_compressor.decompress_weight(weight_data) return decompress print(self.layer_name) @@ -119,12 +123,9 @@ def _process_split(input_weight, input_meta): compressed = torch.cat(all_compress) compressed = compress_to_torch_sparse_semi_structured_mat(compressed) else: - decompress = sparse_semi_structured_to_dense_cutlass(weight_packed_data, meta) + decompress = _process_split(weight_packed_data, meta) compressed = compress_to_torch_sparse_semi_structured_mat(decompress) - #decompress = self.model_compressor.sparsity_compressor.decompress_weight(weight_data) - # Temporarily swap in to use Alex's method. Seems like the compression might be wrong? - layer.weight = Parameter(compressed, requires_grad=False) else: