From a4030e85a62f99c69d9e375c9245eacf79289929 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Mon, 7 Oct 2024 16:39:58 -0700 Subject: [PATCH] minor fixes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- transformer_engine/pytorch/attention.py | 31 ++++++++++++------------- 1 file changed, 15 insertions(+), 16 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 9b9a6a860b..763a8b6c5d 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -348,7 +348,7 @@ def get_attention_backend( use_fused_attention = False # Filter: Compute capability - global _flash_attn_3_plus, _use_flash_attn_3 + global _use_flash_attn_3 if device_compute_capability < (8, 0): if use_flash_attention: logger.debug("Disabling FlashAttention as it requires compute capability sm80+") @@ -357,7 +357,7 @@ def get_attention_backend( logger.debug("Disabling FusedAttention as it requires compute capability sm80+") use_fused_attention = False if device_compute_capability < (9, 0): - if use_flash_attention and _flash_attn_3_plus: + if use_flash_attention and _use_flash_attn_3: logger.debug("Disabling FlashAttention 3 as it requires compute capability sm90+") _use_flash_attn_3 = False @@ -436,15 +436,11 @@ def get_attention_backend( "padding between sequences, i.e. [a, a, PAD, b, b, b, PAD, c, PAD]" ) use_flash_attention = False - if use_flash_attention and _use_flash_attn_3 and fp8 and fp8_meta["recipe"].fp8_dpa: - logger.debug("Disabling FlashAttention 3 for FP8 and qkv_format = thd") - _use_flash_attn_3 = False # Filter: Dropout - if attention_dropout != 0.0 and use_flash_attention: - if _flash_attn_3_plus and _use_flash_attn_3: - logger.debug("Disabling FlashAttention 3 for dropout") - _use_flash_attn_3 = False + if attention_dropout != 0.0 and use_flash_attention and _use_flash_attn_3: + logger.debug("Disabling FlashAttention 3 for dropout") + _use_flash_attn_3 = False # Filter: Context parallelism # qkv_format | attn_mask_type | attn_bias_type | supported backends @@ -464,7 +460,7 @@ def get_attention_backend( ) use_unfused_attention = False if context_parallel and use_flash_attention: - if _flash_attn_3_plus and _use_flash_attn_3: + if _use_flash_attn_3: logger.debug("Disabling FlashAttention 3 for context parallelism") _use_flash_attn_3 = False if fp8 and fp8_meta["recipe"].fp8_dpa: @@ -559,7 +555,7 @@ def get_attention_backend( use_fused_attention = False if ( use_flash_attention - and _flash_attn_3_plus + and _use_flash_attn_3 and attn_mask_type in ["causal", "padding_causal"] and max_seqlen_q != max_seqlen_kv ): @@ -593,6 +589,9 @@ def get_attention_backend( "https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag" ) use_flash_attention = False + if use_flash_attention and _use_flash_attn_3 and fp8 and fp8_meta["recipe"].fp8_dpa and "padding" in attn_mask_type: + logger.debug("Disabling FlashAttention 3 for FP8 and padding masks") + _use_flash_attn_3 = False # Filter: Sliding window attention # backend | window_size | diagonal alignment @@ -656,12 +655,12 @@ def get_attention_backend( # UnfusedDotProductAttention | no_bias, pre/post_scale_bias | # | alibi/alibi_slopes | both; converts to a 'post_scale_bias' bias if use_flash_attention and core_attention_bias_type == "alibi": - if _flash_attn_3_plus and _use_flash_attn_3: + if _use_flash_attn_3: logger.debug("Disabling FlashAttention 3 for ALiBi") _use_flash_attn_3 = False - if not _flash_attn_2_4_plus: - logger.debug("Disabling FlashAttention for ALiBi") - use_flash_attention = False + elif not _flash_attn_2_4_plus: + logger.debug("Disabling FlashAttention as ALiBi requires flash-attn 2.4+") + use_flash_attention = False if use_flash_attention and ( core_attention_bias_type not in ["no_bias", "alibi"] @@ -5011,7 +5010,7 @@ def forward( if _flash_attn_2_4_1_plus: fa_optional_forward_kwargs["deterministic"] = self.deterministic fa_optional_forward_args_thd = [] - if (qkv_format in ["bshd", "sbhd"] and "padding" not in attn_mask_type) or fp8: + if qkv_format in ["bshd", "sbhd"] and "padding" not in attn_mask_type: func = flash_attn_func if not _use_flash_attn_3 else flash_attn_func_v3 else: if _flash_attn_2_5_7_plus: