Skip to content

Commit

Permalink
spawn multi process to return pythonic dict
Browse files Browse the repository at this point in the history
Summary: Return pythonic dict from spawn_multi_process

Differential Revision: D52558016
  • Loading branch information
galrotem authored and facebook-github-bot committed Jan 9, 2024
1 parent ba203c0 commit 1e232cb
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 9 deletions.
12 changes: 4 additions & 8 deletions tests/utils/test_timer.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,25 +284,21 @@ def _full_sync_worker_with_timeout(cls, timeout: int) -> bool:
def test_full_sync_pt_multi_process_check_false(self) -> None:
mp_dict = spawn_multi_process(2, "gloo", self._full_sync_worker_without_timeout)
# Both processes should return False
self.assertFalse(mp_dict[0])
self.assertFalse(mp_dict[1])
self.assertDictEqual(mp_dict, {0: False, 1: False})

def test_full_sync_pt_multi_process_check_true(self) -> None:
mp_dict = spawn_multi_process(2, "gloo", self._full_sync_worker_with_timeout, 8)
# Both processes should return True
self.assertTrue(mp_dict[0])
self.assertTrue(mp_dict[1])
self.assertDictEqual(mp_dict, {0: True, 1: True})

def test_full_sync_pt_multi_process_edgecase(self) -> None:
mp_dict = spawn_multi_process(2, "gloo", self._full_sync_worker_with_timeout, 5)

# Both processes should return True
self.assertTrue(mp_dict[0])
self.assertTrue(mp_dict[1])
self.assertDictEqual(mp_dict, {0: True, 1: True})

# Launch 2 worker processes. Each will check time diff >= interval threshold
mp_dict = spawn_multi_process(2, "gloo", self._full_sync_worker_with_timeout, 4)

# Both processes should return False
self.assertFalse(mp_dict[0])
self.assertFalse(mp_dict[1])
self.assertDictEqual(mp_dict, {0: False, 1: False})
2 changes: 1 addition & 1 deletion torchtnt/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def spawn_multi_process(
p.join()
tc.assertEqual(p.exitcode, 0)

return mp_output_dict
return mp_output_dict.copy()


def init_pg_and_rank_and_launch_test(
Expand Down

0 comments on commit 1e232cb

Please sign in to comment.