Skip to content

Commit

Permalink
update test
Browse files Browse the repository at this point in the history
  • Loading branch information
ramanishsingh committed Jan 22, 2025
1 parent 2859e90 commit 99fa7e7
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions test/nodes/test_multi_node_weighted_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def test_multi_node_weighted_sampler_first_exhausted(self) -> None:
seed=self._seed
)

for _ in range(self._num_epochs):
for _ in range(1): #only running for one epoch as the number of samples taken from each dataset is stochastic and is epoch dependent
results = list(mixer)

datasets_in_results = [result["name"] for result in results]
Expand All @@ -109,10 +109,17 @@ def test_multi_node_weighted_sampler_first_exhausted(self) -> None:
# Check max item count for dataset is exactly _num_samples
self.assertEqual(max(dataset_counts_in_results), self._num_samples)

# Check only one dataset has been exhausted
# Check that the max number of samples (10) have been taken from two datasets (ds2 and ds3)
self.assertEqual(dataset_counts_in_results.count(self._num_samples), 2)
# The number of datasets from which max number of samples is taken is 2 because StopIteration is called after next is called
# on the dataset node which is on its last element. We do not have a way to preemptively tell if a node is at its last element without
# calling next on it. Thus, during multi dataset sampling, multiple dataset nodes can be at their last element and when next is called
# on any one of them, it raises StopIteration. Thus, multiple datasets can yield max number of elements.
# Check that the number of samples taken from each dataset
self.assertEqual(dataset_counts_in_results, [4, 8, 10, 10])
mixer.reset()


def test_multi_node_weighted_sampler_all_dataset_exhausted(self) -> None:
"""Test MultiNodeWeightedSampler with stop criteria ALL_DATASETS_EXHAUSTED"""
mixer = self._setup_multi_node_weighted_sampler(
Expand Down

0 comments on commit 99fa7e7

Please sign in to comment.