Skip to content

Commit

Permalink
unit tests - use skip if not gpu/distributed decorators (#684)
Browse files Browse the repository at this point in the history
Summary:

Adopt skip_if_not_distributed, skip_if_not_gpu test decorators across all unit tests

Differential Revision: D52893384
  • Loading branch information
galrotem authored and facebook-github-bot committed Jan 19, 2024
1 parent d236579 commit 32a52b0
Show file tree
Hide file tree
Showing 15 changed files with 105 additions and 243 deletions.
10 changes: 2 additions & 8 deletions examples/torchrec/tests/torchrec_example_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,13 @@
import unittest

import torch
from torchtnt.utils.test_utils import skip_if_asan, spawn_multi_process
from torchtnt.utils.test_utils import skip_if_asan, skip_if_not_gpu, spawn_multi_process

from ..main import main


class TorchrecExampleTest(unittest.TestCase):

cuda_available: bool = torch.cuda.is_available()

@skip_if_asan
@unittest.skipUnless(
cuda_available,
"Skip when CUDA is not available",
)
@skip_if_not_gpu
def test_torchrec_example(self) -> None:
spawn_multi_process(2, "nccl", main, [])
18 changes: 8 additions & 10 deletions tests/framework/callbacks/test_base_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,11 @@
from torchtnt.framework.unit import AppStateMixin, TrainUnit, TTrainData
from torchtnt.utils.distributed import get_global_rank
from torchtnt.utils.env import init_from_env
from torchtnt.utils.test_utils import spawn_multi_process
from torchtnt.utils.test_utils import (
skip_if_not_distributed,
skip_if_not_gpu,
spawn_multi_process,
)


class BaseCheckpointSaver(BaseCheckpointer):
Expand Down Expand Up @@ -363,9 +367,7 @@ def test_save_on_train_end(self) -> None:
],
)

@unittest.skipUnless(
condition=distributed_available, reason="Torch distributed is needed to run"
)
@skip_if_not_distributed
def test_directory_sync_collective(self) -> None:
spawn_multi_process(
2,
Expand Down Expand Up @@ -410,12 +412,8 @@ def test_invalid_args(self) -> None:
):
BaseCheckpointSaver(temp_dir, save_every_n_epochs=0)

