-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support Scaled matmul and scaled dot_general in jax.nn #26345
base: main
Are you sure you want to change the base?
Conversation
f22cff9
to
332af58
Compare
PTAL test errors - looks like Also, is it our common practice to add purely-cudnn APIs to Adding maintainers who know more about JAX's lower level common practices: @hawkinsp @superbobry |
This dtype should already in the JAX (see here). The dtype is first defined in ml_dtype since 0.5.0 and imported to JAX.
The |
cc @reedwm |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please also make sure all CI tests pass.
mode='mxfp8', | ||
block_size=32, | ||
data_type=jnp.float8_e4m3fn, | ||
scale_type=jnp.float8_e8m0fnu, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that JAX doesn't require ml_dtypes >= 0.5.0
, so jax.numpy.float8_e8m0fnu
is not guaranteed to exist.
Please also avoid defining a default fp8 config as a global instance - you can always pass it in via jax.nn
level API.
_scaled_matmul_p.def_abstract_eval(_scaled_matmul_abstract) | ||
|
||
|
||
mlir.register_lowering( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See CI test failures - this code only runs on GPU now.
You will need to either:
- add lowering fallback to XLA for other platforms like CPU and TPU, or
- explicitly say your implementation only works for certain platforms and prevent invalid usages: platform-check or require an explicit platform arg in the API (like this cudnn implementation), and disable tests on other platforms.
f68837f
to
ac92b13
Compare
ac92b13
to
bfb9d3c
Compare
This PR introduces two new nn APIs:
scaled_matmul
– Performs matrix multiplication with scaling factors.scaled_dot_general
– Implements a generalized dot product operation with quantization, matmul and dequantization.Currently, only the e8m0fnu type scaling factor is supported.
@kaixih @nouiz