Skip to content

Commit

Permalink
Simplify WaitAndLogIfStuck() to use absl::Barrier instead of `tsl…
Browse files Browse the repository at this point in the history
…::BlockingCounter`.

This gets rid of the shared_ptr.

PiperOrigin-RevId: 705916112
  • Loading branch information
majnemer authored and Google-ML-Automation committed Dec 17, 2024
1 parent d903f75 commit d5f12e6
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 39 deletions.
1 change: 0 additions & 1 deletion xla/python/ifrt/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,6 @@ cc_library(
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/types:span",
"@tsl//tsl/platform:statusor",
"@tsl//tsl/platform:test",
],
)

Expand Down
1 change: 0 additions & 1 deletion xla/python/ifrt/test_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ limitations under the License.
#include "xla/python/ifrt/shape.h"
#include "xla/tsl/concurrency/ref_count.h"
#include "tsl/platform/statusor.h"
#include "tsl/platform/test.h"

namespace xla {
namespace ifrt {
Expand Down
4 changes: 3 additions & 1 deletion xla/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -5436,14 +5436,16 @@ cc_library(
"//xla/hlo/ir:hlo",
"//xla/service/gpu:backend_configs_cc",
"//xla/stream_executor:device_memory",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/functional:function_ref",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/time",
"@com_google_absl//absl/types:span",
"@tsl//tsl/platform:blocking_counter",
"@tsl//tsl/platform:statusor",
],
)
Expand Down
109 changes: 73 additions & 36 deletions xla/service/collective_ops_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,23 @@ limitations under the License.
#define XLA_SERVICE_COLLECTIVE_OPS_UTILS_H_

#include <cstdint>
#include <functional>
#include <memory>
#include <optional>
#include <string>
#include <tuple>
#include <type_traits>
#include <utility>
#include <vector>

#include "absl/base/attributes.h"
#include "absl/functional/function_ref.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "absl/synchronization/blocking_counter.h"
#include "absl/synchronization/notification.h"
#include "absl/time/time.h"
#include "absl/types/span.h"
#include "xla/executable_run_options.h"
#include "xla/hlo/ir/collective_device_list.h"
Expand All @@ -38,7 +44,6 @@ limitations under the License.
#include "xla/service/global_device_id.h"
#include "xla/service/pattern_matcher.h"
#include "xla/stream_executor/device_memory.h"
#include "tsl/platform/blocking_counter.h"

namespace xla {

Expand Down Expand Up @@ -133,7 +138,7 @@ const std::vector<ReplicaGroup>& GetCollectiveReplicaGroups(
const HloInstruction* hlo);

// Returns the group formation mode of instr, assuming that instr is, or is
// dervied from, an HloAllGatherInstruction, HloAllReduceInstructionBase,
// derived from, an HloAllGatherInstruction, HloAllReduceInstructionBase,
// HloAllToAllInstruction, HloCollectiveBroadcastInstruction or
// HloCollectivePermuteInstruction.
absl::StatusOr<CollectiveOpGroupMode> GetCollectiveOpGroupMode(
Expand Down Expand Up @@ -334,21 +339,45 @@ struct RendezvousKey {
int64_t op_id;
};

template <typename DescFn>
void WaitAndLogIfStuck(tsl::BlockingCounter* counter, const DescFn& desc_fn) {
VLOG(3) << "Begin: " << desc_fn();
const std::chrono::milliseconds timeout(5000);
bool ok = counter->WaitFor(timeout);
if (ok) {
VLOG(3) << "Finished: " << desc_fn();
return;
class TimeoutLoggingBarrier {
public:
explicit TimeoutLoggingBarrier(int num_threads) : counter_(num_threads) {}

template <typename DescFn>
ABSL_MUST_USE_RESULT bool BlockAndWarnAfterTimeout(
const std::function<std::string()>& desc_fn, absl::Duration timeout) {
VLOG(3) << "Begin: " << desc_fn();
bool is_last = counter_.DecrementCount();
if (is_last) {
finished_.Notify();
// This call to `Wait()` is not expected to block. Calling `Wait()` here
// allows us to satisfy `BlockingCounter`'s requirement: "When `Wait()`
// returns, it is legal to destroy the `BlockingCounter`.".
counter_.Wait();
}
if (finished_.WaitForNotificationWithTimeout(timeout)) {
VLOG(3) << "Finished: " << desc_fn();
return is_last;
}
LOG(ERROR) << "This thread has been waiting for " << timeout
<< "ms for and may be stuck: " << desc_fn();
finished_.WaitForNotification();
LOG(ERROR) << "Thread is unstuck! Warning above was a false-positive. "
"Perhaps the timeout is too short: "
<< desc_fn();
return is_last;
}
LOG(ERROR) << "This thread has been waiting for " << timeout.count()
<< "ms for and may be stuck: " << desc_fn();
counter->Wait();
LOG(ERROR) << "Thread is unstuck! Warning above was a false-positive. "
"Perhaps the timeout is too short: "
<< desc_fn();

private:
absl::BlockingCounter counter_;
absl::Notification finished_;
};

template <typename DescFn>
ABSL_MUST_USE_RESULT bool WaitAndLogIfStuck(TimeoutLoggingBarrier* barrier,
const DescFn& desc_fn) {
constexpr absl::Duration kTimeout = absl::Milliseconds(5000);
return barrier->BlockAndWarnAfterTimeout(desc_fn, kTimeout);
}

// Participant data for each rendezvous.
Expand Down Expand Up @@ -399,16 +428,20 @@ class Rendezvous {
// An alternative way of accomplishing this goal would be to implement
// RefcountingHashMap::erase() and call it during SubmitParticipant. But
// erase() is deceptively complex to implement correctly.
std::shared_ptr<tsl::BlockingCounter> blocking_counter = p.second;
TimeoutLoggingBarrier* barrier = std::get<1>(p);
uintptr_t rendezvous_address =
reinterpret_cast<uintptr_t>(rendezvous.get());
rendezvous.reset();
blocking_counter->DecrementCount();
xla::WaitAndLogIfStuck(blocking_counter.get(), [&] {
bool is_last = xla::WaitAndLogIfStuck(barrier, [&] {
return absl::StrFormat(
"participant waiting for all threads to drop their reference to the "
"rendezvous: %p",
rendezvous.get());
"rendezvous: %#x",
rendezvous_address);
});
return std::move(p.first);
if (is_last) {
delete barrier;
}
return std::move(std::get<0>(p));
}

protected:
Expand All @@ -430,34 +463,38 @@ class Rendezvous {
// - a BlockingCounter initialized to the number of participants, so that
// the caller can coordinate with the participants one last time if it
// chooses. This is useful for coordinating destruction of the Rendezvous.
absl::StatusOr<std::pair<O, std::shared_ptr<tsl::BlockingCounter>>>
SubmitParticipant(const I& participant) {
absl::StatusOr<std::tuple<O, TimeoutLoggingBarrier*>> SubmitParticipant(
const I& participant) {
{
absl::MutexLock lock(&mu_);
CHECK(!participants_[participant.local_rank].has_value());
participants_[participant.local_rank] = participant;
}

// Wait for all participants to arrive.
all_participants_present_.DecrementCount();
WaitAndLogIfStuck(&all_participants_present_, [&] {
return absl::StrFormat(
"participant %s waiting for all participants to arrive at rendezvous "
"%s",
participant.ToString(), key_.ToString());
});
bool is_last_thread_to_enter_barrier =
WaitAndLogIfStuck(&arrival_barrier_, [&] {
return absl::StrFormat(
"participant %s waiting for all participants to arrive at "
"rendezvous "
"%s",
participant.ToString(), key_.ToString());
});
// In this case, we don't need to know whether this is the last thread to
// enter the barrier.
(void)is_last_thread_to_enter_barrier;

TF_ASSIGN_OR_RETURN(O output, RunCollectiveOp(participant));
return std::make_pair(std::move(output), returned_blocking_counter_);
return std::make_tuple(std::move(output), returned_barrier_);
}

const RendezvousKey key_;

tsl::BlockingCounter all_participants_present_{key_.num_local_participants};
TimeoutLoggingBarrier arrival_barrier_{key_.num_local_participants};

// tsl::BlockingCounter returned by SubmitParticipant.
std::shared_ptr<tsl::BlockingCounter> returned_blocking_counter_{
std::make_shared<tsl::BlockingCounter>(key_.num_local_participants)};
// TimeoutLoggingBarrier returned by SubmitParticipant.
TimeoutLoggingBarrier* returned_barrier_ =
new TimeoutLoggingBarrier(key_.num_local_participants);
};

// We only pipeline Send-Recv chains with channel_id > 0, where each chain
Expand Down

0 comments on commit d5f12e6

Please sign in to comment.