Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix MultiNodeWeightedSampler tests #1426

Merged
merged 8 commits into from
Jan 23, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 60 additions & 40 deletions test/nodes/test_multi_node_weighted_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ def setUp(self) -> None:
self._num_datasets = 4
self._weights_fn = lambda i: 0.1 * (i + 1)
self._num_epochs = 3

self.datasets = {
f"ds{i}": IterableWrapper(DummyIterableDataset(self._num_samples, f"ds{i}"))
for i in range(self._num_datasets)
Expand All @@ -38,11 +37,13 @@ def test_torchdata_nodes_imports(self) -> None:
except ImportError:
self.fail("MultiNodeWeightedSampler or StopCriteria failed to import")

def _setup_multi_node_weighted_sampler(self, num_samples, num_datasets, weights_fn, stop_criteria) -> Prefetcher:
def _setup_multi_node_weighted_sampler(
self, num_samples, num_datasets, weights_fn, stop_criteria, seed=0
) -> Prefetcher:

datasets = {f"ds{i}": IterableWrapper(DummyIterableDataset(num_samples, f"ds{i}")) for i in range(num_datasets)}
weights = {f"ds{i}": weights_fn(i) for i in range(num_datasets)}
node = MultiNodeWeightedSampler(datasets, weights, stop_criteria)
node = MultiNodeWeightedSampler(datasets, weights, stop_criteria, seed=seed)
return Prefetcher(node, prefetch_factor=3)

def test_multi_node_weighted_sampler_weight_sampler_keys_mismatch(self) -> None:
Expand Down Expand Up @@ -86,39 +87,43 @@ def test_multi_node_weighted_batch_sampler_zero_weights(
weights={f"ds{i}": 10 * i for i in range(self._num_datasets)},
)

def test_multi_node_weighted_sampler_first_exhausted(self) -> None:
@parameterized.expand(range(10))
def test_multi_node_weighted_sampler_first_exhausted(self, seed) -> None:
"""Test MultiNodeWeightedSampler with stop criteria FIRST_DATASET_EXHAUSTED"""
mixer = self._setup_multi_node_weighted_sampler(
node = self._setup_multi_node_weighted_sampler(
self._num_samples,
self._num_datasets,
self._weights_fn,
stop_criteria=StopCriteria.FIRST_DATASET_EXHAUSTED,
seed=seed,
)

for _ in range(self._num_epochs):
results = list(mixer)
results = list(node)

datasets_in_results = [result["name"] for result in results]
dataset_counts_in_results = [datasets_in_results.count(f"ds{i}") for i in range(self._num_datasets)]

# Check max item count for dataset is exactly _num_samples
self.assertEqual(max(dataset_counts_in_results), self._num_samples)

# Check only one dataset has been exhausted
self.assertEqual(dataset_counts_in_results.count(self._num_samples), 1)
mixer.reset()
# Check that max items are taken from at least one dataset
self.assertGreaterEqual(dataset_counts_in_results.count(self._num_samples), 1)
node.reset()

def test_multi_node_weighted_sampler_all_dataset_exhausted(self) -> None:
@parameterized.expand(range(10))
def test_multi_node_weighted_sampler_all_dataset_exhausted(self, seed) -> None:
"""Test MultiNodeWeightedSampler with stop criteria ALL_DATASETS_EXHAUSTED"""
mixer = self._setup_multi_node_weighted_sampler(
node = self._setup_multi_node_weighted_sampler(
self._num_samples,
self._num_datasets,
self._weights_fn,
stop_criteria=StopCriteria.ALL_DATASETS_EXHAUSTED,
seed=seed,
)

for _ in range(self._num_epochs):
results = list(mixer)
results = list(node)
datasets_in_results = [result["name"] for result in results]
dataset_counts_in_results = [datasets_in_results.count(f"ds{i}") for i in range(self._num_datasets)]

Expand All @@ -131,56 +136,67 @@ def test_multi_node_weighted_sampler_all_dataset_exhausted(self) -> None:

# check that all datasets are exhausted
self.assertEqual(sorted(set(datasets_in_results)), ["ds0", "ds1", "ds2", "ds3"])
mixer.reset()
node.reset()

def test_multi_node_weighted_sampler_cycle_until_all_exhausted(self) -> None:
@parameterized.expand(range(10))
def test_multi_node_weighted_sampler_cycle_until_all_exhausted(self, seed) -> None:
"""Test MultiNodeWeightedSampler with stop criteria CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED"""
mixer = self._setup_multi_node_weighted_sampler(
node = self._setup_multi_node_weighted_sampler(
self._num_samples,
self._num_datasets,
self._weights_fn,
stop_criteria=StopCriteria.CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED,
seed=seed,
)

for _ in range(self._num_epochs):
results = list(mixer)
results = list(node)
datasets_in_results = {result["name"] for result in results}

# check that all datasets are exhausted
self.assertEqual(sorted(datasets_in_results), ["ds0", "ds1", "ds2", "ds3"])
mixer.reset()
node.reset()

def test_multi_node_weighted_sampler_cycle_forever(self) -> None:
@parameterized.expand(range(10))
def test_multi_node_weighted_sampler_cycle_forever(self, seed) -> None:
"""Test MultiNodeWeightedSampler with stop criteria CYCLE_FOREVER"""
mixer = MultiNodeWeightedSampler(
self.datasets,
self.equal_weights,
stop_criteria=StopCriteria.CYCLE_FOREVER,
node = MultiNodeWeightedSampler(
self.datasets, self.equal_weights, stop_criteria=StopCriteria.CYCLE_FOREVER, seed=seed
)

num_yielded = 0
_it = iter(mixer)
_it = iter(node)
while num_yielded < 256: # total number of samples is 4 * 10 = 40, 256 is an arbitrary larger number
next(_it)
num_yielded += 1

mixer_num_yielded = mixer.get_state()[MultiNodeWeightedSampler.NUM_YIELDED_KEY]
self.assertEqual(mixer_num_yielded, num_yielded)
node_num_yielded = node.get_state()[MultiNodeWeightedSampler.NUM_YIELDED_KEY]
self.assertEqual(node_num_yielded, num_yielded)

@parameterized.expand([(1, 8), (8, 32)])
def test_multi_node_weighted_batch_sampler_set_rank_world_size(self, rank, world_size):
"""Test MultiNodeWeightedSampler with different rank and world size"""
mixer = MultiNodeWeightedSampler(self.datasets, self.weights, rank=rank, world_size=world_size)
self.assertEqual(mixer.rank, rank)
self.assertEqual(mixer.world_size, world_size)
node = MultiNodeWeightedSampler(
self.datasets,
self.weights,
rank=rank,
world_size=world_size,
)
self.assertEqual(node.rank, rank)
self.assertEqual(node.world_size, world_size)

def test_multi_node_weighted_batch_sampler_results_for_ranks(self):
"""Test MultiNodeWeightedSampler with different results for different ranks"""
world_size = 8
global_results = []
for rank in range(world_size):
mixer = MultiNodeWeightedSampler(self.datasets, self.weights, rank=rank, world_size=world_size)
results = list(mixer)
node = MultiNodeWeightedSampler(
self.datasets,
self.weights,
rank=rank,
world_size=world_size,
)
results = list(node)
global_results.append(results)

unique_results = []
Expand All @@ -192,17 +208,17 @@ def test_multi_node_weighted_batch_sampler_results_for_ranks(self):
def test_multi_node_weighted_batch_sampler_results_for_multiple_epochs(self):
"""Test MultiNodeWeightedSampler with different results in each epoch"""

# Check for the mixer node only
mixer = MultiNodeWeightedSampler(
# Check for the MultiNodeWeightedSampler node only
node = MultiNodeWeightedSampler(
self.datasets,
self.weights,
)

overall_results = []
for _ in range(self._num_epochs):
results = list(mixer)
results = list(node)
overall_results.append(results)
mixer.reset()
node.reset()

unique_results = []
for results in overall_results:
Expand All @@ -211,7 +227,7 @@ def test_multi_node_weighted_batch_sampler_results_for_multiple_epochs(self):

self.assertEqual(unique_results, overall_results)

# Check for mixer along with Prefetcher node
# Check for MultiNodeWeightedSampler node along with Prefetcher node
node = self._setup_multi_node_weighted_sampler(
self._num_samples,
self._num_datasets,
Expand Down Expand Up @@ -242,9 +258,13 @@ def test_multi_node_weighted_batch_sampler_results_for_multiple_epochs(self):
],
)
)
def test_save_load_state_mixer_over_multiple_epochs(self, midpoint: int, stop_criteria: str):
def test_save_load_state_mds_node_over_multiple_epochs(self, midpoint: int, stop_criteria: str):
"""Test MultiNodeWeightedSampler with saving and loading of state across multiple epochs"""
node = MultiNodeWeightedSampler(self.datasets, self.weights, stop_criteria)
node = MultiNodeWeightedSampler(
self.datasets,
self.weights,
stop_criteria,
)
run_test_save_load_state(self, node, midpoint)

@parameterized.expand(
Expand All @@ -257,7 +277,7 @@ def test_save_load_state_mixer_over_multiple_epochs(self, midpoint: int, stop_cr
],
)
)
def test_save_load_state_mixer_over_multiple_epochs_with_prefetcher(self, midpoint: int, stop_criteria: str):
def test_save_load_state_mds_node_over_multiple_epochs_with_prefetcher(self, midpoint: int, stop_criteria: str):
node = self._setup_multi_node_weighted_sampler(
self._num_samples,
self._num_datasets,
Expand All @@ -281,10 +301,10 @@ def test_multi_node_weighted_large_sample_size_with_prefetcher(self, midpoint, s
num_samples = 1500
num_datasets = 5

mixer = self._setup_multi_node_weighted_sampler(
node = self._setup_multi_node_weighted_sampler(
num_samples,
num_datasets,
self._weights_fn,
stop_criteria,
)
run_test_save_load_state(self, mixer, midpoint)
run_test_save_load_state(self, node, midpoint)
Loading