From a322051e31d81d74e1d03819e6c9585e24f8d023 Mon Sep 17 00:00:00 2001 From: Ke Bao Date: Thu, 6 Feb 2025 01:16:02 +0800 Subject: [PATCH] Support custom mask for Triton attention (#3317) --- .../srt/layers/attention/triton_backend.py | 20 +++++-- .../attention/triton_ops/extend_attention.py | 55 ++++++++++++++++--- test/srt/test_triton_attention_kernels.py | 43 +++++++++++++++ 3 files changed, 107 insertions(+), 11 deletions(-) diff --git a/python/sglang/srt/layers/attention/triton_backend.py b/python/sglang/srt/layers/attention/triton_backend.py index 3475df72192..4da1654864e 100644 --- a/python/sglang/srt/layers/attention/triton_backend.py +++ b/python/sglang/srt/layers/attention/triton_backend.py @@ -91,6 +91,7 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): qo_indptr = None custom_mask = None + mask_offsets = None else: kv_indptr[1 : bs + 1] = torch.cumsum( forward_batch.extend_prefix_lens, dim=0 @@ -115,6 +116,7 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): qo_indptr[1 : bs + 1] = torch.cumsum(forward_batch.extend_seq_lens, dim=0) qo_indptr = qo_indptr[: bs + 1] custom_mask = None + mask_offsets = None attn_logits = None max_extend_len = torch.max(forward_batch.extend_seq_lens).item() @@ -126,6 +128,7 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): kv_indices, qo_indptr, custom_mask, + mask_offsets, ) def init_cuda_graph_state(self, max_bs: int): @@ -180,6 +183,7 @@ def init_forward_metadata_capture_cuda_graph( kv_indices, None, None, + None, ) def init_forward_metadata_replay_cuda_graph( @@ -233,9 +237,15 @@ def forward_extend( layer, forward_batch.out_cache_loc, k, v ) - _, max_extend_len, kv_indptr, kv_indices, qo_indptr, custom_mask = ( - self.forward_metadata - ) + ( + _, + max_extend_len, + kv_indptr, + kv_indices, + qo_indptr, + custom_mask, + mask_offsets, + ) = self.forward_metadata self.extend_attention_fwd( q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), k.contiguous(), @@ -246,6 +256,8 @@ def forward_extend( qo_indptr, kv_indptr, kv_indices, + custom_mask, + mask_offsets, max_extend_len, layer.scaling, layer.logit_cap, @@ -271,7 +283,7 @@ def forward_decode( else: o = torch.empty_like(q) - attn_logits, _, kv_indptr, kv_indices, _, _ = self.forward_metadata + attn_logits, _, kv_indptr, kv_indices, _, _, _ = self.forward_metadata if save_kv_cache: forward_batch.token_to_kv_pool.set_kv_buffer( diff --git a/python/sglang/srt/layers/attention/triton_ops/extend_attention.py b/python/sglang/srt/layers/attention/triton_ops/extend_attention.py index 6c9976931d0..e070bc3a916 100644 --- a/python/sglang/srt/layers/attention/triton_ops/extend_attention.py +++ b/python/sglang/srt/layers/attention/triton_ops/extend_attention.py @@ -49,6 +49,8 @@ def _fwd_kernel( qo_indptr, kv_indptr, kv_indices, + mask_ptr, + mask_offsets, sm_scale, kv_group_num, stride_qbs, @@ -71,6 +73,7 @@ def _fwd_kernel( BLOCK_DV: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, + USE_CUSTOM_MASK: tl.constexpr, ): cur_seq = tl.program_id(0) cur_head = tl.program_id(1) @@ -81,6 +84,10 @@ def _fwd_kernel( cur_seq_len_extend = tl.load(qo_indptr + cur_seq + 1) - cur_seq_extend_start_idx cur_seq_kv_start_idx = tl.load(kv_indptr + cur_seq) cur_seq_len_prefix = tl.load(kv_indptr + cur_seq + 1) - cur_seq_kv_start_idx + cur_seq_len = cur_seq_len_prefix + cur_seq_len_extend + + if USE_CUSTOM_MASK: + cur_seq_mask_start_idx = tl.load(mask_offsets + cur_seq) offs_d = tl.arange(0, BLOCK_DMODEL) offs_dv = tl.arange(0, BLOCK_DV) @@ -152,7 +159,20 @@ def _fwd_kernel( if logit_cap > 0: qk = logit_cap * tanh(qk / logit_cap) - qk = tl.where(mask_m[:, None] & mask_n[None, :], qk, float("-inf")) + if USE_CUSTOM_MASK: + custom_mask = tl.load( + mask_ptr + + cur_seq_mask_start_idx + + (cur_block_m * BLOCK_M + offs_m[:, None]) * cur_seq_len + + start_n + + offs_n[None, :], + mask=(mask_m[:, None] & mask_n[None, :]), + other=0, + ) + custom_mask &= mask_m[:, None] & mask_n[None, :] + qk = tl.where(custom_mask, qk, float("-inf")) + else: + qk = tl.where(mask_m[:, None] & mask_n[None, :], qk, float("-inf")) n_e_max = tl.maximum(tl.max(qk, 1), e_max) re_scale = tl.exp(e_max - n_e_max) @@ -172,7 +192,7 @@ def _fwd_kernel( e_max = n_e_max - # stage 2: compute the trianlge part + # stage 2: compute the triangle part cur_block_m_end = tl.minimum(cur_seq_len_extend, (cur_block_m + 1) * BLOCK_M) for start_n in range(0, cur_block_m_end, BLOCK_N): @@ -208,11 +228,25 @@ def _fwd_kernel( if logit_cap > 0: qk = logit_cap * tanh(qk / logit_cap) - mask_causual = (cur_block_m * BLOCK_M + offs_m[:, None]) >= ( - start_n + offs_n[None, :] - ) - mask_causual &= mask_m[:, None] & mask_n[None, :] - qk = tl.where(mask_causual, qk, float("-inf")) + if USE_CUSTOM_MASK: + custom_mask = tl.load( + mask_ptr + + cur_seq_mask_start_idx + + (cur_block_m * BLOCK_M + offs_m[:, None]) * cur_seq_len + + cur_seq_len_prefix + + start_n + + offs_n[None, :], + mask=(mask_m[:, None] & mask_n[None, :]), + other=0, + ) + custom_mask &= mask_m[:, None] & mask_n[None, :] + qk = tl.where(custom_mask, qk, float("-inf")) + else: + mask_causual = (cur_block_m * BLOCK_M + offs_m[:, None]) >= ( + start_n + offs_n[None, :] + ) + mask_causual &= mask_m[:, None] & mask_n[None, :] + qk = tl.where(mask_causual, qk, float("-inf")) n_e_max = tl.maximum(tl.max(qk, 1), e_max) re_scale = tl.exp(e_max - n_e_max) @@ -253,6 +287,8 @@ def extend_attention_fwd( qo_indptr, kv_indptr, kv_indices, + custom_mask, + mask_offsets, max_len_extend, sm_scale=None, logit_cap=0.0, @@ -308,6 +344,8 @@ def extend_attention_fwd( batch_size, head_num = qo_indptr.shape[0] - 1, q_extend.shape[1] kv_group_num = q_extend.shape[1] // k_extend.shape[1] + USE_CUSTOM_MASK = custom_mask is not None + grid = (batch_size, head_num, triton.cdiv(max_len_extend, BLOCK_M)) num_stages = 1 @@ -325,6 +363,8 @@ def extend_attention_fwd( qo_indptr, kv_indptr, kv_indices, + custom_mask, + mask_offsets, sm_scale, kv_group_num, q_extend.stride(0), @@ -347,6 +387,7 @@ def extend_attention_fwd( BLOCK_N=BLOCK_N, Lq=Lq, Lv=Lv, + USE_CUSTOM_MASK=USE_CUSTOM_MASK, num_warps=num_warps, num_stages=num_stages, **extra_kargs, diff --git a/test/srt/test_triton_attention_kernels.py b/test/srt/test_triton_attention_kernels.py index 3617e17be2a..73e304fec27 100644 --- a/test/srt/test_triton_attention_kernels.py +++ b/test/srt/test_triton_attention_kernels.py @@ -89,6 +89,9 @@ def _test_extend_attention_once(self, B, N_CTX, H_Q, H_KV, D): ).normal_(mean=0.1, std=0.2) o_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda") + o_extend_mask = torch.empty( + (extend_token_num, H_Q, D), dtype=dtype, device="cuda" + ) o_redundant = torch.empty( (extend_token_num, H_Q, D), dtype=dtype, device="cuda" ) @@ -98,6 +101,9 @@ def _test_extend_attention_once(self, B, N_CTX, H_Q, H_KV, D): qo_indptr = torch.zeros((B + 1,), dtype=torch.int32, device="cuda") qo_indptr[1 : B + 1] = torch.cumsum(b_seq_len_extend[:B], dim=0) + custom_mask = None + mask_offsets = None + extend_attention_fwd( q_extend, k_extend, @@ -108,6 +114,42 @@ def _test_extend_attention_once(self, B, N_CTX, H_Q, H_KV, D): qo_indptr, kv_indptr, kv_indices, + custom_mask, + mask_offsets, + max_len_extend, + ) + + b_seq_mask_len = b_seq_len_extend * b_seq_len + custom_mask = torch.ones( + (b_seq_mask_len.sum().item(),), dtype=torch.bool, device="cuda" + ) + mask_offsets = torch.zeros((B + 1,), dtype=torch.int64, device="cuda") + mask_offsets[1 : B + 1] = torch.cumsum(b_seq_mask_len[:B], dim=0) + for i in range(B): + causal_mask = ( + torch.tril( + torch.ones(b_seq_len_extend[i], b_seq_len_extend[i]), diagonal=0 + ) + == 1 + ) + prefix_mask = torch.ones( + b_seq_len_extend[i], b_seq_len_prefix[i], dtype=torch.bool + ) + mask_flatten = torch.cat([prefix_mask, causal_mask], dim=1).flatten() + custom_mask[mask_offsets[i] : mask_offsets[i + 1]] = mask_flatten + + extend_attention_fwd( + q_extend, + k_extend, + v_extend, + o_extend_mask, + k_buffer, + v_buffer, + qo_indptr, + kv_indptr, + kv_indices, + custom_mask, + mask_offsets, max_len_extend, ) @@ -124,6 +166,7 @@ def _test_extend_attention_once(self, B, N_CTX, H_Q, H_KV, D): ) self.assertTrue(torch.allclose(o_extend, o_redundant, rtol=1e-2)) + self.assertTrue(torch.allclose(o_extend_mask, o_redundant, rtol=1e-2)) def test_extend_attention(self):