Skip to content

Commit

Permalink
more cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Feb 23, 2025
1 parent 2f14297 commit 059bf78
Showing 1 changed file with 8 additions and 112 deletions.
120 changes: 8 additions & 112 deletions native_sparse_attention_pytorch/triton_native_sparse_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ def _fwd_kernel(
Q,
K,
V,
Bias,
Out,
M,
Lse,
Expand All @@ -70,9 +69,6 @@ def _fwd_kernel(
stride_vb,
stride_vh,
stride_vn,
stride_bb,
stride_bh,
stride_bm,
stride_ob,
stride_oh,
stride_om,
Expand All @@ -83,7 +79,6 @@ def _fwd_kernel(
headdim,
CACHE_KEY_SEQLEN_Q,
CACHE_KEY_SEQLEN_K,
HAS_BIAS: tl.constexpr,
BLOCK_HEADDIM: tl.constexpr,
EVEN_M: tl.constexpr,
EVEN_N: tl.constexpr,
Expand All @@ -110,9 +105,6 @@ def _fwd_kernel(
V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :])
)

if HAS_BIAS:
b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + offs_n

# maximum

m_ptrs = M + off_hb * seqlen_q_rounded + offs_m
Expand Down Expand Up @@ -183,22 +175,8 @@ def _fwd_kernel(

qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf"))

if HAS_BIAS:
if EVEN_N:
bias = tl.load(b_ptrs + start_n)
else:
bias = tl.load(
b_ptrs + start_n, mask=(start_n + offs_n) < seqlen_k, other=0.0
)
bias = bias[None, :]

bias = bias.to(tl.float32)
qk = qk * softmax_scale + bias
m_ij = tl.maximum(tl.max(qk, 1), lse_i)
p = tl.exp(qk - m_ij[:, None])
else:
m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, lse_i)
p = tl.exp(qk * softmax_scale - m_ij[:, None])
m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, lse_i)
p = tl.exp(qk * softmax_scale - m_ij[:, None])

l_ij = tl.sum(p, 1)

