Skip to content

Commit

Permalink
fix stale weight bug for FSDP EMA + AutoUnit (#962)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #962

Reviewed By: anshulverma

Differential Revision: D68450131

fbshipit-source-id: 8f8981f39ea654a9e83af612c7d93880066308e3
  • Loading branch information
JKSenthil authored and facebook-github-bot committed Jan 22, 2025
1 parent 9984243 commit 3232a91
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 3 deletions.
129 changes: 128 additions & 1 deletion tests/framework/test_auto_unit_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import unittest

from copy import deepcopy
from typing import TypeVar
from typing import Tuple, TypeVar
from unittest.mock import MagicMock, patch

import torch
Expand All @@ -27,7 +27,9 @@

from torchtnt.framework.auto_unit import AutoPredictUnit, SWALRParams, SWAParams
from torchtnt.framework.evaluate import evaluate
from torchtnt.framework.fit import fit
from torchtnt.framework.predict import predict
from torchtnt.framework.state import ActivePhase, State
from torchtnt.framework.train import train
from torchtnt.utils.distributed import spawn_multi_process
from torchtnt.utils.env import init_from_env, seed
Expand All @@ -38,6 +40,25 @@
T = TypeVar("T")


Batch = Tuple[torch.Tensor, torch.Tensor]


class DummySWAAutoUnit(DummyAutoUnit):
def compute_loss(self, state: State, data: Batch) -> Tuple[torch.Tensor, object]:
"""
Computes loss for given batch. If in EVAL or PREDICT phase, uses swa model's output
"""
inputs, targets = data
if state.active_phase == ActivePhase.TRAIN:
outputs = self.module(inputs)
else:
outputs = self.swa_model(inputs) if self.swa_model else self.module(inputs)

loss = torch.nn.functional.cross_entropy(outputs, targets)

return loss, outputs


class TestAutoUnitGPU(unittest.TestCase):
@skip_if_not_gpu
@skip_if_not_distributed
Expand Down Expand Up @@ -184,6 +205,112 @@ def forward(self, x):
for p1, p2 in zip(swa_params, swa_fsdp_params, strict=True):
torch.testing.assert_close(p2, p1, check_device=False)

@skip_if_not_distributed
@skip_if_not_gpu
def test_stochastic_weight_averaging_fsdp_with_eval(self) -> None:
"""
Test that swa params with FSDP is identical to non-FSDP swa
"""
spawn_multi_process(
2,
"nccl",
self._test_stochastic_weight_averaging_fsdp_with_eval,
)

@staticmethod
def _test_stochastic_weight_averaging_fsdp_with_eval() -> None:
"""
Compares the swa model parameters after training without FSDP and with FSDP.
They should be identical.
"""

class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.l1 = torch.nn.Linear(2, 2)
self.b1 = torch.nn.BatchNorm1d(2)
self.l2 = torch.nn.Linear(2, 2)

def forward(self, x):
x = self.l1(x)
x = self.b1(x)
x = self.l2(x)
return x

# so all ranks start with same initialized weights
device = init_from_env()
seed(0)
my_module = Net()

auto_unit = DummySWAAutoUnit(
module=deepcopy(my_module),
device=device,
step_lr_interval="step",
swa_params=SWAParams(
warmup_steps_or_epochs=1,
step_or_epoch_update_freq=1,
swalr_params=SWALRParams(
anneal_steps_or_epochs=3,
),
averaging_method="ema",
),
)

auto_unit_fsdp = DummySWAAutoUnit(
module=my_module,
device=device,
step_lr_interval="step",
strategy=FSDPStrategy(),
swa_params=SWAParams(
warmup_steps_or_epochs=1,
step_or_epoch_update_freq=1,
swalr_params=SWALRParams(
anneal_steps_or_epochs=3,
),
averaging_method="ema",
),
)

input_dim = 2
dataset_len = 10
batch_size = 2

dataloader = generate_random_dataloader(dataset_len, input_dim, batch_size)
eval_dataloader = generate_random_dataloader(dataset_len, input_dim, batch_size)
fit(
auto_unit,
dataloader,
eval_dataloader,
max_epochs=3,
max_train_steps_per_epoch=5,
evaluate_every_n_epochs=0,
)

fit(
auto_unit_fsdp,
dataloader,
eval_dataloader,
max_epochs=3,
max_train_steps_per_epoch=5,
# this is key arg, to ensure that swa model is updated
# even after swa model forward pass is used in eval
evaluate_every_n_epochs=1,
)

swa_params = list(auto_unit.swa_model.parameters())
swa_buffers = list(auto_unit.swa_model.buffers())
with FSDP.summon_full_params(auto_unit_fsdp.swa_model):
swa_fsdp_params = auto_unit_fsdp.swa_model.parameters()
swa_fsdp_buffers = auto_unit_fsdp.swa_model.buffers()

# Iterate and compare each parameter
for p1, p2 in zip(swa_params, swa_fsdp_params, strict=True):
torch.testing.assert_close(p2, p1, check_device=False)

# Iterate and compare each buffer
for b1, b2 in zip(swa_buffers, swa_fsdp_buffers, strict=True):
torch.testing.assert_close(b2, b1, check_device=False)

@skip_if_not_gpu
@patch("torch.autocast")
def test_eval_mixed_precision_bf16(self, mock_autocast: MagicMock) -> None:
Expand Down
6 changes: 5 additions & 1 deletion torchtnt/framework/auto_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,7 @@ class AutoUnit(
detect_anomaly: whether to enable anomaly detection for the autograd engine https://pytorch.org/docs/stable/autograd.html#anomaly-detection
clip_grad_norm: max norm of the gradients for clipping https://pytorch.org/docs/stable/generated/torch.nn.utils.clip_grad_norm_.html
clip_grad_value: max value of the gradients for clipping https://pytorch.org/docs/stable/generated/torch.nn.utils.clip_grad_value_.html
swa_params: params for stochastic weight averaging https://pytorch.org/docs/stable/optim.html#stochastic-weight-averaging
swa_params: params for stochastic weight averaging https://pytorch.org/docs/stable/optim.html#stochastic-weight-averaging (Please see note if using with FSDP)
torch_compile_params: params for Torch compile https://pytorch.org/docs/stable/generated/torch.compile.html
activation_checkpoint_params: params for enabling activation checkpointing
training: if True, the optimizer and optionally LR scheduler will be created after the class is initialized.
Expand All @@ -481,6 +481,10 @@ class AutoUnit(
Note:
Torch compile support is only available in PyTorch 2.0 or higher.
Note:
If using SWA with FSDP, the SWA model will be sharded with the same FSDP configuration as the original model. If you need the swa model's output in evaluation / prediction step,
please call `self.swa_model(inputs, ...)` to ensure all hooks (especially for FSDP) are fired correctly.
"""

def __init__(
Expand Down
17 changes: 16 additions & 1 deletion torchtnt/utils/swa.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@

# pyre-strict

from typing import Callable, List, Literal, Optional
from typing import Any, Callable, List, Literal, Optional

import torch
from torch.distributed.fsdp import FullyShardedDataParallel, ShardingStrategy

_AVERAGED_MODEL_AVAIL: bool = True

Expand Down Expand Up @@ -105,6 +106,20 @@ def __init__(
use_buffers=use_buffers,
)

# pyre-ignore: Missing return annotation [3]: Return type must be specified as type other than `Any`
def forward(self, *args: Any, **kwargs: Any) -> Any:
output = self.module(*args, **kwargs)

# for fsdp modules, we need to manually reshard the swa_model in case the
# model fwd was used in evaluation loop, due to how fsdp manages the param state
# see https://github.com/pytorch/pytorch/issues/117742
for m in FullyShardedDataParallel.fsdp_modules(self.module):
if m._has_params and m.sharding_strategy is not ShardingStrategy.NO_SHARD:
# pyre-ignore: Incompatible parameter type [6]: In call `torch.distributed.fsdp._runtime_utils._reshard`, for 2nd positional argument, expected `FlatParamHandle` but got `Optional[FlatParamHandle]`.
torch.distributed.fsdp._runtime_utils._reshard(m, m._handle, True)

return output

def update_parameters(self, model: torch.nn.Module) -> None:
self._num_updates += 1
if self._use_lit:
Expand Down

0 comments on commit 3232a91

Please sign in to comment.