diff --git a/tests/utils/test_timer.py b/tests/utils/test_timer.py index bf27ae87f3..93c65d3fe8 100644 --- a/tests/utils/test_timer.py +++ b/tests/utils/test_timer.py @@ -284,25 +284,21 @@ def _full_sync_worker_with_timeout(cls, timeout: int) -> bool: def test_full_sync_pt_multi_process_check_false(self) -> None: mp_dict = spawn_multi_process(2, "gloo", self._full_sync_worker_without_timeout) # Both processes should return False - self.assertFalse(mp_dict[0]) - self.assertFalse(mp_dict[1]) + self.assertDictEqual(mp_dict, {0: False, 1: False}) def test_full_sync_pt_multi_process_check_true(self) -> None: mp_dict = spawn_multi_process(2, "gloo", self._full_sync_worker_with_timeout, 8) # Both processes should return True - self.assertTrue(mp_dict[0]) - self.assertTrue(mp_dict[1]) + self.assertDictEqual(mp_dict, {0: True, 1: True}) def test_full_sync_pt_multi_process_edgecase(self) -> None: mp_dict = spawn_multi_process(2, "gloo", self._full_sync_worker_with_timeout, 5) # Both processes should return True - self.assertTrue(mp_dict[0]) - self.assertTrue(mp_dict[1]) + self.assertDictEqual(mp_dict, {0: True, 1: True}) # Launch 2 worker processes. Each will check time diff >= interval threshold mp_dict = spawn_multi_process(2, "gloo", self._full_sync_worker_with_timeout, 4) # Both processes should return False - self.assertFalse(mp_dict[0]) - self.assertFalse(mp_dict[1]) + self.assertDictEqual(mp_dict, {0: False, 1: False}) diff --git a/torchtnt/utils/test_utils.py b/torchtnt/utils/test_utils.py index a0af23394a..eb4135d371 100644 --- a/torchtnt/utils/test_utils.py +++ b/torchtnt/utils/test_utils.py @@ -130,7 +130,7 @@ def spawn_multi_process( p.join() tc.assertEqual(p.exitcode, 0) - return mp_output_dict + return mp_output_dict.copy() def init_pg_and_rank_and_launch_test(