-
Notifications
You must be signed in to change notification settings - Fork 308
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[JAX] Expose sliding window attn to TE-JAX API #1205
Open
huanghua1994
wants to merge
12
commits into
NVIDIA:main
Choose a base branch
from
huanghua1994:sliding-window-attn-dev
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+485
−70
Open
Changes from 10 commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
7163946
Expose JAX sliding window attn API
6a5b20c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 3dae9ad
No SWA in context parallel; fix RNG seed in test
813287e
Handle SAW API discrepancy in cuDNN and Python
f9ee6e5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] f56b28a
Add SAW API for flax, all tests passed
6f4df48
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 6440665
Update test_praxis_layers.py for SWA, test passed
52a2d1c
Use tuple window_size; update for PR #1212
a4e647e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] b420504
Add and adjust some pytest.skip
ddb3fe7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -6,6 +6,7 @@ | |||||||||||||||||||||||||||||
from dataclasses import dataclass | ||||||||||||||||||||||||||||||
from functools import partial | ||||||||||||||||||||||||||||||
from math import sqrt | ||||||||||||||||||||||||||||||
from typing import Tuple | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
import jax | ||||||||||||||||||||||||||||||
import jax.numpy as jnp | ||||||||||||||||||||||||||||||
|
@@ -27,6 +28,8 @@ | |||||||||||||||||||||||||||||
fused_attn, | ||||||||||||||||||||||||||||||
fused_attn_thd, | ||||||||||||||||||||||||||||||
get_qkv_format, | ||||||||||||||||||||||||||||||
check_set_window_size, | ||||||||||||||||||||||||||||||
get_swa_mask, | ||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||
from transformer_engine.jax.cpp_extensions import FusedAttnHelper | ||||||||||||||||||||||||||||||
from transformer_engine.transformer_engine_jax import ( | ||||||||||||||||||||||||||||||
|
@@ -123,6 +126,7 @@ def make_mask( | |||||||||||||||||||||||||||||
segment_pad_q: ArrayLike, | ||||||||||||||||||||||||||||||
segment_pad_kv: ArrayLike, | ||||||||||||||||||||||||||||||
attn_mask_type: AttnMaskType, | ||||||||||||||||||||||||||||||
window_size: Tuple[int, int], | ||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||
) -> Array: | ||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||
Create attention mask based on mask type. A `True` value in the mask means | ||||||||||||||||||||||||||||||
|
@@ -140,6 +144,15 @@ def make_mask( | |||||||||||||||||||||||||||||
segment_pad_q, segment_pad_kv, lambda x, y: jnp.logical_and(x != 1, y != 1) | ||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||
inv_mask = combine_masks(inv_pad_mask, inv_mask) | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
if window_size[0] >= 0: | ||||||||||||||||||||||||||||||
max_seqlen_q = inv_mask.shape[-2] | ||||||||||||||||||||||||||||||
max_seqlen_kv = inv_mask.shape[-1] | ||||||||||||||||||||||||||||||
swa_mask = get_swa_mask(window_size, max_seqlen_q, max_seqlen_kv, attn_mask_type) | ||||||||||||||||||||||||||||||
swa_mask_bcast = jnp.broadcast_to(swa_mask, inv_mask.shape) | ||||||||||||||||||||||||||||||
# In swa_mask and inv_mask 0 is masked out | ||||||||||||||||||||||||||||||
inv_mask = jnp.where(inv_mask != 0, swa_mask_bcast, inv_mask) | ||||||||||||||||||||||||||||||
Comment on lines
+148
to
+154
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
mask = jnp.logical_not(inv_mask) | ||||||||||||||||||||||||||||||
return mask | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
|
@@ -274,6 +287,7 @@ class FusedAttnRunner: | |||||||||||||||||||||||||||||
is_training: bool | ||||||||||||||||||||||||||||||
qkv_layout: QKVLayout | ||||||||||||||||||||||||||||||
bias_shape: BiasShape | ||||||||||||||||||||||||||||||
window_size: Tuple[int, int] = (-1, -1) | ||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
# See https://docs.nvidia.com/deeplearning/cudnn/latest/release-notes.html#cudnn-9-4-0 for known issue | ||||||||||||||||||||||||||||||
# generating zero-length ragged tensors. This setting adjusts the test to avoid the zero-length cases. | ||||||||||||||||||||||||||||||
|
@@ -310,6 +324,7 @@ def _check_configs(self): | |||||||||||||||||||||||||||||
self.max_seqlen_q, | ||||||||||||||||||||||||||||||
self.max_seqlen_kv, | ||||||||||||||||||||||||||||||
self.head_dim, | ||||||||||||||||||||||||||||||
self.window_size, | ||||||||||||||||||||||||||||||
).get_fused_attn_backend() | ||||||||||||||||||||||||||||||
if self.backend == NVTE_Fused_Attn_Backend.NVTE_No_Backend: | ||||||||||||||||||||||||||||||
pytest.skip("Unsupported inputs combination or device compute capability.") | ||||||||||||||||||||||||||||||
|
@@ -456,6 +471,7 @@ def generate_random_segment_ids( | |||||||||||||||||||||||||||||
self.segment_pad_q, | ||||||||||||||||||||||||||||||
self.segment_pad_kv, | ||||||||||||||||||||||||||||||
self.attn_mask_type, | ||||||||||||||||||||||||||||||
self.window_size, | ||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
if get_qkv_format(self.qkv_layout) == QKVFormat.THD: | ||||||||||||||||||||||||||||||
|
@@ -500,6 +516,7 @@ def test_forward(self): | |||||||||||||||||||||||||||||
"is_training": self.is_training, | ||||||||||||||||||||||||||||||
"qkv_layout": self.qkv_layout, | ||||||||||||||||||||||||||||||
"max_segments_per_seq": self._get_max_segments_per_sequence(), | ||||||||||||||||||||||||||||||
"window_size": self.window_size, | ||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
# Convert the outputs to float32 for the elementwise comparison | ||||||||||||||||||||||||||||||
|
@@ -557,6 +574,7 @@ def grad_func(func, *args, **kwargs): | |||||||||||||||||||||||||||||
"is_training": self.is_training, | ||||||||||||||||||||||||||||||
"qkv_layout": self.qkv_layout, | ||||||||||||||||||||||||||||||
"max_segments_per_seq": self._get_max_segments_per_sequence(), | ||||||||||||||||||||||||||||||
"window_size": self.window_size, | ||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
# We can compute dBias only for the [1, h, s, s] layout | ||||||||||||||||||||||||||||||
|
@@ -668,7 +686,7 @@ def check_dqkv(primitive, reference, pad): | |||||||||||||||||||||||||||||
pytest.param(4, 128, 128, 16, 16, 64, jnp.bfloat16, id="4-128-128-16-16-64-BF16-SELF"), | ||||||||||||||||||||||||||||||
pytest.param(4, 128, 128, 16, 16, 64, jnp.float16, id="4-128-128-16-16-64-FP16-SELF"), | ||||||||||||||||||||||||||||||
pytest.param(2, 2048, 2048, 12, 12, 64, jnp.bfloat16, id="2-2048-2048-12-12-64-BF16-SELF"), | ||||||||||||||||||||||||||||||
pytest.param(4, 512, 128, 16, 16, 64, jnp.bfloat16, id="4-512-128-16-16-64-BF16-CROSS"), | ||||||||||||||||||||||||||||||
pytest.param(4, 128, 256, 16, 16, 64, jnp.bfloat16, id="4-128-256-16-16-64-BF16-CROSS"), | ||||||||||||||||||||||||||||||
mgoldfarb-nvidia marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||
pytest.param( | ||||||||||||||||||||||||||||||
2, | ||||||||||||||||||||||||||||||
2048, | ||||||||||||||||||||||||||||||
|
@@ -677,7 +695,7 @@ def check_dqkv(primitive, reference, pad): | |||||||||||||||||||||||||||||
12, | ||||||||||||||||||||||||||||||
64, | ||||||||||||||||||||||||||||||
jnp.bfloat16, | ||||||||||||||||||||||||||||||
id="2-2048-1048-12-12-64-BF16-CROSS", | ||||||||||||||||||||||||||||||
id="2-2048-1024-12-12-64-BF16-CROSS", | ||||||||||||||||||||||||||||||
), | ||||||||||||||||||||||||||||||
pytest.param(4, 128, 128, 16, 8, 64, jnp.bfloat16, id="4-128-128-16-8-64-BF16-GQA"), | ||||||||||||||||||||||||||||||
pytest.param(2, 2048, 2048, 12, 6, 64, jnp.bfloat16, id="2-2048-2048-12-6-64-BF16-GQA"), | ||||||||||||||||||||||||||||||
|
@@ -690,6 +708,13 @@ def check_dqkv(primitive, reference, pad): | |||||||||||||||||||||||||||||
pytest.param(0.1, id="DROP_0.1"), | ||||||||||||||||||||||||||||||
], | ||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||
@pytest.mark.parametrize( | ||||||||||||||||||||||||||||||
"swa", | ||||||||||||||||||||||||||||||
[ | ||||||||||||||||||||||||||||||
pytest.param(False, id="NO_SWA"), | ||||||||||||||||||||||||||||||
pytest.param(True, id="SWA"), | ||||||||||||||||||||||||||||||
], | ||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||
class TestFusedAttn: | ||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||
Fused attention tester | ||||||||||||||||||||||||||||||
|
@@ -717,12 +742,20 @@ def _test_forward( | |||||||||||||||||||||||||||||
is_training, | ||||||||||||||||||||||||||||||
qkv_layout, | ||||||||||||||||||||||||||||||
bias_shape, | ||||||||||||||||||||||||||||||
swa, | ||||||||||||||||||||||||||||||
): | ||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||
Test forward with parameterized configs | ||||||||||||||||||||||||||||||
This test is not intended to run automatically during CI as it is time-consuming | ||||||||||||||||||||||||||||||
It is kept for development and debugging | ||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||
window_size = (-1, -1) | ||||||||||||||||||||||||||||||
if swa: | ||||||||||||||||||||||||||||||
window_size = (s_kv // 10, 0) | ||||||||||||||||||||||||||||||
if s_q > s_kv: | ||||||||||||||||||||||||||||||
pytest.skip( | ||||||||||||||||||||||||||||||
"seqlen_q > seqlen_kv is not supported with sliding window attention in cuDNN" | ||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||
zlsh80826 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||
runner = FusedAttnRunner( | ||||||||||||||||||||||||||||||
b, | ||||||||||||||||||||||||||||||
s_q, | ||||||||||||||||||||||||||||||
|
@@ -737,6 +770,7 @@ def _test_forward( | |||||||||||||||||||||||||||||
is_training, | ||||||||||||||||||||||||||||||
qkv_layout, | ||||||||||||||||||||||||||||||
bias_shape, | ||||||||||||||||||||||||||||||
window_size, | ||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||
runner.test_forward() | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
|
@@ -754,10 +788,18 @@ def test_backward( | |||||||||||||||||||||||||||||
dtype, | ||||||||||||||||||||||||||||||
qkv_layout, | ||||||||||||||||||||||||||||||
bias_shape, | ||||||||||||||||||||||||||||||
swa, | ||||||||||||||||||||||||||||||
): | ||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||
Test backward with parameterized configs | ||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||
window_size = (-1, -1) | ||||||||||||||||||||||||||||||
if swa: | ||||||||||||||||||||||||||||||
window_size = (s_kv // 10, 0) | ||||||||||||||||||||||||||||||
if s_q > s_kv: | ||||||||||||||||||||||||||||||
pytest.skip( | ||||||||||||||||||||||||||||||
"seqlen_q > seqlen_kv is not supported with sliding window attention in cuDNN" | ||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||
huanghua1994 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||
runner = FusedAttnRunner( | ||||||||||||||||||||||||||||||
b, | ||||||||||||||||||||||||||||||
s_q, | ||||||||||||||||||||||||||||||
|
@@ -772,5 +814,6 @@ def test_backward( | |||||||||||||||||||||||||||||
True, | ||||||||||||||||||||||||||||||
qkv_layout, | ||||||||||||||||||||||||||||||
bias_shape, | ||||||||||||||||||||||||||||||
window_size, | ||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||
runner.test_backward() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.