Skip to content

Commit

Permalink
minor fixes
Browse files Browse the repository at this point in the history
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
  • Loading branch information
cyanguwa committed Oct 7, 2024
1 parent b765f3d commit a4030e8
Showing 1 changed file with 15 additions and 16 deletions.
31 changes: 15 additions & 16 deletions transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ def get_attention_backend(
use_fused_attention = False

# Filter: Compute capability
global _flash_attn_3_plus, _use_flash_attn_3
global _use_flash_attn_3
if device_compute_capability < (8, 0):
if use_flash_attention:
logger.debug("Disabling FlashAttention as it requires compute capability sm80+")
Expand All @@ -357,7 +357,7 @@ def get_attention_backend(
logger.debug("Disabling FusedAttention as it requires compute capability sm80+")
use_fused_attention = False
if device_compute_capability < (9, 0):
if use_flash_attention and _flash_attn_3_plus:
if use_flash_attention and _use_flash_attn_3:
logger.debug("Disabling FlashAttention 3 as it requires compute capability sm90+")
_use_flash_attn_3 = False

Expand Down Expand Up @@ -436,15 +436,11 @@ def get_attention_backend(
"padding between sequences, i.e. [a, a, PAD, b, b, b, PAD, c, PAD]"
)
use_flash_attention = False
if use_flash_attention and _use_flash_attn_3 and fp8 and fp8_meta["recipe"].fp8_dpa:
logger.debug("Disabling FlashAttention 3 for FP8 and qkv_format = thd")
_use_flash_attn_3 = False

# Filter: Dropout
if attention_dropout != 0.0 and use_flash_attention:
if _flash_attn_3_plus and _use_flash_attn_3:
logger.debug("Disabling FlashAttention 3 for dropout")
_use_flash_attn_3 = False
if attention_dropout != 0.0 and use_flash_attention and _use_flash_attn_3:
logger.debug("Disabling FlashAttention 3 for dropout")
_use_flash_attn_3 = False

# Filter: Context parallelism
# qkv_format | attn_mask_type | attn_bias_type | supported backends
Expand All @@ -464,7 +460,7 @@ def get_attention_backend(
)
use_unfused_attention = False
if context_parallel and use_flash_attention:
if _flash_attn_3_plus and _use_flash_attn_3:
if _use_flash_attn_3:
logger.debug("Disabling FlashAttention 3 for context parallelism")
_use_flash_attn_3 = False
if fp8 and fp8_meta["recipe"].fp8_dpa:
Expand Down Expand Up @@ -559,7 +555,7 @@ def get_attention_backend(
use_fused_attention = False
if (
use_flash_attention
and _flash_attn_3_plus
and _use_flash_attn_3
and attn_mask_type in ["causal", "padding_causal"]
and max_seqlen_q != max_seqlen_kv
):
Expand Down Expand Up @@ -593,6 +589,9 @@ def get_attention_backend(
"https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag"
)
use_flash_attention = False
if use_flash_attention and _use_flash_attn_3 and fp8 and fp8_meta["recipe"].fp8_dpa and "padding" in attn_mask_type:
logger.debug("Disabling FlashAttention 3 for FP8 and padding masks")
_use_flash_attn_3 = False

# Filter: Sliding window attention
# backend | window_size | diagonal alignment
Expand Down Expand Up @@ -656,12 +655,12 @@ def get_attention_backend(
# UnfusedDotProductAttention | no_bias, pre/post_scale_bias |
# | alibi/alibi_slopes | both; converts to a 'post_scale_bias' bias
if use_flash_attention and core_attention_bias_type == "alibi":
if _flash_attn_3_plus and _use_flash_attn_3:
if _use_flash_attn_3:
logger.debug("Disabling FlashAttention 3 for ALiBi")
_use_flash_attn_3 = False
if not _flash_attn_2_4_plus:
logger.debug("Disabling FlashAttention for ALiBi")
use_flash_attention = False
elif not _flash_attn_2_4_plus:
logger.debug("Disabling FlashAttention as ALiBi requires flash-attn 2.4+")
use_flash_attention = False

if use_flash_attention and (
core_attention_bias_type not in ["no_bias", "alibi"]
Expand Down Expand Up @@ -5011,7 +5010,7 @@ def forward(
if _flash_attn_2_4_1_plus:
fa_optional_forward_kwargs["deterministic"] = self.deterministic
fa_optional_forward_args_thd = []
if (qkv_format in ["bshd", "sbhd"] and "padding" not in attn_mask_type) or fp8:
if qkv_format in ["bshd", "sbhd"] and "padding" not in attn_mask_type:
func = flash_attn_func if not _use_flash_attn_3 else flash_attn_func_v3
else:
if _flash_attn_2_5_7_plus:
Expand Down

0 comments on commit a4030e8

Please sign in to comment.