Expand Down Expand Up @@ -264,7 +242,6 @@ def flash_attn_forward(
q,
k,
v,
bias = None,
o = None,
m = None,
lse = None,
Expand All @@ -285,23 +262,6 @@ def flash_attn_forward(

softmax_scale = default(softmax_scale, d ** -0.5)

has_bias = exists(bias)

if has_bias:
assert bias.dtype in [q.dtype, torch.float]
assert bias.is_cuda

if bias.ndim == 2:
bias = repeat(bias, 'b j -> b h i j', h = nheads, i = seqlen_q)

if not is_contiguous(bias):
bias = bias.contiguous()

assert bias.shape[-2:] == (seqlen_q, seqlen_k)
bias = bias.expand(batch, nheads, seqlen_q, seqlen_k)

bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0)

seqlen_q_rounded = ceil(seqlen_q / 128) * 128

if not exists(lse):
Expand All @@ -324,7 +284,6 @@ def flash_attn_forward(
q,
k,
v,
bias,
o,
m,
lse,
Expand All @@ -338,7 +297,6 @@ def flash_attn_forward(
v.stride(0),
v.stride(2),
v.stride(1),
*bias_strides,
o.stride(0),
o.stride(2),
o.stride(1),
Expand All @@ -349,7 +307,6 @@ def flash_attn_forward(
d,
seqlen_q // 32,
seqlen_k // 32,
has_bias,
BLOCK_HEADDIM,
BLOCK_M = BLOCK,
BLOCK_N = BLOCK,
Expand Down Expand Up @@ -445,7 +402,6 @@ def _bwd_kernel_one_col_block(
Q,
K,
V,
Bias,
DO,
DQ,
DK,
Expand All @@ -456,7 +412,6 @@ def _bwd_kernel_one_col_block(
stride_qm,
stride_kn,
stride_vn,
stride_bm,
stride_dom,
stride_dqm,
stride_dkn,
Expand All @@ -465,7 +420,6 @@ def _bwd_kernel_one_col_block(
seqlen_k,
headdim,
ATOMIC_ADD: tl.constexpr,
BIAS_TYPE: tl.constexpr,
BLOCK_HEADDIM: tl.constexpr,
EVEN_M: tl.constexpr,
EVEN_N: tl.constexpr,
Expand All @@ -486,10 +440,7 @@ def _bwd_kernel_one_col_block(
v_ptrs = V + (offs_n[:, None] * stride_vn + offs_d[None, :])
do_ptrs = DO + (offs_qm[:, None] * stride_dom + offs_d[None, :])
dq_ptrs = DQ + (offs_qm[:, None] * stride_dqm + offs_d[None, :])
if BIAS_TYPE == "vector":
b_ptrs = Bias + offs_n
elif BIAS_TYPE == "matrix":
b_ptrs = Bias + (offs_qm[:, None] * stride_bm + offs_n[None, :])

# initialize dv and dk
dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
Expand Down Expand Up @@ -562,33 +513,14 @@ def _bwd_kernel_one_col_block(

qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf"))

if BIAS_TYPE != "none":
tl.debug_barrier() # Race condition otherwise
if BIAS_TYPE == "vector":
if EVEN_N:
bias = tl.load(b_ptrs).to(tl.float32)
else:
bias = tl.load(b_ptrs, mask=offs_n < seqlen_k, other=0.0).to(tl.float32)
bias = bias[None, :]
elif BIAS_TYPE == "matrix":
if EVEN_M & EVEN_N:
bias = tl.load(b_ptrs).to(tl.float32)
else:
bias = tl.load(
b_ptrs,
mask=(offs_m_curr[:, None] < seqlen_q) & (offs_n[None, :] < seqlen_k),
other=0.0,
).to(tl.float32)
qk = qk * softmax_scale + bias
# There seems to be a race condition when headdim=48/96, and dq, dk, dv are wrong.
# Also wrong for headdim=64.
if not (EVEN_M & EVEN_HEADDIM):
tl.debug_barrier()
lse_i = tl.load(LSE + offs_m_curr)
if BIAS_TYPE == "none":
p = tl.exp(qk * softmax_scale - lse_i[:, None])
else:
p = tl.exp(qk - lse_i[:, None])

p = tl.exp(qk * softmax_scale - lse_i[:, None])

# compute dv
# [2022-10-30] TD: A Triton bug: if EVEN_M=True and EVEN_HEADDIM=False, if we call
# do = tl.load(do_ptrs, mask=offs_d[None, :] < headdim, other=0.0), we get wrong outputs
Expand Down Expand Up @@ -693,8 +625,7 @@ def _bwd_kernel_one_col_block(
dq_ptrs += BLOCK_M * stride_dqm
q_ptrs += BLOCK_M * stride_qm
do_ptrs += BLOCK_M * stride_dom
if BIAS_TYPE == "matrix":
b_ptrs += BLOCK_M * stride_bm

# write-back
dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :])
dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :])
Expand Down Expand Up @@ -738,7 +669,7 @@ def init_to_zero(name):
# triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')),
# triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')),
],
key=["CACHE_KEY_SEQLEN_Q", "CACHE_KEY_SEQLEN_K", "BIAS_TYPE", "BLOCK_HEADDIM"],
key=["CACHE_KEY_SEQLEN_Q", "CACHE_KEY_SEQLEN_K", "BLOCK_HEADDIM"],
)
@triton.heuristics(
{
Expand All @@ -752,7 +683,6 @@ def _bwd_kernel(
Q,
K,
V,
Bias,
DO,
DQ,
DK,
Expand All @@ -769,9 +699,6 @@ def _bwd_kernel(
stride_vb,
stride_vh,
stride_vn,
stride_bb,
stride_bh,
stride_bm,
stride_dob,
stride_doh,
stride_dom,
Expand All @@ -791,7 +718,6 @@ def _bwd_kernel(
headdim,
CACHE_KEY_SEQLEN_Q,
CACHE_KEY_SEQLEN_K,
BIAS_TYPE: tl.constexpr,
BLOCK_HEADDIM: tl.constexpr,
SEQUENCE_PARALLEL: tl.constexpr,
EVEN_M: tl.constexpr,
Expand All @@ -811,8 +737,6 @@ def _bwd_kernel(
DQ += off_b * stride_dqb + off_h * stride_dqh
DK += off_b * stride_dkb + off_h * stride_dkh
DV += off_b * stride_dvb + off_h * stride_dvh
if BIAS_TYPE != "none":
Bias += off_b * stride_bb + off_h * stride_bh
# pointer to row-wise quantities in value-like data
D += off_hb * seqlen_q_rounded
LSE += off_hb * seqlen_q_rounded
Expand All @@ -824,7 +748,6 @@ def _bwd_kernel(
Q,
K,
V,
Bias,
DO,
DQ,
DK,
Expand All @@ -835,7 +758,6 @@ def _bwd_kernel(
stride_qm,
stride_kn,
stride_vn,
stride_bm,
stride_dom,
stride_dqm,
stride_dkn,
Expand All @@ -844,7 +766,6 @@ def _bwd_kernel(
seqlen_k,
headdim,
ATOMIC_ADD=False,
BIAS_TYPE=BIAS_TYPE,
BLOCK_HEADDIM=BLOCK_HEADDIM,
EVEN_M=EVEN_M,
EVEN_N=EVEN_N,
Expand All @@ -859,7 +780,6 @@ def _bwd_kernel(
Q,
K,
V,
Bias,
DO,
DQ,
DK,
Expand All @@ -870,7 +790,6 @@ def _bwd_kernel(
stride_qm,
stride_kn,
stride_vn,
stride_bm,
stride_dom,
stride_dqm,
stride_dkn,
Expand All @@ -879,7 +798,6 @@ def _bwd_kernel(
seqlen_k,
headdim,
ATOMIC_ADD=True,
BIAS_TYPE=BIAS_TYPE,
BLOCK_HEADDIM=BLOCK_HEADDIM,
EVEN_M=EVEN_M,
EVEN_N=EVEN_N,
Expand All @@ -899,7 +817,6 @@ def flash_attn_backward(
dk,
dv,
delta = None,
bias = None,
softmax_scale = None,
):
# Make sure that the last dimension is contiguous
Expand Down Expand Up @@ -944,24 +861,6 @@ def flash_attn_backward(
BLOCK_HEADDIM=BLOCK_HEADDIM,
)

has_bias = bias is not None
bias_type = "none"
if has_bias:
assert bias.dtype in [q.dtype, torch.float]
assert bias.is_cuda
assert bias.dim() == 4
assert bias.stride(-1) == 1
if bias.shape[2:] == (1, seqlen_k):
bias_type = "vector"
elif bias.shape[2:] == (seqlen_q, seqlen_k):
bias_type = "matrix"
else:
raise RuntimeError(
"Last 2 dimensions of bias must be (1, seqlen_k)" " or (seqlen_q, seqlen_k)"
)
bias = bias.expand(batch, nheads, seqlen_q, seqlen_k)
bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0)

# BLOCK_M = 128
# BLOCK_N = 64
# num_warps = 4
Expand All @@ -973,7 +872,6 @@ def flash_attn_backward(
q,
k,
v,
bias,
do,
dq_accum,
dk,
Expand All @@ -990,7 +888,6 @@ def flash_attn_backward(
v.stride(0),
v.stride(2),
v.stride(1),
*bias_strides,
do.stride(0),
do.stride(2),
do.stride(1),
Expand All @@ -1012,7 +909,6 @@ def flash_attn_backward(
seqlen_k // 32, # key for triton cache (limit number of compilations)
# Can't use kwargs here because triton autotune expects key to be args, not kwargs
# IS_CAUSAL=causal, BLOCK_HEADDIM=d,
bias_type,
BLOCK_HEADDIM,
# SEQUENCE_PARALLEL=False,
# BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,
Expand Down

0 comments on commit 059bf78

Please sign in to comment.