Skip to content

Commit

Permalink
dq down
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Feb 25, 2025
1 parent af261a6 commit f6515e0
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 51 deletions.
74 changes: 25 additions & 49 deletions native_sparse_attention_pytorch/triton_native_sparse_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,54 +676,21 @@ def backward_kernel_one_col_block(
EVEN_M & EVEN_HEADDIM
): # Otherewise there's a race condition when BIAS_TYPE='matrix'
tl.debug_barrier()
if not ATOMIC_ADD:
if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M
dq = tl.load(dq_ptrs, eviction_policy="evict_last")
dq += tl.dot(ds, k)
tl.store(dq_ptrs, dq, eviction_policy="evict_last")
else:
if EVEN_HEADDIM:
dq = tl.load(
dq_ptrs,
mask=offs_m[:, None] < seqlen_q,
other=0.0,
eviction_policy="evict_last",
)
dq += tl.dot(ds, k)
tl.store(
dq_ptrs,
dq,
mask=offs_m[:, None] < seqlen_q,
eviction_policy="evict_last",
)
else:
dq = tl.load(
dq_ptrs,
mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
other=0.0,
eviction_policy="evict_last",
)
dq += tl.dot(ds, k)
tl.store(
dq_ptrs,
dq,
mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
eviction_policy="evict_last",
)
else: # If we're parallelizing across the seqlen_k dimension
dq = tl.dot(ds, k)
if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M
tl.atomic_add(dq_ptrs, dq, sem = 'relaxed')

dq = tl.dot(ds, k)

if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M
tl.atomic_add(dq_ptrs, dq, sem = 'relaxed')
else:
if EVEN_HEADDIM:
tl.atomic_add(dq_ptrs, dq, mask=offs_m[:, None] < seqlen_q, sem = 'relaxed')
else:
if EVEN_HEADDIM:
tl.atomic_add(dq_ptrs, dq, mask=offs_m[:, None] < seqlen_q, sem = 'relaxed')
else:
tl.atomic_add(
dq_ptrs,
dq,
mask = (offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
sem = 'relaxed',
)
tl.atomic_add(
dq_ptrs,
dq,
mask = (offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
sem = 'relaxed',
)

# handle kv block indices using atomic adds for starters, todo: swap dq and dk/dv loops at some point, semi big refactor

Expand Down Expand Up @@ -765,8 +732,8 @@ def backward_kernel_one_col_block(
q_expanded = tl.expand_dims(q, 1)
q_expanded = tl.broadcast_to(q_expanded, (BLOCK, 16, BLOCK_HEADDIM))

block_k = tl.permute(block_k, (0, 2, 1))
block_qk = tl.dot(q_expanded, block_k)
block_k_permuted = tl.permute(block_k, (0, 2, 1))
block_qk = tl.dot(q_expanded, block_k_permuted)

qk = tl.sum(block_qk, 1) / 16.
qk += tl.where(block_masks[:, None], 0, float("-inf"))
Expand Down Expand Up @@ -800,6 +767,15 @@ def backward_kernel_one_col_block(

tl.atomic_add(block_dk_ptrs, block_dk, sem = 'relaxed')

# block dq

ds_expanded = tl.expand_dims(ds, 1)
ds_expanded = tl.broadcast_to(ds_expanded, (BLOCK, 16, BLOCK))
block_dq = tl.dot(ds_expanded, block_k)
block_dq = tl.sum(block_dq, 1) / 16

tl.atomic_add(dq_ptrs, block_dq, sem = 'relaxed')

# # increment pointers
# dq_ptrs += BLOCK * stride_dqm
# q_ptrs += BLOCK * stride_qm
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "native-sparse-attention-pytorch"
version = "0.0.47"
version = "0.0.48"
description = "Native Sparse Attention"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
Expand Down
1 change: 0 additions & 1 deletion test_triton_nsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,5 @@ def regular_attend(
assert torch.allclose(out, nsa_out, atol = 1e-2)

assert torch.allclose(nv.grad, rv.grad, atol = 1e-2)
print((nk.grad - rk.grad).abs().amax())
assert torch.allclose(nk.grad, rk.grad, atol = 1e-2)
assert torch.allclose(nq.grad, rq.grad, atol = 1e-2)

0 comments on commit f6515e0

Please sign in to comment.