diff --git a/csrc/moe/marlin_moe_ops.h b/csrc/moe/marlin_moe_ops.h index 78e2f5d346652..43d264e0770d6 100644 --- a/csrc/moe/marlin_moe_ops.h +++ b/csrc/moe/marlin_moe_ops.h @@ -7,7 +7,6 @@ torch::Tensor marlin_gemm_moe( const torch::Tensor& sorted_ids, const torch::Tensor& topk_weights, const torch::Tensor& topk_ids, const torch::Tensor& b_scales, const torch::Tensor& g_idx, const torch::Tensor& perm, - const torch::Tensor& expert_offsets, torch::Tensor& workspace, - int64_t size_m, int64_t size_n, int64_t size_k, bool is_k_full, - int64_t num_experts, int64_t topk, int64_t moe_block_size, + torch::Tensor& workspace, int64_t size_m, int64_t size_n, int64_t size_k, + bool is_k_full, int64_t num_experts, int64_t topk, int64_t moe_block_size, bool replicate_input, bool apply_weights); diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index 0b4b92e16f7b4..0ed9b1f64590a 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -2,23 +2,21 @@ #include "moe_ops.h" #include "marlin_moe_ops.h" -#include - -TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { +TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { // Apply topk softmax to the gating outputs. - ops.def( + m.def( "topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! " "token_expert_indices, Tensor gating_output) -> ()"); - ops.impl("topk_softmax", torch::kCUDA, &topk_softmax); + m.impl("topk_softmax", torch::kCUDA, &topk_softmax); - ops.def( + m.def( "marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, " "Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! " "g_idx, Tensor! perm, Tensor! workspace, int size_m, int size_n, int " "size_k, bool is_k_full, int num_experts, int topk, int moe_block_size, " "bool replicate_input, bool apply_weights) -> Tensor"); - ops.impl("marlin_gemm_moe", torch::kCUDA, &marlin_gemm_moe); + m.impl("marlin_gemm_moe", torch::kCUDA, &marlin_gemm_moe); } REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 3774a442f9180..64e47ad803232 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -384,8 +384,7 @@ def fused_topk( topk, dtype=torch.int32, device=hidden_states.device) - from pprint import pprint - pprint(vars(ops)) + ops.topk_softmax( topk_weights, topk_ids, @@ -692,8 +691,7 @@ def single_marlin_moe( block_size_m = config['BLOCK_SIZE_M'] - sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( - topk_ids, block_size_m, E) + sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E) max_workspace_size = (N // 64) * 16 workspace = torch.zeros(max_workspace_size, @@ -781,8 +779,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor, block_size_m = config['BLOCK_SIZE_M'] - sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( - topk_ids, block_size_m, E) + sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E) max_workspace_size = ((M + 255) // 256) * (max(2 * N, K) // 64) * 16 workspace = torch.zeros(max_workspace_size, @@ -806,8 +803,5 @@ def fused_marlin_moe(hidden_states: torch.Tensor, w2_scale, g_idx2, rand_perm2, workspace, M, K, N, True, E, topk, block_size_m, False, True) - # intermediate_cache3 = torch.zeros((M, topk, K), - # device=hidden_states.device, - # dtype=hidden_states.dtype) return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1)