From b5db9eca57e43c896c231c237c2ed110298eb5c5 Mon Sep 17 00:00:00 2001 From: Ramanish Singh Date: Tue, 21 Jan 2025 14:36:16 -0800 Subject: [PATCH 1/2] remove prefetcher in mixer creation --- test/nodes/test_multi_node_weighted_sampler.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/test/nodes/test_multi_node_weighted_sampler.py b/test/nodes/test_multi_node_weighted_sampler.py index b9f00d926..b0b837e42 100644 --- a/test/nodes/test_multi_node_weighted_sampler.py +++ b/test/nodes/test_multi_node_weighted_sampler.py @@ -88,12 +88,7 @@ def test_multi_node_weighted_batch_sampler_zero_weights( def test_multi_node_weighted_sampler_first_exhausted(self) -> None: """Test MultiNodeWeightedSampler with stop criteria FIRST_DATASET_EXHAUSTED""" - mixer = self._setup_multi_node_weighted_sampler( - self._num_samples, - self._num_datasets, - self._weights_fn, - stop_criteria=StopCriteria.FIRST_DATASET_EXHAUSTED, - ) + mixer = MultiNodeWeightedSampler(self.datasets, self.weights, StopCriteria.FIRST_DATASET_EXHAUSTED) for _ in range(self._num_epochs): results = list(mixer) From 65664c7a9b6cc78d404e7a29feba5d77250d4912 Mon Sep 17 00:00:00 2001 From: Ramanish Singh Date: Tue, 21 Jan 2025 16:23:14 -0800 Subject: [PATCH 2/2] local repro --- test/nodes/test_mds_sampler.py | 39 +++++++++++++++++++ .../nodes/test_multi_node_weighted_sampler.py | 2 +- .../samplers/multi_node_weighted_sampler.py | 5 +++ 3 files changed, 45 insertions(+), 1 deletion(-) create mode 100644 test/nodes/test_mds_sampler.py diff --git a/test/nodes/test_mds_sampler.py b/test/nodes/test_mds_sampler.py new file mode 100644 index 000000000..5ef79d059 --- /dev/null +++ b/test/nodes/test_mds_sampler.py @@ -0,0 +1,39 @@ +import itertools + +from parameterized import parameterized +from torch.testing._internal.common_utils import TestCase +from torchdata.nodes.adapters import IterableWrapper +from torchdata.nodes.prefetch import Prefetcher + +from torchdata.nodes.samplers.multi_node_weighted_sampler import MultiNodeWeightedSampler +from torchdata.nodes.samplers.stop_criteria import StopCriteria + +from utils import DummyIterableDataset, run_test_save_load_state + + +num_samples = 5 +num_datasets = 4 + +datasets = { + f"ds{i}": IterableWrapper(DummyIterableDataset(num_samples, f"ds{i}")) + for i in range(num_datasets) +} + +weights_fn = lambda i: 0.1 * (i + 1) +weights = {f"ds{i}": weights_fn(i) for i in range(num_datasets)} +mixer = MultiNodeWeightedSampler(datasets, weights, StopCriteria.FIRST_DATASET_EXHAUSTED, seed=42) + +num_epochs = 1 + +for epoch in range(num_epochs): + results = list(mixer) + + datasets_in_results = [result["name"] for result in results] + + dataset_counts_in_results = [datasets_in_results.count(f"ds{i}") for i in range(num_datasets)] + elements = [[result["name"], result["step"]] for result in results] + print(elements) + print(datasets_in_results) + print(dataset_counts_in_results) + + mixer.reset() diff --git a/test/nodes/test_multi_node_weighted_sampler.py b/test/nodes/test_multi_node_weighted_sampler.py index b0b837e42..fa02eb272 100644 --- a/test/nodes/test_multi_node_weighted_sampler.py +++ b/test/nodes/test_multi_node_weighted_sampler.py @@ -88,7 +88,7 @@ def test_multi_node_weighted_batch_sampler_zero_weights( def test_multi_node_weighted_sampler_first_exhausted(self) -> None: """Test MultiNodeWeightedSampler with stop criteria FIRST_DATASET_EXHAUSTED""" - mixer = MultiNodeWeightedSampler(self.datasets, self.weights, StopCriteria.FIRST_DATASET_EXHAUSTED) + mixer = MultiNodeWeightedSampler(self.datasets, self.weights, StopCriteria.FIRST_DATASET_EXHAUSTED, seed=42) for _ in range(self._num_epochs): results = list(mixer) diff --git a/torchdata/nodes/samplers/multi_node_weighted_sampler.py b/torchdata/nodes/samplers/multi_node_weighted_sampler.py index 572b57b26..fd39bf5fb 100644 --- a/torchdata/nodes/samplers/multi_node_weighted_sampler.py +++ b/torchdata/nodes/samplers/multi_node_weighted_sampler.py @@ -158,17 +158,21 @@ def _check_for_stop_iteration(self) -> None: # the first dataset is exhausted. Doing this to correctly catch StopIteration # when trying next(it) on already exhausted iterator if self.stop_criteria == StopCriteria.FIRST_DATASET_EXHAUSTED and any(self._datasets_exhausted.values()): + print(self._datasets_exhausted) + print("Stopping: First dataset exhausted.") raise StopIteration() return def next(self) -> T: self._started = True + while True: self._check_for_stop_iteration() # Fetch the next item's key from the weighted sampler key = next(self._weighted_sampler) + print(f"The next key selected is {key}") try: if self._datasets_exhausted[key] and self.stop_criteria == StopCriteria.ALL_DATASETS_EXHAUSTED: # Before fetching a new item check if key corresponds to an already @@ -190,6 +194,7 @@ def next(self) -> T: # reset the iterator and try again self.source_nodes[key].reset() item = next(self.source_nodes[key]) + print(self._datasets_exhausted) break # If we did't throw StopIteration, increment the number of items yielded and return the item