From ae46170d561f30d7f7003d76e1d33facb94e506e Mon Sep 17 00:00:00 2001 From: lucidrains Date: Thu, 20 Feb 2025 16:16:16 +0000 Subject: [PATCH] make a decision to deviate from their diagram, where the last token of the compressed block does not attend to the compressed version of that block. this is so during selection process, it does not re-select the current local block, which is already covered by both fine and sliding window attention.. --- native_sparse_attention_pytorch/native_sparse_attention.py | 4 ++-- pyproject.toml | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/native_sparse_attention_pytorch/native_sparse_attention.py b/native_sparse_attention_pytorch/native_sparse_attention.py index bf0ff97..4dfa350 100644 --- a/native_sparse_attention_pytorch/native_sparse_attention.py +++ b/native_sparse_attention_pytorch/native_sparse_attention.py @@ -59,7 +59,7 @@ def create_compress_mask(seq_len, kv_seq_len, compress_block_size): def compress_mask(_, __, q_idx, kv_idx): compress_kv_idx = (kv_idx * compress_block_size) + (compress_block_size - 1) - causal_mask = q_idx >= compress_kv_idx + causal_mask = q_idx > compress_kv_idx return causal_mask block_mask = create_block_mask(compress_mask, B = None, H = None, Q_LEN = seq_len, KV_LEN = kv_seq_len, _compile = True) @@ -266,7 +266,7 @@ def forward( ck_seq = ((arange(num_compress_blocks, device = device) + 1) * self.compress_block_size) - 1 ck_seq = F.pad(ck_seq, (num_mem_compress_kv, 0), value = -1) - cmask = einx.less_equal('j, i -> i j', ck_seq, cq_seq) + cmask = einx.less('j, i -> i j', ck_seq, cq_seq) mask_value = -torch.finfo(csim.dtype).max diff --git a/pyproject.toml b/pyproject.toml index 0dca1e6..a8c770f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "native-sparse-attention-pytorch" -version = "0.0.21" +version = "0.0.22" description = "Native Sparse Attention" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" }