Skip to content

Commit

Permalink
complete the fine attention masking with flex attention, not wired up
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Feb 20, 2025
1 parent 14c90bd commit fd0c756
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 5 deletions.
15 changes: 11 additions & 4 deletions native_sparse_attention_pytorch/native_sparse_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,17 +67,24 @@ def compress_mask(_, __, q_idx, kv_idx):


def create_fine_mask(selected_block_indices: Tensor, seq_len, fine_block_size):
device = selected_block_indices.device
batch, heads = selected_block_indices.shape[:2]

one_hot_selected_block_indices = torch.zeros((*selected_block_indices.shape[:-1], seq_len // fine_block_size), device = device, dtype = torch.bool)
one_hot_selected_block_indices.scatter_(-1, selected_block_indices, True)

def fine_mask(b_idx, h_idx, q_idx, kv_idx):
selected_indices = selected_block_indices[b_idx, h_idx]

# todo - fill in logic for creating the selected kv ranges per query
compressed_q_idx = q_idx // fine_block_size
compressed_kv_idx = kv_idx // fine_block_size

block_causal_mask = compressed_q_idx > compressed_kv_idx
is_selected = one_hot_selected_block_indices[b_idx, h_idx, q_idx, compressed_kv_idx]

causal_mask = q_idx >= kv_idx
block_diagonal = (q_idx // fine_block_size) == (kv_idx // fine_block_size)
block_diagonal = compressed_q_idx == compressed_kv_idx

return (block_diagonal & causal_mask)
return (causal_mask & block_diagonal) | (block_causal_mask & is_selected)

block_mask = create_block_mask(fine_mask, B = batch, H = heads, Q_LEN = seq_len, KV_LEN = seq_len, _compile = True)
return block_mask
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "native-sparse-attention-pytorch"
version = "0.0.23"
version = "0.0.24"
description = "Native Sparse Attention"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
Expand Down

0 comments on commit fd0c756

Please sign in to comment.