-
Notifications
You must be signed in to change notification settings - Fork 308
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[JAX] Expose sliding window attn to TE-JAX API #1205
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Hua Huang <huah@nvidia.com>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for this PR! Have a few small questions/comments.
Could you port the SWA to |
Signed-off-by: Hua Huang <huah@nvidia.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Please address @mingxu1067 comments and will be in good shape.
Signed-off-by: Hua Huang <huah@nvidia.com>
for more information, see https://pre-commit.ci
Hi, MaxText is using the |
Hi @kocchop I am working on it, should be able to commit new changes next week |
Will update tests/jax/test_praxis_layers.py next Signed-off-by: Hua Huang <huah@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Hua Huang <huah@nvidia.com>
PR #1212 affects this PR. Maybe we should wait for this PR. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, pending for CI and #1212
@zlsh80826, could you help review this PR as well? Thanky you. |
Signed-off-by: Hua Huang <huah@nvidia.com>
for more information, see https://pre-commit.ci
Sure, I will review it by this week. |
@@ -943,7 +971,14 @@ def __post_init__(self): | |||
|
|||
@nn.compact | |||
def __call__(self, inputs, encoder_mask=None, deterministic=False): | |||
del self.self_attn_mask_type # dummy, just align to TE's impl | |||
# Currently cuDNN backend only supports SWA for causal/padding_causal, follow this | |||
if self.self_attn_mask_type in ["causal", "padding_causal"] and self.window_size[0] > 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the expected behavior if self.self_attn_mask_type == 'padding' and self.window_size[0] > 0?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In test_fused_attn.py, if we use padding mask and a window_size[0] > 0
, no supported C++ backend is available, and the tests are all skipped. The EncoderLayer and DecoderLayer classes in this utils.py
will only be used in test_layer.py
. Added some code in test_layer.py
to skip if window_size[0] > 0
and not using causal/padding_causal mask.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logical between test_layer.py
and utils.py
should be seperated. We don't know when will other people make changes to test_layer.py
or utils.py
standalone. If someone don't know the relationship between them, they might only change only one side and spend lots of time on debugging.
Currently, I noticed that this call doesn’t handle self.self_attn_mask_type == 'padding', nor does it provide any warning or error. It’s important not to assume this function will only be used by test_layer.py
or padding
will not be passed, as new developers may not be aware of these assumptions. If padding is not supported at the moment, it would be best to raise an exception within this function to prevent unintended behavior.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does the following code look good?
if self.window_size[0] > 0:
if self.self_attn_mask_type in ["causal", "padding_causal"]:
encoder_mask = apply_swa_mask(
self.self_attn_mask_type,
encoder_mask,
self.window_size,
)
else:
raise NotImplementedError("cuDNN only supports SWA for causal and padding_causal")
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good
Signed-off-by: Hua Huang <huah@nvidia.com>
for more information, see https://pre-commit.ci
/te-ci jax |
@@ -1042,6 +1056,10 @@ class FusedAttnCPWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive): | |||
def partition(config, mesh, arg_infos, result_infos): | |||
# Call base implementation for non-context parallel mesh to avoid unecessary work. | |||
is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1 | |||
if is_context_parallel and config.window_size[0] > -1: | |||
assert ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think raise NotImplementedError
is more suitable here. Actually assert may be ignored in the optimization mode in python or release mode in C++.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I use assert here to follow the earlier code in this file, for example, the asserts in fused_attn_fwd()
(lines 1353 to 1366). Do I only need to replace the asserts in my commits with raise NotImplementedError
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, let's just keep the assert style. But the if statement here doesn't look need.
if is_context_parallel and config.window_size[0] > -1:
assert (
is_context_parallel and config.window_size[0] == -1
), "Sliding window attention is not supported when context parallelism is enabled"
equals to
assert (
is_context_parallel and config.window_size[0] == -1
), "Sliding window attention is not supported when context parallelism is enabled"
@@ -1136,6 +1154,10 @@ class FusedAttnCPWithAllGatherBwdPrimitive(FusedAttnBwdPrimitive): | |||
def partition(config, mesh, arg_infos, result_infos): | |||
# Call base implementation for non-context parallel mesh to avoid unecessary work. | |||
is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1 | |||
if is_context_parallel and config.window_size[0] > -1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above
def convert_to_softmax_type(attn_mask_type, mask): | ||
"""Convert the attn_mask_type to SoftmaxType""" | ||
# mask is ignored for no_mask and causal_mask | ||
if attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]: | ||
# mask is ignored for no_mask and causal_mask without sliding window |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to raise a ValueError when we got self.window_size[0] >= -1 and AttnMaskType == NO_MASK or PADDING_MASK
attn_mask_type | window_size | ||
------------------------------------------------------------------------- | ||
NO_MASK, PADDING_MASK | (-1, -1) or (>=0, >=0) | ||
CAUSAL_MASK | (-1, 0) or (>=0, 0) | ||
PADDING_CAUSAL_MASK | (-1, 0) or (>=0, 0) | ||
CAUSAL_BOTTOM_RIGHT_MASK | (-1, 0) or (>=0, 0) | ||
PADDING_CAUSAL_BOTTOM_RIGHT_MASK | (-1, 0) or (>=0, 0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @cyanguwa, do you think to use None
for no sliding window a better idea? I found there are lots of logical with greater than -1 but it looks like not easily to maintain. I think using None as no sliding window can enhance the code readability and maintainability.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have used None
for the higher-level APIs, for example, 1, 2, 3, 4, but I've used a check_set_window_size function to make sure window_size
is consistent with mask type before passing it further down.
Description
Recent models employ sliding window attention (SWA). Some frameworks use cuDNN fused attention through the TE-JAX Flash Attention API. The SWA support has not been exposed to this API yet. However, on the backend TE does have support for the SWA. This PR expose the SWA support to the Flash Attention API.
Type of change
Changes
Please list the changes introduced in this PR:
Expose sliding window attention to the TE-JAX API
Checklist: