Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[draft] revert changes to test tests #1425

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions test/nodes/test_mds_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import itertools

from parameterized import parameterized
from torch.testing._internal.common_utils import TestCase
from torchdata.nodes.adapters import IterableWrapper
from torchdata.nodes.prefetch import Prefetcher

from torchdata.nodes.samplers.multi_node_weighted_sampler import MultiNodeWeightedSampler
from torchdata.nodes.samplers.stop_criteria import StopCriteria

from utils import DummyIterableDataset, run_test_save_load_state


num_samples = 5
num_datasets = 4

datasets = {
f"ds{i}": IterableWrapper(DummyIterableDataset(num_samples, f"ds{i}"))
for i in range(num_datasets)
}

weights_fn = lambda i: 0.1 * (i + 1)
weights = {f"ds{i}": weights_fn(i) for i in range(num_datasets)}
mixer = MultiNodeWeightedSampler(datasets, weights, StopCriteria.FIRST_DATASET_EXHAUSTED, seed=42)

num_epochs = 1

for epoch in range(num_epochs):
results = list(mixer)

datasets_in_results = [result["name"] for result in results]

dataset_counts_in_results = [datasets_in_results.count(f"ds{i}") for i in range(num_datasets)]
elements = [[result["name"], result["step"]] for result in results]
print(elements)
print(datasets_in_results)
print(dataset_counts_in_results)

mixer.reset()
7 changes: 1 addition & 6 deletions test/nodes/test_multi_node_weighted_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,7 @@ def test_multi_node_weighted_batch_sampler_zero_weights(

def test_multi_node_weighted_sampler_first_exhausted(self) -> None:
"""Test MultiNodeWeightedSampler with stop criteria FIRST_DATASET_EXHAUSTED"""
mixer = self._setup_multi_node_weighted_sampler(
self._num_samples,
self._num_datasets,
self._weights_fn,
stop_criteria=StopCriteria.FIRST_DATASET_EXHAUSTED,
)
mixer = MultiNodeWeightedSampler(self.datasets, self.weights, StopCriteria.FIRST_DATASET_EXHAUSTED, seed=42)

for _ in range(self._num_epochs):
results = list(mixer)
Expand Down
5 changes: 5 additions & 0 deletions torchdata/nodes/samplers/multi_node_weighted_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,17 +158,21 @@ def _check_for_stop_iteration(self) -> None:
# the first dataset is exhausted. Doing this to correctly catch StopIteration
# when trying next(it) on already exhausted iterator
if self.stop_criteria == StopCriteria.FIRST_DATASET_EXHAUSTED and any(self._datasets_exhausted.values()):
print(self._datasets_exhausted)
print("Stopping: First dataset exhausted.")
raise StopIteration()

return

def next(self) -> T:
self._started = True

while True:
self._check_for_stop_iteration()

# Fetch the next item's key from the weighted sampler
key = next(self._weighted_sampler)
print(f"The next key selected is {key}")
try:
if self._datasets_exhausted[key] and self.stop_criteria == StopCriteria.ALL_DATASETS_EXHAUSTED:
# Before fetching a new item check if key corresponds to an already
Expand All @@ -190,6 +194,7 @@ def next(self) -> T:
# reset the iterator and try again
self.source_nodes[key].reset()
item = next(self.source_nodes[key])
print(self._datasets_exhausted)
break

# If we did't throw StopIteration, increment the number of items yielded and return the item
Expand Down
Loading