Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify WaitAndLogIfStuck() to use absl::Barrier instead of tsl::BlockingCounter. #20516

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion xla/python/ifrt/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,6 @@ cc_library(
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/types:span",
"@tsl//tsl/platform:statusor",
"@tsl//tsl/platform:test",
],
)

Expand Down
1 change: 0 additions & 1 deletion xla/python/ifrt/test_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ limitations under the License.
#include "xla/python/ifrt/shape.h"
#include "xla/tsl/concurrency/ref_count.h"
#include "tsl/platform/statusor.h"
#include "tsl/platform/test.h"

namespace xla {
namespace ifrt {
Expand Down
3 changes: 2 additions & 1 deletion xla/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -5442,8 +5442,9 @@ cc_library(
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/time",
"@com_google_absl//absl/types:span",
"@tsl//tsl/platform:blocking_counter",
"@tsl//tsl/platform:statusor",
],
)
Expand Down
83 changes: 53 additions & 30 deletions xla/service/collective_ops_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ limitations under the License.
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "absl/synchronization/blocking_counter.h"
#include "absl/synchronization/notification.h"
#include "absl/time/time.h"
#include "absl/types/span.h"
#include "xla/executable_run_options.h"
#include "xla/hlo/ir/collective_device_list.h"
Expand All @@ -38,7 +41,6 @@ limitations under the License.
#include "xla/service/global_device_id.h"
#include "xla/service/pattern_matcher.h"
#include "xla/stream_executor/device_memory.h"
#include "tsl/platform/blocking_counter.h"

namespace xla {

Expand Down Expand Up @@ -133,7 +135,7 @@ const std::vector<ReplicaGroup>& GetCollectiveReplicaGroups(
const HloInstruction* hlo);

// Returns the group formation mode of instr, assuming that instr is, or is
// dervied from, an HloAllGatherInstruction, HloAllReduceInstructionBase,
// derived from, an HloAllGatherInstruction, HloAllReduceInstructionBase,
// HloAllToAllInstruction, HloCollectiveBroadcastInstruction or
// HloCollectivePermuteInstruction.
absl::StatusOr<CollectiveOpGroupMode> GetCollectiveOpGroupMode(
Expand Down Expand Up @@ -334,21 +336,42 @@ struct RendezvousKey {
int64_t op_id;
};

template <typename DescFn>
void WaitAndLogIfStuck(tsl::BlockingCounter* counter, const DescFn& desc_fn) {
VLOG(3) << "Begin: " << desc_fn();
const std::chrono::milliseconds timeout(5000);
bool ok = counter->WaitFor(timeout);
if (ok) {
VLOG(3) << "Finished: " << desc_fn();
return;
class TimeoutLoggingBarrier {
public:
explicit TimeoutLoggingBarrier(int num_threads) : counter_(num_threads) {}

template <typename DescFn>
void BlockAndWarnAfterTimeout(const DescFn& desc_fn, absl::Duration timeout) {
VLOG(3) << "Begin: " << desc_fn();
bool is_last = counter_.DecrementCount();
if (is_last) {
finished_.Notify();
// This call to `Wait()` is not expected to block. Calling `Wait()` here
// allows us to satisfy `BlockingCounter`'s requirement: "When `Wait()`
// returns, it is legal to destroy the `BlockingCounter`.".
counter_.Wait();
}
if (finished_.WaitForNotificationWithTimeout(timeout)) {
VLOG(3) << "Finished: " << desc_fn();
return;
}
LOG(ERROR) << "This thread has been waiting for " << timeout
<< "ms for and may be stuck: " << desc_fn();
finished_.WaitForNotification();
LOG(ERROR) << "Thread is unstuck! Warning above was a false-positive. "
"Perhaps the timeout is too short: "
<< desc_fn();
}
LOG(ERROR) << "This thread has been waiting for " << timeout.count()
<< "ms for and may be stuck: " << desc_fn();
counter->Wait();
LOG(ERROR) << "Thread is unstuck! Warning above was a false-positive. "
"Perhaps the timeout is too short: "
<< desc_fn();

private:
absl::BlockingCounter counter_;
absl::Notification finished_;
};

template <typename DescFn>
void WaitAndLogIfStuck(TimeoutLoggingBarrier* barrier, const DescFn& desc_fn) {
constexpr absl::Duration kTimeout = absl::Milliseconds(5000);
return barrier->BlockAndWarnAfterTimeout(desc_fn, kTimeout);
}

// Participant data for each rendezvous.
Expand Down Expand Up @@ -399,16 +422,17 @@ class Rendezvous {
// An alternative way of accomplishing this goal would be to implement
// RefcountingHashMap::erase() and call it during SubmitParticipant. But
// erase() is deceptively complex to implement correctly.
std::shared_ptr<tsl::BlockingCounter> blocking_counter = p.second;
std::shared_ptr<TimeoutLoggingBarrier> barrier = std::get<1>(p);
uintptr_t rendezvous_address =
reinterpret_cast<uintptr_t>(rendezvous.get());
rendezvous.reset();
blocking_counter->DecrementCount();
xla::WaitAndLogIfStuck(blocking_counter.get(), [&] {
xla::WaitAndLogIfStuck(barrier.get(), [&] {
return absl::StrFormat(
"participant waiting for all threads to drop their reference to the "
"rendezvous: %p",
rendezvous.get());
"rendezvous: %#x",
rendezvous_address);
});
return std::move(p.first);
return std::move(std::get<0>(p));
}

protected:
Expand All @@ -430,7 +454,7 @@ class Rendezvous {
// - a BlockingCounter initialized to the number of participants, so that
// the caller can coordinate with the participants one last time if it
// chooses. This is useful for coordinating destruction of the Rendezvous.
absl::StatusOr<std::pair<O, std::shared_ptr<tsl::BlockingCounter>>>
absl::StatusOr<std::pair<O, std::shared_ptr<TimeoutLoggingBarrier>>>
SubmitParticipant(const I& participant) {
{
absl::MutexLock lock(&mu_);
Expand All @@ -439,25 +463,24 @@ class Rendezvous {
}

// Wait for all participants to arrive.
all_participants_present_.DecrementCount();
WaitAndLogIfStuck(&all_participants_present_, [&] {
WaitAndLogIfStuck(&arrival_barrier_, [&] {
return absl::StrFormat(
"participant %s waiting for all participants to arrive at rendezvous "
"%s",
participant.ToString(), key_.ToString());
});

TF_ASSIGN_OR_RETURN(O output, RunCollectiveOp(participant));
return std::make_pair(std::move(output), returned_blocking_counter_);
return std::make_pair(std::move(output), returned_barrier_);
}

const RendezvousKey key_;

tsl::BlockingCounter all_participants_present_{key_.num_local_participants};
TimeoutLoggingBarrier arrival_barrier_{key_.num_local_participants};

// tsl::BlockingCounter returned by SubmitParticipant.
std::shared_ptr<tsl::BlockingCounter> returned_blocking_counter_{
std::make_shared<tsl::BlockingCounter>(key_.num_local_participants)};
// TimeoutLoggingBarrier returned by SubmitParticipant.
std::shared_ptr<TimeoutLoggingBarrier> returned_barrier_ =
std::make_shared<TimeoutLoggingBarrier>(key_.num_local_participants);
};

// We only pipeline Send-Recv chains with channel_id > 0, where each chain
Expand Down