Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[JAX] Expose sliding window attn to TE-JAX API #1205

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
47 changes: 45 additions & 2 deletions tests/jax/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from dataclasses import dataclass
from functools import partial
from math import sqrt
from typing import Tuple

import jax
import jax.numpy as jnp
Expand All @@ -27,6 +28,8 @@
fused_attn,
fused_attn_thd,
get_qkv_format,
check_set_window_size,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
check_set_window_size,

get_swa_mask,
)
from transformer_engine.jax.cpp_extensions import FusedAttnHelper
from transformer_engine.transformer_engine_jax import (
Expand Down Expand Up @@ -123,6 +126,7 @@ def make_mask(
segment_pad_q: ArrayLike,
segment_pad_kv: ArrayLike,
attn_mask_type: AttnMaskType,
window_size: Tuple[int, int],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
window_size: Tuple[int, int],
window_size: Optional[Tuple[int, int]] = None,

) -> Array:
"""
Create attention mask based on mask type. A `True` value in the mask means
Expand All @@ -140,6 +144,15 @@ def make_mask(
segment_pad_q, segment_pad_kv, lambda x, y: jnp.logical_and(x != 1, y != 1)
)
inv_mask = combine_masks(inv_pad_mask, inv_mask)

if window_size[0] >= 0:
max_seqlen_q = inv_mask.shape[-2]
max_seqlen_kv = inv_mask.shape[-1]
swa_mask = get_swa_mask(window_size, max_seqlen_q, max_seqlen_kv, attn_mask_type)
swa_mask_bcast = jnp.broadcast_to(swa_mask, inv_mask.shape)
# In swa_mask and inv_mask 0 is masked out
inv_mask = jnp.where(inv_mask != 0, swa_mask_bcast, inv_mask)
Comment on lines +148 to +154
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if window_size[0] >= 0:
max_seqlen_q = inv_mask.shape[-2]
max_seqlen_kv = inv_mask.shape[-1]
swa_mask = get_swa_mask(window_size, max_seqlen_q, max_seqlen_kv, attn_mask_type)
swa_mask_bcast = jnp.broadcast_to(swa_mask, inv_mask.shape)
# In swa_mask and inv_mask 0 is masked out
inv_mask = jnp.where(inv_mask != 0, swa_mask_bcast, inv_mask)
if window_size is not None:
max_seqlen_q = inv_mask.shape[-2]
max_seqlen_kv = inv_mask.shape[-1]
inv_swa_mask = get_swa_mask(window_size, max_seqlen_q, max_seqlen_kv, attn_mask_type)
inv_swa_mask = jnp.broadcast_to(inv_swa_mask, inv_mask.shape)
# In inv_swa_mask and inv_mask 0 is masked out
inv_mask = jnp.where(inv_mask != 0, inv_swa_mask, inv_mask)


mask = jnp.logical_not(inv_mask)
return mask

Expand Down Expand Up @@ -274,6 +287,7 @@ class FusedAttnRunner:
is_training: bool
qkv_layout: QKVLayout
bias_shape: BiasShape
window_size: Tuple[int, int] = (-1, -1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
window_size: Tuple[int, int] = (-1, -1)
window_size: Optional[Tuple[int, int]] = None


# See https://docs.nvidia.com/deeplearning/cudnn/latest/release-notes.html#cudnn-9-4-0 for known issue
# generating zero-length ragged tensors. This setting adjusts the test to avoid the zero-length cases.
Expand Down Expand Up @@ -310,6 +324,7 @@ def _check_configs(self):
self.max_seqlen_q,
self.max_seqlen_kv,
self.head_dim,
self.window_size,
).get_fused_attn_backend()
if self.backend == NVTE_Fused_Attn_Backend.NVTE_No_Backend:
pytest.skip("Unsupported inputs combination or device compute capability.")
Expand Down Expand Up @@ -456,6 +471,7 @@ def generate_random_segment_ids(
self.segment_pad_q,
self.segment_pad_kv,
self.attn_mask_type,
self.window_size,
)

if get_qkv_format(self.qkv_layout) == QKVFormat.THD:
Expand Down Expand Up @@ -500,6 +516,7 @@ def test_forward(self):
"is_training": self.is_training,
"qkv_layout": self.qkv_layout,
"max_segments_per_seq": self._get_max_segments_per_sequence(),
"window_size": self.window_size,
}

# Convert the outputs to float32 for the elementwise comparison
Expand Down Expand Up @@ -557,6 +574,7 @@ def grad_func(func, *args, **kwargs):
"is_training": self.is_training,
"qkv_layout": self.qkv_layout,
"max_segments_per_seq": self._get_max_segments_per_sequence(),
"window_size": self.window_size,
}

# We can compute dBias only for the [1, h, s, s] layout
Expand Down Expand Up @@ -668,7 +686,7 @@ def check_dqkv(primitive, reference, pad):
pytest.param(4, 128, 128, 16, 16, 64, jnp.bfloat16, id="4-128-128-16-16-64-BF16-SELF"),
pytest.param(4, 128, 128, 16, 16, 64, jnp.float16, id="4-128-128-16-16-64-FP16-SELF"),
pytest.param(2, 2048, 2048, 12, 12, 64, jnp.bfloat16, id="2-2048-2048-12-12-64-BF16-SELF"),
pytest.param(4, 512, 128, 16, 16, 64, jnp.bfloat16, id="4-512-128-16-16-64-BF16-CROSS"),
pytest.param(4, 128, 256, 16, 16, 64, jnp.bfloat16, id="4-128-256-16-16-64-BF16-CROSS"),
mgoldfarb-nvidia marked this conversation as resolved.
Show resolved Hide resolved
pytest.param(
2,
2048,
Expand All @@ -677,7 +695,7 @@ def check_dqkv(primitive, reference, pad):
12,
64,
jnp.bfloat16,
id="2-2048-1048-12-12-64-BF16-CROSS",
id="2-2048-1024-12-12-64-BF16-CROSS",
),
pytest.param(4, 128, 128, 16, 8, 64, jnp.bfloat16, id="4-128-128-16-8-64-BF16-GQA"),
pytest.param(2, 2048, 2048, 12, 6, 64, jnp.bfloat16, id="2-2048-2048-12-6-64-BF16-GQA"),
Expand All @@ -690,6 +708,13 @@ def check_dqkv(primitive, reference, pad):
pytest.param(0.1, id="DROP_0.1"),
],
)
@pytest.mark.parametrize(
"swa",
[
pytest.param(False, id="NO_SWA"),
pytest.param(True, id="SWA"),
],
)
class TestFusedAttn:
"""
Fused attention tester
Expand Down Expand Up @@ -717,12 +742,20 @@ def _test_forward(
is_training,
qkv_layout,
bias_shape,
swa,
):
"""
Test forward with parameterized configs
This test is not intended to run automatically during CI as it is time-consuming
It is kept for development and debugging
"""
window_size = (-1, -1)
if swa:
window_size = (s_kv // 10, 0)
if s_q > s_kv:
pytest.skip(
"seqlen_q > seqlen_kv is not supported with sliding window attention in cuDNN"
)
zlsh80826 marked this conversation as resolved.
Show resolved Hide resolved
runner = FusedAttnRunner(
b,
s_q,
Expand All @@ -737,6 +770,7 @@ def _test_forward(
is_training,
qkv_layout,
bias_shape,
window_size,
)
runner.test_forward()

Expand All @@ -754,10 +788,18 @@ def test_backward(
dtype,
qkv_layout,
bias_shape,
swa,
):
"""
Test backward with parameterized configs
"""
window_size = (-1, -1)
if swa:
window_size = (s_kv // 10, 0)
if s_q > s_kv:
pytest.skip(
"seqlen_q > seqlen_kv is not supported with sliding window attention in cuDNN"
)
huanghua1994 marked this conversation as resolved.
Show resolved Hide resolved
runner = FusedAttnRunner(
b,
s_q,
Expand All @@ -772,5 +814,6 @@ def test_backward(
True,
qkv_layout,
bias_shape,
window_size,
)
runner.test_backward()
8 changes: 7 additions & 1 deletion tests/jax/test_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"""Test transformer_engine.jax.flax.TransformerLayer"""
import os
from functools import partial
from typing import Dict
from typing import Dict, Tuple

import flax
import jax
Expand Down Expand Up @@ -61,6 +61,7 @@ def enable_fused_attn():
_KEY_OF_FLOAT32_ATTENTION_LOGITS = "float32_attention_logits"
_KEY_OF_USE_BIAS = "use_bias"
_KEY_OF_RELATIVE_EMBEDDING = "enable_relative_embedding"
_KEY_OF_WINDOW_SIZE = "window_size"

BASE_ATTRS = {
_KEY_OF_TRANSPOSE_BS: True,
Expand All @@ -70,6 +71,7 @@ def enable_fused_attn():
_KEY_OF_INTERMEDIATE_DROPOUT: 0,
_KEY_OF_SELF_ATTN_MASK_TYPE: "padding_causal",
_KEY_OF_LAYERNORM_TYPE: "layernorm",
_KEY_OF_WINDOW_SIZE: (-1, -1),
}

ATTRS = [
Expand Down Expand Up @@ -193,6 +195,10 @@ def enable_fused_attn():
{
_KEY_OF_MLP_ACTIVATIONS: (("relu", "relu")),
},
{
_KEY_OF_SELF_ATTN_MASK_TYPE: "causal",
_KEY_OF_WINDOW_SIZE: (64, 0), # Left size must < DATA_SHAPE seqlen
},
]

ATTRS = [{**BASE_ATTRS, **attr} for attr in ATTRS]
Expand Down
Loading
Loading