Skip to content

Commit

Permalink
just improvise a solution for compress and selection block sizes not …
Browse files Browse the repository at this point in the history
…equal
  • Loading branch information
lucidrains committed Feb 20, 2025
1 parent 169739f commit dc15660
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 9 deletions.
24 changes: 18 additions & 6 deletions native_sparse_attention_pytorch/native_sparse_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,6 @@ def __init__(

self.scale = dim_head ** -0.5

assert compress_block_size == selection_block_size, 'start off with compressed being equal to selection block sizes'

dim_inner = dim_head * heads
dim_kv_inner = dim_head * kv_heads

Expand Down Expand Up @@ -174,6 +172,8 @@ def __init__(
self.use_diff_topk = use_diff_topk

self.selection_block_size = selection_block_size

assert num_selected_blocks > 0
self.num_selected_blocks = num_selected_blocks

# they combine the three sparse branches through a learned combine with sigmoid activation
Expand Down Expand Up @@ -219,7 +219,7 @@ def forward(

q, k, v = map(self.split_heads, (q, k, v))

# compressed key / values
# compressed key / values - variables prepended with `c` stands for compressed

k_pos = repeat(self.k_intrablock_positions, 'h n d -> h (r n) d', r = num_compress_blocks)
v_pos = repeat(self.v_intrablock_positions, 'h n d -> h (r n) d', r = num_compress_blocks)
Expand Down Expand Up @@ -262,15 +262,27 @@ def forward(

rotated_q, rotated_k = self.rotary_emb.rotate_queries_with_cached_keys(q, k)

# 2. fine attention over selected based on compressed attention logits

# 2. fine attention over selected based on compressed attention logits - variables prepended with `f` stands for the fine attention pathway

importance_scores = cattn[..., num_mem_compress_kv:]

# for gqa, we will average the compressed attention across each grouped queries (per key / values)

importance_scores = reduce(importance_scores, 'b (grouped_queries h) ... -> b h ...', 'mean', grouped_queries = self.num_grouped_queries)

# handle if compress block size not equal to the fine block size
# cannot parse their equation, so will just improvise
# first we expand all the compressed scores to the full sequence length, then average within each fine / selection block size - pad on the right to 0s, which should be fine as sliding window convers the local anyways

if self.compress_block_size != self.selection_block_size:
importance_scores = repeat(importance_scores, '... j -> ... (j block_size)', block_size = self.compress_block_size)
padding = fine_divisible_seq_len - importance_scores.shape[-1]

importance_scores = F.pad(importance_scores, (0, padding))
importance_scores = reduce(importance_scores, '... (j block_size) -> ... j', 'mean', block_size = self.selection_block_size)

# handle if number of total blocks is less than number to select for fine attention

num_selected = min(self.num_selected_blocks, importance_scores.shape[-1])

fq = rotated_q
Expand Down Expand Up @@ -367,7 +379,7 @@ def forward(

fine_attn_out = einsum(fattn, fv, 'b h i j, b h j d -> b h i d')

# 3. overlapping sliding window, this is unsurprising and expected
# 3. overlapping sliding window, this is unsurprising and expected - `s` for sliding

sq = rotated_q
sk = rotated_k
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "native-sparse-attention-pytorch"
version = "0.0.19"
version = "0.0.20"
description = "Native Sparse Attention"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
Expand Down
6 changes: 4 additions & 2 deletions tests/test_sparse_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@
@pytest.mark.parametrize('use_diff_topk', (False, True))
@pytest.mark.parametrize('seq_len', (1, 4, 31, 32, 120))
@pytest.mark.parametrize('kv_heads', (8, 4))
@pytest.mark.parametrize('selection_block_size', (4, 2))
def test_sparse_attn(
use_diff_topk,
seq_len,
kv_heads
kv_heads,
selection_block_size
):
attn = SparseAttention(
dim = 512,
Expand All @@ -21,7 +23,7 @@ def test_sparse_attn(
kv_heads = kv_heads,
sliding_window_size = 2,
compress_block_size = 4,
selection_block_size = 4,
selection_block_size = selection_block_size,
num_selected_blocks = 2,
use_diff_topk = use_diff_topk
)
Expand Down

0 comments on commit dc15660

Please sign in to comment.