diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py index 246835dbf1d24..cb65557be8f90 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py @@ -14,7 +14,8 @@ from vllm.model_executor.parameter import (BasevLLMParameter, ChannelQuantScaleParameter, GroupQuantScaleParameter, - PackedvLLMParameter) + PackedvLLMParameter, + RowvLLMParameter) from vllm.scalar_type import scalar_types logger = init_logger(__name__) @@ -136,11 +137,21 @@ def create_weights(self, layer: torch.nn.Module, output_size: int, layer.register_parameter("weight_scale", weight_scale) layer.register_parameter("weight_shape", weight_shape) + # group index (for activation reordering) + if self.has_g_idx: + weight_g_idx = RowvLLMParameter(data=torch.empty( + input_size_per_partition, + dtype=torch.int32, + ), + input_dim=0, + weight_loader=weight_loader) + layer.register_parameter("weight_g_idx", weight_g_idx) + self.kernel = kernel_type(mp_linear_kernel_config, w_q_param_name="weight_packed", w_s_param_name="weight_scale", w_zp_param_name=None, - w_gidx_param_name=None) + w_gidx_param_name="weight_g_idx") # Checkpoints are serialized in compressed-tensors format, which is # different from the format the kernel may want. Handle repacking here. diff --git a/vllm/model_executor/layers/quantization/kernels/MPLinearKernel.py b/vllm/model_executor/layers/quantization/kernels/MPLinearKernel.py index 89d2cd6b3051c..efd33c9395e91 100644 --- a/vllm/model_executor/layers/quantization/kernels/MPLinearKernel.py +++ b/vllm/model_executor/layers/quantization/kernels/MPLinearKernel.py @@ -63,11 +63,11 @@ def _transform_param(self, layer: torch.nn.Module, name: Optional[str], def _get_weight_params( self, layer: torch.nn.Module - ) -> Tuple[torch.Tensor, # w_q - torch.Tensor, # w_s - Optional[torch.Tensor],# w_zp, - Optional[torch.Tensor] # w_gidx - ]: + ) -> Tuple[torch.Tensor, # w_q + torch.Tensor, # w_s + Optional[torch.Tensor], # w_zp, + Optional[torch.Tensor] # w_gidx + ]: return ( getattr(layer, self.w_q_name), getattr(layer, self.w_s_name),