diff --git a/python/sglang/srt/layers/attention/triton_ops/extend_attention.py b/python/sglang/srt/layers/attention/triton_ops/extend_attention.py index 238351d2995..74ad872bc2e 100644 --- a/python/sglang/srt/layers/attention/triton_ops/extend_attention.py +++ b/python/sglang/srt/layers/attention/triton_ops/extend_attention.py @@ -74,6 +74,7 @@ def _fwd_kernel( BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, USE_CUSTOM_MASK: tl.constexpr, + STORE_TRANSPOSE: tl.constexpr, ): cur_seq = tl.program_id(0) cur_head = tl.program_id(1) @@ -272,9 +273,18 @@ def _fwd_kernel( + cur_head * stride_oh + offs_dv[None, :] ) - tl.store( - O_Extend + offs_o, acc / deno[:, None], mask=mask_m[:, None] & mask_dv[None, :] - ) + if STORE_TRANSPOSE: + tl.store( + O_Extend + offs_o.T, + (acc / deno[:, None]).T, + mask=(mask_m[:, None] & mask_dv[None, :]).T, + ) + else: + tl.store( + O_Extend + offs_o, + acc / deno[:, None], + mask=mask_m[:, None] & mask_dv[None, :], + ) def extend_attention_fwd( @@ -319,8 +329,8 @@ def extend_attention_fwd( BLOCK_DV = triton.next_power_of_2(Lv) if is_hip_: - BLOCK_M, BLOCK_N = (64, 64) - num_warps = 4 + BLOCK_M, BLOCK_N = (32, 32) + num_warps = 2 else: if is_cuda_available and CUDA_CAPABILITY[0] >= 9: @@ -388,6 +398,7 @@ def extend_attention_fwd( Lq=Lq, Lv=Lv, USE_CUSTOM_MASK=USE_CUSTOM_MASK, + STORE_TRANSPOSE=is_hip_, num_warps=num_warps, num_stages=num_stages, **extra_kargs,