diff --git a/test/nodes/test_multi_node_weighted_sampler.py b/test/nodes/test_multi_node_weighted_sampler.py index b9f00d926..4958f84df 100644 --- a/test/nodes/test_multi_node_weighted_sampler.py +++ b/test/nodes/test_multi_node_weighted_sampler.py @@ -24,7 +24,6 @@ def setUp(self) -> None: self._num_datasets = 4 self._weights_fn = lambda i: 0.1 * (i + 1) self._num_epochs = 3 - self.datasets = { f"ds{i}": IterableWrapper(DummyIterableDataset(self._num_samples, f"ds{i}")) for i in range(self._num_datasets) @@ -38,11 +37,13 @@ def test_torchdata_nodes_imports(self) -> None: except ImportError: self.fail("MultiNodeWeightedSampler or StopCriteria failed to import") - def _setup_multi_node_weighted_sampler(self, num_samples, num_datasets, weights_fn, stop_criteria) -> Prefetcher: + def _setup_multi_node_weighted_sampler( + self, num_samples, num_datasets, weights_fn, stop_criteria, seed=0 + ) -> Prefetcher: datasets = {f"ds{i}": IterableWrapper(DummyIterableDataset(num_samples, f"ds{i}")) for i in range(num_datasets)} weights = {f"ds{i}": weights_fn(i) for i in range(num_datasets)} - node = MultiNodeWeightedSampler(datasets, weights, stop_criteria) + node = MultiNodeWeightedSampler(datasets, weights, stop_criteria, seed=seed) return Prefetcher(node, prefetch_factor=3) def test_multi_node_weighted_sampler_weight_sampler_keys_mismatch(self) -> None: @@ -86,17 +87,19 @@ def test_multi_node_weighted_batch_sampler_zero_weights( weights={f"ds{i}": 10 * i for i in range(self._num_datasets)}, ) - def test_multi_node_weighted_sampler_first_exhausted(self) -> None: + @parameterized.expand(range(10)) + def test_multi_node_weighted_sampler_first_exhausted(self, seed) -> None: """Test MultiNodeWeightedSampler with stop criteria FIRST_DATASET_EXHAUSTED""" - mixer = self._setup_multi_node_weighted_sampler( + node = self._setup_multi_node_weighted_sampler( self._num_samples, self._num_datasets, self._weights_fn, stop_criteria=StopCriteria.FIRST_DATASET_EXHAUSTED, + seed=seed, ) for _ in range(self._num_epochs): - results = list(mixer) + results = list(node) datasets_in_results = [result["name"] for result in results] dataset_counts_in_results = [datasets_in_results.count(f"ds{i}") for i in range(self._num_datasets)] @@ -104,21 +107,23 @@ def test_multi_node_weighted_sampler_first_exhausted(self) -> None: # Check max item count for dataset is exactly _num_samples self.assertEqual(max(dataset_counts_in_results), self._num_samples) - # Check only one dataset has been exhausted - self.assertEqual(dataset_counts_in_results.count(self._num_samples), 1) - mixer.reset() + # Check that max items are taken from at least one dataset + self.assertGreaterEqual(dataset_counts_in_results.count(self._num_samples), 1) + node.reset() - def test_multi_node_weighted_sampler_all_dataset_exhausted(self) -> None: + @parameterized.expand(range(10)) + def test_multi_node_weighted_sampler_all_dataset_exhausted(self, seed) -> None: """Test MultiNodeWeightedSampler with stop criteria ALL_DATASETS_EXHAUSTED""" - mixer = self._setup_multi_node_weighted_sampler( + node = self._setup_multi_node_weighted_sampler( self._num_samples, self._num_datasets, self._weights_fn, stop_criteria=StopCriteria.ALL_DATASETS_EXHAUSTED, + seed=seed, ) for _ in range(self._num_epochs): - results = list(mixer) + results = list(node) datasets_in_results = [result["name"] for result in results] dataset_counts_in_results = [datasets_in_results.count(f"ds{i}") for i in range(self._num_datasets)] @@ -131,56 +136,67 @@ def test_multi_node_weighted_sampler_all_dataset_exhausted(self) -> None: # check that all datasets are exhausted self.assertEqual(sorted(set(datasets_in_results)), ["ds0", "ds1", "ds2", "ds3"]) - mixer.reset() + node.reset() - def test_multi_node_weighted_sampler_cycle_until_all_exhausted(self) -> None: + @parameterized.expand(range(10)) + def test_multi_node_weighted_sampler_cycle_until_all_exhausted(self, seed) -> None: """Test MultiNodeWeightedSampler with stop criteria CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED""" - mixer = self._setup_multi_node_weighted_sampler( + node = self._setup_multi_node_weighted_sampler( self._num_samples, self._num_datasets, self._weights_fn, stop_criteria=StopCriteria.CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED, + seed=seed, ) for _ in range(self._num_epochs): - results = list(mixer) + results = list(node) datasets_in_results = {result["name"] for result in results} # check that all datasets are exhausted self.assertEqual(sorted(datasets_in_results), ["ds0", "ds1", "ds2", "ds3"]) - mixer.reset() + node.reset() - def test_multi_node_weighted_sampler_cycle_forever(self) -> None: + @parameterized.expand(range(10)) + def test_multi_node_weighted_sampler_cycle_forever(self, seed) -> None: """Test MultiNodeWeightedSampler with stop criteria CYCLE_FOREVER""" - mixer = MultiNodeWeightedSampler( - self.datasets, - self.equal_weights, - stop_criteria=StopCriteria.CYCLE_FOREVER, + node = MultiNodeWeightedSampler( + self.datasets, self.equal_weights, stop_criteria=StopCriteria.CYCLE_FOREVER, seed=seed ) num_yielded = 0 - _it = iter(mixer) + _it = iter(node) while num_yielded < 256: # total number of samples is 4 * 10 = 40, 256 is an arbitrary larger number next(_it) num_yielded += 1 - mixer_num_yielded = mixer.get_state()[MultiNodeWeightedSampler.NUM_YIELDED_KEY] - self.assertEqual(mixer_num_yielded, num_yielded) + node_num_yielded = node.get_state()[MultiNodeWeightedSampler.NUM_YIELDED_KEY] + self.assertEqual(node_num_yielded, num_yielded) @parameterized.expand([(1, 8), (8, 32)]) def test_multi_node_weighted_batch_sampler_set_rank_world_size(self, rank, world_size): """Test MultiNodeWeightedSampler with different rank and world size""" - mixer = MultiNodeWeightedSampler(self.datasets, self.weights, rank=rank, world_size=world_size) - self.assertEqual(mixer.rank, rank) - self.assertEqual(mixer.world_size, world_size) + node = MultiNodeWeightedSampler( + self.datasets, + self.weights, + rank=rank, + world_size=world_size, + ) + self.assertEqual(node.rank, rank) + self.assertEqual(node.world_size, world_size) def test_multi_node_weighted_batch_sampler_results_for_ranks(self): """Test MultiNodeWeightedSampler with different results for different ranks""" world_size = 8 global_results = [] for rank in range(world_size): - mixer = MultiNodeWeightedSampler(self.datasets, self.weights, rank=rank, world_size=world_size) - results = list(mixer) + node = MultiNodeWeightedSampler( + self.datasets, + self.weights, + rank=rank, + world_size=world_size, + ) + results = list(node) global_results.append(results) unique_results = [] @@ -192,17 +208,17 @@ def test_multi_node_weighted_batch_sampler_results_for_ranks(self): def test_multi_node_weighted_batch_sampler_results_for_multiple_epochs(self): """Test MultiNodeWeightedSampler with different results in each epoch""" - # Check for the mixer node only - mixer = MultiNodeWeightedSampler( + # Check for the MultiNodeWeightedSampler node only + node = MultiNodeWeightedSampler( self.datasets, self.weights, ) overall_results = [] for _ in range(self._num_epochs): - results = list(mixer) + results = list(node) overall_results.append(results) - mixer.reset() + node.reset() unique_results = [] for results in overall_results: @@ -211,7 +227,7 @@ def test_multi_node_weighted_batch_sampler_results_for_multiple_epochs(self): self.assertEqual(unique_results, overall_results) - # Check for mixer along with Prefetcher node + # Check for MultiNodeWeightedSampler node along with Prefetcher node node = self._setup_multi_node_weighted_sampler( self._num_samples, self._num_datasets, @@ -242,9 +258,13 @@ def test_multi_node_weighted_batch_sampler_results_for_multiple_epochs(self): ], ) ) - def test_save_load_state_mixer_over_multiple_epochs(self, midpoint: int, stop_criteria: str): + def test_save_load_state_mds_node_over_multiple_epochs(self, midpoint: int, stop_criteria: str): """Test MultiNodeWeightedSampler with saving and loading of state across multiple epochs""" - node = MultiNodeWeightedSampler(self.datasets, self.weights, stop_criteria) + node = MultiNodeWeightedSampler( + self.datasets, + self.weights, + stop_criteria, + ) run_test_save_load_state(self, node, midpoint) @parameterized.expand( @@ -257,7 +277,7 @@ def test_save_load_state_mixer_over_multiple_epochs(self, midpoint: int, stop_cr ], ) ) - def test_save_load_state_mixer_over_multiple_epochs_with_prefetcher(self, midpoint: int, stop_criteria: str): + def test_save_load_state_mds_node_over_multiple_epochs_with_prefetcher(self, midpoint: int, stop_criteria: str): node = self._setup_multi_node_weighted_sampler( self._num_samples, self._num_datasets, @@ -281,10 +301,10 @@ def test_multi_node_weighted_large_sample_size_with_prefetcher(self, midpoint, s num_samples = 1500 num_datasets = 5 - mixer = self._setup_multi_node_weighted_sampler( + node = self._setup_multi_node_weighted_sampler( num_samples, num_datasets, self._weights_fn, stop_criteria, ) - run_test_save_load_state(self, mixer, midpoint) + run_test_save_load_state(self, node, midpoint)