Skip to content

Commit

Permalink
Move auto_unit GPU tests to dedicated file (#755)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #755

Reviewed By: galrotem

Differential Revision: D55224519

fbshipit-source-id: ec179af6db2303ae4e32fb3d1568f7754aa54d90
  • Loading branch information
diego-urgell authored and facebook-github-bot committed Mar 22, 2024
1 parent bc2bf15 commit e806af0
Show file tree
Hide file tree
Showing 2 changed files with 350 additions and 307 deletions.
310 changes: 3 additions & 307 deletions tests/framework/test_auto_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,8 @@
from unittest.mock import MagicMock, patch

import torch
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
from torchtnt.framework.auto_unit import TrainStepResults
from torchtnt.utils.test_utils import skip_if_not_distributed, skip_if_not_gpu
from torchtnt.utils.test_utils import skip_if_not_distributed

from torchtnt.utils.version import is_torch_version_geq_1_13

Expand All @@ -23,12 +22,9 @@
COMPILE_AVAIL = True
import torch._dynamo

from copy import deepcopy

from pyre_extensions import none_throws, ParameterSpecification as ParamSpec

from torch.distributed import GradBucket
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torchtnt.framework._test_utils import (
DummyAutoUnit,
generate_random_dataloader,
Expand All @@ -49,9 +45,9 @@
from torchtnt.framework.unit import TPredictData
from torchtnt.utils.device import copy_data_to_device
from torchtnt.utils.distributed import spawn_multi_process
from torchtnt.utils.env import init_from_env, seed
from torchtnt.utils.env import init_from_env
from torchtnt.utils.lr_scheduler import TLRScheduler
from torchtnt.utils.prepare_module import DDPStrategy, FSDPStrategy, TorchCompileParams
from torchtnt.utils.prepare_module import DDPStrategy
from torchtnt.utils.progress import Progress
from torchtnt.utils.timer import Timer

Expand Down Expand Up @@ -81,31 +77,6 @@ def test_app_state_mixin(self) -> None:
for key in ("module", "optimizer", "lr_scheduler", "grad_scaler"):
self.assertIn(key, auto_unit.app_state())

@skip_if_not_gpu
@skip_if_not_distributed
def test_fsdp_fp16(self) -> None:
"""
Test that FSDP + FP16 uses ShardedGradScaler
"""
spawn_multi_process(
2,
"nccl",
self._test_fsdp_fp16,
)

@staticmethod
def _test_fsdp_fp16() -> None:
device = init_from_env()
my_module = torch.nn.Linear(2, 2)
auto_unit_fsdp = DummyAutoUnit(
module=my_module,
device=device,
strategy=FSDPStrategy(),
precision="fp16",
)
tc = unittest.TestCase()
tc.assertTrue(isinstance(auto_unit_fsdp.grad_scaler, ShardedGradScaler))

def test_lr_scheduler_step(self) -> None:
"""
Test that the lr scheduler is stepped every optimizer step when step_lr_interval="step"
Expand Down Expand Up @@ -150,49 +121,6 @@ def test_lr_scheduler_epoch(self) -> None:
train(auto_unit, train_dataloader=train_dl, max_epochs=max_epochs)
self.assertEqual(auto_unit.lr_scheduler.step.call_count, max_epochs)

@skip_if_not_gpu
@patch("torch.autocast")
def test_mixed_precision_fp16(self, mock_autocast: MagicMock) -> None:
"""
Test that the mixed precision autocast context is called when fp16 precision is set
"""
my_module = torch.nn.Linear(2, 2)
auto_unit = DummyAutoUnit(
module=my_module,
precision="fp16",
)
dummy_iterable = [(torch.ones(2, 2), torch.ones(2, 2))]
state = get_dummy_train_state(dummy_iterable)
auto_unit.train_step(
state=state,
data=auto_unit.get_next_train_batch(state, iter(dummy_iterable)),
)
mock_autocast.assert_called_with(
device_type="cuda", dtype=torch.float16, enabled=True
)

@skip_if_not_gpu
@patch("torch.autocast")
def test_mixed_precision_bf16(self, mock_autocast: MagicMock) -> None:
"""
Test that the mixed precision autocast context is called when bf16 precision is set
"""
my_module = torch.nn.Linear(2, 2)

auto_unit = DummyAutoUnit(
module=my_module,
precision="bf16",
)
dummy_iterable = [(torch.ones(2, 2), torch.ones(2, 2))]
state = get_dummy_train_state(dummy_iterable)
auto_unit.train_step(
state=state,
data=auto_unit.get_next_train_batch(state, iter(dummy_iterable)),
)
mock_autocast.assert_called_with(
device_type="cuda", dtype=torch.bfloat16, enabled=True
)

def test_mixed_precision_invalid_str(self) -> None:
"""
Test that an exception is raised with an invalid precision string
Expand Down Expand Up @@ -310,191 +238,6 @@ def test_stochastic_weight_averaging_update_freq(self) -> None:
# 1 warmup + epoch 2 + epoch 3 = 2
self.assertEqual(update_swa_mock.call_count, 2)

@skip_if_not_distributed
@skip_if_not_gpu
def test_stochastic_weight_averaging_fsdp(self) -> None:
"""
Test that swa params with FSDP is identical to non-FSDP swa
"""
spawn_multi_process(
2,
"nccl",
self._test_stochastic_weight_averaging_fsdp,
)

@staticmethod
def _test_stochastic_weight_averaging_fsdp() -> None:
class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.l1 = torch.nn.Linear(2, 2)
self.b1 = torch.nn.BatchNorm1d(2)
self.l2 = torch.nn.Linear(2, 2)

def forward(self, x):
x = self.l1(x)
x = self.b1(x)
x = self.l2(x)
return x

# so all ranks start with same initialized weights
seed(0)
device = init_from_env()
my_module = Net()

auto_unit = DummyAutoUnit(
module=deepcopy(my_module),
device=device,
step_lr_interval="step",
swa_params=SWAParams(
warmup_steps_or_epochs=1,
step_or_epoch_update_freq=1,
swalr_params=SWALRParams(
anneal_steps_or_epochs=3,
),
averaging_method="ema",
),
)

auto_unit_fsdp = DummyAutoUnit(
module=my_module,
device=device,
step_lr_interval="step",
strategy=FSDPStrategy(),
swa_params=SWAParams(
warmup_steps_or_epochs=1,
step_or_epoch_update_freq=1,
swalr_params=SWALRParams(
anneal_steps_or_epochs=3,
),
averaging_method="ema",
),
)

input_dim = 2
dataset_len = 10
batch_size = 2

dataloader = generate_random_dataloader(dataset_len, input_dim, batch_size)
train(auto_unit, dataloader, max_epochs=1, max_steps_per_epoch=5)
train(auto_unit_fsdp, dataloader, max_epochs=1, max_steps_per_epoch=5)

swa_params = list(auto_unit.swa_model.module.parameters())
with FSDP.summon_full_params(auto_unit_fsdp.swa_model):
swa_fsdp_params = list(auto_unit_fsdp.swa_model.module.parameters())

# Iterate and compare each parameter
for p1, p2 in zip(swa_params, swa_fsdp_params, strict=True):
torch.testing.assert_close(p2, p1, check_device=False)

@skip_if_not_gpu
@patch("torch.autocast")
def test_eval_mixed_precision_bf16(self, mock_autocast: MagicMock) -> None:
"""
Test that the mixed precision autocast context is called during evaluate when precision = bf16
"""
my_module = torch.nn.Linear(2, 2)
auto_unit = DummyAutoUnit(
module=my_module,
precision="bf16",
)

input_dim = 2
dataset_len = 8
batch_size = 2

eval_dl = generate_random_dataloader(dataset_len, input_dim, batch_size)
evaluate(auto_unit, eval_dl)
mock_autocast.assert_called_with(
device_type="cuda", dtype=torch.bfloat16, enabled=True
)

@skip_if_not_gpu
@skip_if_not_distributed
def test_no_sync(self) -> None:
"""
Test that the no_sync autocast context is correctly applied when using gradient accumulation
"""
spawn_multi_process(
2,
"nccl",
self._test_ddp_no_sync,
)
spawn_multi_process(
2,
"nccl",
self._test_fsdp_no_sync,
)

@staticmethod
def _test_ddp_no_sync() -> None:
"""
Test that the no_sync autocast context is correctly applied when using gradient accumulation and DDP
"""

my_module = torch.nn.Linear(2, 2)

auto_unit = DummyAutoUnit(
module=my_module,
strategy=DDPStrategy(),
gradient_accumulation_steps=2,
)

dummy_iterator = iter(
[(torch.ones(2, 2), torch.ones(2, 2)), (torch.ones(2, 2), torch.ones(2, 2))]
)
state = get_dummy_train_state()

# for the first step no_sync should be called since we accumulate gradients
with patch.object(auto_unit.module, "no_sync") as no_sync_mock:
auto_unit.train_step(
state=state, data=auto_unit.get_next_train_batch(state, dummy_iterator)
)
no_sync_mock.assert_called_once()

auto_unit.train_progress.increment_step()
# for the second step no_sync should not be called since we run optimizer step
with patch.object(auto_unit.module, "no_sync") as no_sync_mock:
auto_unit.train_step(
state=state, data=auto_unit.get_next_train_batch(state, dummy_iterator)
)
no_sync_mock.assert_not_called()

@staticmethod
def _test_fsdp_no_sync() -> None:
"""
Test that the no_sync autocast context is correctly applied when using gradient accumulation and FSDP
"""
device = init_from_env()
my_module = torch.nn.Linear(2, 2).to(device)

auto_unit = DummyAutoUnit(
module=my_module,
device=device,
strategy=FSDPStrategy(),
gradient_accumulation_steps=2,
)

dummy_iterator = iter(
[(torch.ones(2, 2), torch.ones(2, 2)), (torch.ones(2, 2), torch.ones(2, 2))]
)
state = get_dummy_train_state()

# for the first step no_sync should be called since we accumulate gradients
with patch.object(auto_unit.module, "no_sync") as no_sync_mock:
auto_unit.train_step(
state=state, data=auto_unit.get_next_train_batch(state, dummy_iterator)
)
no_sync_mock.assert_called_once()

auto_unit.train_progress.increment_step()
# for the second step no_sync should not be called since we run optimizer step
with patch.object(auto_unit.module, "no_sync") as no_sync_mock:
auto_unit.train_step(
state=state, data=auto_unit.get_next_train_batch(state, dummy_iterator)
)
no_sync_mock.assert_not_called()

def test_move_data_to_device(self) -> None:
"""
Test that move_data_to_device is called
Expand Down Expand Up @@ -746,53 +489,6 @@ def test_auto_unit_timing_predict(self) -> None:
timer=Timer(),
)

@skip_if_not_gpu
@patch("torch.autocast")
def test_predict_mixed_precision_fp16(self, mock_autocast: MagicMock) -> None:
"""
Test that the mixed precision autocast context is called during predict when precision = fp16
"""
my_module = torch.nn.Linear(2, 2)
auto_unit = AutoPredictUnit(module=my_module, precision="fp16")

input_dim = 2
dataset_len = 8
batch_size = 2

predict_dl = generate_random_iterable_dataloader(
dataset_len, input_dim, batch_size
)
predict(auto_unit, predict_dl)
mock_autocast.assert_called_with(
device_type="cuda", dtype=torch.float16, enabled=True
)

@unittest.skipUnless(
condition=COMPILE_AVAIL,
reason="This test needs PyTorch 1.13 or greater to run.",
)
@skip_if_not_gpu
@patch("torch.compile")
def test_compile_predict(self, mock_dynamo: MagicMock) -> None:
"""
e2e torch compile on predict
"""
my_module = torch.nn.Linear(2, 2)
auto_unit = AutoPredictUnit(
module=my_module,
torch_compile_params=TorchCompileParams(backend="eager"),
)

input_dim = 2
dataset_len = 8
batch_size = 2

predict_dl = generate_random_iterable_dataloader(
dataset_len, input_dim, batch_size
)
predict(auto_unit, predict_dl)
mock_dynamo.assert_called()

def test_auto_predict_unit_timing_predict(self) -> None:
"""
Test auto timing in AutoUnit for predict
Expand Down
Loading

0 comments on commit e806af0

Please sign in to comment.