Skip to content

Commit

Permalink
redo get_at with gather, but keep around the ein notation for readabi…
Browse files Browse the repository at this point in the history
…lity
  • Loading branch information
lucidrains committed Feb 19, 2025
1 parent 82a28be commit 41dbb54
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 3 deletions.
11 changes: 9 additions & 2 deletions native_sparse_attention_pytorch/native_sparse_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,8 +265,15 @@ def forward(
fk = rearrange(fk, 'b h (w n) d -> b h w n d', w = num_fine_blocks)
fv = rearrange(fv, 'b h (w n) d -> b h w n d', w = num_fine_blocks)

fk = einx.get_at('b h [w] j d, b h i selected -> b h i selected j d', fk, selected_block_indices)
fv = einx.get_at('b h [w] j d, b h i selected -> b h i selected j d', fv, selected_block_indices)
# get_at("b h [w] j d, b h i selected -> b h i selected j d", fkv, selected_block_indices)

fk = repeat(fk, 'b h w j d -> b h i w j d', i = selected_block_indices.shape[2])
fv = repeat(fv, 'b h w j d -> b h i w j d', i = selected_block_indices.shape[2])

selected_block_indices = repeat(selected_block_indices, 'b h i sel -> b h i sel j d', j = fk.shape[-2], d = fk.shape[-1])

fk = fk.gather(3, selected_block_indices)
fv = fv.gather(3, selected_block_indices)

# handle maybe gating

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "native-sparse-attention-pytorch"
version = "0.0.6"
version = "0.0.7"
description = "Native Sparse Attention"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
Expand Down

0 comments on commit 41dbb54

Please sign in to comment.