Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Oct 6, 2024
1 parent 4885fee commit cf38d34
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 7 deletions.
14 changes: 10 additions & 4 deletions tests/pytorch/test_sanity.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
# Only run FP8 tests on H100.
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()


def custom_amax_to_scale(
amax: torch.Tensor,
scale: torch.Tensor,
Expand Down Expand Up @@ -1013,7 +1014,9 @@ def test_sanity_attention_extra_state(model, dtype):
config = model_configs[model]
outputs = _run_attention_extra_state(dtype, config, checkpoint=False)
outputs_checkpoint = _run_attention_extra_state(dtype, config, checkpoint=True)
outputs_checkpoint_v1_6 = _run_attention_extra_state(dtype, config, mimic_v1_6=True, checkpoint=True)
outputs_checkpoint_v1_6 = _run_attention_extra_state(
dtype, config, mimic_v1_6=True, checkpoint=True
)

# Check that results match
tols = dtype_tols(dtype)
Expand All @@ -1032,9 +1035,10 @@ def test_sanity_attention_extra_state(model, dtype):
**tols,
)


def _run_attention_extra_state(dtype, config, checkpoint=False, mimic_v1_6=False):
steps = 10
path = 'checkpoint.pt'
path = "checkpoint.pt"
fp8_enabled = True
fp8_recipe = recipe.DelayedScaling(
margin=0,
Expand Down Expand Up @@ -1081,7 +1085,9 @@ def get_model(dtype, config):
if checkpoint:
sd = block.state_dict()
if mimic_v1_6:
sd["self_attention.core_attention.fused_attention._extra_state"] = sd["self_attention.core_attention._extra_state"]
sd["self_attention.core_attention.fused_attention._extra_state"] = sd[
"self_attention.core_attention._extra_state"
]
del sd["self_attention.core_attention._extra_state"]
torch.save(sd, path)

Expand All @@ -1105,7 +1111,7 @@ def get_model(dtype, config):

assert not param_grads, "Oops!"

for i in range((steps+1) // 2):
for i in range((steps + 1) // 2):
with fp8_autocast(enabled=fp8_enabled, fp8_recipe=fp8_recipe):
output = block(hidden_states, None)
loss = output.sum()
Expand Down
6 changes: 3 additions & 3 deletions transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -7106,12 +7106,12 @@ def _load_from_state_dict(
fused_attn_key = False
dot_product_attn_key = False
for k in state_dict.keys():
if 'core_attention.fused_attention._extra_state' in k:
if "core_attention.fused_attention._extra_state" in k:
fused_attn_key = True
if 'core_attention._extra_state' in k:
if "core_attention._extra_state" in k:
dot_product_attn_key = True
if fused_attn_key and not dot_product_attn_key:
prefix = prefix + 'fused_attention.'
prefix = prefix + "fused_attention."
super()._load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
)
Expand Down

0 comments on commit cf38d34

Please sign in to comment.