-
Notifications
You must be signed in to change notification settings - Fork 17
/
Copy pathtest_triton_nsa.py
109 lines (73 loc) · 3.37 KB
/
test_triton_nsa.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import torch
from native_sparse_attention_pytorch.triton_native_sparse_attention import native_sparse_attend
import einx
from einops import rearrange, einsum, repeat
assert torch.cuda.is_available()
def exists(v):
return v is not None
def regular_attend(
q, k, v,
indices,
mask,
block_size = None,
):
q_heads, kv_heads = q.shape[1], k.shape[1]
if exists(block_size):
w = q.shape[-2] // block_size
q, k, v = tuple(rearrange(t, 'b h (w n) d -> b (h w) n d', n = block_size) for t in (q, k, v))
seq_len, device = q.shape[-2], q.device
scale = q.shape[-1] ** -0.5
q = q * scale
# block causal diagonal
sim = einsum(q, k, 'b h i d, b h j d -> b h i j')
causal_mask = torch.ones((seq_len, seq_len), device = device, dtype = torch.bool).triu(1)
sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
# rest of the indices
num_sel_kv_blocks = indices.shape[-1]
has_sel_kv_blocks = num_sel_kv_blocks > 0
if has_sel_kv_blocks:
bk, bv = tuple(rearrange(t, 'b (h w) n d -> b h w n d', h = kv_heads) for t in (k, v))
sel_bk = einx.get_at('b h [w] n d, b h i sel -> b h i (sel n) d', bk, indices)
sel_bv = einx.get_at('b h [w] n d, b h i sel -> b h i (sel n) d', bv, indices)
q = rearrange(q, 'b (h w) n d -> b h (w n) d', h = q_heads)
bsim = einsum(q, sel_bk, 'b h i d, b h i j d -> b h i j')
bsim = rearrange(bsim, 'b h (w i) (sel j) -> b h w i sel j', sel = num_sel_kv_blocks, i = fine_block_size)
mask = rearrange(mask, 'b h (w i) sel -> b h w i sel', i = fine_block_size)
bsim = torch.where(mask[..., None], bsim, -torch.finfo(bsim.dtype).max)
sim = rearrange(sim, 'b (h w) i j -> b h w i 1 j', h = q_heads)
sim = torch.cat((sim, bsim), dim = -2)
sim = rearrange(sim, 'b h w i causal_and_sel j -> b h (w i) (causal_and_sel j)')
sel_bv = rearrange(sel_bv, 'b h (w i) j d -> b h w i j d', i = fine_block_size)
v = repeat(v, 'b (h w) j d -> b h w i j d', h = kv_heads, i = fine_block_size)
v = torch.cat((v, sel_bv), dim = -2)
v = rearrange(v, 'b h w i j d -> b h (w i) j d')
# attend
attn = sim.softmax(dim = -1)
if has_sel_kv_blocks:
out = einsum(attn, v, 'b h i j, b h i j d -> b h i d')
else:
out = einsum(attn, v, 'b h i j, b h j d -> b h i d')
if exists(block_size):
out = rearrange(out, 'b (h w) n d -> b h (w n) d', w = w)
return out
# mock inputs
fine_block_size = 16
q = torch.randn(1, 2, 512, 64).cuda()
k = torch.randn(1, 2, 512, 64).cuda()
v = torch.randn(1, 2, 512, 64).cuda()
indices = torch.zeros(1, 2, 512, 1).long().cuda()
mask = torch.ones(1, 2, 512, 1).bool().cuda()
# both regular and nsa pathways `r` and `n`
rq, rk, rv = tuple(t.clone().requires_grad_() for t in (q, k, v))
nq, nk, nv = tuple(t.clone().requires_grad_() for t in (q, k, v))
# regular forwards and backwards
out = regular_attend(rq, rk, rv, indices, mask, block_size = fine_block_size)
out.sum().backward()
# triton nsa forwards and backwards
nsa_out = native_sparse_attend(nq, nk, nv, fine_block_size, indices, mask, 1)
nsa_out.sum().backward()
# asserts
assert torch.allclose(out, nsa_out, atol = 1e-2)
assert torch.allclose(nv.grad, rv.grad, atol = 1e-2)
assert torch.allclose(nk.grad, rk.grad, atol = 1e-2)
assert torch.allclose(nq.grad, rq.grad, atol = 1e-2)