Skip to content

Commit

Permalink
make a decision to deviate from their diagram, where the last token o…
Browse files Browse the repository at this point in the history
…f the compressed block does not attend to the compressed version of that block. this is so during selection process, it does not re-select the current local block, which is already covered by both fine and sliding window attention..
  • Loading branch information
lucidrains committed Feb 20, 2025
1 parent 2b9c54d commit ae46170
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
4 changes: 2 additions & 2 deletions native_sparse_attention_pytorch/native_sparse_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def create_compress_mask(seq_len, kv_seq_len, compress_block_size):
def compress_mask(_, __, q_idx, kv_idx):
compress_kv_idx = (kv_idx * compress_block_size) + (compress_block_size - 1)

causal_mask = q_idx >= compress_kv_idx
causal_mask = q_idx > compress_kv_idx
return causal_mask

block_mask = create_block_mask(compress_mask, B = None, H = None, Q_LEN = seq_len, KV_LEN = kv_seq_len, _compile = True)
Expand Down Expand Up @@ -266,7 +266,7 @@ def forward(
ck_seq = ((arange(num_compress_blocks, device = device) + 1) * self.compress_block_size) - 1
ck_seq = F.pad(ck_seq, (num_mem_compress_kv, 0), value = -1)

cmask = einx.less_equal('j, i -> i j', ck_seq, cq_seq)
cmask = einx.less('j, i -> i j', ck_seq, cq_seq)

mask_value = -torch.finfo(csim.dtype).max

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "native-sparse-attention-pytorch"
version = "0.0.21"
version = "0.0.22"
description = "Native Sparse Attention"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
Expand Down

0 comments on commit ae46170

Please sign in to comment.