Skip to content

Commit

Permalink
fix extra_state tests
Browse files Browse the repository at this point in the history
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
  • Loading branch information
cyanguwa committed Oct 6, 2024
1 parent e483b6e commit 4885fee
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 68 deletions.
159 changes: 98 additions & 61 deletions tests/pytorch/test_sanity.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch
import pytest
import io
import os

from transformer_engine.pytorch.fp8 import (
fp8_autocast,
Expand Down Expand Up @@ -42,11 +43,11 @@
)
from transformer_engine.pytorch.module.base import get_workspace
from test_onnx_export import create_meta
from test_numerics import reset_rng_states, dtype_tols

# Only run FP8 tests on H100.
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()


def custom_amax_to_scale(
amax: torch.Tensor,
scale: torch.Tensor,
Expand Down Expand Up @@ -1010,78 +1011,114 @@ def test_sanity_fp8_gemm_with_unalignment(N, datatype):
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_sanity_attention_extra_state(model, dtype):
config = model_configs[model]
outputs = _run_attention_extra_state(dtype, config, checkpoint=False)
outputs_checkpoint = _run_attention_extra_state(dtype, config, checkpoint=True)
outputs_checkpoint_v1_6 = _run_attention_extra_state(dtype, config, mimic_v1_6=True, checkpoint=True)

# Check that results match
tols = dtype_tols(dtype)
if dtype in (torch.float16, torch.bfloat16):
tols.update(dict(rtol=2e-2, atol=2e-3))
for i, (ref, test) in enumerate(zip(outputs, outputs_checkpoint)):
torch.testing.assert_close(
test,
ref,
**tols,
)
for i, (ref, test) in enumerate(zip(outputs, outputs_checkpoint_v1_6)):
torch.testing.assert_close(
test,
ref,
**tols,
)

def _run_attention_extra_state(dtype, config, checkpoint=False, mimic_v1_6=False):
steps = 10
path = 'checkpoint.pt'
fp8_enabled = True
fp8_recipe = recipe.DelayedScaling(
margin=0,
fp8_format=recipe.Format.HYBRID,
amax_history_len=1,
amax_compute_algo="most_recent",
fp8_dpa=True,
fp8_dpa=fp8_enabled,
fp8_mha=False,
)

reset_rng_states()
hidden_states = torch.randn(
(config.seq_len, config.batch_size, config.hidden_size),
dtype=dtype,
device="cuda",
requires_grad=True,
)

with fp8_model_init(enabled=True):
block = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
fuse_qkv_params=True,
params_dtype=dtype,
device="cuda",
)
with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
output = block(hidden_states, is_first_microbatch=True)
loss = output.sum()
loss.backward()

# call state_dict()
sd = block.state_dict()

# check core_attention._extra_state
attn_extra_state = sd["self_attention.core_attention._extra_state"]
attn_extra_state.seek(0)
attn_extra_state = torch.load(attn_extra_state, map_location="cuda")

# add random core_attention.fused_attention._extra_state
# it should not be loaded or cause any 'unexpected key' errors
random_state = {"a": 1, "b": 2}
fused_attn_extra_state = io.BytesIO()
torch.save(random_state, fused_attn_extra_state)
sd["self_attention.core_attention.fused_attention._extra_state"] = fused_attn_extra_state

# save checkpoint
path = "./checkpoint.pt"
torch.save(sd, path)

# reinit the model
del block
with fp8_model_init(enabled=True):
block_new = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
fuse_qkv_params=True,
params_dtype=dtype,
device="cuda",
)
FP8GlobalStateManager.reset()
def get_model(dtype, config):
sigma = 0.023
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

with fp8_model_init(enabled=fp8_enabled):
block = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
fuse_qkv_params=True,
params_dtype=dtype,
device="cuda",
)
return block

