From c0c13ec26f64b504dc1bde36ee13d7b79a012290 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Thu, 19 Sep 2024 10:41:36 -0400 Subject: [PATCH 1/5] Dynamic group blocks in marlin MoE --- csrc/moe/marlin_moe_ops.cu | 135 ++++++++++++++++++------------------- 1 file changed, 67 insertions(+), 68 deletions(-) diff --git a/csrc/moe/marlin_moe_ops.cu b/csrc/moe/marlin_moe_ops.cu index 49cc03f827f68..d22ba7fe4335a 100644 --- a/csrc/moe/marlin_moe_ops.cu +++ b/csrc/moe/marlin_moe_ops.cu @@ -344,9 +344,7 @@ template shared // fetch pipeline - const bool has_act_order, // whether act_order is enabled - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale + const bool has_act_order // whether act_order is enabled > __device__ inline void MarlinMoESingle( const int4* __restrict__ A, // fp16 input matrix of shape mxk @@ -358,6 +356,8 @@ __device__ inline void MarlinMoESingle( // (k/groupsize)xn const int* __restrict__ g_idx, // int32 group indices of shape k const int* __restrict__ expert_offsets, + int group_blocks, // number of consecutive 16x16 blocks + // with a separate quantization scale int num_groups, // number of scale groups per output channel int expert_idx, // idx of current expert int num_experts, // number of experts @@ -386,8 +386,8 @@ __device__ inline void MarlinMoESingle( int n_tiles = prob_n / 16 / thread_n_blocks; int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x); - if constexpr (!has_act_order && group_blocks != -1) { - if (group_blocks >= thread_k_blocks) { + if constexpr (!has_act_order) { + if (group_blocks != -1 && group_blocks >= thread_k_blocks) { // Ensure that the number of tiles in each stripe is a multiple of the // groupsize; this avoids an annoying special case where a stripe starts // in the middle of group. @@ -481,11 +481,11 @@ __device__ inline void MarlinMoESingle( // Scale sizes/strides without act_order int s_gl_stride = prob_n / 8; constexpr int s_sh_stride = 16 * thread_n_blocks / 8; - constexpr int s_tb_groups = + int s_tb_groups = !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks ? thread_k_blocks / group_blocks : 1; - constexpr int s_sh_stage = s_tb_groups * s_sh_stride; + int s_sh_stage = s_tb_groups * s_sh_stride; int s_gl_rd_delta = s_gl_stride; // Scale size/strides with act_order constexpr int tb_k = 16 * thread_k_blocks; @@ -529,7 +529,7 @@ __device__ inline void MarlinMoESingle( // No act_order int s_gl_rd; if constexpr (!has_act_order) { - if constexpr (group_blocks == -1) { + if (group_blocks == -1) { s_gl_rd = s_sh_stride * slice_col + threadIdx.x; } else { s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + @@ -543,7 +543,7 @@ __device__ inline void MarlinMoESingle( // we scale a `half2` tile in column-major layout in the former and in // row-major in the latter case. int s_sh_rd; - if constexpr (group_blocks != -1) + if (group_blocks != -1) s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4; else @@ -709,10 +709,10 @@ __device__ inline void MarlinMoESingle( } } } else { - if constexpr (group_blocks != -1) { + if (group_blocks != -1) { int4* sh_s_stage = sh_s + s_sh_stage * pipe; - if constexpr (group_blocks >= thread_k_blocks) { + if (group_blocks >= thread_k_blocks) { // Only fetch scales if this tile starts a new group if (pipe % (group_blocks / thread_k_blocks) == 0) { if (s_sh_wr_pred) { @@ -800,8 +800,8 @@ __device__ inline void MarlinMoESingle( if constexpr (!has_act_order) { // No act-order case - if constexpr (group_blocks != -1) { - if constexpr (group_blocks >= thread_k_blocks) { + if (group_blocks != -1) { + if (group_blocks >= thread_k_blocks) { int4* sh_s_stage = sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * (pipe / (group_blocks / thread_k_blocks))); @@ -921,7 +921,7 @@ __device__ inline void MarlinMoESingle( scale4(frag_b0, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 0); } else { - if constexpr (group_blocks != -1) { + if (group_blocks != -1) { scale(frag_b0, frag_s[k % 2][j], 0); } } @@ -932,7 +932,7 @@ __device__ inline void MarlinMoESingle( act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 1); } else { - if constexpr (group_blocks != -1) { + if (group_blocks != -1) { scale(frag_b1, frag_s[k % 2][j], 1); } } @@ -1106,9 +1106,10 @@ __device__ inline void MarlinMoESingle( // For per-column quantization we finally apply the scale here (only for // 4-bit) - if constexpr (!has_act_order && group_blocks == -1 && - w_type.size_bits() == 4) { - res = __hmul2(res, s[0]); + if constexpr (!has_act_order && w_type.size_bits() == 4) { + if (group_blocks == -1) { + res = __hmul2(res, s[0]); + } } ((half2*)sh)[idx] = res; @@ -1237,36 +1238,32 @@ __device__ inline void MarlinMoESingle( if (slice_iters == 0) { cp_async_wait<0>(); bool last = slice_idx == slice_count - 1; - if constexpr (!has_act_order && group_blocks == -1) { + if constexpr (!has_act_order) { if constexpr (w_type.size_bits() == 8) { - if (s_sh_wr_pred) { - cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); - } - cp_async_fence(); - } else { - // For 4-bit per-column scales, we only fetch them here in the - // final step before write-out - if (last) { + if (group_blocks == -1) { if (s_sh_wr_pred) { cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); } cp_async_fence(); } + } else { + if (group_blocks == -1) { + // For 4-bit per-column scales, we only fetch them here in the + // final step before write-out + if (last) { + if (s_sh_wr_pred) { + cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); + } + cp_async_fence(); + } + } } } thread_block_reduce(); if constexpr (!has_act_order && group_blocks == -1) { if constexpr (w_type.size_bits() == 8) { - cp_async_wait<0>(); - __syncthreads(); - if (threadIdx.x / 32 < thread_n_blocks / 4) { - reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; - reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; - } - - } else { - if (last) { + if (group_blocks == -1) { cp_async_wait<0>(); __syncthreads(); if (threadIdx.x / 32 < thread_n_blocks / 4) { @@ -1274,15 +1271,25 @@ __device__ inline void MarlinMoESingle( reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; } } + } else { + if (group_blocks == -1) { + if (last) { + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < thread_n_blocks / 4) { + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + } + } + } } } // For 8-bit channelwise, we apply the scale before the global reduction // that converts the fp32 results to fp16 (so that we avoid possible // overflow in fp16) - if constexpr (!has_act_order && group_blocks == -1 && - w_type.size_bits() == 8) { - if (threadIdx.x / 32 < thread_n_blocks / 4) { + if constexpr (!has_act_order && && w_type.size_bits() == 8) { + if (group_blocks == -1 && threadIdx.x / 32 < thread_n_blocks / 4) { #pragma unroll for (int i = 0; i < thread_m_blocks; i++) { #pragma unroll @@ -1346,9 +1353,7 @@ template shared // fetch pipeline - const bool has_act_order, // whether act_order is enabled - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale + const bool has_act_order // whether act_order is enabled > __global__ void MarlinMoE( const int4* __restrict__ A, // fp16 input matrix of shape mxk @@ -1360,6 +1365,8 @@ __global__ void MarlinMoE( // (k/groupsize)xn const int* __restrict__ g_idx, // int32 group indices of shape k const int* __restrict__ expert_offsets, + int group_blocks, // number of consecutive 16x16 blocks + // with a separate quantization scale int num_groups, // number of scale groups per output channel int expert_idx, // idx of current expert int num_experts, // number of experts @@ -1406,30 +1413,30 @@ __global__ void MarlinMoE( if (max_block == 1) { MarlinMoESingle( + stages, has_act_order>( A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, - expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, + expert_offsets, group_blocks, num_groups, expert_idx, num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, current_m_block); } else if (max_block == 2) { MarlinMoESingle( + stages, has_act_order>( A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, - expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, + expert_offsets, group_blocks, num_groups, expert_idx, num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, current_m_block); } else if (max_block == 3) { MarlinMoESingle( + stages, has_act_order>( A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, - expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, + expert_offsets, group_blocks, num_groups, expert_idx, num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, current_m_block); } else { MarlinMoESingle( + stages, has_act_order>( A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, - expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, + expert_offsets, group_blocks, num_groups, expert_idx, num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, current_m_block); } @@ -1460,9 +1467,7 @@ template shared // fetch pipeline - const bool has_act_order, // whether act_order is enabled - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale + const bool has_act_order // whether act_order is enabled > __global__ void MarlinMoE( const int4* __restrict__ A, // fp16 input matrix of shape mxk @@ -1474,6 +1479,8 @@ __global__ void MarlinMoE( // (k/groupsize)xn const int* __restrict__ g_idx, // int32 group indices of shape k const int* __restrict__ expert_offsets, + int group_blocks, // number of consecutive 16x16 blocks + // with a separate quantization scale int num_groups, // number of scale groups per output channel int expert_idx, // idx of current expert int num_experts, // number of experts @@ -1510,20 +1517,19 @@ static constexpr int min_thread_n = 64; static constexpr int min_thread_k = 64; #define __CALL_IF_MOE(W_TYPE, THREAD_N_BLOCKS, THREAD_K_BLOCKS, HAS_ACT_ORDER, \ - GROUP_BLOCKS, NUM_THREADS) \ + NUM_THREADS) \ else if (q_type == W_TYPE && thread_n_blocks == THREAD_N_BLOCKS && \ thread_k_blocks == THREAD_K_BLOCKS && \ - has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \ - num_threads == NUM_THREADS) { \ + has_act_order == HAS_ACT_ORDER && num_threads == NUM_THREADS) { \ cudaFuncSetAttribute( \ MarlinMoE, \ + STAGES, HAS_ACT_ORDER>, \ cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ MarlinMoE \ + STAGES, HAS_ACT_ORDER> \ <<>>( \ A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \ - g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \ + g_idx_ptr, expert_offsets_ptr, group_blocks, num_groups, expert_idx, \ num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \ replicate_input, apply_weights, m_block, max_par, \ exec_cfg.max_m_blocks); \ @@ -1704,15 +1710,8 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k, } #define CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - \ - __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) + __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, NUM_THREADS) void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, const void* sorted_ids, const void* topk_weights, From 72d150362b1ee0b069be4c8175c141c5622b50ea Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Thu, 19 Sep 2024 16:32:06 +0000 Subject: [PATCH 2/5] fixes --- csrc/moe/marlin_moe_ops.cu | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/csrc/moe/marlin_moe_ops.cu b/csrc/moe/marlin_moe_ops.cu index d22ba7fe4335a..6fdf75863c99a 100644 --- a/csrc/moe/marlin_moe_ops.cu +++ b/csrc/moe/marlin_moe_ops.cu @@ -1261,7 +1261,7 @@ __device__ inline void MarlinMoESingle( } thread_block_reduce(); - if constexpr (!has_act_order && group_blocks == -1) { + if constexpr (!has_act_order) { if constexpr (w_type.size_bits() == 8) { if (group_blocks == -1) { cp_async_wait<0>(); @@ -1288,7 +1288,7 @@ __device__ inline void MarlinMoESingle( // For 8-bit channelwise, we apply the scale before the global reduction // that converts the fp32 results to fp16 (so that we avoid possible // overflow in fp16) - if constexpr (!has_act_order && && w_type.size_bits() == 8) { + if constexpr (!has_act_order && w_type.size_bits() == 8) { if (group_blocks == -1 && threadIdx.x / 32 < thread_n_blocks / 4) { #pragma unroll for (int i = 0; i < thread_m_blocks; i++) { @@ -1713,7 +1713,7 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k, __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, NUM_THREADS) \ __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, NUM_THREADS) -void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, +void marlin_mm_moe(const void* A, const void* B, void* C, const void* sorted_ids, const void* topk_weights, const void* topk_ids, const void* s, const void* g_idx, const void* perm, void* a_tmp, void* expert_offsets, @@ -1888,6 +1888,8 @@ torch::Tensor marlin_gemm_moe( TORCH_CHECK(*b_q_type == vllm::kU4B8 || *b_q_type == vllm::kU8B128, "b_q_type must be uint4b8 or uint8b128. Got = ", b_q_type->str()); + TORCH_CHECK(is_k_full, "NYI: Marlin MoE kernel does not currently support !is_k_full case."); + int pack_factor = 32 / b_q_type->size_bits(); int max_par = 4; @@ -1945,7 +1947,7 @@ torch::Tensor marlin_gemm_moe( } } - marlin_moe::marlin_mm_moe_f16i4( + marlin_moe::marlin_mm_moe( a.data_ptr(), b_q_weights.data_ptr(), c.data_ptr(), sorted_ids.data_ptr(), topk_weights.data_ptr(), topk_ids.data_ptr(), b_scales.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), From 00adeed583d0d85384e6bd639a1f32c0cf18271d Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Fri, 20 Sep 2024 15:25:08 +0000 Subject: [PATCH 3/5] format --- csrc/moe/marlin_moe_ops.cu | 65 +++++++++++++++++++------------------- 1 file changed, 33 insertions(+), 32 deletions(-) diff --git a/csrc/moe/marlin_moe_ops.cu b/csrc/moe/marlin_moe_ops.cu index 6fdf75863c99a..519e3c75a5424 100644 --- a/csrc/moe/marlin_moe_ops.cu +++ b/csrc/moe/marlin_moe_ops.cu @@ -344,7 +344,7 @@ template shared // fetch pipeline - const bool has_act_order // whether act_order is enabled + const bool has_act_order // whether act_order is enabled > __device__ inline void MarlinMoESingle( const int4* __restrict__ A, // fp16 input matrix of shape mxk @@ -1353,7 +1353,7 @@ template shared // fetch pipeline - const bool has_act_order // whether act_order is enabled + const bool has_act_order // whether act_order is enabled > __global__ void MarlinMoE( const int4* __restrict__ A, // fp16 input matrix of shape mxk @@ -1415,29 +1415,29 @@ __global__ void MarlinMoE( MarlinMoESingle( A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, - expert_offsets, group_blocks, num_groups, expert_idx, num_experts, topk, prob_m, - prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, + expert_offsets, group_blocks, num_groups, expert_idx, num_experts, topk, + prob_m, prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, current_m_block); } else if (max_block == 2) { MarlinMoESingle( A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, - expert_offsets, group_blocks, num_groups, expert_idx, num_experts, topk, prob_m, - prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, + expert_offsets, group_blocks, num_groups, expert_idx, num_experts, topk, + prob_m, prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, current_m_block); } else if (max_block == 3) { MarlinMoESingle( A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, - expert_offsets, group_blocks, num_groups, expert_idx, num_experts, topk, prob_m, - prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, + expert_offsets, group_blocks, num_groups, expert_idx, num_experts, topk, + prob_m, prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, current_m_block); } else { MarlinMoESingle( A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, - expert_offsets, group_blocks, num_groups, expert_idx, num_experts, topk, prob_m, - prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, + expert_offsets, group_blocks, num_groups, expert_idx, num_experts, topk, + prob_m, prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, current_m_block); } } @@ -1467,7 +1467,7 @@ template shared // fetch pipeline - const bool has_act_order // whether act_order is enabled + const bool has_act_order // whether act_order is enabled > __global__ void MarlinMoE( const int4* __restrict__ A, // fp16 input matrix of shape mxk @@ -1521,17 +1521,17 @@ static constexpr int min_thread_k = 64; else if (q_type == W_TYPE && thread_n_blocks == THREAD_N_BLOCKS && \ thread_k_blocks == THREAD_K_BLOCKS && \ has_act_order == HAS_ACT_ORDER && num_threads == NUM_THREADS) { \ - cudaFuncSetAttribute( \ - MarlinMoE, \ - cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ + cudaFuncSetAttribute(MarlinMoE, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, \ + max_shared_mem); \ MarlinMoE \ + STAGES, HAS_ACT_ORDER> \ <<>>( \ A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \ - g_idx_ptr, expert_offsets_ptr, group_blocks, num_groups, expert_idx, \ - num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \ - replicate_input, apply_weights, m_block, max_par, \ + g_idx_ptr, expert_offsets_ptr, group_blocks, num_groups, \ + expert_idx, num_experts, topk, prob_m, prob_n, prob_k, tot_m, \ + locks, replicate_input, apply_weights, m_block, max_par, \ exec_cfg.max_m_blocks); \ } @@ -1709,21 +1709,20 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k, return exec_config_t{0, {-1, -1, -1}}; } -#define CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, NUM_THREADS) \ +#define CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, NUM_THREADS) \ __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, NUM_THREADS) void marlin_mm_moe(const void* A, const void* B, void* C, - const void* sorted_ids, const void* topk_weights, - const void* topk_ids, const void* s, const void* g_idx, - const void* perm, void* a_tmp, void* expert_offsets, - int prob_m, int prob_n, int prob_k, void* workspace, - vllm::ScalarType const& q_type, bool has_act_order, - bool is_k_full, int num_groups, int group_size, - int num_experts, int topk, int moe_block_size, int dev, - cudaStream_t stream, int thread_k, int thread_n, - int sms, int max_par, bool replicate_input, - bool apply_weights) { + const void* sorted_ids, const void* topk_weights, + const void* topk_ids, const void* s, const void* g_idx, + const void* perm, void* a_tmp, void* expert_offsets, + int prob_m, int prob_n, int prob_k, void* workspace, + vllm::ScalarType const& q_type, bool has_act_order, + bool is_k_full, int num_groups, int group_size, + int num_experts, int topk, int moe_block_size, int dev, + cudaStream_t stream, int thread_k, int thread_n, int sms, + int max_par, bool replicate_input, bool apply_weights) { TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, ", ", prob_n, ", ", prob_k, "]"); @@ -1888,7 +1887,9 @@ torch::Tensor marlin_gemm_moe( TORCH_CHECK(*b_q_type == vllm::kU4B8 || *b_q_type == vllm::kU8B128, "b_q_type must be uint4b8 or uint8b128. Got = ", b_q_type->str()); - TORCH_CHECK(is_k_full, "NYI: Marlin MoE kernel does not currently support !is_k_full case."); + TORCH_CHECK( + is_k_full, + "NYI: Marlin MoE kernel does not currently support !is_k_full case."); int pack_factor = 32 / b_q_type->size_bits(); From 0f329261da18751f8e077f13ba5db6ec7177db14 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Wed, 25 Sep 2024 11:56:39 +0000 Subject: [PATCH 4/5] delete check --- csrc/moe/marlin_moe_ops.cu | 4 ---- 1 file changed, 4 deletions(-) diff --git a/csrc/moe/marlin_moe_ops.cu b/csrc/moe/marlin_moe_ops.cu index 5f34efeb323d3..dfe0437414013 100644 --- a/csrc/moe/marlin_moe_ops.cu +++ b/csrc/moe/marlin_moe_ops.cu @@ -485,10 +485,6 @@ torch::Tensor marlin_gemm_moe( TORCH_CHECK(*b_q_type == vllm::kU4B8 || *b_q_type == vllm::kU8B128, "b_q_type must be uint4b8 or uint8b128. Got = ", b_q_type->str()); - TORCH_CHECK( - is_k_full, - "NYI: Marlin MoE kernel does not currently support !is_k_full case."); - int pack_factor = 32 / b_q_type->size_bits(); int max_par = 4; From c553e53bf172129ba5f8dc1212f97f6085db7c28 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Wed, 25 Sep 2024 12:00:37 +0000 Subject: [PATCH 5/5] Delete duplicate macros --- csrc/moe/marlin_kernels/marlin_moe_kernel.h | 3 --- 1 file changed, 3 deletions(-) diff --git a/csrc/moe/marlin_kernels/marlin_moe_kernel.h b/csrc/moe/marlin_kernels/marlin_moe_kernel.h index 7337faa968414..c08321285bd07 100644 --- a/csrc/moe/marlin_kernels/marlin_moe_kernel.h +++ b/csrc/moe/marlin_kernels/marlin_moe_kernel.h @@ -1424,9 +1424,6 @@ static constexpr int min_thread_k = 64; #define GPTQ_CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, NUM_THREADS) \ __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, NUM_THREADS) } // namespace marlin_moe