Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate with Mamba #2

Open
jeromeku opened this issue Jan 13, 2024 · 12 comments
Open

Integrate with Mamba #2

jeromeku opened this issue Jan 13, 2024 · 12 comments

Comments

@jeromeku
Copy link

@proger

Awesome work! Always appreciate the wonderful contributions of OSS advancing the frontiers of research.

I know you've done a number of experiments comparing various scan implementations in your other repo nanokitchen -- would it make sense to integrate accelerated-scan as an alternative backend to Mamba? Would be happy to work on this if you think it makes sense.

@proger
Copy link
Owner

proger commented Jan 13, 2024

@jeromeku thank you for the kind words! Glad you checked out nanokitchen as well.

It would be indeed possible to use Accelerated Scan for Mamba as is, however would work best for experimental purposes — Mamba kernel is already designed to achieve the best performance for that architecture.

Concretely, Mamba kernel fuses cub::BlockScan with SSM state expansion operations: A matrix expands every gate dimension (a gate is called delta in Mamba, it's expected that those deltas are stored in log space) into a 16-dimensional SSM, B respectively expands every input token dimension to match gate expansion and C collapses every SSM back. Accelerated Scan will have to accept expanded SSMs as inputs and waste precious memory bandwidth.

Mamba's review gives a hint that memory footprint could be improved for that kernel — a good direction would be to understand why is that the case. Reference: https://openreview.net/forum?id=AL1fq05o7H&noteId=T6WJZb30sz

@proger
Copy link
Owner

proger commented Jan 13, 2024

I found that @srush has done this exact fusion of the SSM bits into the Triton forward kernel here: srush/annotated-mamba#1 (comment)

@srush
Copy link

srush commented Jan 13, 2024

Yeah thanks! Your repo was super helpful for that, we couldn't figure out how to do the two value scan.

Unfortunately I'm stuck now on the backwards. Need to do the scan right-to-left. I see that you do it by loading values in reverse order. Unfortunately we need to reverse the tensor in local memory (or repeat a lot of computation).

Any ideas? I think I might try making an LxL matrix and doing a dot? It seems like overkill, but I'm stuck for other methods.

@jeromeku
Copy link
Author

@proger @srush
will take a closer look and report back...

@proger
Copy link
Owner

proger commented Jan 14, 2024

There's some discussion about making a reverse tl.associative_scan in triton-lang/triton#2930

@srush
Copy link

srush commented Jan 14, 2024

Yes, that's issue is from me as well.

The reading the memory reverse trick is nice in your codebase. The problem is that for Mamba, you need to run the backward scan on an intermediately calculated tensor that is too large to store. Therefore you need to either reverse it in memory or have a reverse associative scan.

@jeromeku
Copy link
Author

@proger

Any luck integrating a reverse option to the Triton backend?

Trying to get up to speed with MLIR :)

@srush
Copy link

srush commented Jan 19, 2024

I sent them a PR for a flip function at the triton level which should be okay: triton-lang/triton#2954 Although would be interesting to do something more low-level

@jeromeku
Copy link
Author

@srush

Thanks -- saw that PR and agree that a more low-level approach would be a worthwhile exercise. Always helps to understand how things work underneath the hood. MLIR is a bit of a beast.

FYI, this series of tutorials is a great intro to MLIR. Also, NVIDIA's cutlass library has similar abstractions (i.e., GEMM hierarchy) as triton, though triton is clearly more extensible to a wider variety of problems and backends.

@Jokeren
Copy link

Jokeren commented Jan 19, 2024

It might be something similar to convert_layout from distribute to distribute in most cases. Feel free to take a look at relevant code.

@Jokeren
Copy link

Jokeren commented Jan 19, 2024

I do think the current python solution is more elegant though.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants
@srush @proger @Jokeren @jeromeku and others