Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix MultiNodeWeightedSampler tests #1426

Merged
merged 8 commits into from
Jan 23, 2025
Merged

Fix MultiNodeWeightedSampler tests #1426

merged 8 commits into from
Jan 23, 2025

Conversation

ramanishsingh
Copy link
Contributor

@ramanishsingh ramanishsingh commented Jan 22, 2025

Due to the stochastic nature of the MultiNodeWeightedSampler and the way StopIteration is called when multiple datasets are involved, some test results are dependent on the OS/seed. Trying to fix seed and change testing conditions with this PR.

@ramanishsingh ramanishsingh marked this pull request as draft January 22, 2025 04:11
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jan 22, 2025
Copy link

pytorch-bot bot commented Jan 22, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/data/1426

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 029221d with merge base daafee4 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@ramanishsingh ramanishsingh changed the title Fix mds tests Fix seed in MultiNodeWeightedSampler tests Jan 22, 2025
@ramanishsingh ramanishsingh changed the title Fix seed in MultiNodeWeightedSampler tests Fix MultiNodeWeightedSampler tests Jan 22, 2025
@ramanishsingh ramanishsingh marked this pull request as ready for review January 22, 2025 17:42
@@ -86,13 +87,15 @@ def test_multi_node_weighted_batch_sampler_zero_weights(
weights={f"ds{i}": 10 * i for i in range(self._num_datasets)},
)

def test_multi_node_weighted_sampler_first_exhausted(self) -> None:
@parameterized.expand(range(10))
def test_multi_node_weighted_sampler_first_exhausted(self, seed) -> None:
"""Test MultiNodeWeightedSampler with stop criteria FIRST_DATASET_EXHAUSTED"""
mixer = self._setup_multi_node_weighted_sampler(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one tiny nit: can you make this node = self._setup_multi_node_weighted_sampler( here as well as in other places we call _setup_multi_node_weighted_sampler, thanks!

Copy link
Contributor

@divyanshk divyanshk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for the fix!

@ramanishsingh ramanishsingh merged commit d6232f1 into main Jan 23, 2025
39 checks passed
@ramanishsingh ramanishsingh deleted the fix_mds_tests branch January 27, 2025 19:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants