Skip to content

Commit

Permalink
adjust spawn_multi_process
Browse files Browse the repository at this point in the history
Summary: Adjusting spawn_multi_process to work properly on gloo

Differential Revision: D52546790
  • Loading branch information
galrotem authored and facebook-github-bot committed Jan 4, 2024
1 parent 28e657d commit 892ce89
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 31 deletions.
24 changes: 24 additions & 0 deletions tests/utils/test_test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

from torchtnt.utils.distributed import get_global_rank

from torchtnt.utils.test_utils import skip_if_not_distributed, spawn_multi_process


class TestUtilsTest(unittest.TestCase):
@staticmethod
def _test_method(offset_arg: int, offset_kwarg: int) -> int:
return get_global_rank() + offset_arg - offset_kwarg

@skip_if_not_distributed
def test_spawn_multi_process(self) -> None:
ret_dict = spawn_multi_process(2, "gloo", self._test_method, 3, offset_kwarg=2)
self.assertEqual(ret_dict[0], 1)
self.assertEqual(ret_dict[1], 2)
77 changes: 46 additions & 31 deletions torchtnt/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,29 @@
import unittest
import uuid
from contextlib import contextmanager
from datetime import timedelta
from functools import wraps
from io import StringIO
from typing import Any, Callable, Dict, Generator, Optional, TextIO, Tuple, TypeVar
from typing import (
Any,
Callable,
Dict,
Generator,
List,
Optional,
TextIO,
Tuple,
TypeVar,
)

import torch

import torch.distributed.launcher as pet
from pyre_extensions import ParameterSpecification
from torch import distributed as dist, multiprocessing

from torch.distributed.elastic.utils.distributed import get_free_port
from torchtnt.utils.distributed import destroy_process_group


TParams = ParameterSpecification("TParams")
Expand Down Expand Up @@ -89,6 +102,7 @@ def spawn_multi_process(
backend: str,
test_method: Callable[TParams, TReturn],
*args: Any,
**kwargs: Any,
) -> Dict[int, TReturn]:
"""
Spawn single node, multi-rank function.
Expand All @@ -103,48 +117,49 @@ def spawn_multi_process(
Returns:
A dictionary of rank -> func return value
"""
os.environ["MASTER_PORT"] = str(get_free_port())
os.environ["MASTER_ADDR"] = "127.0.0.1"

processes = []
manager = multiprocessing.Manager()
mp_output_dict = manager.dict()
tc = unittest.TestCase()
ctx = multiprocessing.get_context("spawn")
for rank in range(world_size):
p = ctx.Process(
target=init_pg_and_rank_and_launch_test,
args=(
test_method,
rank,
world_size,
backend,
mp_output_dict,
*args,
),
)
p.start()
processes.append(p)

for p in processes:
p.join()
tc.assertEqual(p.exitcode, 0)

port = str(get_free_port())
torch.multiprocessing.spawn(
init_pg_and_rank_and_launch_test,
args=(test_method, world_size, backend, mp_output_dict, port, args, kwargs),
nprocs=world_size,
join=True,
daemon=False,
start_method="spawn",
)

return mp_output_dict


def init_pg_and_rank_and_launch_test(
test_method: Callable[TParams, TReturn],
rank: int,
test_method: Callable[TParams, TReturn],
world_size: int,
backend: str,
# pyre-fixme[2]
mp_output_dict: Dict[int, Any],
*args: Any,
mp_output_dict: Dict[int, object],
port: str,
args: List[object],
kwargs: Dict[str, object],
) -> None:
dist.init_process_group(rank=rank, world_size=world_size, backend=backend)
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = port
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["LOCAL_RANK"] = str(rank)
mp_output_dict[rank] = test_method(*args) # pyre-fixme[29]
dist.init_process_group(
rank=rank,
world_size=world_size,
backend=backend,
init_method="env://",
timeout=timedelta(seconds=10),
)
try:
mp_output_dict[rank] = test_method(*args, **kwargs) # pyre-fixme
return

finally:
destroy_process_group()


@contextmanager
Expand Down

0 comments on commit 892ce89

Please sign in to comment.