block = get_model(dtype, config)
for i in range(steps // 2):
with fp8_autocast(enabled=fp8_enabled, fp8_recipe=fp8_recipe):
output = block(hidden_states, None)
loss = output.sum()
loss.backward()

if checkpoint:
sd = block.state_dict()
if mimic_v1_6:
sd["self_attention.core_attention.fused_attention._extra_state"] = sd["self_attention.core_attention._extra_state"]
del sd["self_attention.core_attention._extra_state"]
torch.save(sd, path)

param_grads = []
for p in block.parameters():
if p.requires_grad:
param_grads.append(p.grad.clone())

_cpu_rng_state_new = torch.get_rng_state()
_cuda_rng_state_new = torch.cuda.get_rng_state()

del block
block = get_model(dtype, config)
block.load_state_dict(torch.load(path))
torch.set_rng_state(_cpu_rng_state_new)
torch.cuda.set_rng_state(_cuda_rng_state_new)

for p in block.parameters():
if p.requires_grad:
p.grad = param_grads.pop(0)

assert not param_grads, "Oops!"

for i in range((steps+1) // 2):
with fp8_autocast(enabled=fp8_enabled, fp8_recipe=fp8_recipe):
output = block(hidden_states, None)
loss = output.sum()
loss.backward()

torch.cuda.synchronize()

if os.path.exists(path):
os.remove(path)

outputs = [output, hidden_states.grad]
for p in block.parameters():
if p.requires_grad:
outputs.append(p.grad)

# load from checkpoint
block_new.load_state_dict(torch.load(path))

# check state_dict
sd_new = block_new.state_dict()
attn_extra_state_new = sd_new["self_attention.core_attention._extra_state"]
attn_extra_state_new.seek(0)
attn_extra_state_new = torch.load(attn_extra_state_new, map_location="cuda")
for k, v in attn_extra_state_new.items():
if k != "extra_fp8_variables":
assert torch.equal(v, attn_extra_state[k]), f"{k} is not equal"
else:
for ek, ev in attn_extra_state_new["extra_fp8_variables"].items():
assert ev == attn_extra_state["extra_fp8_variables"][ek], f"{ek} is not equal"
return outputs
36 changes: 29 additions & 7 deletions transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -6614,10 +6614,10 @@ def __init__(
def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unused-argument
"""
Temporarily remove fused_attention._extra_state as a missing key
or an unexpected key when loading TransformerEngine checkpoints.
or an unexpected key when loading Transformer Engine checkpoints.
Please store FP8 metadata as DotProductAttention's _extra_state,
rather than FusedAttention's _extra_state. This hook will be
phased out in TransformerEngine 2.0.
phased out in Transformer Engine 2.0.
"""
for key in incompatible_keys.missing_keys:
if "fused_attention._extra_state" in key:
Expand Down Expand Up @@ -6878,7 +6878,7 @@ class DotProductAttention(TransformerEngineBaseModule):
e.g. a different mask for training and inference.
1. For "`no_mask`", no attention mask is applied.
2. For "`causal`", "`causal_bottom_right`", or the causal mask in
"`padding_causal`" and "`padding_causal_bottom_right`", TransformerEngine
"`padding_causal`" and "`padding_causal_bottom_right`", Transformer Engine
calculates and applies an upper triangular mask to the softmax input.
No user input is needed. Causal masks without the "`bottom_right`" appendix align
the diagonal line to the top left corner of the softmax matrix. With
Expand Down Expand Up @@ -7085,15 +7085,37 @@ def __init__(
def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unused-argument
"""
Temporarily remove core_attention._extra_state as a missing key
when loading older TransformerEngine checkpoints. Will phase out
this hook in TransformerEngine 2.0.
when loading older Transformer Engine checkpoints. Will phase out
this hook in Transformer Engine 2.0.
"""
for key in incompatible_keys.missing_keys:
if "core_attention._extra_state" in key:
incompatible_keys.missing_keys.remove(key)

self.register_load_state_dict_post_hook(remove_extra_states_check)

def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
):
"""
This function helps to load Transformer Engine 1.6 and 1.7 checkpoints, where FP8 metadata
is stored under the `core_attention.fused_attention._extra_state` key and not the
`core_attention._extra_state` key. For more information, please see
`FP8 checkpoint compatibility <https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/faq.html#fp8-checkpoint-compatibility>`_.
"""
fused_attn_key = False
dot_product_attn_key = False
for k in state_dict.keys():
if 'core_attention.fused_attention._extra_state' in k:
fused_attn_key = True
if 'core_attention._extra_state' in k:
dot_product_attn_key = True
if fused_attn_key and not dot_product_attn_key:
prefix = prefix + 'fused_attention.'
super()._load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
)

def _checkpointed_attention_forward(
self,
attention_func: Callable,
Expand Down Expand Up @@ -7197,14 +7219,14 @@ def forward(
Users can use environment variables :attr:`NVTE_FLASH_ATTN`, :attr:`NVTE_FUSED_ATTN`,
and :attr:`NVTE_FUSED_ATTN_BACKEND` to control which DotProductAttention backend,
and FusedAttention backend if applicable, to use. TransformerEngine prioritizes
and FusedAttention backend if applicable, to use. Transformer Engine prioritizes
FlashAttention over FusedAttention and over UnfusedDotProductAttention.
If FusedAttention is being used, users can also choose to switch to flash-attn's
implementation for backward by setting :attr:`NVTE_FUSED_ATTN_USE_FAv2_BWD=1`
(default: 0), because of the performance differences between various versions of
flash-attn and FusedAttention. Further, :attr:`NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT`
can be used to enable (:attr:`1`) or disable (:attr:`0`) the workspace related
optimizations in FusedAttention. When unset, TransformerEngine determines the code path
optimizations in FusedAttention. When unset, Transformer Engine determines the code path
based on its internal logic. These optimizations trade memory for performance
and should be used with care.
Expand Down

0 comments on commit 4885fee

Please sign in to comment.