Skip to content

Commit

Permalink
add timeout to spawn_multi_process
Browse files Browse the repository at this point in the history
Summary: Add timeout to pg operations in spawn_multi_process

Differential Revision: D52558219
  • Loading branch information
galrotem authored and facebook-github-bot committed Jan 9, 2024
1 parent d9a4a4a commit 3a85c6b
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 6 deletions.
2 changes: 1 addition & 1 deletion examples/torchrec/tests/torchrec_example_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,4 @@ class TorchrecExampleTest(unittest.TestCase):
"Skip when CUDA is not available",
)
def test_torchrec_example(self) -> None:
spawn_multi_process(2, "nccl", main, [])
spawn_multi_process(2, "nccl", main, 10, [])
4 changes: 3 additions & 1 deletion tests/utils/test_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,7 @@ def _test_method(offset_arg: int, offset_kwarg: int) -> int:

@skip_if_not_distributed
def test_spawn_multi_process(self) -> None:
mp_dict = spawn_multi_process(2, "gloo", self._test_method, 3, offset_kwarg=2)
mp_dict = spawn_multi_process(
2, "gloo", self._test_method, 10, 3, offset_kwarg=2
)
self.assertEqual(mp_dict, {0: 1, 1: 2})
12 changes: 9 additions & 3 deletions tests/utils/test_timer.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,18 +287,24 @@ def test_full_sync_pt_multi_process_check_false(self) -> None:
self.assertDictEqual(mp_dict, {0: False, 1: False})

def test_full_sync_pt_multi_process_check_true(self) -> None:
mp_dict = spawn_multi_process(2, "gloo", self._full_sync_worker_with_timeout, 8)
mp_dict = spawn_multi_process(
2, "gloo", self._full_sync_worker_with_timeout, 10, 8
)
# Both processes should return True
self.assertDictEqual(mp_dict, {0: True, 1: True})

def test_full_sync_pt_multi_process_edgecase(self) -> None:
mp_dict = spawn_multi_process(2, "gloo", self._full_sync_worker_with_timeout, 5)
mp_dict = spawn_multi_process(
2, "gloo", self._full_sync_worker_with_timeout, 10, 5
)

# Both processes should return True
self.assertDictEqual(mp_dict, {0: True, 1: True})

# Launch 2 worker processes. Each will check time diff >= interval threshold
mp_dict = spawn_multi_process(2, "gloo", self._full_sync_worker_with_timeout, 4)
mp_dict = spawn_multi_process(
2, "gloo", self._full_sync_worker_with_timeout, 10, 4
)

# Both processes should return False
self.assertDictEqual(mp_dict, {0: False, 1: False})
16 changes: 15 additions & 1 deletion torchtnt/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import unittest
import uuid
from contextlib import contextmanager
from datetime import timedelta
from functools import wraps
from io import StringIO
from typing import (
Expand Down Expand Up @@ -100,6 +101,7 @@ def spawn_multi_process(
world_size: int,
backend: str,
test_method: Callable[TParams, TReturn],
timeout_seconds: int = 10,
*args: Any,
**kwargs: Any,
) -> Dict[int, TReturn]:
Expand All @@ -111,6 +113,7 @@ def spawn_multi_process(
world_size: number of processes
backend: backend to use. for example, "nccl", "gloo", etc
test_method: callable to spawn. first 3 arguments are rank, world_size and mp output dict
timeout_seconds: timeout in second for the distributed collective operations, defaults to 10
args: args for the test method
kwargs: kwargs for the test method
Expand All @@ -123,7 +126,16 @@ def spawn_multi_process(
port = str(get_free_port())
torch.multiprocessing.spawn(
_init_pg_and_rank_and_launch_test,
args=(test_method, world_size, backend, port, mp_output_dict, args, kwargs),
args=(
test_method,
world_size,
backend,
port,
timeout_seconds,
mp_output_dict,
args,
kwargs,
),
nprocs=world_size,
join=True,
daemon=False,
Expand All @@ -139,6 +151,7 @@ def _init_pg_and_rank_and_launch_test(
world_size: int,
backend: str,
port: str,
timeout_seconds: int,
mp_output_dict: Dict[int, object],
args: List[object],
kwargs: Dict[str, object],
Expand All @@ -151,6 +164,7 @@ def _init_pg_and_rank_and_launch_test(
rank=rank,
world_size=world_size,
backend=backend,
timeout=timedelta(seconds=timeout_seconds),
)
try:
mp_output_dict[rank] = test_method(*args, **kwargs) # pyre-fixme
Expand Down

0 comments on commit 3a85c6b

Please sign in to comment.