From e0e5a749b7a41c4554ed02e659b4bf90bc8ac04a Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Tue, 1 Oct 2024 03:19:05 -0400 Subject: [PATCH] Michael's feedback, cleanup --- csrc/moe/marlin_moe_ops.cu | 6 +-- csrc/moe/marlin_moe_ops.h | 5 +- csrc/moe/torch_bindings.cpp | 5 +- tests/kernels/test_awq_marlin.py | 2 - tests/kernels/test_moe.py | 8 +-- vllm/_custom_ops.py | 6 +-- .../layers/fused_moe/fused_marlin_moe.py | 49 +++++++++---------- 7 files changed, 38 insertions(+), 43 deletions(-) diff --git a/csrc/moe/marlin_moe_ops.cu b/csrc/moe/marlin_moe_ops.cu index e540f07236498..ec0836131ba82 100644 --- a/csrc/moe/marlin_moe_ops.cu +++ b/csrc/moe/marlin_moe_ops.cu @@ -484,9 +484,9 @@ torch::Tensor marlin_gemm_moe( torch::Tensor& b_zeros, const torch::Tensor& g_idx, const torch::Tensor& perm, torch::Tensor& workspace, vllm::ScalarTypeTorchPtr const& b_q_type, int64_t size_m, int64_t size_n, - int64_t size_k, bool is_k_full, bool has_zp, int64_t num_experts, - int64_t topk, int64_t moe_block_size, bool replicate_input, - bool apply_weights) { + int64_t size_k, bool is_k_full, int64_t num_experts, int64_t topk, + int64_t moe_block_size, bool replicate_input, bool apply_weights) { + bool has_zp = b_zeros.size(1) != 0; if (has_zp) { TORCH_CHECK(*b_q_type == vllm::kU4 || *b_q_type == vllm::kU8, "b_q_type must be u4 or u8 when has_zp = True. Got = ", diff --git a/csrc/moe/marlin_moe_ops.h b/csrc/moe/marlin_moe_ops.h index 0a54d93cedebc..0013787a623de 100644 --- a/csrc/moe/marlin_moe_ops.h +++ b/csrc/moe/marlin_moe_ops.h @@ -11,6 +11,5 @@ torch::Tensor marlin_gemm_moe( torch::Tensor& b_zeros, const torch::Tensor& g_idx, const torch::Tensor& perm, torch::Tensor& workspace, vllm::ScalarTypeTorchPtr const& b_q_type, int64_t size_m, int64_t size_n, - int64_t size_k, bool is_k_full, bool has_zp, int64_t num_experts, - int64_t topk, int64_t moe_block_size, bool replicate_input, - bool apply_weights); + int64_t size_k, bool is_k_full, int64_t num_experts, int64_t topk, + int64_t moe_block_size, bool replicate_input, bool apply_weights); diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index 85098df34b2d0..576305d48ae47 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -15,9 +15,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { "Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! " "b_zeros, Tensor! g_idx, Tensor! perm, Tensor! workspace, " "__torch__.torch.classes._core_C.ScalarType b_q_type, int size_m, " - "int size_n, int size_k, bool is_k_full, bool has_zp, int num_experts, " - "int topk, int moe_block_size, bool replicate_input, bool apply_weights)" - " -> Tensor"); + "int size_n, int size_k, bool is_k_full, int num_experts, int topk, int " + "moe_block_size, bool replicate_input, bool apply_weights) -> Tensor"); m.impl("marlin_gemm_moe", torch::kCUDA, &marlin_gemm_moe); #endif } diff --git a/tests/kernels/test_awq_marlin.py b/tests/kernels/test_awq_marlin.py index 338f46cbe09fb..f1a0b09e8e464 100644 --- a/tests/kernels/test_awq_marlin.py +++ b/tests/kernels/test_awq_marlin.py @@ -87,7 +87,6 @@ def test_fused_marlin_moe_awq( score, topk_weights, topk_ids, - has_zero_point=True, w1_zeros=zp1, w2_zeros=zp2, num_bits=num_bits, @@ -155,7 +154,6 @@ def test_single_marlin_moe_multiply_awq( score, topk, renormalize=False, - has_zero_point=True, w_zeros=zp, num_bits=num_bits) diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index 360ef1330bd69..b73c45b9cd198 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -234,12 +234,14 @@ def test_fused_marlin_moe( device="cuda", requires_grad=False) - zp = torch.empty((0), dtype=dtype, device="cuda", requires_grad=False) - + zp = torch.empty((0, 0), + dtype=dtype, + device="cuda", + requires_grad=False) opcheck(torch.ops._moe_C.marlin_gemm_moe, (a, qweight1, sorted_token_ids, topk_weights, topk_ids, scales1, zp, g_idx1, sort_indices1, workspace, quant_type, m, - 2 * n, k, True, False, e, topk, block_size_m, True, False)) + 2 * n, k, True, e, topk, block_size_m, True, False)) @pytest.mark.skip("This test is here for the sake of debugging, " diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index bc7b4293c119e..6081fa674579c 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -822,9 +822,9 @@ def marlin_gemm_moe_fake(a: torch.Tensor, b_q_weights: torch.Tensor, b_zero_points: torch.Tensor, g_idx: torch.Tensor, perm: torch.Tensor, workspace: torch.Tensor, b_q_type: ScalarType, size_m: int, size_n: int, - size_k: int, is_k_full: bool, - has_zero_point: bool, num_experts: int, topk: int, - moe_block_size: int, replicate_input: bool, + size_k: int, is_k_full: bool, num_experts: int, + topk: int, moe_block_size: int, + replicate_input: bool, apply_weights: bool) -> torch.Tensor: return torch.empty((size_m, topk, size_n), dtype=a.dtype, diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index e57b15936aa8b..466b0edd81fe7 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -24,7 +24,6 @@ def single_marlin_moe( gating_output: torch.Tensor, topk: int, renormalize: bool, - has_zero_point: bool = False, g_idx: Optional[torch.Tensor] = None, sort_indices: Optional[torch.Tensor] = None, w_zeros: Optional[torch.Tensor] = None, @@ -93,11 +92,9 @@ def single_marlin_moe( device=hidden_states.device, requires_grad=False) - if has_zero_point: - assert w_zeros is not None and w_zeros.nelement() > 0 - + has_zero_point = w_zeros is not None if w_zeros is None: - w_zeros = torch.empty((0), + w_zeros = torch.empty((0, 0), dtype=hidden_states.dtype, device=hidden_states.device, requires_grad=False) @@ -119,7 +116,7 @@ def single_marlin_moe( intermediate_cache = torch.ops._moe_C.marlin_gemm_moe( hidden_states, w, sorted_token_ids, topk_weights, topk_ids, scales, w_zeros, g_idx, sort_indices, workspace, scalar_type, M, N, K, - is_k_full, has_zero_point, E, topk, block_size_m, True, False) + is_k_full, E, topk, block_size_m, True, False) return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1) @@ -133,7 +130,6 @@ def fused_marlin_moe( gating_output: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, - has_zero_point: bool = False, g_idx1: Optional[torch.Tensor] = None, g_idx2: Optional[torch.Tensor] = None, sort_indices1: Optional[torch.Tensor] = None, @@ -187,6 +183,20 @@ def fused_marlin_moe( assert hidden_states.dtype == torch.float16 assert num_bits in [4, 8] + has_no_act_order = (g_idx1 is None and g_idx2 is None + and sort_indices1 is None and sort_indices2 is None) + has_all_act_order = (g_idx1 is not None and g_idx2 is not None + and sort_indices1 is not None + and sort_indices2 is not None) + assert has_no_act_order or has_all_act_order, ( + "g_idx and sorted_indices " + "must be all not None or must be all None") + + has_no_zp = w1_zeros is None and w2_zeros is None + has_all_zp = w1_zeros is not None and w2_zeros is not None + assert has_no_zp or has_all_zp, ("zero points must be both not None or " + "must be both None") + M, K = hidden_states.shape E = w1.shape[0] N = w2.shape[1] * 16 @@ -213,47 +223,36 @@ def fused_marlin_moe( device="cuda", requires_grad=False) - if has_zero_point: - assert w1_zeros is not None and w1_zeros.nelement() > 0 - assert w2_zeros is not None and w2_zeros.nelement() > 0 - - if w1_zeros is None: - w1_zeros = torch.empty((0), + if has_no_zp: + w1_zeros = torch.empty((0, 0), dtype=hidden_states.dtype, device=hidden_states.device, requires_grad=False) - if w2_zeros is None: - w2_zeros = torch.empty((0), + w2_zeros = torch.empty((0, 0), dtype=hidden_states.dtype, device=hidden_states.device, requires_grad=False) - if g_idx1 is None: + if has_no_act_order: g_idx1 = torch.empty((0, 0), dtype=torch.int32, device=hidden_states.device, requires_grad=False) - - if g_idx2 is None: g_idx2 = torch.empty((0, 0), dtype=torch.int32, device=hidden_states.device, requires_grad=False) - - if sort_indices1 is None: sort_indices1 = torch.empty((0), dtype=torch.int32, device=hidden_states.device, requires_grad=False) - - if sort_indices2 is None: sort_indices2 = torch.empty((0, 0), dtype=torch.int32, device=hidden_states.device, requires_grad=False) - scalar_type1 = get_scalar_type(num_bits, has_zero_point) - scalar_type2 = get_scalar_type(num_bits, has_zero_point) + scalar_type1 = get_scalar_type(num_bits, has_all_zp) + scalar_type2 = get_scalar_type(num_bits, has_all_zp) intermediate_cache2 = torch.empty( (M * topk_ids.shape[1], N), @@ -277,7 +276,6 @@ def fused_marlin_moe( 2 * N, K, is_k_full, - has_zero_point, E, topk, block_size_m, @@ -303,7 +301,6 @@ def fused_marlin_moe( K, N, is_k_full, - has_zero_point, E, topk, block_size_m,