Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support variable-length sequences for mamba block #244

Open
wants to merge 37 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
d28e1b0
add cu_seqlens support and ensure numerical equality
zigzagcai Mar 8, 2024
a78a9eb
add notes for variable length sequences
zigzagcai Mar 14, 2024
e223353
fix typos
zigzagcai Mar 15, 2024
5955450
fix typos
zigzagcai Mar 18, 2024
ca189f6
Merge branch 'main' into feat/add-cu_seqlens
zigzagcai Mar 18, 2024
c2d5b88
fix typos
Dmovic Mar 18, 2024
db0dd09
fix typos
zigzagcai Mar 18, 2024
842bef5
Merge branch 'main' into feat/add-cu_seqlens
zigzagcai Mar 18, 2024
e7774aa
refine cu_seqlens implementation
zigzagcai Mar 18, 2024
1ccc60f
Merge branch 'feat/add-cu_seqlens' into feat/add-cu_seqlens
Dmovic Mar 19, 2024
4bf2697
Merge pull request #1 from Dmovic/feat/add-cu_seqlens
zigzagcai Mar 19, 2024
f357c44
add unit test for variable length
Dmovic Mar 19, 2024
6b98161
update unit test
Dmovic Mar 19, 2024
e4af927
fix typos
zigzagcai Mar 19, 2024
4221d48
update selective scan
zigzagcai Mar 25, 2024
934c0e6
Add logic for variable-length sequences
wang-zerui Mar 25, 2024
63b646d
Merge branch 'main' into feat/add-cu_seqlens
zigzagcai Apr 18, 2024
f6bb7e2
add example test to prove the mathematical equivalence of cu_seqlens …
zigzagcai Apr 26, 2024
bffcd97
fix typos
zigzagcai Apr 26, 2024
e3cab98
add cu_seqlens support for MixerModel
zigzagcai Apr 26, 2024
2f01ede
code refine for tests
zigzagcai Apr 30, 2024
f0a6508
refine code for tests
zigzagcai Apr 30, 2024
623d246
update API notes
zigzagcai Apr 30, 2024
ef3f760
update test code
zigzagcai Apr 30, 2024
71c77b1
Merge remote-tracking branch 'origin/main' into feat/add-cu_seqlens
zigzagcai Jun 6, 2024
2d27ccc
fix conflicts with latest main branch
zigzagcai Jun 6, 2024
f802627
Merge remote-tracking branch 'origin/main' into feat/add-cu_seqlens
zigzagcai Jul 16, 2024
596943c
fix unittest for test_selective_state_update_with_heads
zigzagcai Jul 16, 2024
6961faa
Merge branch 'state-spaces:main' into feat/add-cu_seqlens
zigzagcai Jul 18, 2024
b69b957
migrate to tridao's native varlen causal_conv1d kernel for speedup
zigzagcai Jul 19, 2024
50bffae
Merge branch 'state-spaces:main' into feat/add-cu_seqlens
zigzagcai Jul 22, 2024
909f970
typo fix
zigzagcai Jul 23, 2024
8174c45
use seq_idx if provided, or compute it by cu_seqlens
zigzagcai Aug 5, 2024
59be631
use seq_idx if provided, or compute it by cu_seqlens
zigzagcai Aug 5, 2024
3bc4a51
Merge branch 'state-spaces:main' into feat/add-cu_seqlens
zigzagcai Aug 6, 2024
210b6f6
mv cu_seqlens in ssm kernel to smem
zigzagcai Aug 7, 2024
cda4b5a
remove smem implementation because const vals and bi-search is enough
zigzagcai Aug 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix typos
  • Loading branch information
zigzagcai committed Mar 19, 2024
commit e4af927d7accfddf9ca0ec0d96b8b812c49324ac
3 changes: 2 additions & 1 deletion mamba_ssm/ops/selective_scan_interface.py
Original file line number Diff line number Diff line change
@@ -257,7 +257,7 @@ def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weigh
conv1d_out, delta = None, None
ctx.save_for_backward(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight,
delta_proj_weight, out_proj_weight, conv1d_out, delta,
A, B, C, D, delta_bias, scan_intermediates, out, cu_seqlens, d_conv)
A, B, C, D, delta_bias, scan_intermediates, out, cu_seqlens, torch.tensor([d_conv]))
return F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias)

@staticmethod
@@ -267,6 +267,7 @@ def backward(ctx, dout):
assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d."
(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, out_proj_weight,
conv1d_out, delta, A, B, C, D, delta_bias, scan_intermediates, out, cu_seqlens, d_conv) = ctx.saved_tensors
d_conv = d_conv.item()
L = xz.shape[-1]
delta_rank = delta_proj_weight.shape[1]
d_state = A.shape[-1] * (1 if not A.is_complex() else 2)