-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Integrate with Mamba #2
Comments
@jeromeku thank you for the kind words! Glad you checked out nanokitchen as well. It would be indeed possible to use Accelerated Scan for Mamba as is, however would work best for experimental purposes — Mamba kernel is already designed to achieve the best performance for that architecture. Concretely, Mamba kernel fuses Mamba's review gives a hint that memory footprint could be improved for that kernel — a good direction would be to understand why is that the case. Reference: https://openreview.net/forum?id=AL1fq05o7H¬eId=T6WJZb30sz |
I found that @srush has done this exact fusion of the SSM bits into the Triton forward kernel here: srush/annotated-mamba#1 (comment) |
Yeah thanks! Your repo was super helpful for that, we couldn't figure out how to do the two value scan. Unfortunately I'm stuck now on the backwards. Need to do the scan right-to-left. I see that you do it by loading values in reverse order. Unfortunately we need to reverse the tensor in local memory (or repeat a lot of computation). Any ideas? I think I might try making an LxL matrix and doing a dot? It seems like overkill, but I'm stuck for other methods. |
There's some discussion about making a reverse |
Yes, that's issue is from me as well. The reading the memory reverse trick is nice in your codebase. The problem is that for Mamba, you need to run the backward scan on an intermediately calculated tensor that is too large to store. Therefore you need to either reverse it in memory or have a reverse associative scan. |
Any luck integrating a Trying to get up to speed with |
I sent them a PR for a flip function at the triton level which should be okay: triton-lang/triton#2954 Although would be interesting to do something more low-level |
Thanks -- saw that PR and agree that a more low-level approach would be a worthwhile exercise. Always helps to understand how things work underneath the hood. MLIR is a bit of a beast. FYI, this series of tutorials is a great intro to MLIR. Also, NVIDIA's cutlass library has similar abstractions (i.e., GEMM hierarchy) as |
It might be something similar to convert_layout from distribute to distribute in most cases. Feel free to take a look at relevant code. |
I do think the current python solution is more elegant though. |
@proger
Awesome work! Always appreciate the wonderful contributions of OSS advancing the frontiers of research.
I know you've done a number of experiments comparing various scan implementations in your other repo
nanokitchen
-- would it make sense to integrateaccelerated-scan
as an alternative backend toMamba
? Would be happy to work on this if you think it makes sense.The text was updated successfully, but these errors were encountered: