Skip to content

fix: Use default torch timeout for nccl watchdog unless overridden #521

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

booxter
Copy link
Contributor

@booxter booxter commented May 2, 2025

The default value is recommended, and we should not change it in
production. The knob may still be useful for debugging or testing
purposes though.

Signed-off-by: Ihar Hrachyshka ihar.hrachyshka@gmail.com

@mergify mergify bot added testing Relates to testing ci-failure labels May 2, 2025
@booxter booxter force-pushed the revert-to-default-timeout-for-pytorch branch from 2ab4c09 to 0cbbe2b Compare May 2, 2025 13:59
@mergify mergify bot removed the ci-failure label May 2, 2025
Copy link
Member

@RobotSail RobotSail left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We just need some comments explaining why we're doing it this way, otherwise I think it's good. 👍

The default value is recommended, and we should not change it in
production. The knob may still be useful for debugging or testing
purposes though.

Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
@booxter booxter force-pushed the revert-to-default-timeout-for-pytorch branch from 0cbbe2b to ca87077 Compare May 5, 2025 23:01
@booxter booxter requested a review from RobotSail May 6, 2025 02:13
Copy link
Contributor

@JamesKunstle JamesKunstle left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this solution.
I'm curious whether the torch environment variable TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC from the docs overrides this or this overrides that var.
https://pytorch.org/docs/stable/torch_nccl_environment_variables.html

@mergify mergify bot added the one-approval label May 6, 2025
@booxter
Copy link
Contributor Author

booxter commented May 6, 2025

@JamesKunstle below is my understanding of how this works. I may be wrong since I'm new to the topic, but I'll try to link the relevant code for reference. Please double check me: it's important we understand how this works.


So, there are three separate entities - a timeout for process group, a NCCL monitoring thread and a NCCL watchdog. They are separate entities serving separate needs.


The process group timeout is what you configure when passing timeout= to init_process_group. Torch uses this timeout (or its default torch value if not passed) to monitor collectives making progress (each collective is allocated the specified time).

Each backend will implement it in some way, for example, NCCL will assign the timeout to each work item.

This parameter is not backend specific. If we ever use a different backend (not nccl), the timeout should still be honored. (This is why I'm planning to rename the variable in a follow-up patch since currently it has a NCCL specific name - a better name would be e.g. INSTRUCTLAB_PROCESS_GROUP_TIMEOUT_SEC, not INSTRUCTLAB_NCCL_TIMEOUT_SEC.)


Now, to watchdogs.

These are backend specific. It's up to the backend to run a watchdog or some other mechanism to implement the PG timeout.

Each NCCL rank starts a separate native thread running a peer watchdog. The watchdog thread will periodically check on each worker to see if it failed or timed out, and report back. It has its own sleep timer between iterations (just 100ms).


There's also a NCCL monitoring thread. Also backend specific. This thread is running separate to a watchdog and monitors the watchdog itself. Specifically, it watches for its heartbeats. If a single heartbeat is not detected in TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC seconds (8 mins by default), torch will fail the job because something really bad happened to the watchdog that it couldn't even loop through a busy-wait iteration in such a long time; side note: I find the default 8 minutes value very conservative).

In case you wonder, the monitoring thread is enabled by default since 2.3.0. It is controlled by TORCH_NCCL_ENABLE_MONITORING that you can use to disable the monitoring thread (why would you tho?)

What's a heartbeat? Just an atomic integer increment on a shared variable that is executed by the watchdog thread and watched by the monitoring thread.

Important to note: for a watchdog thread heartbeat to happen, currently running collectives don't have to make any progress: as long as the watchdog thread is alive, it will heartbeat. So tuning TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC up is not very helpful.


The watchdog may even be disabled with TORCH_NCCL_BLOCKING_WAIT, in which case no separate watchdog thread is running (and since the watchdog thread starts its own monitoring thread, neither the latter). Even then, the PG timeout set for each work item should apply. I believe the blocking mode is only used for debugging purposes, to surface errors from collectives synchronously (for better error traces or to attach a debugger right where it fails).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
one-approval testing Relates to testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants