Skip to content

Commit

Permalink
Fix MultiNodeWeightedSampler tests (#1426)
Browse files Browse the repository at this point in the history
* fix seed

* change test check

* update test

* add multiple seeds

* run precommit

* fix epochs

* update seeds

* update mixer name to node
  • Loading branch information
ramanishsingh authored Jan 23, 2025
1 parent daafee4 commit d6232f1
Showing 1 changed file with 60 additions and 40 deletions.
100 changes: 60 additions & 40 deletions test/nodes/test_multi_node_weighted_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ def setUp(self) -> None:
self._num_datasets = 4
self._weights_fn = lambda i: 0.1 * (i + 1)
self._num_epochs = 3

self.datasets = {
f"ds{i}": IterableWrapper(DummyIterableDataset(self._num_samples, f"ds{i}"))
for i in range(self._num_datasets)
Expand All @@ -38,11 +37,13 @@ def test_torchdata_nodes_imports(self) -> None:
except ImportError:
self.fail("MultiNodeWeightedSampler or StopCriteria failed to import")

def _setup_multi_node_weighted_sampler(self, num_samples, num_datasets, weights_fn, stop_criteria) -> Prefetcher:
def _setup_multi_node_weighted_sampler(
self, num_samples, num_datasets, weights_fn, stop_criteria, seed=0
) -> Prefetcher:

datasets = {f"ds{i}": IterableWrapper(DummyIterableDataset(num_samples, f"ds{i}")) for i in range(num_datasets)}
weights = {f"ds{i}": weights_fn(i) for i in range(num_datasets)}
node = MultiNodeWeightedSampler(datasets, weights, stop_criteria)
node = MultiNodeWeightedSampler(datasets, weights, stop_criteria, seed=seed)
return Prefetcher(node, prefetch_factor=3)

def test_multi_node_weighted_sampler_weight_sampler_keys_mismatch(self) -> None:
Expand Down Expand Up @@ -86,39 +87,43 @@ def test_multi_node_weighted_batch_sampler_zero_weights(
weights={f"ds{i}": 10 * i for i in range(self._num_datasets)},
)

def test_multi_node_weighted_sampler_first_exhausted(self) -> None:
@parameterized.expand(range(10))
def test_multi_node_weighted_sampler_first_exhausted(self, seed) -> None:
"""Test MultiNodeWeightedSampler with stop criteria FIRST_DATASET_EXHAUSTED"""
mixer = self._setup_multi_node_weighted_sampler(
node = self._setup_multi_node_weighted_sampler(
self._num_samples,
self._num_datasets,
self._weights_fn,
stop_criteria=StopCriteria.FIRST_DATASET_EXHAUSTED,
seed=seed,
)

for _ in range(self._num_epochs):
results = list(mixer)
results = list(node)

datasets_in_results = [result["name"] for result in results]
dataset_counts_in_results = [datasets_in_results.count(f"ds{i}") for i in range(self._num_datasets)]

# Check max item count for dataset is exactly _num_samples
self.assertEqual(max(dataset_counts_in_results), self._num_samples)

# Check only one dataset has been exhausted
self.assertEqual(dataset_counts_in_results.count(self._num_samples), 1)
mixer.reset()
# Check that max items are taken from at least one dataset
self.assertGreaterEqual(dataset_counts_in_results.count(self._num_samples), 1)
node.reset()

def test_multi_node_weighted_sampler_all_dataset_exhausted(self) -> None:
@parameterized.expand(range(10))
def test_multi_node_weighted_sampler_all_dataset_exhausted(self, seed) -> None:
"""Test MultiNodeWeightedSampler with stop criteria ALL_DATASETS_EXHAUSTED"""
mixer = self._setup_multi_node_weighted_sampler(
node = self._setup_multi_node_weighted_sampler(
self._num_samples,
self._num_datasets,
self._weights_fn,
stop_criteria=StopCriteria.ALL_DATASETS_EXHAUSTED,
seed=seed,
)

for _ in range(self._num_epochs):
results = list(mixer)
results = list(node)
datasets_in_results = [result["name"] for result in results]
dataset_counts_in_results = [datasets_in_results.count(f"ds{i}") for i in range(self._num_datasets)]

Expand All @@ -131,56 +136,67 @@ def test_multi_node_weighted_sampler_all_dataset_exhausted(self) -> None:

# check that all datasets are exhausted
self.assertEqual(sorted(set(datasets_in_results)), ["ds0", "ds1", "ds2", "ds3"])
mixer.reset()
node.reset()

def test_multi_node_weighted_sampler_cycle_until_all_exhausted(self) -> None:
@parameterized.expand(range(10))
def test_multi_node_weighted_sampler_cycle_until_all_exhausted(self, seed) -> None:
"""Test MultiNodeWeightedSampler with stop criteria CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED"""
mixer = self._setup_multi_node_weighted_sampler(
node = self._setup_multi_node_weighted_sampler(
self._num_samples,
self._num_datasets,
self._weights_fn,
stop_criteria=StopCriteria.CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED,
seed=seed,
)

for _ in range(self._num_epochs):
results = list(mixer)
results = list(node)
datasets_in_results = {result["name"] for result in results}

# check that all datasets are exhausted
self.assertEqual(sorted(datasets_in_results), ["ds0", "ds1", "ds2", "ds3"])
mixer.reset()
node.reset()

def test_multi_node_weighted_sampler_cycle_forever(self) -> None:
@parameterized.expand(range(10))
def test_multi_node_weighted_sampler_cycle_forever(self, seed) -> None:
"""Test MultiNodeWeightedSampler with stop criteria CYCLE_FOREVER"""
mixer = MultiNodeWeightedSampler(
self.datasets,
self.equal_weights,
stop_criteria=StopCriteria.CYCLE_FOREVER,
node = MultiNodeWeightedSampler(
self.datasets, self.equal_weights, stop_criteria=StopCriteria.CYCLE_FOREVER, seed=seed
)

num_yielded = 0
_it = iter(mixer)
_it = iter(node)
while num_yielded < 256: # total number of samples is 4 * 10 = 40, 256 is an arbitrary larger number
next(_it)
num_yielded += 1

mixer_num_yielded = mixer.get_state()[MultiNodeWeightedSampler.NUM_YIELDED_KEY]
self.assertEqual(mixer_num_yielded, num_yielded)
node_num_yielded = node.get_state()[MultiNodeWeightedSampler.NUM_YIELDED_KEY]
self.assertEqual(node_num_yielded, num_yielded)

@parameterized.expand([(1, 8), (8, 32)])
def test_multi_node_weighted_batch_sampler_set_rank_world_size(self, rank, world_size):
"""Test MultiNodeWeightedSampler with different rank and world size"""
mixer = MultiNodeWeightedSampler(self.datasets, self.weights, rank=rank, world_size=world_size)
self.assertEqual(mixer.rank, rank)
self.assertEqual(mixer.world_size, world_size)
node = MultiNodeWeightedSampler(
self.datasets,
self.weights,
rank=rank,
world_size=world_size,
)
self.assertEqual(node.rank, rank)
self.assertEqual(node.world_size, world_size)

def test_multi_node_weighted_batch_sampler_results_for_ranks(self):
"""Test MultiNodeWeightedSampler with different results for different ranks"""
world_size = 8
global_results = []
for rank in range(world_size):
mixer = MultiNodeWeightedSampler(self.datasets, self.weights, rank=rank, world_size=world_size)
results = list(mixer)
node = MultiNodeWeightedSampler(
self.datasets,
self.weights,
rank=rank,
world_size=world_size,
)
results = list(node)
global_results.append(results)

unique_results = []
Expand All @@ -192,17 +208,17 @@ def test_multi_node_weighted_batch_sampler_results_for_ranks(self):
def test_multi_node_weighted_batch_sampler_results_for_multiple_epochs(self):
"""Test MultiNodeWeightedSampler with different results in each epoch"""

# Check for the mixer node only
mixer = MultiNodeWeightedSampler(
# Check for the MultiNodeWeightedSampler node only
node = MultiNodeWeightedSampler(
self.datasets,
self.weights,
)

overall_results = []
for _ in range(self._num_epochs):
results = list(mixer)
results = list(node)
overall_results.append(results)
mixer.reset()
node.reset()

unique_results = []
for results in overall_results:
Expand All @@ -211,7 +227,7 @@ def test_multi_node_weighted_batch_sampler_results_for_multiple_epochs(self):

self.assertEqual(unique_results, overall_results)

# Check for mixer along with Prefetcher node
# Check for MultiNodeWeightedSampler node along with Prefetcher node
node = self._setup_multi_node_weighted_sampler(
self._num_samples,
self._num_datasets,
Expand Down Expand Up @@ -242,9 +258,13 @@ def test_multi_node_weighted_batch_sampler_results_for_multiple_epochs(self):
],
)
)
def test_save_load_state_mixer_over_multiple_epochs(self, midpoint: int, stop_criteria: str):
def test_save_load_state_mds_node_over_multiple_epochs(self, midpoint: int, stop_criteria: str):
"""Test MultiNodeWeightedSampler with saving and loading of state across multiple epochs"""
node = MultiNodeWeightedSampler(self.datasets, self.weights, stop_criteria)
node = MultiNodeWeightedSampler(
self.datasets,
self.weights,
stop_criteria,
)
run_test_save_load_state(self, node, midpoint)

@parameterized.expand(
Expand All @@ -257,7 +277,7 @@ def test_save_load_state_mixer_over_multiple_epochs(self, midpoint: int, stop_cr
],
)
)
def test_save_load_state_mixer_over_multiple_epochs_with_prefetcher(self, midpoint: int, stop_criteria: str):
def test_save_load_state_mds_node_over_multiple_epochs_with_prefetcher(self, midpoint: int, stop_criteria: str):
node = self._setup_multi_node_weighted_sampler(
self._num_samples,
self._num_datasets,
Expand All @@ -281,10 +301,10 @@ def test_multi_node_weighted_large_sample_size_with_prefetcher(self, midpoint, s
num_samples = 1500
num_datasets = 5

mixer = self._setup_multi_node_weighted_sampler(
node = self._setup_multi_node_weighted_sampler(
num_samples,
num_datasets,
self._weights_fn,
stop_criteria,
)
run_test_save_load_state(self, mixer, midpoint)
run_test_save_load_state(self, node, midpoint)

0 comments on commit d6232f1

Please sign in to comment.