diff --git a/tests/utils/test_distributed_gpu.py b/tests/utils/test_distributed_gpu.py index d1b06f2cc5..dec90b8ed0 100644 --- a/tests/utils/test_distributed_gpu.py +++ b/tests/utils/test_distributed_gpu.py @@ -14,6 +14,7 @@ from torchtnt.utils.device import get_device_from_env from torchtnt.utils.distributed import ( all_gather_tensors, + get_global_rank, get_local_rank, PGWrapper, spawn_multi_process, @@ -67,3 +68,13 @@ def _test_pg_wrapper_scatter_object_list( ) tc = unittest.TestCase() tc.assertEqual(output_list[0], get_local_rank() + 1) + + @staticmethod + def _test_method(offset_arg: int, offset_kwarg: int) -> int: + return get_global_rank() + offset_arg - offset_kwarg + + @skip_if_not_gpu + @skip_if_not_distributed + def test_spawn_multi_process(self) -> None: + mp_list = spawn_multi_process(2, "nccl", self._test_method, 3, offset_kwarg=2) + self.assertEqual(mp_list, [1, 2]) diff --git a/torchtnt/utils/distributed.py b/torchtnt/utils/distributed.py index 263421f549..b65dfbfb96 100644 --- a/torchtnt/utils/distributed.py +++ b/torchtnt/utils/distributed.py @@ -523,9 +523,9 @@ class ProcessGroupSetupParams: def spawn_multi_process( world_size: int, backend: str, - test_method: Callable[TParams, TReturn], - *test_method_args: Any, - **test_method_kwargs: Any, + method: Callable[TParams, TReturn], + *method_args: Any, + **method_kwargs: Any, ) -> List[TReturn]: """ Spawn single node, multi-rank function. @@ -534,12 +534,12 @@ def spawn_multi_process( Args: world_size: number of processes backend: backend to use. for example, "nccl", "gloo", etc - test_method: callable to spawn. first 3 arguments are rank, world_size and mp output dict - test_method_args: args for the test method - test_method_kwargs: kwargs for the test method + method: callable to spawn. + method_args: args for the method + method_kwargs: kwargs for the method Returns: - A list, l, where l[i] is the return value of test_method on rank i + A list, l, where l[i] is the return value of method(*method_args, **methods_kwargs) on rank i """ manager = multiprocessing.Manager() mp_output_dict = manager.dict() @@ -548,13 +548,13 @@ def spawn_multi_process( torch.multiprocessing.spawn( # torch.multiprocessing.spawn sends rank as the first param # https://pytorch.org/docs/stable/multiprocessing.html#torch.multiprocessing.spawn - _init_pg_and_rank_and_launch_test, + _init_pg_and_rank_and_launch_method, args=( ProcessGroupSetupParams(backend=backend, port=port, world_size=world_size), mp_output_dict, - test_method, - test_method_args, - test_method_kwargs, + method, + method_args, + method_kwargs, ), nprocs=world_size, ) @@ -566,11 +566,11 @@ def spawn_multi_process( return output_list -def _init_pg_and_rank_and_launch_test( +def _init_pg_and_rank_and_launch_method( rank: int, pg_setup_params: ProcessGroupSetupParams, mp_output_dict: Dict[int, object], - test_method: Callable[TParams, TReturn], + method: Callable[TParams, TReturn], args: List[object], kwargs: Dict[str, object], ) -> None: @@ -586,7 +586,7 @@ def _init_pg_and_rank_and_launch_test( ) try: # pyre-ignore: spawn_multi_process uses unsafe types to begin with - mp_output_dict[rank] = test_method(*args, **kwargs) + mp_output_dict[rank] = method(*args, **kwargs) finally: destroy_process_group()