Skip to content

Commit

Permalink
copy paste in a working impl of flash attention from ring-attention-p…
Browse files Browse the repository at this point in the history
…ytorch repo for modification. get basic scaffolding ready
  • Loading branch information
lucidrains committed Feb 23, 2025
1 parent 582b844 commit b1dee31
Show file tree
Hide file tree
Showing 3 changed files with 1,172 additions and 1 deletion.
15 changes: 14 additions & 1 deletion native_sparse_attention_pytorch/native_sparse_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ def __init__(
num_compressed_mem_kv = 1,
norm = True,
use_diff_topk = False,
use_triton_kernel = False,
interpolated_importance_score = False,
query_heads_share_selected_kv = True, # if set to True, importance score is averaged across query heads to select top-n buckets of kv per kv head - but can be set to False for each query head within a group to look at different sets of kv buckets. will be more memory and compute of course
compress_mlp: Module | None = None,
Expand Down Expand Up @@ -287,6 +288,8 @@ def __init__(

self.num_selected_blocks = num_selected_blocks

self.use_triton_kernel = use_triton_kernel

# they combine the three sparse branches through a learned combine with sigmoid activation

if not exists(strategy_combine_mlp):
Expand Down Expand Up @@ -438,7 +441,17 @@ def forward(
gates = gates.cumprod(dim = -1)[..., -1]
gates = repeat(gates, 'b h ... -> b (h qh) ...', qh = fine_num_grouped_queries)

if exists(fine_selection_flex_mask):
if self.use_triton_kernel:
from native_sparse_attention_pytorch.triton_native_sparse_attention import native_sparse_attend

fine_attn_out = native_sparse_attend(
fq, fk, fv,
self.selection_block_size,
selected_block_indices,
fine_num_grouped_queries
)

elif exists(fine_selection_flex_mask):
# flex attention for the selection for fine attention

fine_block_mask = fine_selection_flex_mask(selected_block_indices, num_grouped_queries = fine_num_grouped_queries)
Expand Down
7 changes: 7 additions & 0 deletions native_sparse_attention_pytorch/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ def exists(v):
def default(v, d):
return v if exists(v) else d

def at_most_one_of(*bools):
return sum([*map(int, bools)]) <= 1

# attention

class Attention(Module):
Expand Down Expand Up @@ -123,6 +126,7 @@ def __init__(
use_sparse_attn = False,
use_flex_sliding_window = False,
use_flex_fine_selection = False,
use_triton_fine_selection = False,
sparse_attn_kwargs: dict = dict(
sliding_window_size = 32,
compress_block_size = 4,
Expand All @@ -131,6 +135,8 @@ def __init__(
)
):
super().__init__()
assert at_most_one_of(use_flex_fine_selection, use_triton_fine_selection), 'either using flex or custom triton kernel for fine attn, but not both'

self.token_emb = nn.Embedding(num_tokens, dim)

if use_flex_sliding_window or use_flex_fine_selection:
Expand All @@ -149,6 +155,7 @@ def __init__(
dim_head = dim_head,
heads = heads,
kv_heads = kv_heads,
use_triton_kernel = use_triton_fine_selection,
**sparse_attn_kwargs
)
else:
Expand Down
Loading

0 comments on commit b1dee31

Please sign in to comment.