diff --git a/native_sparse_attention_pytorch/triton_native_sparse_attention.py b/native_sparse_attention_pytorch/triton_native_sparse_attention.py index 2e1c05c..f8fa678 100644 --- a/native_sparse_attention_pytorch/triton_native_sparse_attention.py +++ b/native_sparse_attention_pytorch/triton_native_sparse_attention.py @@ -56,7 +56,6 @@ def _fwd_kernel( Q, K, V, - Bias, Out, M, Lse, @@ -70,9 +69,6 @@ def _fwd_kernel( stride_vb, stride_vh, stride_vn, - stride_bb, - stride_bh, - stride_bm, stride_ob, stride_oh, stride_om, @@ -83,7 +79,6 @@ def _fwd_kernel( headdim, CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K, - HAS_BIAS: tl.constexpr, BLOCK_HEADDIM: tl.constexpr, EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, @@ -110,9 +105,6 @@ def _fwd_kernel( V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :]) ) - if HAS_BIAS: - b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + offs_n - # maximum m_ptrs = M + off_hb * seqlen_q_rounded + offs_m @@ -183,22 +175,8 @@ def _fwd_kernel( qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf")) - if HAS_BIAS: - if EVEN_N: - bias = tl.load(b_ptrs + start_n) - else: - bias = tl.load( - b_ptrs + start_n, mask=(start_n + offs_n) < seqlen_k, other=0.0 - ) - bias = bias[None, :] - - bias = bias.to(tl.float32) - qk = qk * softmax_scale + bias - m_ij = tl.maximum(tl.max(qk, 1), lse_i) - p = tl.exp(qk - m_ij[:, None]) - else: - m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, lse_i) - p = tl.exp(qk * softmax_scale - m_ij[:, None]) + m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, lse_i) + p = tl.exp(qk * softmax_scale - m_ij[:, None]) l_ij = tl.sum(p, 1) @@ -264,7 +242,6 @@ def flash_attn_forward( q, k, v, - bias = None, o = None, m = None, lse = None, @@ -285,23 +262,6 @@ def flash_attn_forward( softmax_scale = default(softmax_scale, d ** -0.5) - has_bias = exists(bias) - - if has_bias: - assert bias.dtype in [q.dtype, torch.float] - assert bias.is_cuda - - if bias.ndim == 2: - bias = repeat(bias, 'b j -> b h i j', h = nheads, i = seqlen_q) - - if not is_contiguous(bias): - bias = bias.contiguous() - - assert bias.shape[-2:] == (seqlen_q, seqlen_k) - bias = bias.expand(batch, nheads, seqlen_q, seqlen_k) - - bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0) - seqlen_q_rounded = ceil(seqlen_q / 128) * 128 if not exists(lse): @@ -324,7 +284,6 @@ def flash_attn_forward( q, k, v, - bias, o, m, lse, @@ -338,7 +297,6 @@ def flash_attn_forward( v.stride(0), v.stride(2), v.stride(1), - *bias_strides, o.stride(0), o.stride(2), o.stride(1), @@ -349,7 +307,6 @@ def flash_attn_forward( d, seqlen_q // 32, seqlen_k // 32, - has_bias, BLOCK_HEADDIM, BLOCK_M = BLOCK, BLOCK_N = BLOCK, @@ -445,7 +402,6 @@ def _bwd_kernel_one_col_block( Q, K, V, - Bias, DO, DQ, DK, @@ -456,7 +412,6 @@ def _bwd_kernel_one_col_block( stride_qm, stride_kn, stride_vn, - stride_bm, stride_dom, stride_dqm, stride_dkn, @@ -465,7 +420,6 @@ def _bwd_kernel_one_col_block( seqlen_k, headdim, ATOMIC_ADD: tl.constexpr, - BIAS_TYPE: tl.constexpr, BLOCK_HEADDIM: tl.constexpr, EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, @@ -486,10 +440,7 @@ def _bwd_kernel_one_col_block( v_ptrs = V + (offs_n[:, None] * stride_vn + offs_d[None, :]) do_ptrs = DO + (offs_qm[:, None] * stride_dom + offs_d[None, :]) dq_ptrs = DQ + (offs_qm[:, None] * stride_dqm + offs_d[None, :]) - if BIAS_TYPE == "vector": - b_ptrs = Bias + offs_n - elif BIAS_TYPE == "matrix": - b_ptrs = Bias + (offs_qm[:, None] * stride_bm + offs_n[None, :]) + # initialize dv and dk dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) @@ -562,33 +513,14 @@ def _bwd_kernel_one_col_block( qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf")) - if BIAS_TYPE != "none": - tl.debug_barrier() # Race condition otherwise - if BIAS_TYPE == "vector": - if EVEN_N: - bias = tl.load(b_ptrs).to(tl.float32) - else: - bias = tl.load(b_ptrs, mask=offs_n < seqlen_k, other=0.0).to(tl.float32) - bias = bias[None, :] - elif BIAS_TYPE == "matrix": - if EVEN_M & EVEN_N: - bias = tl.load(b_ptrs).to(tl.float32) - else: - bias = tl.load( - b_ptrs, - mask=(offs_m_curr[:, None] < seqlen_q) & (offs_n[None, :] < seqlen_k), - other=0.0, - ).to(tl.float32) - qk = qk * softmax_scale + bias # There seems to be a race condition when headdim=48/96, and dq, dk, dv are wrong. # Also wrong for headdim=64. if not (EVEN_M & EVEN_HEADDIM): tl.debug_barrier() lse_i = tl.load(LSE + offs_m_curr) - if BIAS_TYPE == "none": - p = tl.exp(qk * softmax_scale - lse_i[:, None]) - else: - p = tl.exp(qk - lse_i[:, None]) + + p = tl.exp(qk * softmax_scale - lse_i[:, None]) + # compute dv # [2022-10-30] TD: A Triton bug: if EVEN_M=True and EVEN_HEADDIM=False, if we call # do = tl.load(do_ptrs, mask=offs_d[None, :] < headdim, other=0.0), we get wrong outputs @@ -693,8 +625,7 @@ def _bwd_kernel_one_col_block( dq_ptrs += BLOCK_M * stride_dqm q_ptrs += BLOCK_M * stride_qm do_ptrs += BLOCK_M * stride_dom - if BIAS_TYPE == "matrix": - b_ptrs += BLOCK_M * stride_bm + # write-back dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :]) dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :]) @@ -738,7 +669,7 @@ def init_to_zero(name): # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')), # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')), ], - key=["CACHE_KEY_SEQLEN_Q", "CACHE_KEY_SEQLEN_K", "BIAS_TYPE", "BLOCK_HEADDIM"], + key=["CACHE_KEY_SEQLEN_Q", "CACHE_KEY_SEQLEN_K", "BLOCK_HEADDIM"], ) @triton.heuristics( { @@ -752,7 +683,6 @@ def _bwd_kernel( Q, K, V, - Bias, DO, DQ, DK, @@ -769,9 +699,6 @@ def _bwd_kernel( stride_vb, stride_vh, stride_vn, - stride_bb, - stride_bh, - stride_bm, stride_dob, stride_doh, stride_dom, @@ -791,7 +718,6 @@ def _bwd_kernel( headdim, CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K, - BIAS_TYPE: tl.constexpr, BLOCK_HEADDIM: tl.constexpr, SEQUENCE_PARALLEL: tl.constexpr, EVEN_M: tl.constexpr, @@ -811,8 +737,6 @@ def _bwd_kernel( DQ += off_b * stride_dqb + off_h * stride_dqh DK += off_b * stride_dkb + off_h * stride_dkh DV += off_b * stride_dvb + off_h * stride_dvh - if BIAS_TYPE != "none": - Bias += off_b * stride_bb + off_h * stride_bh # pointer to row-wise quantities in value-like data D += off_hb * seqlen_q_rounded LSE += off_hb * seqlen_q_rounded @@ -824,7 +748,6 @@ def _bwd_kernel( Q, K, V, - Bias, DO, DQ, DK, @@ -835,7 +758,6 @@ def _bwd_kernel( stride_qm, stride_kn, stride_vn, - stride_bm, stride_dom, stride_dqm, stride_dkn, @@ -844,7 +766,6 @@ def _bwd_kernel( seqlen_k, headdim, ATOMIC_ADD=False, - BIAS_TYPE=BIAS_TYPE, BLOCK_HEADDIM=BLOCK_HEADDIM, EVEN_M=EVEN_M, EVEN_N=EVEN_N, @@ -859,7 +780,6 @@ def _bwd_kernel( Q, K, V, - Bias, DO, DQ, DK, @@ -870,7 +790,6 @@ def _bwd_kernel( stride_qm, stride_kn, stride_vn, - stride_bm, stride_dom, stride_dqm, stride_dkn, @@ -879,7 +798,6 @@ def _bwd_kernel( seqlen_k, headdim, ATOMIC_ADD=True, - BIAS_TYPE=BIAS_TYPE, BLOCK_HEADDIM=BLOCK_HEADDIM, EVEN_M=EVEN_M, EVEN_N=EVEN_N, @@ -899,7 +817,6 @@ def flash_attn_backward( dk, dv, delta = None, - bias = None, softmax_scale = None, ): # Make sure that the last dimension is contiguous @@ -944,24 +861,6 @@ def flash_attn_backward( BLOCK_HEADDIM=BLOCK_HEADDIM, ) - has_bias = bias is not None - bias_type = "none" - if has_bias: - assert bias.dtype in [q.dtype, torch.float] - assert bias.is_cuda - assert bias.dim() == 4 - assert bias.stride(-1) == 1 - if bias.shape[2:] == (1, seqlen_k): - bias_type = "vector" - elif bias.shape[2:] == (seqlen_q, seqlen_k): - bias_type = "matrix" - else: - raise RuntimeError( - "Last 2 dimensions of bias must be (1, seqlen_k)" " or (seqlen_q, seqlen_k)" - ) - bias = bias.expand(batch, nheads, seqlen_q, seqlen_k) - bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0) - # BLOCK_M = 128 # BLOCK_N = 64 # num_warps = 4 @@ -973,7 +872,6 @@ def flash_attn_backward( q, k, v, - bias, do, dq_accum, dk, @@ -990,7 +888,6 @@ def flash_attn_backward( v.stride(0), v.stride(2), v.stride(1), - *bias_strides, do.stride(0), do.stride(2), do.stride(1), @@ -1012,7 +909,6 @@ def flash_attn_backward( seqlen_k // 32, # key for triton cache (limit number of compilations) # Can't use kwargs here because triton autotune expects key to be args, not kwargs # IS_CAUSAL=causal, BLOCK_HEADDIM=d, - bias_type, BLOCK_HEADDIM, # SEQUENCE_PARALLEL=False, # BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,