diff --git a/csrc/moe/marlin_kernels/marlin_moe_kernel.h b/csrc/moe/marlin_kernels/marlin_moe_kernel.h index 0bd3017226c94..c08321285bd07 100644 --- a/csrc/moe/marlin_kernels/marlin_moe_kernel.h +++ b/csrc/moe/marlin_kernels/marlin_moe_kernel.h @@ -247,9 +247,7 @@ template shared // fetch pipeline - const bool has_act_order, // whether act_order is enabled - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale + const bool has_act_order // whether act_order is enabled > __device__ inline void MarlinMoESingle( const int4* __restrict__ A, // fp16 input matrix of shape mxk @@ -261,6 +259,8 @@ __device__ inline void MarlinMoESingle( // (k/groupsize)xn const int* __restrict__ g_idx, // int32 group indices of shape k const int* __restrict__ expert_offsets, + int group_blocks, // number of consecutive 16x16 blocks + // with a separate quantization scale int num_groups, // number of scale groups per output channel int expert_idx, // idx of current expert int num_experts, // number of experts @@ -289,8 +289,8 @@ __device__ inline void MarlinMoESingle( int n_tiles = prob_n / 16 / thread_n_blocks; int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x); - if constexpr (!has_act_order && group_blocks != -1) { - if (group_blocks >= thread_k_blocks) { + if constexpr (!has_act_order) { + if (group_blocks != -1 && group_blocks >= thread_k_blocks) { // Ensure that the number of tiles in each stripe is a multiple of the // groupsize; this avoids an annoying special case where a stripe starts // in the middle of group. @@ -384,11 +384,11 @@ __device__ inline void MarlinMoESingle( // Scale sizes/strides without act_order int s_gl_stride = prob_n / 8; constexpr int s_sh_stride = 16 * thread_n_blocks / 8; - constexpr int s_tb_groups = + int s_tb_groups = !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks ? thread_k_blocks / group_blocks : 1; - constexpr int s_sh_stage = s_tb_groups * s_sh_stride; + int s_sh_stage = s_tb_groups * s_sh_stride; int s_gl_rd_delta = s_gl_stride; // Scale size/strides with act_order constexpr int tb_k = 16 * thread_k_blocks; @@ -432,7 +432,7 @@ __device__ inline void MarlinMoESingle( // No act_order int s_gl_rd; if constexpr (!has_act_order) { - if constexpr (group_blocks == -1) { + if (group_blocks == -1) { s_gl_rd = s_sh_stride * slice_col + threadIdx.x; } else { s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + @@ -446,7 +446,7 @@ __device__ inline void MarlinMoESingle( // we scale a `half2` tile in column-major layout in the former and in // row-major in the latter case. int s_sh_rd; - if constexpr (group_blocks != -1) + if (group_blocks != -1) s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4; else @@ -612,10 +612,10 @@ __device__ inline void MarlinMoESingle( } } } else { - if constexpr (group_blocks != -1) { + if (group_blocks != -1) { int4* sh_s_stage = sh_s + s_sh_stage * pipe; - if constexpr (group_blocks >= thread_k_blocks) { + if (group_blocks >= thread_k_blocks) { // Only fetch scales if this tile starts a new group if (pipe % (group_blocks / thread_k_blocks) == 0) { if (s_sh_wr_pred) { @@ -703,8 +703,8 @@ __device__ inline void MarlinMoESingle( if constexpr (!has_act_order) { // No act-order case - if constexpr (group_blocks != -1) { - if constexpr (group_blocks >= thread_k_blocks) { + if (group_blocks != -1) { + if (group_blocks >= thread_k_blocks) { int4* sh_s_stage = sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * (pipe / (group_blocks / thread_k_blocks))); @@ -824,7 +824,7 @@ __device__ inline void MarlinMoESingle( scale4(frag_b0, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 0); } else { - if constexpr (group_blocks != -1) { + if (group_blocks != -1) { scale(frag_b0, frag_s[k % 2][j], 0); } } @@ -835,7 +835,7 @@ __device__ inline void MarlinMoESingle( act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 1); } else { - if constexpr (group_blocks != -1) { + if (group_blocks != -1) { scale(frag_b1, frag_s[k % 2][j], 1); } } @@ -1009,9 +1009,10 @@ __device__ inline void MarlinMoESingle( // For per-column quantization we finally apply the scale here (only for // 4-bit) - if constexpr (!has_act_order && group_blocks == -1 && - w_type.size_bits() == 4) { - res = __hmul2(res, s[0]); + if constexpr (!has_act_order && w_type.size_bits() == 4) { + if (group_blocks == -1) { + res = __hmul2(res, s[0]); + } } ((half2*)sh)[idx] = res; @@ -1140,36 +1141,32 @@ __device__ inline void MarlinMoESingle( if (slice_iters == 0) { cp_async_wait<0>(); bool last = slice_idx == slice_count - 1; - if constexpr (!has_act_order && group_blocks == -1) { + if constexpr (!has_act_order) { if constexpr (w_type.size_bits() == 8) { - if (s_sh_wr_pred) { - cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); - } - cp_async_fence(); - } else { - // For 4-bit per-column scales, we only fetch them here in the - // final step before write-out - if (last) { + if (group_blocks == -1) { if (s_sh_wr_pred) { cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); } cp_async_fence(); } + } else { + if (group_blocks == -1) { + // For 4-bit per-column scales, we only fetch them here in the + // final step before write-out + if (last) { + if (s_sh_wr_pred) { + cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); + } + cp_async_fence(); + } + } } } thread_block_reduce(); - if constexpr (!has_act_order && group_blocks == -1) { + if constexpr (!has_act_order) { if constexpr (w_type.size_bits() == 8) { - cp_async_wait<0>(); - __syncthreads(); - if (threadIdx.x / 32 < thread_n_blocks / 4) { - reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; - reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; - } - - } else { - if (last) { + if (group_blocks == -1) { cp_async_wait<0>(); __syncthreads(); if (threadIdx.x / 32 < thread_n_blocks / 4) { @@ -1177,15 +1174,26 @@ __device__ inline void MarlinMoESingle( reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; } } + + } else { + if (group_blocks == -1) { + if (last) { + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < thread_n_blocks / 4) { + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + } + } + } } } // For 8-bit channelwise, we apply the scale before the global reduction // that converts the fp32 results to fp16 (so that we avoid possible // overflow in fp16) - if constexpr (!has_act_order && group_blocks == -1 && - w_type.size_bits() == 8) { - if (threadIdx.x / 32 < thread_n_blocks / 4) { + if constexpr (!has_act_order && w_type.size_bits() == 8) { + if (group_blocks == -1 && threadIdx.x / 32 < thread_n_blocks / 4) { #pragma unroll for (int i = 0; i < thread_m_blocks; i++) { #pragma unroll @@ -1249,9 +1257,7 @@ template shared // fetch pipeline - const bool has_act_order, // whether act_order is enabled - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale + const bool has_act_order // whether act_order is enabled > __global__ void MarlinMoE( const int4* __restrict__ A, // fp16 input matrix of shape mxk @@ -1263,6 +1269,8 @@ __global__ void MarlinMoE( // (k/groupsize)xn const int* __restrict__ g_idx, // int32 group indices of shape k const int* __restrict__ expert_offsets, + int group_blocks, // number of consecutive 16x16 blocks + // with a separate quantization scale int num_groups, // number of scale groups per output channel int expert_idx, // idx of current expert int num_experts, // number of experts @@ -1309,31 +1317,31 @@ __global__ void MarlinMoE( if (max_block == 1) { MarlinMoESingle( + stages, has_act_order>( A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, - expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, - prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, + expert_offsets, group_blocks, num_groups, expert_idx, num_experts, topk, + prob_m, prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, current_m_block); } else if (max_block == 2) { MarlinMoESingle( + stages, has_act_order>( A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, - expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, - prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, + expert_offsets, group_blocks, num_groups, expert_idx, num_experts, topk, + prob_m, prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, current_m_block); } else if (max_block == 3) { MarlinMoESingle( + stages, has_act_order>( A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, - expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, - prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, + expert_offsets, group_blocks, num_groups, expert_idx, num_experts, topk, + prob_m, prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, current_m_block); } else { MarlinMoESingle( + stages, has_act_order>( A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, - expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, - prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, + expert_offsets, group_blocks, num_groups, expert_idx, num_experts, topk, + prob_m, prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, current_m_block); } } @@ -1346,9 +1354,7 @@ template shared // fetch pipeline - const bool has_act_order, // whether act_order is enabled - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale + const bool has_act_order // whether act_order is enabled > __global__ void MarlinMoE( const int4* __restrict__ A, // fp16 input matrix of shape mxk @@ -1360,6 +1366,8 @@ __global__ void MarlinMoE( // (k/groupsize)xn const int* __restrict__ g_idx, // int32 group indices of shape k const int* __restrict__ expert_offsets, + int group_blocks, // number of consecutive 16x16 blocks + // with a separate quantization scale int num_groups, // number of scale groups per output channel int expert_idx, // idx of current expert int num_experts, // number of experts @@ -1396,30 +1404,26 @@ static constexpr int min_thread_n = 64; static constexpr int min_thread_k = 64; #define __CALL_IF_MOE(W_TYPE, THREAD_N_BLOCKS, THREAD_K_BLOCKS, HAS_ACT_ORDER, \ - GROUP_BLOCKS, NUM_THREADS) \ + NUM_THREADS) \ else if (q_type == W_TYPE && thread_n_blocks == THREAD_N_BLOCKS && \ thread_k_blocks == THREAD_K_BLOCKS && \ - has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \ - num_threads == NUM_THREADS) { \ - cudaFuncSetAttribute( \ - MarlinMoE, \ - cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ + has_act_order == HAS_ACT_ORDER && num_threads == NUM_THREADS) { \ + cudaFuncSetAttribute(MarlinMoE, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, \ + max_shared_mem); \ MarlinMoE \ + STAGES, HAS_ACT_ORDER> \ <<>>( \ A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \ - g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \ - num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \ - replicate_input, apply_weights, m_block, max_par, \ + g_idx_ptr, expert_offsets_ptr, group_blocks, num_groups, \ + expert_idx, num_experts, topk, prob_m, prob_n, prob_k, tot_m, \ + locks, replicate_input, apply_weights, m_block, max_par, \ cfg_max_m_blocks); \ } -#define GPTQ_CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) +#define GPTQ_CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, NUM_THREADS) } // namespace marlin_moe