Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[JAX] Expose sliding window attn to TE-JAX API #1205

Open
wants to merge 12 commits into
base: main
Choose a base branch
from

Conversation

huanghua1994
Copy link
Collaborator

@huanghua1994 huanghua1994 commented Sep 25, 2024

Description

Recent models employ sliding window attention (SWA). Some frameworks use cuDNN fused attention through the TE-JAX Flash Attention API. The SWA support has not been exposed to this API yet. However, on the backend TE does have support for the SWA. This PR expose the SWA support to the Flash Attention API.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refractor

Changes

Please list the changes introduced in this PR:

Expose sliding window attention to the TE-JAX API

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Hua Huang <huah@nvidia.com>
Copy link
Collaborator

@mgoldfarb-nvidia mgoldfarb-nvidia left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for this PR! Have a few small questions/comments.

tests/jax/test_fused_attn.py Outdated Show resolved Hide resolved
transformer_engine/jax/cpp_extensions/attention.py Outdated Show resolved Hide resolved
tests/jax/test_fused_attn.py Show resolved Hide resolved
@mingxu1067
Copy link
Collaborator

Could you port the SWA to flax and praxis modules as well?

Signed-off-by: Hua Huang <huah@nvidia.com>
Copy link
Collaborator

@mgoldfarb-nvidia mgoldfarb-nvidia left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Please address @mingxu1067 comments and will be in good shape.

Hua Huang and others added 2 commits September 27, 2024 10:56
@kocchop
Copy link

kocchop commented Sep 29, 2024

Hi, MaxText is using the DotProductAttention API from transformer_engine/jax/flax/transformer.py. It'd be super useful to expose the SWA to this

@huanghua1994
Copy link
Collaborator Author

Hi @kocchop I am working on it, should be able to commit new changes next week

Hua Huang and others added 3 commits September 30, 2024 16:30
Will update tests/jax/test_praxis_layers.py next

Signed-off-by: Hua Huang <huah@nvidia.com>
Signed-off-by: Hua Huang <huah@nvidia.com>
@huanghua1994 huanghua1994 changed the title Expose sliding window attn to TE-JAX API [JAX] Expose sliding window attn to TE-JAX API Oct 1, 2024
@huanghua1994
Copy link
Collaborator Author

PR #1212 affects this PR. Maybe we should wait for this PR.

Copy link
Collaborator

@mingxu1067 mingxu1067 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, pending for CI and #1212

transformer_engine/jax/attention.py Outdated Show resolved Hide resolved
tests/jax/test_fused_attn.py Outdated Show resolved Hide resolved
@mingxu1067
Copy link
Collaborator

@zlsh80826, could you help review this PR as well? Thanky you.

Hua Huang and others added 2 commits October 1, 2024 15:28
@zlsh80826
Copy link
Collaborator

@zlsh80826, could you help review this PR as well? Thanky you.

Sure, I will review it by this week.

tests/jax/test_fused_attn.py Outdated Show resolved Hide resolved
@@ -943,7 +971,14 @@ def __post_init__(self):

@nn.compact
def __call__(self, inputs, encoder_mask=None, deterministic=False):
del self.self_attn_mask_type # dummy, just align to TE's impl
# Currently cuDNN backend only supports SWA for causal/padding_causal, follow this
if self.self_attn_mask_type in ["causal", "padding_causal"] and self.window_size[0] > 0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the expected behavior if self.self_attn_mask_type == 'padding' and self.window_size[0] > 0?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In test_fused_attn.py, if we use padding mask and a window_size[0] > 0, no supported C++ backend is available, and the tests are all skipped. The EncoderLayer and DecoderLayer classes in this utils.py will only be used in test_layer.py. Added some code in test_layer.py to skip if window_size[0] > 0 and not using causal/padding_causal mask.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logical between test_layer.py and utils.py should be seperated. We don't know when will other people make changes to test_layer.py or utils.py standalone. If someone don't know the relationship between them, they might only change only one side and spend lots of time on debugging.

