diff --git a/test/nodes/test_mds_sampler.py b/test/nodes/test_mds_sampler.py new file mode 100644 index 000000000..5ef79d059 --- /dev/null +++ b/test/nodes/test_mds_sampler.py @@ -0,0 +1,39 @@ +import itertools + +from parameterized import parameterized +from torch.testing._internal.common_utils import TestCase +from torchdata.nodes.adapters import IterableWrapper +from torchdata.nodes.prefetch import Prefetcher + +from torchdata.nodes.samplers.multi_node_weighted_sampler import MultiNodeWeightedSampler +from torchdata.nodes.samplers.stop_criteria import StopCriteria + +from utils import DummyIterableDataset, run_test_save_load_state + + +num_samples = 5 +num_datasets = 4 + +datasets = { + f"ds{i}": IterableWrapper(DummyIterableDataset(num_samples, f"ds{i}")) + for i in range(num_datasets) +} + +weights_fn = lambda i: 0.1 * (i + 1) +weights = {f"ds{i}": weights_fn(i) for i in range(num_datasets)} +mixer = MultiNodeWeightedSampler(datasets, weights, StopCriteria.FIRST_DATASET_EXHAUSTED, seed=42) + +num_epochs = 1 + +for epoch in range(num_epochs): + results = list(mixer) + + datasets_in_results = [result["name"] for result in results] + + dataset_counts_in_results = [datasets_in_results.count(f"ds{i}") for i in range(num_datasets)] + elements = [[result["name"], result["step"]] for result in results] + print(elements) + print(datasets_in_results) + print(dataset_counts_in_results) + + mixer.reset() diff --git a/test/nodes/test_multi_node_weighted_sampler.py b/test/nodes/test_multi_node_weighted_sampler.py index b9f00d926..fa02eb272 100644 --- a/test/nodes/test_multi_node_weighted_sampler.py +++ b/test/nodes/test_multi_node_weighted_sampler.py @@ -88,12 +88,7 @@ def test_multi_node_weighted_batch_sampler_zero_weights( def test_multi_node_weighted_sampler_first_exhausted(self) -> None: """Test MultiNodeWeightedSampler with stop criteria FIRST_DATASET_EXHAUSTED""" - mixer = self._setup_multi_node_weighted_sampler( - self._num_samples, - self._num_datasets, - self._weights_fn, - stop_criteria=StopCriteria.FIRST_DATASET_EXHAUSTED, - ) + mixer = MultiNodeWeightedSampler(self.datasets, self.weights, StopCriteria.FIRST_DATASET_EXHAUSTED, seed=42) for _ in range(self._num_epochs): results = list(mixer) diff --git a/torchdata/nodes/samplers/multi_node_weighted_sampler.py b/torchdata/nodes/samplers/multi_node_weighted_sampler.py index 572b57b26..fd39bf5fb 100644 --- a/torchdata/nodes/samplers/multi_node_weighted_sampler.py +++ b/torchdata/nodes/samplers/multi_node_weighted_sampler.py @@ -158,17 +158,21 @@ def _check_for_stop_iteration(self) -> None: # the first dataset is exhausted. Doing this to correctly catch StopIteration # when trying next(it) on already exhausted iterator if self.stop_criteria == StopCriteria.FIRST_DATASET_EXHAUSTED and any(self._datasets_exhausted.values()): + print(self._datasets_exhausted) + print("Stopping: First dataset exhausted.") raise StopIteration() return def next(self) -> T: self._started = True + while True: self._check_for_stop_iteration() # Fetch the next item's key from the weighted sampler key = next(self._weighted_sampler) + print(f"The next key selected is {key}") try: if self._datasets_exhausted[key] and self.stop_criteria == StopCriteria.ALL_DATASETS_EXHAUSTED: # Before fetching a new item check if key corresponds to an already @@ -190,6 +194,7 @@ def next(self) -> T: # reset the iterator and try again self.source_nodes[key].reset() item = next(self.source_nodes[key]) + print(self._datasets_exhausted) break # If we did't throw StopIteration, increment the number of items yielded and return the item