Skip to content

Commit

Permalink
Move unit_utils GPU test to specfic file (#756)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #756

Reviewed By: JKSenthil

Differential Revision: D55257789

fbshipit-source-id: 645906442f243fcc872965c80e8d7fcc2e229fbe
  • Loading branch information
diego-urgell authored and facebook-github-bot committed Mar 22, 2024
1 parent e806af0 commit c2dcee9
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 28 deletions.
28 changes: 0 additions & 28 deletions tests/framework/test_unit_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,12 @@
from typing import Dict, Iterator

import torch

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.optim import Optimizer
from torchtnt.framework._unit_utils import (
_find_optimizers_for_module,
_step_requires_iterator,
)
from torchtnt.framework.state import State
from torchtnt.utils.distributed import spawn_multi_process
from torchtnt.utils.env import init_from_env
from torchtnt.utils.test_utils import skip_if_not_distributed, skip_if_not_gpu


class UnitUtilsTest(unittest.TestCase):
Expand Down Expand Up @@ -55,26 +50,3 @@ def test_find_optimizers_for_module(self) -> None:
optimizers = _find_optimizers_for_module(module2, opts)
optim_name, _ = optimizers[0]
self.assertEqual(optim_name, "optim2")

@skip_if_not_distributed
@skip_if_not_gpu
def test_find_optimizers_for_FSDP_module(self) -> None:
spawn_multi_process(2, "nccl", self._find_optimizers_for_FSDP_module)

@staticmethod
def _find_optimizers_for_FSDP_module() -> None:
device = init_from_env()
module1 = FSDP(torch.nn.Linear(10, 10).to(device))
module2 = torch.nn.Linear(10, 10)
optim1 = torch.optim.Adam(module1.parameters())
optim2 = torch.optim.Adagrad(module2.parameters())

opts: Dict[str, Optimizer] = {"optim1": optim1, "optim2": optim2}
optim_list = _find_optimizers_for_module(module1, opts)
optim_name, _ = optim_list[0]

tc = unittest.TestCase()
tc.assertEqual(optim_name, "optim1")
optim_list = _find_optimizers_for_module(module2, opts)
optim_name, _ = optim_list[0]
tc.assertEqual(optim_name, "optim2")
45 changes: 45 additions & 0 deletions tests/framework/test_unit_utils_gpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import unittest
from typing import Dict

import torch

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.optim import Optimizer
from torchtnt.framework._unit_utils import _find_optimizers_for_module
from torchtnt.utils.distributed import spawn_multi_process
from torchtnt.utils.env import init_from_env
from torchtnt.utils.test_utils import skip_if_not_distributed, skip_if_not_gpu


class UnitUtilsGPUTest(unittest.TestCase):
@skip_if_not_distributed
@skip_if_not_gpu
def test_find_optimizers_for_FSDP_module(self) -> None:
spawn_multi_process(2, "nccl", self._find_optimizers_for_FSDP_module)

@staticmethod
def _find_optimizers_for_FSDP_module() -> None:
device = init_from_env()
module1 = FSDP(torch.nn.Linear(10, 10).to(device))
module2 = torch.nn.Linear(10, 10)
optim1 = torch.optim.Adam(module1.parameters())
optim2 = torch.optim.Adagrad(module2.parameters())

opts: Dict[str, Optimizer] = {"optim1": optim1, "optim2": optim2}
optim_list = _find_optimizers_for_module(module1, opts)
optim_name, _ = optim_list[0]

tc = unittest.TestCase()
tc.assertEqual(optim_name, "optim1")
optim_list = _find_optimizers_for_module(module2, opts)
optim_name, _ = optim_list[0]
tc.assertEqual(optim_name, "optim2")

0 comments on commit c2dcee9

Please sign in to comment.