Currently, I noticed that this call doesn’t handle self.self_attn_mask_type == 'padding', nor does it provide any warning or error. It’s important not to assume this function will only be used by test_layer.py or padding will not be passed, as new developers may not be aware of these assumptions. If padding is not supported at the moment, it would be best to raise an exception within this function to prevent unintended behavior.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the following code look good?

if self.window_size[0] > 0:
    if self.self_attn_mask_type in ["causal", "padding_causal"]:
        encoder_mask = apply_swa_mask(
            self.self_attn_mask_type,
            encoder_mask,
            self.window_size,
        )
    else:
        raise NotImplementedError("cuDNN only supports SWA for causal and padding_causal")

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good

tests/jax/test_fused_attn.py Outdated Show resolved Hide resolved
Hua Huang and others added 2 commits October 3, 2024 10:18
Signed-off-by: Hua Huang <huah@nvidia.com>
@zlsh80826
Copy link
Collaborator

/te-ci jax

@@ -1042,6 +1056,10 @@ class FusedAttnCPWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive):
def partition(config, mesh, arg_infos, result_infos):
# Call base implementation for non-context parallel mesh to avoid unecessary work.
is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1
if is_context_parallel and config.window_size[0] > -1:
assert (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think raise NotImplementedError is more suitable here. Actually assert may be ignored in the optimization mode in python or release mode in C++.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I use assert here to follow the earlier code in this file, for example, the asserts in fused_attn_fwd() (lines 1353 to 1366). Do I only need to replace the asserts in my commits with raise NotImplementedError?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, let's just keep the assert style. But the if statement here doesn't look need.

        if is_context_parallel and config.window_size[0] > -1:
            assert (
                is_context_parallel and config.window_size[0] == -1
            ), "Sliding window attention is not supported when context parallelism is enabled"

equals to

        assert (
                is_context_parallel and config.window_size[0] == -1
        ), "Sliding window attention is not supported when context parallelism is enabled"

@@ -1136,6 +1154,10 @@ class FusedAttnCPWithAllGatherBwdPrimitive(FusedAttnBwdPrimitive):
def partition(config, mesh, arg_infos, result_infos):
# Call base implementation for non-context parallel mesh to avoid unecessary work.
is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1
if is_context_parallel and config.window_size[0] > -1:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above

def convert_to_softmax_type(attn_mask_type, mask):
"""Convert the attn_mask_type to SoftmaxType"""
# mask is ignored for no_mask and causal_mask
if attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]:
# mask is ignored for no_mask and causal_mask without sliding window
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to raise a ValueError when we got self.window_size[0] >= -1 and AttnMaskType == NO_MASK or PADDING_MASK

transformer_engine/jax/flax/transformer.py Show resolved Hide resolved
transformer_engine/jax/flax/transformer.py Show resolved Hide resolved
transformer_engine/jax/flax/transformer.py Show resolved Hide resolved
Comment on lines +98 to +104
attn_mask_type | window_size
-------------------------------------------------------------------------
NO_MASK, PADDING_MASK | (-1, -1) or (>=0, >=0)
CAUSAL_MASK | (-1, 0) or (>=0, 0)
PADDING_CAUSAL_MASK | (-1, 0) or (>=0, 0)
CAUSAL_BOTTOM_RIGHT_MASK | (-1, 0) or (>=0, 0)
PADDING_CAUSAL_BOTTOM_RIGHT_MASK | (-1, 0) or (>=0, 0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @cyanguwa, do you think to use None for no sliding window a better idea? I found there are lots of logical with greater than -1 but it looks like not easily to maintain. I think using None as no sliding window can enhance the code readability and maintainability.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have used None for the higher-level APIs, for example, 1, 2, 3, 4, but I've used a check_set_window_size function to make sure window_size is consistent with mask type before passing it further down.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request jax
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants