Skip to content

Commit

Permalink
add g_idx back
Browse files Browse the repository at this point in the history
  • Loading branch information
LucasWilkinson committed Sep 11, 2024
1 parent 378de64 commit b7c13e7
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
from vllm.model_executor.parameter import (BasevLLMParameter,
ChannelQuantScaleParameter,
GroupQuantScaleParameter,
PackedvLLMParameter)
PackedvLLMParameter,
RowvLLMParameter)
from vllm.scalar_type import scalar_types

logger = init_logger(__name__)
Expand Down Expand Up @@ -136,11 +137,21 @@ def create_weights(self, layer: torch.nn.Module, output_size: int,
layer.register_parameter("weight_scale", weight_scale)
layer.register_parameter("weight_shape", weight_shape)

# group index (for activation reordering)
if self.has_g_idx:
weight_g_idx = RowvLLMParameter(data=torch.empty(
input_size_per_partition,
dtype=torch.int32,
),
input_dim=0,
weight_loader=weight_loader)
layer.register_parameter("weight_g_idx", weight_g_idx)

self.kernel = kernel_type(mp_linear_kernel_config,
w_q_param_name="weight_packed",
w_s_param_name="weight_scale",
w_zp_param_name=None,
w_gidx_param_name=None)
w_gidx_param_name="weight_g_idx")

# Checkpoints are serialized in compressed-tensors format, which is
# different from the format the kernel may want. Handle repacking here.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,11 @@ def _transform_param(self, layer: torch.nn.Module, name: Optional[str],

def _get_weight_params(
self, layer: torch.nn.Module
) -> Tuple[torch.Tensor, # w_q
torch.Tensor, # w_s
Optional[torch.Tensor],# w_zp,
Optional[torch.Tensor] # w_gidx
]:
) -> Tuple[torch.Tensor, # w_q
torch.Tensor, # w_s
Optional[torch.Tensor], # w_zp,
Optional[torch.Tensor] # w_gidx
]:
return (
getattr(layer, self.w_q_name),
getattr(layer, self.w_s_name),
Expand Down

0 comments on commit b7c13e7

Please sign in to comment.