diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index 9b8bfbfb38..f854fc13d6 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -48,6 +48,7 @@ # Only run FP8 tests on H100. fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() + def custom_amax_to_scale( amax: torch.Tensor, scale: torch.Tensor, @@ -1013,7 +1014,9 @@ def test_sanity_attention_extra_state(model, dtype): config = model_configs[model] outputs = _run_attention_extra_state(dtype, config, checkpoint=False) outputs_checkpoint = _run_attention_extra_state(dtype, config, checkpoint=True) - outputs_checkpoint_v1_6 = _run_attention_extra_state(dtype, config, mimic_v1_6=True, checkpoint=True) + outputs_checkpoint_v1_6 = _run_attention_extra_state( + dtype, config, mimic_v1_6=True, checkpoint=True + ) # Check that results match tols = dtype_tols(dtype) @@ -1032,9 +1035,10 @@ def test_sanity_attention_extra_state(model, dtype): **tols, ) + def _run_attention_extra_state(dtype, config, checkpoint=False, mimic_v1_6=False): steps = 10 - path = 'checkpoint.pt' + path = "checkpoint.pt" fp8_enabled = True fp8_recipe = recipe.DelayedScaling( margin=0, @@ -1081,7 +1085,9 @@ def get_model(dtype, config): if checkpoint: sd = block.state_dict() if mimic_v1_6: - sd["self_attention.core_attention.fused_attention._extra_state"] = sd["self_attention.core_attention._extra_state"] + sd["self_attention.core_attention.fused_attention._extra_state"] = sd[ + "self_attention.core_attention._extra_state" + ] del sd["self_attention.core_attention._extra_state"] torch.save(sd, path) @@ -1105,7 +1111,7 @@ def get_model(dtype, config): assert not param_grads, "Oops!" - for i in range((steps+1) // 2): + for i in range((steps + 1) // 2): with fp8_autocast(enabled=fp8_enabled, fp8_recipe=fp8_recipe): output = block(hidden_states, None) loss = output.sum() diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 3fe48ed181..1a7807fede 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -7106,12 +7106,12 @@ def _load_from_state_dict( fused_attn_key = False dot_product_attn_key = False for k in state_dict.keys(): - if 'core_attention.fused_attention._extra_state' in k: + if "core_attention.fused_attention._extra_state" in k: fused_attn_key = True - if 'core_attention._extra_state' in k: + if "core_attention._extra_state" in k: dot_product_attn_key = True if fused_attn_key and not dot_product_attn_key: - prefix = prefix + 'fused_attention.' + prefix = prefix + "fused_attention." super()._load_from_state_dict( state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs )