From 263a870ee18bd6a90e25dbfa342be32c6b92c33e Mon Sep 17 00:00:00 2001 From: Avshalom Manevich <12231371+avshalomman@users.noreply.github.com> Date: Sun, 12 Jan 2025 17:53:51 +0200 Subject: [PATCH] [Hardware][TPU] workaround fix for MoE on TPU (#11764) --- tests/kernels/test_moe.py | 7 +++ vllm/model_executor/layers/fused_moe/layer.py | 3 +- .../layers/fused_moe/moe_torch_iterative.py | 51 +++++++++++++++++++ 3 files changed, 60 insertions(+), 1 deletion(-) create mode 100644 vllm/model_executor/layers/fused_moe/moe_torch_iterative.py diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index 8b23b62826053..7fa5de1984452 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -14,6 +14,8 @@ from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe.fused_moe import ( fused_topk, moe_align_block_size) +from vllm.model_executor.layers.fused_moe.moe_torch_iterative import ( + fused_moe as iterative_moe) from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( marlin_quantize) from vllm.model_executor.models.mixtral import MixtralMoE @@ -46,6 +48,11 @@ def test_fused_moe( triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False) torch_output = torch_moe(a, w1, w2, score, topk) torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) + iterative_output = iterative_moe(a, w1, w2, score, topk, renormalize=False) + torch.testing.assert_close(iterative_output, + torch_output, + atol=2e-2, + rtol=0) @pytest.mark.parametrize("dtype", diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index cf5db368926b4..3d822fc0c7f99 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -20,7 +20,8 @@ else: fused_experts = None # type: ignore if current_platform.is_tpu(): - from .moe_pallas import fused_moe as fused_moe_pallas + # the iterative moe implementation is used until the moe_pallas is fixed + from .moe_torch_iterative import fused_moe as fused_moe_pallas else: fused_moe_pallas = None # type: ignore logger = init_logger(__name__) diff --git a/vllm/model_executor/layers/fused_moe/moe_torch_iterative.py b/vllm/model_executor/layers/fused_moe/moe_torch_iterative.py new file mode 100644 index 0000000000000..bcff55f4fdf16 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/moe_torch_iterative.py @@ -0,0 +1,51 @@ +import torch +import torch.nn.functional as F + + +def fused_moe( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, +) -> torch.Tensor: + """ + Args: + hidden_states: [*, hidden_size] + w1: [num_experts, intermediate_size * 2, hidden_size] + w2: [num_experts, hidden_size, intermediate_size] + gating_output: [*, num_experts] + """ + orig_shape = hidden_states.shape + hidden_size = hidden_states.shape[-1] + num_tokens = hidden_states.shape[:-1].numel() + num_experts = w1.shape[0] + intermediate_size = w2.shape[-1] + dtype = hidden_states.dtype + + hidden_states = hidden_states.view(num_tokens, hidden_size) + gating_output = gating_output.view(num_tokens, num_experts) + topk_weights = gating_output.softmax(dim=-1, dtype=torch.float) + topk_weights, selected_experts = topk_weights.topk(topk, dim=-1) + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + topk_weights = topk_weights.to(dtype) + + final_hidden_states = None + for expert_idx in range(num_experts): + expert_w1 = w1[expert_idx] + expert_w2 = w2[expert_idx] + expert_mask = (selected_experts == expert_idx) + expert_weights = (topk_weights * expert_mask).sum(dim=-1, keepdim=True) + x = F.linear(hidden_states, expert_w1) + gate = F.silu(x[:, :intermediate_size]) + x = x[:, intermediate_size:] * gate + x = F.linear(x, expert_w2) + current_hidden_states = x * expert_weights + if final_hidden_states is None: + final_hidden_states = current_hidden_states + else: + final_hidden_states = final_hidden_states + current_hidden_states + + return final_hidden_states.view(orig_shape) # type: ignore