From 000796acdf1e6184eeb36272c5ddd6ffbc41fac3 Mon Sep 17 00:00:00 2001 From: Dipika Date: Thu, 26 Sep 2024 18:03:11 +0000 Subject: [PATCH 01/11] add awq moe --- .../model_executor/layers/quantization/awq.py | 191 +++++++++++++++++- vllm/model_executor/model_loader/utils.py | 2 +- 2 files changed, 188 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index 410b3cb5321cb..e564b18e7d323 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -1,14 +1,22 @@ -from typing import Any, Dict, List, Optional +from typing import Callable, Any, Dict, List, Optional import torch - +from torch.nn import Parameter from vllm import _custom_ops as ops -from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.fused_moe.layer import ( + FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase, set_weight_attrs from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.parameter import (GroupQuantScaleParameter, PackedvLLMParameter) +from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + apply_gptq_marlin_linear, check_marlin_supported, marlin_is_k_full, + marlin_make_empty_g_idx, marlin_make_workspace, marlin_moe_permute_scales, + marlin_permute_scales, marlin_repeat_scales_on_all_ranks, + marlin_sort_g_idx, replace_tensor, verify_marlin_supported, + verify_marlin_supports_shape) class AWQConfig(QuantizationConfig): """Config class for AWQ. @@ -64,9 +72,11 @@ def from_config(cls, config: Dict[str, Any]) -> "AWQConfig": return cls(weight_bits, group_size, zero_point) def get_quant_method(self, layer: torch.nn.Module, - prefix: str) -> Optional["AWQLinearMethod"]: + prefix: str) -> Optional["QuantizedMethodBase"]: if isinstance(layer, LinearBase): return AWQLinearMethod(self) + elif isinstance(layer, FusedMoE): + return AWQMoEMethod(self) return None def get_scaled_act_names(self) -> List[str]: @@ -170,3 +180,176 @@ def apply(self, if bias is not None: out.add_(bias) return out.reshape(out_shape) + +class AWQMoEMethod(FusedMoEMethodBase): + + def __init__(self, quant_config: AWQConfig): + self.quant_config = quant_config + self.num_bits = self.quant_config.weight_bits + self.packed_factor = self.quant_config.pack_factor + self.group_size = self.quant_config.group_size + + def create_weights(self, layer: torch.nn.Module, num_experts: int, + hidden_size: int, intermediate_size: int, + params_dtype: torch.dtype, **extra_weight_attrs): + extra_weight_attrs.update({ + "is_transposed": True, + "quant_method": "group", + }) + + w13_qweight = Parameter(torch.empty(num_experts, + hidden_size, + 2 * intermediate_size // + self.quant_config.pack_factor, + dtype=torch.int32), + requires_grad=False) + layer.register_parameter("w13_qweight", w13_qweight) + set_weight_attrs(w13_qweight, extra_weight_attrs) + + w2_qweight = Parameter(torch.empty(num_experts, + intermediate_size, + hidden_size // + self.quant_config.pack_factor, + dtype=torch.int32), + requires_grad=False) + layer.register_parameter("w2_qweight", w2_qweight) + set_weight_attrs(w2_qweight, extra_weight_attrs) + + num_groups_w13 = hidden_size // self.quant_config.group_size + num_groups_w2 = intermediate_size // self.quant_config.group_size + + # WEIGHT_SCALES + # Allocate 2 scales for w1 and w3 respectively. + w13_scales = Parameter(torch.empty(num_experts, + num_groups_w13, + intermediate_size * 2, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w13_scales", w13_scales) + set_weight_attrs(w13_scales, extra_weight_attrs) + + w2_scales = Parameter(torch.empty(num_experts, + num_groups_w2, + hidden_size, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w2_scales", w2_scales) + set_weight_attrs(w2_scales, extra_weight_attrs) + + # WEIGHT_ZERO_POINT + # Allocate 2 zero points for w1 and w3 respectively. + w13_qzeros = Parameter(torch.empty(num_experts, + num_groups_w13, + 2 * intermediate_size // + self.quant_config.pack_factor, + dtype=torch.int32), + requires_grad=False) + layer.register_parameter("w13_qzeros", w13_qzeros) + set_weight_attrs(w13_qzeros, extra_weight_attrs) + + w2_qzeros = Parameter(torch.empty(num_experts, + num_groups_w2, + hidden_size // + self.quant_config.pack_factor, + dtype=torch.int32), + requires_grad=False) + layer.register_parameter("w2_qzeros", w2_qzeros) + set_weight_attrs(w2_qzeros, extra_weight_attrs) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + num_experts = layer.w13_qweight.shape[0] + device = layer.w13_qweight.device + + layer.w13_g_idx = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, + device=device), + requires_grad=False, + ) + layer.w2_g_idx = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, + device=device), + requires_grad=False, + ) + layer.w13_g_idx_sort_indices = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, + device=device), + requires_grad=False, + ) + layer.w2_g_idx_sort_indices = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, + device=device), + requires_grad=False, + ) + marlin_w13_qweight = ops.gptq_marlin_moe_repack( + layer.w13_qweight, + layer.w13_g_idx_sort_indices, + layer.w13_qweight.shape[1], + layer.w13_qweight.shape[2] * self.packed_factor, + self.num_bits, + ) + replace_tensor(layer, "w13_qweight", marlin_w13_qweight) + marlin_w2_qweight = ops.gptq_marlin_moe_repack( + layer.w2_qweight, + layer.w2_g_idx_sort_indices, + layer.w2_qweight.shape[1], + layer.w2_qweight.shape[2] * self.packed_factor, + self.num_bits, + ) + replace_tensor(layer, "w2_qweight", marlin_w2_qweight) + # Repack scales + marlin_w13_scales = marlin_moe_permute_scales( + s=layer.w13_scales, + size_k=layer.intermediate_size_per_partition, + size_n=layer.w13_scales.shape[2], + group_size=self.group_size + ) + + replace_tensor(layer, "w13_scales", marlin_w13_scales) + marlin_w2_scales = marlin_moe_permute_scales( + s=layer.w2_scales, + size_k=layer.w2_scales.shape[1] , + size_n=layer.w2_scales.shape[2] * self.packed_factor, + group_size=self.group_size, + ) + replace_tensor(layer, "w2_scales", marlin_w2_scales) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool = True, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + ) -> torch.Tensor: + + from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( + fused_marlin_moe) + + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function) + + return fused_marlin_moe( + x, + layer.w13_qweight, + layer.w2_qweight, + layer.w13_scales, + layer.w2_scales, + router_logits, + topk_weights, + topk_ids, + g_idx1=layer.w13_g_idx, + g_idx2=layer.w2_g_idx, + sort_indices1=layer.w13_g_idx_sort_indices, + sort_indices2=layer.w2_g_idx_sort_indices, + num_bits=self.num_bits) \ No newline at end of file diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index 2bfe6ea09bd62..995bb253db8a1 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -23,7 +23,7 @@ def get_model_architecture( architectures = getattr(model_config.hf_config, "architectures", []) # Special handling for quantized Mixtral. # FIXME(woosuk): This is a temporary hack. - mixtral_supported = ["fp8", "compressed-tensors", "gptq_marlin"] + mixtral_supported = ["fp8", "compressed-tensors", "gptq_marlin", "awq"] if (model_config.quantization is not None and model_config.quantization not in mixtral_supported From e8289ae95dbe7f100898935997906791b01adcc2 Mon Sep 17 00:00:00 2001 From: Dipika Date: Thu, 26 Sep 2024 19:39:56 +0000 Subject: [PATCH 02/11] update --- .../model_executor/layers/quantization/awq.py | 49 +++++++++---------- 1 file changed, 24 insertions(+), 25 deletions(-) diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index e564b18e7d323..b82714bd5ba55 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -18,6 +18,7 @@ marlin_sort_g_idx, replace_tensor, verify_marlin_supported, verify_marlin_supports_shape) + class AWQConfig(QuantizationConfig): """Config class for AWQ. @@ -181,13 +182,11 @@ def apply(self, out.add_(bias) return out.reshape(out_shape) + class AWQMoEMethod(FusedMoEMethodBase): def __init__(self, quant_config: AWQConfig): self.quant_config = quant_config - self.num_bits = self.quant_config.weight_bits - self.packed_factor = self.quant_config.pack_factor - self.group_size = self.quant_config.group_size def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size: int, @@ -255,61 +254,60 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, requires_grad=False) layer.register_parameter("w2_qzeros", w2_qzeros) set_weight_attrs(w2_qzeros, extra_weight_attrs) - + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: num_experts = layer.w13_qweight.shape[0] device = layer.w13_qweight.device layer.w13_g_idx = torch.nn.Parameter( - torch.empty((num_experts, 0), dtype=torch.int32, - device=device), + torch.empty((num_experts, 0), dtype=torch.int32, device=device), requires_grad=False, ) layer.w2_g_idx = torch.nn.Parameter( - torch.empty((num_experts, 0), dtype=torch.int32, - device=device), + torch.empty((num_experts, 0), dtype=torch.int32, device=device), requires_grad=False, ) layer.w13_g_idx_sort_indices = torch.nn.Parameter( - torch.empty((num_experts, 0), dtype=torch.int32, - device=device), + torch.empty((num_experts, 0), dtype=torch.int32, device=device), requires_grad=False, ) layer.w2_g_idx_sort_indices = torch.nn.Parameter( - torch.empty((num_experts, 0), dtype=torch.int32, - device=device), + torch.empty((num_experts, 0), dtype=torch.int32, device=device), requires_grad=False, ) + marlin_w13_qweight = ops.gptq_marlin_moe_repack( layer.w13_qweight, layer.w13_g_idx_sort_indices, - layer.w13_qweight.shape[1], - layer.w13_qweight.shape[2] * self.packed_factor, - self.num_bits, + size_k=layer.w13_qweight.shape[1], + size_n=layer.w13_qweight.shape[2] * self.quant_config.pack_factor, + num_bits=self.quant_config.weight_bits, ) replace_tensor(layer, "w13_qweight", marlin_w13_qweight) + marlin_w2_qweight = ops.gptq_marlin_moe_repack( layer.w2_qweight, layer.w2_g_idx_sort_indices, - layer.w2_qweight.shape[1], - layer.w2_qweight.shape[2] * self.packed_factor, - self.num_bits, + size_k=layer.w2_qweight.shape[1], + size_n=layer.w2_qweight.shape[2] * self.quant_config.pack_factor, + num_bits=self.quant_config.weight_bits, ) replace_tensor(layer, "w2_qweight", marlin_w2_qweight) - # Repack scales + marlin_w13_scales = marlin_moe_permute_scales( s=layer.w13_scales, size_k=layer.intermediate_size_per_partition, size_n=layer.w13_scales.shape[2], - group_size=self.group_size + group_size=self.quant_config.group_size, ) - + # for @eliza: why do we need to apply a pack factor to the scales? + # they're not in packed format? replace_tensor(layer, "w13_scales", marlin_w13_scales) marlin_w2_scales = marlin_moe_permute_scales( s=layer.w2_scales, - size_k=layer.w2_scales.shape[1] , - size_n=layer.w2_scales.shape[2] * self.packed_factor, - group_size=self.group_size, + size_k=layer.w2_scales.shape[1] * self.quant_config.pack_factor, + size_n=layer.w2_scales.shape[2], + group_size=self.quant_config.group_size, ) replace_tensor(layer, "w2_scales", marlin_w2_scales) @@ -352,4 +350,5 @@ def apply( g_idx2=layer.w2_g_idx, sort_indices1=layer.w13_g_idx_sort_indices, sort_indices2=layer.w2_g_idx_sort_indices, - num_bits=self.num_bits) \ No newline at end of file + num_bits=self.quant_config.weight_bits, + ) From 0385aa85eabc5005e706f112e96f370ae3e28326 Mon Sep 17 00:00:00 2001 From: Dipika Date: Fri, 27 Sep 2024 17:07:25 +0000 Subject: [PATCH 03/11] update awq --- vllm/_custom_ops.py | 14 ++++++++ .../layers/fused_moe/fused_moe.py | 2 +- .../model_executor/layers/quantization/awq.py | 36 +++++++++++++------ .../layers/quantization/gptq_marlin.py | 1 + .../layers/quantization/utils/marlin_utils.py | 15 ++++++++ 5 files changed, 57 insertions(+), 11 deletions(-) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 77c46584ef530..8ce01b2d82532 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -317,6 +317,20 @@ def gptq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor, return output +def awq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor, + size_k: int, size_n: int, + num_bits: int) -> torch.Tensor: + num_experts = b_q_weight.shape[0] + assert size_k % 16 == 0 + output = torch.empty((num_experts, size_k // 16, size_n * (num_bits // 2)), + device=b_q_weight.device, + dtype=b_q_weight.dtype) + for e in range(num_experts): + output[e] = torch.ops._C.awq_marlin_repack(b_q_weight[e], size_k, + size_n, num_bits) + return output + + def gptq_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, b_scales: torch.Tensor, diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index bd13d8fecbb96..1a98666204f93 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -443,7 +443,7 @@ def grouped_topk(hidden_states: torch.Tensor, if renormalize: topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) - return topk_weights, topk_ids + return topk_weights.to(torch.float32), topk_ids.to(torch.int32) def get_config_dtype_str(dtype: torch.dtype, diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index b82714bd5ba55..ba912aa6552d3 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -12,11 +12,10 @@ PackedvLLMParameter) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - apply_gptq_marlin_linear, check_marlin_supported, marlin_is_k_full, - marlin_make_empty_g_idx, marlin_make_workspace, marlin_moe_permute_scales, - marlin_permute_scales, marlin_repeat_scales_on_all_ranks, - marlin_sort_g_idx, replace_tensor, verify_marlin_supported, - verify_marlin_supports_shape) + marlin_moe_permute_scales, moe_awq_to_marlin_zero_points, + apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported, + marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales, + replace_tensor, verify_marlin_supported, verify_marlin_supports_shape) class AWQConfig(QuantizationConfig): @@ -276,7 +275,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: requires_grad=False, ) - marlin_w13_qweight = ops.gptq_marlin_moe_repack( + marlin_w13_qweight = ops.awq_marlin_moe_repack( layer.w13_qweight, layer.w13_g_idx_sort_indices, size_k=layer.w13_qweight.shape[1], @@ -285,7 +284,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: ) replace_tensor(layer, "w13_qweight", marlin_w13_qweight) - marlin_w2_qweight = ops.gptq_marlin_moe_repack( + marlin_w2_qweight = ops.awq_marlin_moe_repack( layer.w2_qweight, layer.w2_g_idx_sort_indices, size_k=layer.w2_qweight.shape[1], @@ -294,23 +293,38 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: ) replace_tensor(layer, "w2_qweight", marlin_w2_qweight) + # Why does this take the intermediate size for size_k? marlin_w13_scales = marlin_moe_permute_scales( s=layer.w13_scales, size_k=layer.intermediate_size_per_partition, size_n=layer.w13_scales.shape[2], group_size=self.quant_config.group_size, ) - # for @eliza: why do we need to apply a pack factor to the scales? - # they're not in packed format? + replace_tensor(layer, "w13_scales", marlin_w13_scales) + marlin_w2_scales = marlin_moe_permute_scales( s=layer.w2_scales, - size_k=layer.w2_scales.shape[1] * self.quant_config.pack_factor, + size_k=layer.intermediate_size_per_partition, size_n=layer.w2_scales.shape[2], group_size=self.quant_config.group_size, ) replace_tensor(layer, "w2_scales", marlin_w2_scales) + marlin_w13_zp = moe_awq_to_marlin_zero_points( + layer.w13_qzeros, + size_k=layer.w13_qzeros.shape[1], + size_n=layer.w13_qzeros.shape[2] * self.quant_config.pack_factor, + num_bits=self.quant_config.weight_bits) + replace_tensor(layer, "w13_qzeros", marlin_w13_zp) + + marlin_w2_zp = moe_awq_to_marlin_zero_points( + layer.w2_qzeros, + size_k=layer.w2_qzeros.shape[1], + size_n=layer.w2_qzeros.shape[2] * self.quant_config.pack_factor, + num_bits=self.quant_config.weight_bits) + replace_tensor(layer, "w2_qzeros", marlin_w2_zp) + def apply( self, layer: torch.nn.Module, @@ -346,6 +360,8 @@ def apply( router_logits, topk_weights, topk_ids, + w1_zeros=layer.w13_qzeros, + w2_zeros=layer.w2_qzeros, g_idx1=layer.w13_g_idx, g_idx2=layer.w2_g_idx, sort_indices1=layer.w13_g_idx_sort_indices, diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index dd46f0ce5a39c..04bea28ec4630 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -554,6 +554,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: ) replace_tensor(layer, "w2_qweight", marlin_w2_qweight) # Repack scales + # Why does this take the intermediate size for size_k? marlin_w13_scales = marlin_moe_permute_scales( s=layer.w13_scales, size_k=layer.intermediate_size_per_partition, diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py index 699d5f1844146..db8ec78f937ee 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -188,6 +188,7 @@ def marlin_moe_permute_scales( device=s.device, dtype=s.dtype, ) + for e in range(num_experts): output[e] = marlin_permute_scales(s[e], size_k, size_n, group_size) return output @@ -238,6 +239,20 @@ def awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int, return marlin_zp +def moe_awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int, + size_n: int, num_bits: int): + num_experts = q_zp_packed.shape[0] + output = torch.empty( + (num_experts, q_zp_packed.shape[1], q_zp_packed.shape[2]), + device=q_zp_packed.device, + dtype=q_zp_packed.dtype, + ) + for e in range(num_experts): + output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n, + num_bits) + return output + + # Newly generated tensors need to replace existing tensors that are # already registered as parameters by vLLM (and won't be freed) def replace_tensor(layer: torch.nn.Module, name: str, From 3d125547c775e3048e4c327f2a5dbb272f490a8b Mon Sep 17 00:00:00 2001 From: Dipika Date: Mon, 30 Sep 2024 15:16:38 +0000 Subject: [PATCH 04/11] move to marlin; clean-up --- .../model_executor/layers/quantization/awq.py | 208 +----------------- .../layers/quantization/awq_marlin.py | 206 ++++++++++++++++- vllm/model_executor/model_loader/utils.py | 4 +- 3 files changed, 204 insertions(+), 214 deletions(-) diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index ba912aa6552d3..30380ec0407c5 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -1,22 +1,14 @@ -from typing import Callable, Any, Dict, List, Optional +from typing import Any, Dict, List, Optional import torch -from torch.nn import Parameter + from vllm import _custom_ops as ops -from vllm.model_executor.layers.fused_moe.layer import ( - FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) -from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase, set_weight_attrs +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.parameter import (GroupQuantScaleParameter, PackedvLLMParameter) -from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - marlin_moe_permute_scales, moe_awq_to_marlin_zero_points, - apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported, - marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales, - replace_tensor, verify_marlin_supported, verify_marlin_supports_shape) - class AWQConfig(QuantizationConfig): """Config class for AWQ. @@ -72,11 +64,9 @@ def from_config(cls, config: Dict[str, Any]) -> "AWQConfig": return cls(weight_bits, group_size, zero_point) def get_quant_method(self, layer: torch.nn.Module, - prefix: str) -> Optional["QuantizedMethodBase"]: + prefix: str) -> Optional["AWQLinearMethod"]: if isinstance(layer, LinearBase): return AWQLinearMethod(self) - elif isinstance(layer, FusedMoE): - return AWQMoEMethod(self) return None def get_scaled_act_names(self) -> List[str]: @@ -179,192 +169,4 @@ def apply(self, pack_factor) if bias is not None: out.add_(bias) - return out.reshape(out_shape) - - -class AWQMoEMethod(FusedMoEMethodBase): - - def __init__(self, quant_config: AWQConfig): - self.quant_config = quant_config - - def create_weights(self, layer: torch.nn.Module, num_experts: int, - hidden_size: int, intermediate_size: int, - params_dtype: torch.dtype, **extra_weight_attrs): - extra_weight_attrs.update({ - "is_transposed": True, - "quant_method": "group", - }) - - w13_qweight = Parameter(torch.empty(num_experts, - hidden_size, - 2 * intermediate_size // - self.quant_config.pack_factor, - dtype=torch.int32), - requires_grad=False) - layer.register_parameter("w13_qweight", w13_qweight) - set_weight_attrs(w13_qweight, extra_weight_attrs) - - w2_qweight = Parameter(torch.empty(num_experts, - intermediate_size, - hidden_size // - self.quant_config.pack_factor, - dtype=torch.int32), - requires_grad=False) - layer.register_parameter("w2_qweight", w2_qweight) - set_weight_attrs(w2_qweight, extra_weight_attrs) - - num_groups_w13 = hidden_size // self.quant_config.group_size - num_groups_w2 = intermediate_size // self.quant_config.group_size - - # WEIGHT_SCALES - # Allocate 2 scales for w1 and w3 respectively. - w13_scales = Parameter(torch.empty(num_experts, - num_groups_w13, - intermediate_size * 2, - dtype=params_dtype), - requires_grad=False) - layer.register_parameter("w13_scales", w13_scales) - set_weight_attrs(w13_scales, extra_weight_attrs) - - w2_scales = Parameter(torch.empty(num_experts, - num_groups_w2, - hidden_size, - dtype=params_dtype), - requires_grad=False) - layer.register_parameter("w2_scales", w2_scales) - set_weight_attrs(w2_scales, extra_weight_attrs) - - # WEIGHT_ZERO_POINT - # Allocate 2 zero points for w1 and w3 respectively. - w13_qzeros = Parameter(torch.empty(num_experts, - num_groups_w13, - 2 * intermediate_size // - self.quant_config.pack_factor, - dtype=torch.int32), - requires_grad=False) - layer.register_parameter("w13_qzeros", w13_qzeros) - set_weight_attrs(w13_qzeros, extra_weight_attrs) - - w2_qzeros = Parameter(torch.empty(num_experts, - num_groups_w2, - hidden_size // - self.quant_config.pack_factor, - dtype=torch.int32), - requires_grad=False) - layer.register_parameter("w2_qzeros", w2_qzeros) - set_weight_attrs(w2_qzeros, extra_weight_attrs) - - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - num_experts = layer.w13_qweight.shape[0] - device = layer.w13_qweight.device - - layer.w13_g_idx = torch.nn.Parameter( - torch.empty((num_experts, 0), dtype=torch.int32, device=device), - requires_grad=False, - ) - layer.w2_g_idx = torch.nn.Parameter( - torch.empty((num_experts, 0), dtype=torch.int32, device=device), - requires_grad=False, - ) - layer.w13_g_idx_sort_indices = torch.nn.Parameter( - torch.empty((num_experts, 0), dtype=torch.int32, device=device), - requires_grad=False, - ) - layer.w2_g_idx_sort_indices = torch.nn.Parameter( - torch.empty((num_experts, 0), dtype=torch.int32, device=device), - requires_grad=False, - ) - - marlin_w13_qweight = ops.awq_marlin_moe_repack( - layer.w13_qweight, - layer.w13_g_idx_sort_indices, - size_k=layer.w13_qweight.shape[1], - size_n=layer.w13_qweight.shape[2] * self.quant_config.pack_factor, - num_bits=self.quant_config.weight_bits, - ) - replace_tensor(layer, "w13_qweight", marlin_w13_qweight) - - marlin_w2_qweight = ops.awq_marlin_moe_repack( - layer.w2_qweight, - layer.w2_g_idx_sort_indices, - size_k=layer.w2_qweight.shape[1], - size_n=layer.w2_qweight.shape[2] * self.quant_config.pack_factor, - num_bits=self.quant_config.weight_bits, - ) - replace_tensor(layer, "w2_qweight", marlin_w2_qweight) - - # Why does this take the intermediate size for size_k? - marlin_w13_scales = marlin_moe_permute_scales( - s=layer.w13_scales, - size_k=layer.intermediate_size_per_partition, - size_n=layer.w13_scales.shape[2], - group_size=self.quant_config.group_size, - ) - - replace_tensor(layer, "w13_scales", marlin_w13_scales) - - marlin_w2_scales = marlin_moe_permute_scales( - s=layer.w2_scales, - size_k=layer.intermediate_size_per_partition, - size_n=layer.w2_scales.shape[2], - group_size=self.quant_config.group_size, - ) - replace_tensor(layer, "w2_scales", marlin_w2_scales) - - marlin_w13_zp = moe_awq_to_marlin_zero_points( - layer.w13_qzeros, - size_k=layer.w13_qzeros.shape[1], - size_n=layer.w13_qzeros.shape[2] * self.quant_config.pack_factor, - num_bits=self.quant_config.weight_bits) - replace_tensor(layer, "w13_qzeros", marlin_w13_zp) - - marlin_w2_zp = moe_awq_to_marlin_zero_points( - layer.w2_qzeros, - size_k=layer.w2_qzeros.shape[1], - size_n=layer.w2_qzeros.shape[2] * self.quant_config.pack_factor, - num_bits=self.quant_config.weight_bits) - replace_tensor(layer, "w2_qzeros", marlin_w2_zp) - - def apply( - self, - layer: torch.nn.Module, - x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool = True, - use_grouped_topk: bool = False, - num_expert_group: Optional[int] = None, - topk_group: Optional[int] = None, - custom_routing_function: Optional[Callable] = None, - ) -> torch.Tensor: - - from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( - fused_marlin_moe) - - topk_weights, topk_ids = FusedMoE.select_experts( - hidden_states=x, - router_logits=router_logits, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function) - - return fused_marlin_moe( - x, - layer.w13_qweight, - layer.w2_qweight, - layer.w13_scales, - layer.w2_scales, - router_logits, - topk_weights, - topk_ids, - w1_zeros=layer.w13_qzeros, - w2_zeros=layer.w2_qzeros, - g_idx1=layer.w13_g_idx, - g_idx2=layer.w2_g_idx, - sort_indices1=layer.w13_g_idx_sort_indices, - sort_indices2=layer.w2_g_idx_sort_indices, - num_bits=self.quant_config.weight_bits, - ) + return out.reshape(out_shape) \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index eee6a8f7cff49..9704b1adbce55 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -1,16 +1,21 @@ -from typing import Any, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional import torch +from torch.nn import Parameter from vllm import _custom_ops as ops from vllm.logger import init_logger -from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.fused_moe.layer import ( + FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) +from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, + set_weight_attrs) from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) + QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported, - marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales, - replace_tensor, verify_marlin_supported, verify_marlin_supports_shape) + marlin_make_empty_g_idx, marlin_make_workspace, marlin_moe_permute_scales, + marlin_permute_scales, moe_awq_to_marlin_zero_points, replace_tensor, + verify_marlin_supported, verify_marlin_supports_shape) from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.parameter import (GroupQuantScaleParameter, PackedvLLMParameter) @@ -34,12 +39,13 @@ def __init__(self, weight_bits: int, group_size: int, has_zp: bool, self.group_size = group_size self.has_zp = has_zp self.lm_head_quantized = lm_head_quantized + self.weight_bits = weight_bits - if weight_bits not in self.TYPE_MAP: - raise ValueError(f"Unsupported num_bits = {weight_bits}. " + if self.weight_bits not in self.TYPE_MAP: + raise ValueError(f"Unsupported num_bits = {self.weight_bits}. " f"Supported num_bits = {self.TYPE_MAP.keys()}") - self.quant_type = self.TYPE_MAP[weight_bits] + self.quant_type = self.TYPE_MAP[self.weight_bits] verify_marlin_supported(self.quant_type, group_size=self.group_size, @@ -97,10 +103,12 @@ def override_quantization_method(cls, hf_quant_cfg, return None def get_quant_method(self, layer: torch.nn.Module, - prefix: str) -> Optional["AWQMarlinLinearMethod"]: + prefix: str) -> Optional["QuantizeMethodBase"]: if (isinstance(layer, LinearBase) or (isinstance(layer, ParallelLMHead) and self.lm_head_quantized)): return AWQMarlinLinearMethod(self) + elif isinstance(layer, FusedMoE): + return AWQMoEMethod(self) return None def get_scaled_act_names(self) -> List[str]: @@ -270,4 +278,182 @@ def apply( quant_type=self.quant_config.quant_type, output_size_per_partition=layer.output_size_per_partition, input_size_per_partition=layer.input_size_per_partition, - bias=bias) \ No newline at end of file + bias=bias) + + +class AWQMoEMethod(FusedMoEMethodBase): + + def __init__(self, quant_config: AWQMarlinConfig): + self.quant_config = quant_config + + def create_weights(self, layer: torch.nn.Module, num_experts: int, + hidden_size: int, intermediate_size: int, + params_dtype: torch.dtype, **extra_weight_attrs): + extra_weight_attrs.update({ + "is_transposed": + True, + "quant_method": + FusedMoeWeightScaleSupported.GROUP, + }) + + w13_qweight = Parameter(torch.empty(num_experts, + hidden_size, + 2 * intermediate_size // + self.quant_config.pack_factor, + dtype=torch.int32), + requires_grad=False) + layer.register_parameter("w13_qweight", w13_qweight) + set_weight_attrs(w13_qweight, extra_weight_attrs) + + w2_qweight = Parameter(torch.empty(num_experts, + intermediate_size, + hidden_size // + self.quant_config.pack_factor, + dtype=torch.int32), + requires_grad=False) + layer.register_parameter("w2_qweight", w2_qweight) + set_weight_attrs(w2_qweight, extra_weight_attrs) + + num_groups_w13 = hidden_size // self.quant_config.group_size + num_groups_w2 = intermediate_size // self.quant_config.group_size + + # WEIGHT_SCALES + # Allocate 2 scales for w1 and w3 respectively. + w13_scales = Parameter(torch.empty(num_experts, + num_groups_w13, + intermediate_size * 2, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w13_scales", w13_scales) + set_weight_attrs(w13_scales, extra_weight_attrs) + + w2_scales = Parameter(torch.empty(num_experts, + num_groups_w2, + hidden_size, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w2_scales", w2_scales) + set_weight_attrs(w2_scales, extra_weight_attrs) + + # WEIGHT_ZERO_POINT + # Allocate 2 zero points for w1 and w3 respectively. + w13_qzeros = Parameter(torch.empty(num_experts, + num_groups_w13, + 2 * intermediate_size // + self.quant_config.pack_factor, + dtype=torch.int32), + requires_grad=False) + layer.register_parameter("w13_qzeros", w13_qzeros) + set_weight_attrs(w13_qzeros, extra_weight_attrs) + + w2_qzeros = Parameter(torch.empty(num_experts, + num_groups_w2, + hidden_size // + self.quant_config.pack_factor, + dtype=torch.int32), + requires_grad=False) + layer.register_parameter("w2_qzeros", w2_qzeros) + set_weight_attrs(w2_qzeros, extra_weight_attrs) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + num_experts = layer.w13_qweight.shape[0] + device = layer.w13_qweight.device + + layer.w13_g_idx_sort_indices = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, device=device), + requires_grad=False, + ) + layer.w2_g_idx_sort_indices = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, device=device), + requires_grad=False, + ) + + marlin_w13_qweight = ops.awq_marlin_moe_repack( + layer.w13_qweight, + layer.w13_g_idx_sort_indices, + size_k=layer.w13_qweight.shape[1], + size_n=layer.w13_qweight.shape[2] * self.quant_config.pack_factor, + num_bits=self.quant_config.weight_bits, + ) + replace_tensor(layer, "w13_qweight", marlin_w13_qweight) + + marlin_w2_qweight = ops.awq_marlin_moe_repack( + layer.w2_qweight, + layer.w2_g_idx_sort_indices, + size_k=layer.w2_qweight.shape[1], + size_n=layer.w2_qweight.shape[2] * self.quant_config.pack_factor, + num_bits=self.quant_config.weight_bits, + ) + replace_tensor(layer, "w2_qweight", marlin_w2_qweight) + + # Why does this take the intermediate size for size_k? + marlin_w13_scales = marlin_moe_permute_scales( + s=layer.w13_scales, + size_k=layer.intermediate_size_per_partition, + size_n=layer.w13_scales.shape[2], + group_size=self.quant_config.group_size, + ) + + replace_tensor(layer, "w13_scales", marlin_w13_scales) + + marlin_w2_scales = marlin_moe_permute_scales( + s=layer.w2_scales, + size_k=layer.intermediate_size_per_partition, + size_n=layer.w2_scales.shape[2], + group_size=self.quant_config.group_size, + ) + replace_tensor(layer, "w2_scales", marlin_w2_scales) + + marlin_w13_zp = moe_awq_to_marlin_zero_points( + layer.w13_qzeros, + size_k=layer.w13_qzeros.shape[1], + size_n=layer.w13_qzeros.shape[2] * self.quant_config.pack_factor, + num_bits=self.quant_config.weight_bits) + replace_tensor(layer, "w13_qzeros", marlin_w13_zp) + + marlin_w2_zp = moe_awq_to_marlin_zero_points( + layer.w2_qzeros, + size_k=layer.w2_qzeros.shape[1], + size_n=layer.w2_qzeros.shape[2] * self.quant_config.pack_factor, + num_bits=self.quant_config.weight_bits) + replace_tensor(layer, "w2_qzeros", marlin_w2_zp) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool = True, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + ) -> torch.Tensor: + + from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( + fused_marlin_moe) + + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function) + + return fused_marlin_moe( + x, + layer.w13_qweight, + layer.w2_qweight, + layer.w13_scales, + layer.w2_scales, + router_logits, + topk_weights, + topk_ids, + w1_zeros=layer.w13_qzeros, + w2_zeros=layer.w2_qzeros, + num_bits=self.quant_config.weight_bits, + ) diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index 995bb253db8a1..792c359a559a9 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -23,7 +23,9 @@ def get_model_architecture( architectures = getattr(model_config.hf_config, "architectures", []) # Special handling for quantized Mixtral. # FIXME(woosuk): This is a temporary hack. - mixtral_supported = ["fp8", "compressed-tensors", "gptq_marlin", "awq"] + mixtral_supported = [ + "fp8", "compressed-tensors", "gptq_marlin", "awq", "awq_marlin" + ] if (model_config.quantization is not None and model_config.quantization not in mixtral_supported From b54b633cbf21ae4a2b600b96be3f04603d9d5c9a Mon Sep 17 00:00:00 2001 From: Dipika Date: Mon, 30 Sep 2024 16:35:23 +0000 Subject: [PATCH 05/11] fix typo; add test --- tests/weight_loading/models-large.txt | 1 + vllm/model_executor/layers/quantization/awq_marlin.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/weight_loading/models-large.txt b/tests/weight_loading/models-large.txt index 2f5c6c5a117f3..8ab7f05d7d1b2 100644 --- a/tests/weight_loading/models-large.txt +++ b/tests/weight_loading/models-large.txt @@ -2,3 +2,4 @@ compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-quantized, main compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-channel-quantized, main compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W8A16-quantized, main gptq_marlin, TheBloke/Mixtral-8x7B-v0.1-GPTQ, main +awq_marlin, casperhansen/deepseek-coder-v2-instruct-awq, main \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index 9704b1adbce55..5c689f03925a1 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -293,7 +293,7 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, "is_transposed": True, "quant_method": - FusedMoeWeightScaleSupported.GROUP, + FusedMoeWeightScaleSupported.GROUP.value, }) w13_qweight = Parameter(torch.empty(num_experts, From e0e5a749b7a41c4554ed02e659b4bf90bc8ac04a Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Tue, 1 Oct 2024 03:19:05 -0400 Subject: [PATCH 06/11] Michael's feedback, cleanup --- csrc/moe/marlin_moe_ops.cu | 6 +-- csrc/moe/marlin_moe_ops.h | 5 +- csrc/moe/torch_bindings.cpp | 5 +- tests/kernels/test_awq_marlin.py | 2 - tests/kernels/test_moe.py | 8 +-- vllm/_custom_ops.py | 6 +-- .../layers/fused_moe/fused_marlin_moe.py | 49 +++++++++---------- 7 files changed, 38 insertions(+), 43 deletions(-) diff --git a/csrc/moe/marlin_moe_ops.cu b/csrc/moe/marlin_moe_ops.cu index e540f07236498..ec0836131ba82 100644 --- a/csrc/moe/marlin_moe_ops.cu +++ b/csrc/moe/marlin_moe_ops.cu @@ -484,9 +484,9 @@ torch::Tensor marlin_gemm_moe( torch::Tensor& b_zeros, const torch::Tensor& g_idx, const torch::Tensor& perm, torch::Tensor& workspace, vllm::ScalarTypeTorchPtr const& b_q_type, int64_t size_m, int64_t size_n, - int64_t size_k, bool is_k_full, bool has_zp, int64_t num_experts, - int64_t topk, int64_t moe_block_size, bool replicate_input, - bool apply_weights) { + int64_t size_k, bool is_k_full, int64_t num_experts, int64_t topk, + int64_t moe_block_size, bool replicate_input, bool apply_weights) { + bool has_zp = b_zeros.size(1) != 0; if (has_zp) { TORCH_CHECK(*b_q_type == vllm::kU4 || *b_q_type == vllm::kU8, "b_q_type must be u4 or u8 when has_zp = True. Got = ", diff --git a/csrc/moe/marlin_moe_ops.h b/csrc/moe/marlin_moe_ops.h index 0a54d93cedebc..0013787a623de 100644 --- a/csrc/moe/marlin_moe_ops.h +++ b/csrc/moe/marlin_moe_ops.h @@ -11,6 +11,5 @@ torch::Tensor marlin_gemm_moe( torch::Tensor& b_zeros, const torch::Tensor& g_idx, const torch::Tensor& perm, torch::Tensor& workspace, vllm::ScalarTypeTorchPtr const& b_q_type, int64_t size_m, int64_t size_n, - int64_t size_k, bool is_k_full, bool has_zp, int64_t num_experts, - int64_t topk, int64_t moe_block_size, bool replicate_input, - bool apply_weights); + int64_t size_k, bool is_k_full, int64_t num_experts, int64_t topk, + int64_t moe_block_size, bool replicate_input, bool apply_weights); diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index 85098df34b2d0..576305d48ae47 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -15,9 +15,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { "Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! " "b_zeros, Tensor! g_idx, Tensor! perm, Tensor! workspace, " "__torch__.torch.classes._core_C.ScalarType b_q_type, int size_m, " - "int size_n, int size_k, bool is_k_full, bool has_zp, int num_experts, " - "int topk, int moe_block_size, bool replicate_input, bool apply_weights)" - " -> Tensor"); + "int size_n, int size_k, bool is_k_full, int num_experts, int topk, int " + "moe_block_size, bool replicate_input, bool apply_weights) -> Tensor"); m.impl("marlin_gemm_moe", torch::kCUDA, &marlin_gemm_moe); #endif } diff --git a/tests/kernels/test_awq_marlin.py b/tests/kernels/test_awq_marlin.py index 338f46cbe09fb..f1a0b09e8e464 100644 --- a/tests/kernels/test_awq_marlin.py +++ b/tests/kernels/test_awq_marlin.py @@ -87,7 +87,6 @@ def test_fused_marlin_moe_awq( score, topk_weights, topk_ids, - has_zero_point=True, w1_zeros=zp1, w2_zeros=zp2, num_bits=num_bits, @@ -155,7 +154,6 @@ def test_single_marlin_moe_multiply_awq( score, topk, renormalize=False, - has_zero_point=True, w_zeros=zp, num_bits=num_bits) diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index 360ef1330bd69..b73c45b9cd198 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -234,12 +234,14 @@ def test_fused_marlin_moe( device="cuda", requires_grad=False) - zp = torch.empty((0), dtype=dtype, device="cuda", requires_grad=False) - + zp = torch.empty((0, 0), + dtype=dtype, + device="cuda", + requires_grad=False) opcheck(torch.ops._moe_C.marlin_gemm_moe, (a, qweight1, sorted_token_ids, topk_weights, topk_ids, scales1, zp, g_idx1, sort_indices1, workspace, quant_type, m, - 2 * n, k, True, False, e, topk, block_size_m, True, False)) + 2 * n, k, True, e, topk, block_size_m, True, False)) @pytest.mark.skip("This test is here for the sake of debugging, " diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index bc7b4293c119e..6081fa674579c 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -822,9 +822,9 @@ def marlin_gemm_moe_fake(a: torch.Tensor, b_q_weights: torch.Tensor, b_zero_points: torch.Tensor, g_idx: torch.Tensor, perm: torch.Tensor, workspace: torch.Tensor, b_q_type: ScalarType, size_m: int, size_n: int, - size_k: int, is_k_full: bool, - has_zero_point: bool, num_experts: int, topk: int, - moe_block_size: int, replicate_input: bool, + size_k: int, is_k_full: bool, num_experts: int, + topk: int, moe_block_size: int, + replicate_input: bool, apply_weights: bool) -> torch.Tensor: return torch.empty((size_m, topk, size_n), dtype=a.dtype, diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index e57b15936aa8b..466b0edd81fe7 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -24,7 +24,6 @@ def single_marlin_moe( gating_output: torch.Tensor, topk: int, renormalize: bool, - has_zero_point: bool = False, g_idx: Optional[torch.Tensor] = None, sort_indices: Optional[torch.Tensor] = None, w_zeros: Optional[torch.Tensor] = None, @@ -93,11 +92,9 @@ def single_marlin_moe( device=hidden_states.device, requires_grad=False) - if has_zero_point: - assert w_zeros is not None and w_zeros.nelement() > 0 - + has_zero_point = w_zeros is not None if w_zeros is None: - w_zeros = torch.empty((0), + w_zeros = torch.empty((0, 0), dtype=hidden_states.dtype, device=hidden_states.device, requires_grad=False) @@ -119,7 +116,7 @@ def single_marlin_moe( intermediate_cache = torch.ops._moe_C.marlin_gemm_moe( hidden_states, w, sorted_token_ids, topk_weights, topk_ids, scales, w_zeros, g_idx, sort_indices, workspace, scalar_type, M, N, K, - is_k_full, has_zero_point, E, topk, block_size_m, True, False) + is_k_full, E, topk, block_size_m, True, False) return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1) @@ -133,7 +130,6 @@ def fused_marlin_moe( gating_output: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, - has_zero_point: bool = False, g_idx1: Optional[torch.Tensor] = None, g_idx2: Optional[torch.Tensor] = None, sort_indices1: Optional[torch.Tensor] = None, @@ -187,6 +183,20 @@ def fused_marlin_moe( assert hidden_states.dtype == torch.float16 assert num_bits in [4, 8] + has_no_act_order = (g_idx1 is None and g_idx2 is None + and sort_indices1 is None and sort_indices2 is None) + has_all_act_order = (g_idx1 is not None and g_idx2 is not None + and sort_indices1 is not None + and sort_indices2 is not None) + assert has_no_act_order or has_all_act_order, ( + "g_idx and sorted_indices " + "must be all not None or must be all None") + + has_no_zp = w1_zeros is None and w2_zeros is None + has_all_zp = w1_zeros is not None and w2_zeros is not None + assert has_no_zp or has_all_zp, ("zero points must be both not None or " + "must be both None") + M, K = hidden_states.shape E = w1.shape[0] N = w2.shape[1] * 16 @@ -213,47 +223,36 @@ def fused_marlin_moe( device="cuda", requires_grad=False) - if has_zero_point: - assert w1_zeros is not None and w1_zeros.nelement() > 0 - assert w2_zeros is not None and w2_zeros.nelement() > 0 - - if w1_zeros is None: - w1_zeros = torch.empty((0), + if has_no_zp: + w1_zeros = torch.empty((0, 0), dtype=hidden_states.dtype, device=hidden_states.device, requires_grad=False) - if w2_zeros is None: - w2_zeros = torch.empty((0), + w2_zeros = torch.empty((0, 0), dtype=hidden_states.dtype, device=hidden_states.device, requires_grad=False) - if g_idx1 is None: + if has_no_act_order: g_idx1 = torch.empty((0, 0), dtype=torch.int32, device=hidden_states.device, requires_grad=False) - - if g_idx2 is None: g_idx2 = torch.empty((0, 0), dtype=torch.int32, device=hidden_states.device, requires_grad=False) - - if sort_indices1 is None: sort_indices1 = torch.empty((0), dtype=torch.int32, device=hidden_states.device, requires_grad=False) - - if sort_indices2 is None: sort_indices2 = torch.empty((0, 0), dtype=torch.int32, device=hidden_states.device, requires_grad=False) - scalar_type1 = get_scalar_type(num_bits, has_zero_point) - scalar_type2 = get_scalar_type(num_bits, has_zero_point) + scalar_type1 = get_scalar_type(num_bits, has_all_zp) + scalar_type2 = get_scalar_type(num_bits, has_all_zp) intermediate_cache2 = torch.empty( (M * topk_ids.shape[1], N), @@ -277,7 +276,6 @@ def fused_marlin_moe( 2 * N, K, is_k_full, - has_zero_point, E, topk, block_size_m, @@ -303,7 +301,6 @@ def fused_marlin_moe( K, N, is_k_full, - has_zero_point, E, topk, block_size_m, From bbf575e2985b8476bf858866be9158bb0bf2a0e1 Mon Sep 17 00:00:00 2001 From: Dipika Date: Tue, 1 Oct 2024 13:49:52 +0000 Subject: [PATCH 07/11] use replace_parameters; clean-up --- .../layers/quantization/awq_marlin.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index cc98cbfb70ad2..294fe11815c0f 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -11,14 +11,11 @@ set_weight_attrs) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) -from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported, - marlin_make_empty_g_idx, marlin_make_workspace, marlin_moe_permute_scales, - marlin_permute_scales, moe_awq_to_marlin_zero_points) from vllm.model_executor.layers.quantization.utils import replace_parameter from vllm.model_executor.layers.quantization.utils.marlin_utils import ( apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported, - marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales, + marlin_make_empty_g_idx, marlin_make_workspace, marlin_moe_permute_scales, + marlin_permute_scales, moe_awq_to_marlin_zero_points, verify_marlin_supported, verify_marlin_supports_shape) from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.parameter import (GroupQuantScaleParameter, @@ -379,7 +376,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: size_n=layer.w13_qweight.shape[2] * self.quant_config.pack_factor, num_bits=self.quant_config.weight_bits, ) - replace_tensor(layer, "w13_qweight", marlin_w13_qweight) + replace_parameter(layer, "w13_qweight", marlin_w13_qweight) marlin_w2_qweight = ops.awq_marlin_moe_repack( layer.w2_qweight, @@ -388,7 +385,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: size_n=layer.w2_qweight.shape[2] * self.quant_config.pack_factor, num_bits=self.quant_config.weight_bits, ) - replace_tensor(layer, "w2_qweight", marlin_w2_qweight) + replace_parameter(layer, "w2_qweight", marlin_w2_qweight) # Why does this take the intermediate size for size_k? marlin_w13_scales = marlin_moe_permute_scales( @@ -398,7 +395,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: group_size=self.quant_config.group_size, ) - replace_tensor(layer, "w13_scales", marlin_w13_scales) + replace_parameter(layer, "w13_scales", marlin_w13_scales) marlin_w2_scales = marlin_moe_permute_scales( s=layer.w2_scales, @@ -406,21 +403,21 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: size_n=layer.w2_scales.shape[2], group_size=self.quant_config.group_size, ) - replace_tensor(layer, "w2_scales", marlin_w2_scales) + replace_parameter(layer, "w2_scales", marlin_w2_scales) marlin_w13_zp = moe_awq_to_marlin_zero_points( layer.w13_qzeros, size_k=layer.w13_qzeros.shape[1], size_n=layer.w13_qzeros.shape[2] * self.quant_config.pack_factor, num_bits=self.quant_config.weight_bits) - replace_tensor(layer, "w13_qzeros", marlin_w13_zp) + replace_parameter(layer, "w13_qzeros", marlin_w13_zp) marlin_w2_zp = moe_awq_to_marlin_zero_points( layer.w2_qzeros, size_k=layer.w2_qzeros.shape[1], size_n=layer.w2_qzeros.shape[2] * self.quant_config.pack_factor, num_bits=self.quant_config.weight_bits) - replace_tensor(layer, "w2_qzeros", marlin_w2_zp) + replace_parameter(layer, "w2_qzeros", marlin_w2_zp) def apply( self, From 79126f906b5eafda1df6866f7328f6d78f2eeec3 Mon Sep 17 00:00:00 2001 From: Dipika Date: Tue, 1 Oct 2024 13:54:42 +0000 Subject: [PATCH 08/11] more clean-up --- vllm/model_executor/layers/quantization/awq.py | 2 +- .../model_executor/layers/quantization/gptq_marlin.py | 1 - .../layers/quantization/utils/marlin_utils.py | 11 ----------- vllm/model_executor/model_loader/utils.py | 2 +- 4 files changed, 2 insertions(+), 14 deletions(-) diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index 30380ec0407c5..410b3cb5321cb 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -169,4 +169,4 @@ def apply(self, pack_factor) if bias is not None: out.add_(bias) - return out.reshape(out_shape) \ No newline at end of file + return out.reshape(out_shape) diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index b9b43413b35db..e77191796bd7e 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -509,7 +509,6 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: ) replace_parameter(layer, "w2_qweight", marlin_w2_qweight) # Repack scales - # Why does this take the intermediate size for size_k? marlin_w13_scales = marlin_moe_permute_scales( s=layer.w13_scales, size_k=layer.intermediate_size_per_partition, diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py index 1275b4474a06c..9a1defa409714 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -273,17 +273,6 @@ def moe_awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int, return output -# Newly generated tensors need to replace existing tensors that are -# already registered as parameters by vLLM (and won't be freed) -def replace_tensor(layer: torch.nn.Module, name: str, - new_t: torch.Tensor) -> None: - # It is important to use resize_() here since it ensures - # the same buffer is reused - getattr(layer, name).resize_(new_t.shape) - getattr(layer, name).copy_(new_t) - del new_t - - def apply_gptq_marlin_linear( input: torch.Tensor, weight: torch.Tensor, diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index 792c359a559a9..b95c0b7cd0612 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -24,7 +24,7 @@ def get_model_architecture( # Special handling for quantized Mixtral. # FIXME(woosuk): This is a temporary hack. mixtral_supported = [ - "fp8", "compressed-tensors", "gptq_marlin", "awq", "awq_marlin" + "fp8", "compressed-tensors", "gptq_marlin", "awq_marlin" ] if (model_config.quantization is not None From 87d46dc91021432b51c7b19730f3b04403c119ed Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Tue, 1 Oct 2024 11:33:57 -0400 Subject: [PATCH 09/11] Delete 8-bit zero point code --- CMakeLists.txt | 2 -- .../marlin_kernels/marlin_moe_kernel_ku8.cu | 31 ------------------- .../marlin_kernels/marlin_moe_kernel_ku8.h | 20 ------------ csrc/moe/marlin_moe_ops.cu | 8 ++--- tests/kernels/test_awq_marlin.py | 10 +++--- .../layers/fused_moe/fused_marlin_moe.py | 3 +- 6 files changed, 9 insertions(+), 65 deletions(-) delete mode 100644 csrc/moe/marlin_kernels/marlin_moe_kernel_ku8.cu delete mode 100644 csrc/moe/marlin_kernels/marlin_moe_kernel_ku8.h diff --git a/CMakeLists.txt b/CMakeLists.txt index df22ce47e54bf..8c66c31aa6ce4 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -332,8 +332,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu" "csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.h" "csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.cu" - "csrc/moe/marlin_kernels/marlin_moe_kernel_ku8.h" - "csrc/moe/marlin_kernels/marlin_moe_kernel_ku8.cu" "csrc/moe/marlin_moe_ops.cu") endif() diff --git a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8.cu b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8.cu deleted file mode 100644 index 7abbc45440bfc..0000000000000 --- a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8.cu +++ /dev/null @@ -1,31 +0,0 @@ -#include "marlin_moe_kernel_ku8.h" - -namespace marlin_moe { - -// We return bool so we can create these different kernel calls as a sequence -// of if-elseif's. -bool call_marlin_moe_kernel_ku8( - vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks, - bool has_act_order, int group_blocks, int num_threads, int blocks, - int max_shared_mem, cudaStream_t stream, const int4* A_ptr, - const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, - const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr, - const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups, - int expert_idx, int num_experts, int topk, int prob_m, int prob_n, - int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights, - int m_block, int max_par, int cfg_max_m_blocks) { - bool has_zp = true; - - if (false) { - } - AWQ_CALL_IF_MOE(vllm::kU8, 16, 4, 256) - AWQ_CALL_IF_MOE(vllm::kU8, 8, 8, 256) - AWQ_CALL_IF_MOE(vllm::kU8, 8, 4, 128) - AWQ_CALL_IF_MOE(vllm::kU8, 4, 8, 128) - else { - return false; - } - return true; -} - -} // namespace marlin_moe diff --git a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8.h b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8.h deleted file mode 100644 index 03a0132aa347c..0000000000000 --- a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8.h +++ /dev/null @@ -1,20 +0,0 @@ -#pragma once - -#include "marlin_moe_kernel.h" - -namespace marlin_moe { - -// We return bool so we can create these different kernel calls as a sequence -// of if-elseif's. -bool call_marlin_moe_kernel_ku8( - vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks, - bool has_act_order, int group_blocks, int num_threads, int blocks, - int max_shared_mem, cudaStream_t stream, const int4* A_ptr, - const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, - const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr, - const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups, - int expert_idx, int num_experts, int topk, int prob_m, int prob_n, - int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights, - int m_block, int max_par, int cfg_max_m_blocks); - -} // namespace marlin_moe diff --git a/csrc/moe/marlin_moe_ops.cu b/csrc/moe/marlin_moe_ops.cu index ec0836131ba82..b3cccd4c566fb 100644 --- a/csrc/moe/marlin_moe_ops.cu +++ b/csrc/moe/marlin_moe_ops.cu @@ -30,7 +30,6 @@ #include "marlin_kernels/marlin_moe_kernel_ku4b8.h" #include "marlin_kernels/marlin_moe_kernel_ku8b128.h" #include "marlin_kernels/marlin_moe_kernel_ku4.h" -#include "marlin_kernels/marlin_moe_kernel_ku8.h" template inline std::string str(T x) { @@ -461,7 +460,6 @@ void marlin_mm_moe(const void* A, const void* B, void* C, CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku4b8) CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku8b128) CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku4) - CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku8) else { TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " + str(prob_n) + ", " + str(prob_k) + "]" + @@ -488,9 +486,9 @@ torch::Tensor marlin_gemm_moe( int64_t moe_block_size, bool replicate_input, bool apply_weights) { bool has_zp = b_zeros.size(1) != 0; if (has_zp) { - TORCH_CHECK(*b_q_type == vllm::kU4 || *b_q_type == vllm::kU8, - "b_q_type must be u4 or u8 when has_zp = True. Got = ", - b_q_type->str()); + TORCH_CHECK( + *b_q_type == vllm::kU4, + "b_q_type must be u4 when has_zp = True. Got = ", b_q_type->str()); } else { TORCH_CHECK( *b_q_type == vllm::kU4B8 || *b_q_type == vllm::kU8B128, diff --git a/tests/kernels/test_awq_marlin.py b/tests/kernels/test_awq_marlin.py index f1a0b09e8e464..0738ea9b97edb 100644 --- a/tests/kernels/test_awq_marlin.py +++ b/tests/kernels/test_awq_marlin.py @@ -21,7 +21,6 @@ @pytest.mark.parametrize("e", [8, 64]) @pytest.mark.parametrize("topk", [2, 6]) @pytest.mark.parametrize("group_size", [-1, 32, 64, 128]) -@pytest.mark.parametrize("num_bits", [4, 8]) def test_fused_marlin_moe_awq( m: int, n: int, @@ -29,11 +28,11 @@ def test_fused_marlin_moe_awq( e: int, topk: int, group_size: int, - num_bits: int, ): torch.manual_seed(7) - quant_type = (scalar_types.uint4 if num_bits == 4 else scalar_types.uint8) + num_bits = 4 + quant_type = scalar_types.uint4 dtype = torch.float16 a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 @@ -111,7 +110,6 @@ def test_fused_marlin_moe_awq( @pytest.mark.parametrize("e", [8, 64]) @pytest.mark.parametrize("topk", [2, 6]) @pytest.mark.parametrize("group_size", [-1, 32, 64, 128]) -@pytest.mark.parametrize("num_bits", [4, 8]) def test_single_marlin_moe_multiply_awq( m: int, n: int, @@ -119,11 +117,11 @@ def test_single_marlin_moe_multiply_awq( e: int, topk: int, group_size: int, - num_bits: int, ): torch.manual_seed(7) - quant_type = (scalar_types.uint4 if num_bits == 4 else scalar_types.uint8) + num_bits = 4 + quant_type = scalar_types.uint4 dtype = torch.float16 a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 w = torch.randn((e, n, k), device="cuda", dtype=dtype) / 10 diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index 466b0edd81fe7..66f589dba7851 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -12,7 +12,8 @@ def get_scalar_type(num_bits: int, has_zp: bool): if has_zp: - return scalar_types.uint4 if num_bits == 4 else scalar_types.uint8 + assert num_bits == 4 + return scalar_types.uint4 else: return scalar_types.uint4b8 if num_bits == 4 else scalar_types.uint8b128 From 8fe6da46bf7aa2673a480c72558a80aeeedcc544 Mon Sep 17 00:00:00 2001 From: Dipika Date: Tue, 1 Oct 2024 17:57:00 +0000 Subject: [PATCH 10/11] fix file reverted from some commit hoopla --- .../run_model_weight_loading_test.sh | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/tests/weight_loading/run_model_weight_loading_test.sh b/tests/weight_loading/run_model_weight_loading_test.sh index 0cb45d1780c2c..e80c1d6c5849c 100755 --- a/tests/weight_loading/run_model_weight_loading_test.sh +++ b/tests/weight_loading/run_model_weight_loading_test.sh @@ -1,7 +1,20 @@ #!/bin/bash SUCCESS=0 -IFS=$'\n' read -d '' -r -a MODEL_CONFIGS < "weight_loading/models.txt" +while getopts "c:" OPT; do + case ${OPT} in + c ) + CONFIG="$OPTARG" + ;; + \? ) + usage + exit 1 + ;; + esac +done + + +IFS=$'\n' read -d '' -r -a MODEL_CONFIGS < $CONFIG for MODEL_CONFIG in "${MODEL_CONFIGS[@]}" do From a966417fa4433c6bbb863f14db27d15c1547dbb4 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Wed, 2 Oct 2024 01:14:48 -0400 Subject: [PATCH 11/11] Make workspace smaller, add very small thread config --- csrc/moe/marlin_moe_ops.cu | 2 ++ vllm/model_executor/layers/fused_moe/fused_marlin_moe.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/csrc/moe/marlin_moe_ops.cu b/csrc/moe/marlin_moe_ops.cu index b3cccd4c566fb..69d66b5d7101e 100644 --- a/csrc/moe/marlin_moe_ops.cu +++ b/csrc/moe/marlin_moe_ops.cu @@ -157,6 +157,7 @@ thread_config_t small_batch_thread_configs[] = { {128, 64, 128}, // Reduce N 2X, same K {64, 256, 256}, // Reduce K 2X, increase N 2X {64, 128, 128}, // Reduce K 2X, same N + {64, 64, 128}, // Reduce both 2X }; thread_config_t large_batch_thread_configs[] = { @@ -167,6 +168,7 @@ thread_config_t large_batch_thread_configs[] = { {128, 128, 256}, // Reduce N 2X, increase K 2X {64, 128, 128}, // Reduce N 2X, same K {128, 64, 128}, // Reduce N 4X, increase K 2X + {64, 64, 128}, // Reduce N 4X, same K }; int get_scales_cache_size(thread_config_t const& th_config, int prob_m, diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index 66f589dba7851..5964d5a5465fd 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -218,7 +218,7 @@ def fused_marlin_moe( sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E) - max_workspace_size = ((M + 255) // 256) * (max(2 * N, K) // 64) * 16 + max_workspace_size = (max(2 * N, K) // 64) * 16 workspace = torch.zeros(max_workspace_size, dtype=torch.int, device="cuda",