Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Scaled matmul and scaled dot_general in jax.nn #26345

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

wenscarl
Copy link
Contributor

@wenscarl wenscarl commented Feb 5, 2025

This PR introduces two new nn APIs:

  1. scaled_matmul – Performs matrix multiplication with scaling factors.
  2. scaled_dot_general – Implements a generalized dot product operation with quantization, matmul and dequantization.
    Currently, only the e8m0fnu type scaling factor is supported.
    @kaixih @nouiz

@wenscarl wenscarl marked this pull request as ready for review February 13, 2025 04:36
@wenscarl wenscarl changed the title Scaled matmul for mxfp8 Scaled matmul and dot_general Feb 13, 2025
@wenscarl wenscarl changed the title Scaled matmul and dot_general Support Scaled matmul and scaled dot_general in nn Feb 13, 2025
@wenscarl wenscarl changed the title Support Scaled matmul and scaled dot_general in nn Support Scaled matmul and scaled dot_general in jax.nn Feb 13, 2025
@kaixih
Copy link
Contributor

kaixih commented Feb 13, 2025

@IvyZX @sdasgup3 can you help review or find other reviewers?

@IvyZX
Copy link
Collaborator

IvyZX commented Feb 14, 2025

PTAL test errors - looks like e8m0fnu is a not type that Numpy supports. I'm not sure if we should adopt this type in JAX.

Also, is it our common practice to add purely-cudnn APIs to jax.nn? Looks like in #21371 an XLA implementation is provided as an alternative for users without cudnn.

Adding maintainers who know more about JAX's lower level common practices: @hawkinsp @superbobry

@kaixih
Copy link
Contributor

kaixih commented Feb 14, 2025

PTAL test errors - looks like e8m0fnu is a not type that Numpy supports. I'm not sure if we should adopt this type in JAX.

This dtype should already in the JAX (see here). The dtype is first defined in ml_dtype since 0.5.0 and imported to JAX.

Also, is it our common practice to add purely-cudnn APIs to jax.nn? Looks like in #21371 an XLA implementation is provided as an alternative for users without cudnn.

The scaled_matmul and scaled_dot_general are supposed to be general and not cudnn-specific. Yes, both of them call into the cudnn_scaled_matmul, which appears to be purely-cudnn API. But cudnn_scaled_matmul will actually fall back to XLA impl if cudnn doesn't support the case (see this XLA pass). So, that explains why we don't see the XLA impl in the python code.

@sdasgup3
Copy link
Contributor

cc @reedwm

Copy link
Collaborator

@IvyZX IvyZX left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please also make sure all CI tests pass.

mode='mxfp8',
block_size=32,
data_type=jnp.float8_e4m3fn,
scale_type=jnp.float8_e8m0fnu,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that JAX doesn't require ml_dtypes >= 0.5.0, so jax.numpy.float8_e8m0fnu is not guaranteed to exist.

Please also avoid defining a default fp8 config as a global instance - you can always pass it in via jax.nn level API.

_scaled_matmul_p.def_abstract_eval(_scaled_matmul_abstract)


mlir.register_lowering(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See CI test failures - this code only runs on GPU now.

You will need to either:

  • add lowering fallback to XLA for other platforms like CPU and TPU, or
  • explicitly say your implementation only works for certain platforms and prevent invalid usages: platform-check or require an explicit platform arg in the API (like this cudnn implementation), and disable tests on other platforms.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants