Skip to content

Commit

Permalink
rename params in spawn_mp (#739)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #739

Reviewed By: JKSenthil

Differential Revision: D54863367

fbshipit-source-id: d33a804578dfb29bb67a48ea098bb0d60b35810e
  • Loading branch information
galrotem authored and facebook-github-bot committed Mar 13, 2024
1 parent 67159a4 commit ae05491
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 14 deletions.
11 changes: 11 additions & 0 deletions tests/utils/test_distributed_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from torchtnt.utils.device import get_device_from_env
from torchtnt.utils.distributed import (
all_gather_tensors,
get_global_rank,
get_local_rank,
PGWrapper,
spawn_multi_process,
Expand Down Expand Up @@ -67,3 +68,13 @@ def _test_pg_wrapper_scatter_object_list(
)
tc = unittest.TestCase()
tc.assertEqual(output_list[0], get_local_rank() + 1)

@staticmethod
def _test_method(offset_arg: int, offset_kwarg: int) -> int:
return get_global_rank() + offset_arg - offset_kwarg

@skip_if_not_gpu
@skip_if_not_distributed
def test_spawn_multi_process(self) -> None:
mp_list = spawn_multi_process(2, "nccl", self._test_method, 3, offset_kwarg=2)
self.assertEqual(mp_list, [1, 2])
28 changes: 14 additions & 14 deletions torchtnt/utils/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,9 +523,9 @@ class ProcessGroupSetupParams:
def spawn_multi_process(
world_size: int,
backend: str,
test_method: Callable[TParams, TReturn],
*test_method_args: Any,
**test_method_kwargs: Any,
method: Callable[TParams, TReturn],
*method_args: Any,
**method_kwargs: Any,
) -> List[TReturn]:
"""
Spawn single node, multi-rank function.
Expand All @@ -534,12 +534,12 @@ def spawn_multi_process(
Args:
world_size: number of processes
backend: backend to use. for example, "nccl", "gloo", etc
test_method: callable to spawn. first 3 arguments are rank, world_size and mp output dict
test_method_args: args for the test method
test_method_kwargs: kwargs for the test method
method: callable to spawn.
method_args: args for the method
method_kwargs: kwargs for the method
Returns:
A list, l, where l[i] is the return value of test_method on rank i
A list, l, where l[i] is the return value of method(*method_args, **methods_kwargs) on rank i
"""
manager = multiprocessing.Manager()
mp_output_dict = manager.dict()
Expand All @@ -548,13 +548,13 @@ def spawn_multi_process(
torch.multiprocessing.spawn(
# torch.multiprocessing.spawn sends rank as the first param
# https://pytorch.org/docs/stable/multiprocessing.html#torch.multiprocessing.spawn
_init_pg_and_rank_and_launch_test,
_init_pg_and_rank_and_launch_method,
args=(
ProcessGroupSetupParams(backend=backend, port=port, world_size=world_size),
mp_output_dict,
test_method,
test_method_args,
test_method_kwargs,
method,
method_args,
method_kwargs,
),
nprocs=world_size,
)
Expand All @@ -566,11 +566,11 @@ def spawn_multi_process(
return output_list


def _init_pg_and_rank_and_launch_test(
def _init_pg_and_rank_and_launch_method(
rank: int,
pg_setup_params: ProcessGroupSetupParams,
mp_output_dict: Dict[int, object],
test_method: Callable[TParams, TReturn],
method: Callable[TParams, TReturn],
args: List[object],
kwargs: Dict[str, object],
) -> None:
Expand All @@ -586,7 +586,7 @@ def _init_pg_and_rank_and_launch_test(
)
try:
# pyre-ignore: spawn_multi_process uses unsafe types to begin with
mp_output_dict[rank] = test_method(*args, **kwargs)
mp_output_dict[rank] = method(*args, **kwargs)

finally:
destroy_process_group()

0 comments on commit ae05491

Please sign in to comment.