@unittest.skipUnless(
condition=distributed_available, reason="Torch distributed is needed to run"
)
@unittest.skipUnless(
condition=cuda_available, reason="This test needs a GPU host to run."
)
@skip_if_not_distributed
@skip_if_not_gpu
def test_process_group_plumbing(self) -> None:
"""
Creates a new process group and verifies that it's passed through correctly
Expand Down
20 changes: 8 additions & 12 deletions tests/framework/callbacks/test_checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,16 @@
from torchtnt.utils.distributed import get_global_rank, PGWrapper
from torchtnt.utils.env import init_from_env
from torchtnt.utils.fsspec import get_filesystem
from torchtnt.utils.test_utils import get_pet_launch_config, spawn_multi_process
from torchtnt.utils.test_utils import (
get_pet_launch_config,
skip_if_not_distributed,
spawn_multi_process,
)

METADATA_FNAME: str = ".metadata"


class CheckpointUtilsTest(unittest.TestCase):
distributed_available: bool = torch.distributed.is_available()

@staticmethod
def _create_snapshot_metadata(output_dir: str) -> None:
path = os.path.join(output_dir, METADATA_FNAME)
Expand Down Expand Up @@ -86,9 +88,7 @@ def test_latest_checkpoint_path(self) -> None:
get_latest_checkpoint_path(temp_dir, METADATA_FNAME), path_2
)

@unittest.skipUnless(
condition=distributed_available, reason="Torch distributed is needed to run"
)
@skip_if_not_distributed
def test_latest_checkpoint_path_distributed(self) -> None:
config = get_pet_launch_config(2)
launcher.elastic_launch(
Expand Down Expand Up @@ -290,9 +290,7 @@ def test_retrieve_checkpoint_dirpaths_with_metrics(self) -> None:
{os.path.join(temp_dir, paths[1])},
)

@unittest.skipUnless(
condition=distributed_available, reason="Torch distributed is needed to run"
)
@skip_if_not_distributed
def test_distributed_get_checkpoint_dirpaths(self) -> None:
spawn_multi_process(2, "gloo", self._distributed_get_checkpoint_dirpaths)

Expand Down Expand Up @@ -425,9 +423,7 @@ def test_get_app_state(self) -> None:
["module", "optimizer", "loss_fn", "train_progress"],
)

@unittest.skipUnless(
condition=distributed_available, reason="Torch distributed is needed to run"
)
@skip_if_not_distributed
def test_rank_zero_read_and_broadcast(self) -> None:
spawn_multi_process(2, "gloo", self._test_rank_zero_read_and_broadcast)

Expand Down
21 changes: 8 additions & 13 deletions tests/framework/callbacks/test_dcp_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,14 @@
from torchtnt.framework.train import train
from torchtnt.utils.distributed import get_global_rank
from torchtnt.utils.env import seed
from torchtnt.utils.test_utils import spawn_multi_process
from torchtnt.utils.test_utils import (
skip_if_not_distributed,
skip_if_not_gpu,
spawn_multi_process,
)


class DistributedCheckpointSaverTest(unittest.TestCase):
cuda_available: bool = torch.cuda.is_available()
distributed_available: bool = torch.distributed.is_available()

def test_save_restore(self) -> None:
input_dim = 2
dataset_len = 10
Expand Down Expand Up @@ -223,12 +224,8 @@ def test_save_restore_no_lr_scheduler_restore(
app_state = mock_dist_cp.load_state_dict.call_args.args[0]["app_state"]
self.assertIn("lr_scheduler", app_state)

@unittest.skipUnless(
condition=distributed_available, reason="Torch distributed is needed to run"
)
@unittest.skipUnless(
condition=cuda_available, reason="This test needs a GPU host to run."
)
@skip_if_not_distributed
@skip_if_not_gpu
def test_save_restore_fsdp(self) -> None:
spawn_multi_process(
2,
Expand Down Expand Up @@ -276,9 +273,7 @@ def _save_restore_fsdp() -> None:
if get_global_rank() == 0:
shutil.rmtree(temp_dir) # delete temp directory

@unittest.skipUnless(
condition=distributed_available, reason="Torch distributed is needed to run"
)
@skip_if_not_distributed
def test_save_restore_ddp(self) -> None:
spawn_multi_process(
2,
Expand Down
21 changes: 8 additions & 13 deletions tests/framework/callbacks/test_torchsnapshot_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,14 @@
from torchtnt.framework.train import train
from torchtnt.utils.distributed import get_global_rank
from torchtnt.utils.env import seed
from torchtnt.utils.test_utils import spawn_multi_process
from torchtnt.utils.test_utils import (
skip_if_not_distributed,
skip_if_not_gpu,
spawn_multi_process,
)


class TorchSnapshotSaverTest(unittest.TestCase):
cuda_available: bool = torch.cuda.is_available()
distributed_available: bool = torch.distributed.is_available()

def test_save_restore(self) -> None:
input_dim = 2
dataset_len = 10
Expand Down Expand Up @@ -227,12 +228,8 @@ def test_save_restore_no_lr_scheduler_restore(
app_state = mock_torchsnapshot.Snapshot().restore.call_args.args[0]
self.assertIn("lr_scheduler", app_state)

@unittest.skipUnless(
condition=distributed_available, reason="Torch distributed is needed to run"
)
@unittest.skipUnless(
condition=cuda_available, reason="This test needs a GPU host to run."
)
@skip_if_not_distributed
@skip_if_not_gpu
def test_save_restore_fsdp(self) -> None:
spawn_multi_process(
2,
Expand Down Expand Up @@ -281,9 +278,7 @@ def _save_restore_fsdp() -> None:
if get_global_rank() == 0:
shutil.rmtree(temp_dir) # delete temp directory

@unittest.skipUnless(
condition=distributed_available, reason="Torch distributed is needed to run"
)
@skip_if_not_distributed
def test_save_restore_ddp(self) -> None:
spawn_multi_process(
2,
Expand Down
17 changes: 7 additions & 10 deletions tests/framework/test_unit_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,14 @@
)
from torchtnt.framework.state import State
from torchtnt.utils.env import init_from_env
from torchtnt.utils.test_utils import spawn_multi_process
from torchtnt.utils.test_utils import (
skip_if_not_distributed,
skip_if_not_gpu,
spawn_multi_process,
)


class UnitUtilsTest(unittest.TestCase):
cuda_available: bool = torch.cuda.is_available()
distributed_available: bool = torch.distributed.is_available()

def test_step_func_requires_iterator(self) -> None:
class Foo:
def bar(self, state: State, data: object) -> object:
Expand Down Expand Up @@ -56,12 +57,8 @@ def test_find_optimizers_for_module(self) -> None:
optim_name, _ = optimizers[0]
self.assertEqual(optim_name, "optim2")

@unittest.skipUnless(
condition=distributed_available, reason="Torch distributed is needed to run"
)
@unittest.skipUnless(
condition=cuda_available, reason="This test needs a GPU host to run."
)
@skip_if_not_distributed
@skip_if_not_gpu
def test_find_optimizers_for_FSDP_module(self) -> None:
spawn_multi_process(2, "nccl", self._find_optimizers_for_FSDP_module)

Expand Down
13 changes: 3 additions & 10 deletions tests/utils/data/test_data_prefetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,12 @@
import torch
from torch.utils.data.dataset import Dataset, TensorDataset
from torchtnt.utils.data.data_prefetcher import CudaDataPrefetcher
from torchtnt.utils.test_utils import skip_if_not_gpu

Batch = Tuple[torch.Tensor, torch.Tensor]


class DataTest(unittest.TestCase):

# pyre-fixme[4]: Attribute must be annotated.
cuda_available = torch.cuda.is_available()

def _generate_dataset(self, num_samples: int, input_dim: int) -> Dataset[Batch]:
"""Returns a dataset of random inputs and labels for binary classification."""
data = torch.randn(num_samples, input_dim)
Expand All @@ -39,9 +36,7 @@ def test_cpu_device_data_prefetcher(self) -> None:
with self.assertRaisesRegex(ValueError, "expects a CUDA device"):
_ = CudaDataPrefetcher(dataloader, device, num_prefetch_batches)

@unittest.skipUnless(
condition=cuda_available, reason="This test needs a GPU host to run."
)
@skip_if_not_gpu
def test_num_prefetch_batches_data_prefetcher(self) -> None:
device = torch.device("cuda:0")

Expand All @@ -65,9 +60,7 @@ def test_num_prefetch_batches_data_prefetcher(self) -> None:
_ = CudaDataPrefetcher(dataloader, device, num_prefetch_batches=1)
_ = CudaDataPrefetcher(dataloader, device, num_prefetch_batches=2)

@unittest.skipUnless(
condition=cuda_available, reason="This test needs a GPU host to run."
)
@skip_if_not_gpu
def test_cuda_data_prefetcher(self) -> None:
device = torch.device("cuda:0")

Expand Down
5 changes: 2 additions & 3 deletions tests/utils/data/test_profile_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from torch.profiler import ProfilerActivity
from torchtnt.utils.data.profile_dataloader import profile_dataloader
from torchtnt.utils.env import init_from_env
from torchtnt.utils.test_utils import skip_if_not_gpu


class DummyIterable:
Expand Down Expand Up @@ -46,9 +47,7 @@ def test_profile_dataloader_profiler(self) -> None:
timer = profile_dataloader(iterable, p)
self.assertEqual(len(timer.recorded_durations["next(iter)"]), max_length)

@unittest.skipUnless(
bool(torch.cuda.is_available()), reason="This test needs a GPU host to run."
)
@skip_if_not_gpu
def test_profile_dataloader_device(self) -> None:
device = init_from_env()
max_length = 10
Expand Down
6 changes: 2 additions & 4 deletions tests/utils/loggers/test_tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from torch import distributed as dist

from torchtnt.utils.loggers.tensorboard import TensorBoardLogger
from torchtnt.utils.test_utils import get_pet_launch_config
from torchtnt.utils.test_utils import get_pet_launch_config, skip_if_not_distributed


class TensorBoardLoggerTest(unittest.TestCase):
Expand Down Expand Up @@ -87,9 +87,7 @@ def _test_distributed() -> None:
assert test_path in logger.path
assert invalid_path not in logger.path

@unittest.skipUnless(
bool(dist.is_available()), reason="Torch distributed is needed to run"
)
@skip_if_not_distributed
def test_multiple_workers(self: TensorBoardLoggerTest) -> None:
config = get_pet_launch_config(2)
launcher.elastic_launch(config, entrypoint=self._test_distributed)()
Expand Down
Loading

0 comments on commit 32a52b0

Please sign in to comment.