From 41dbb548dcba5ee815abeef763c76262edc5a4ec Mon Sep 17 00:00:00 2001 From: lucidrains Date: Wed, 19 Feb 2025 20:58:08 +0000 Subject: [PATCH] redo get_at with gather, but keep around the ein notation for readability --- .../native_sparse_attention.py | 11 +++++++++-- pyproject.toml | 2 +- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/native_sparse_attention_pytorch/native_sparse_attention.py b/native_sparse_attention_pytorch/native_sparse_attention.py index 5c10310..dc3a997 100644 --- a/native_sparse_attention_pytorch/native_sparse_attention.py +++ b/native_sparse_attention_pytorch/native_sparse_attention.py @@ -265,8 +265,15 @@ def forward( fk = rearrange(fk, 'b h (w n) d -> b h w n d', w = num_fine_blocks) fv = rearrange(fv, 'b h (w n) d -> b h w n d', w = num_fine_blocks) - fk = einx.get_at('b h [w] j d, b h i selected -> b h i selected j d', fk, selected_block_indices) - fv = einx.get_at('b h [w] j d, b h i selected -> b h i selected j d', fv, selected_block_indices) + # get_at("b h [w] j d, b h i selected -> b h i selected j d", fkv, selected_block_indices) + + fk = repeat(fk, 'b h w j d -> b h i w j d', i = selected_block_indices.shape[2]) + fv = repeat(fv, 'b h w j d -> b h i w j d', i = selected_block_indices.shape[2]) + + selected_block_indices = repeat(selected_block_indices, 'b h i sel -> b h i sel j d', j = fk.shape[-2], d = fk.shape[-1]) + + fk = fk.gather(3, selected_block_indices) + fv = fv.gather(3, selected_block_indices) # handle maybe gating diff --git a/pyproject.toml b/pyproject.toml index cd3446e..c34ab57 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "native-sparse-attention-pytorch" -version = "0.0.6" +version = "0.0.7" description = "Native Sparse Attention" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" }