diff --git a/xla/backends/cpu/collectives/BUILD b/xla/backends/cpu/collectives/BUILD index be10b5cafa1250..4363e15a7e13f1 100644 --- a/xla/backends/cpu/collectives/BUILD +++ b/xla/backends/cpu/collectives/BUILD @@ -14,6 +14,59 @@ package_group( ], ) +cc_library( + name = "cpu_clique_key", + srcs = ["cpu_clique_key.cc"], + hdrs = ["cpu_clique_key.h"], + deps = [ + "//xla/core/collectives:clique_key", + "//xla/service:global_device_id", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@tsl//tsl/platform:casts", + ], +) + +cc_library( + name = "cpu_clique", + srcs = ["cpu_clique.cc"], + hdrs = ["cpu_clique.h"], + deps = [ + ":cpu_clique_key", + "//xla/core/collectives:clique", + "//xla/core/collectives:communicator", + "//xla/core/collectives:rank_id", + "//xla/tsl/platform:logging", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + ], +) + +cc_library( + name = "cpu_cliques", + srcs = ["cpu_cliques.cc"], + hdrs = ["cpu_cliques.h"], + deps = [ + ":cpu_clique", + ":cpu_clique_key", + ":cpu_collectives", + "//xla:util", + "//xla/core/collectives:clique", + "//xla/core/collectives:communicator", + "//xla/core/collectives:rank_id", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:logging", + "//xla/tsl/platform:statusor", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:node_hash_map", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/synchronization", + ], +) + cc_library( name = "cpu_collectives", srcs = ["cpu_collectives.cc"], @@ -23,14 +76,17 @@ cc_library( "//xla:util", "//xla:xla_data_proto_cc", "//xla/core/collectives", + "//xla/core/collectives:clique_id", "//xla/core/collectives:collectives_registry", "//xla/core/collectives:communicator", + "//xla/core/collectives:rank_id", "//xla/service:collective_ops_utils", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/time", + "@com_google_absl//absl/types:span", "@tsl//tsl/platform:casts", ], ) diff --git a/xla/backends/cpu/collectives/cpu_clique.cc b/xla/backends/cpu/collectives/cpu_clique.cc new file mode 100644 index 00000000000000..a81dd80392f9f1 --- /dev/null +++ b/xla/backends/cpu/collectives/cpu_clique.cc @@ -0,0 +1,59 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/backends/cpu/collectives/cpu_clique.h" + +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "xla/backends/cpu/collectives/cpu_clique_key.h" +#include "xla/core/collectives/clique.h" +#include "xla/core/collectives/communicator.h" +#include "xla/core/collectives/rank_id.h" +#include "xla/tsl/platform/logging.h" + +namespace xla::cpu { + +CpuClique::CpuClique(CpuCliqueKey key) : Clique({}), key_(std::move(key)) {} + +std::string CpuClique::DebugString() const { + std::string out = + absl::StrFormat("key: %s; size: %d; communicators: ", key_.ToString(), + num_communicators()); + int32_t cnt = 0; + ForEachComm([&](RankId rank, Communicator* comm) { + if (cnt++) absl::StrAppend(&out, ", "); + absl::StrAppendFormat(&out, "[rank=%d, comm=%s]", rank.value(), + comm->ToString()); + }); + return out; +} + +absl::Status CpuClique::HealthCheck() const { + absl::Status health_check = absl::OkStatus(); + ForEachComm([&health_check](RankId rank, Communicator* comm) { + if (auto s = comm->HealthCheck(); !s.ok()) { + LOG(ERROR) << "CPU communicator error (rank " << rank << "): " << s; + if (health_check.ok()) health_check = std::move(s); // return first error + } + }); + return health_check; +} + +} // namespace xla::cpu diff --git a/xla/backends/cpu/collectives/cpu_clique.h b/xla/backends/cpu/collectives/cpu_clique.h new file mode 100644 index 00000000000000..e1ff3025a955b0 --- /dev/null +++ b/xla/backends/cpu/collectives/cpu_clique.h @@ -0,0 +1,42 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_BACKENDS_CPU_COLLECTIVES_CPU_CLIQUE_H_ +#define XLA_BACKENDS_CPU_COLLECTIVES_CPU_CLIQUE_H_ + +#include + +#include "absl/status/status.h" +#include "xla/backends/cpu/collectives/cpu_clique_key.h" +#include "xla/core/collectives/clique.h" + +namespace xla::cpu { + +// A group of CPU communicators making up a clique. +class CpuClique final : public Clique { + public: + explicit CpuClique(CpuCliqueKey key); + + absl::Status HealthCheck() const final; + + std::string DebugString() const final; + + private: + CpuCliqueKey key_; +}; + +} // namespace xla::cpu + +#endif // XLA_BACKENDS_CPU_COLLECTIVES_CPU_CLIQUE_H_ diff --git a/xla/backends/cpu/collectives/cpu_clique_key.cc b/xla/backends/cpu/collectives/cpu_clique_key.cc new file mode 100644 index 00000000000000..b66c844d4983ed --- /dev/null +++ b/xla/backends/cpu/collectives/cpu_clique_key.cc @@ -0,0 +1,59 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/backends/cpu/collectives/cpu_clique_key.h" + +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/hash/hash.h" +#include "absl/strings/str_format.h" +#include "xla/core/collectives/clique_key.h" +#include "xla/service/global_device_id.h" +#include "tsl/platform/casts.h" + +namespace xla::cpu { + +bool CpuCliqueKey::IsSubsetOf(const CliqueKey& other) const { + auto* other_cpu = tsl::down_cast(&other); + if (other_cpu == nullptr) return false; + + return absl::c_all_of(devices(), [&](GlobalDeviceId id) { + return absl::c_linear_search(other_cpu->devices(), id); + }); +} + +std::string CpuCliqueKey::ToString() const { + return absl::StrFormat("devices=[%s]", GlobalDeviceIdsToString(devices())); +} + +void CpuCliqueKey::HashValue(absl::HashState state) const { + absl::HashState::combine(std::move(state), devices()); +} + +bool operator==(const CpuCliqueKey& a, const CpuCliqueKey& b) { + return a.devices() == b.devices(); +} + +bool operator<(const CpuCliqueKey& a, const CpuCliqueKey& b) { + return a.devices() < b.devices(); +} + +bool operator>(const CpuCliqueKey& a, const CpuCliqueKey& b) { + return a.devices() > b.devices(); +} + +} // namespace xla::cpu diff --git a/xla/backends/cpu/collectives/cpu_clique_key.h b/xla/backends/cpu/collectives/cpu_clique_key.h new file mode 100644 index 00000000000000..30b257c1a0d0c0 --- /dev/null +++ b/xla/backends/cpu/collectives/cpu_clique_key.h @@ -0,0 +1,44 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_BACKENDS_CPU_COLLECTIVES_CPU_CLIQUE_KEY_H_ +#define XLA_BACKENDS_CPU_COLLECTIVES_CPU_CLIQUE_KEY_H_ + +#include + +#include "absl/hash/hash.h" +#include "xla/core/collectives/clique_key.h" + +namespace xla::cpu { + +// Clique key for identifying a particular CPU collectives clique. +class CpuCliqueKey final : public CliqueKey { + public: + using CliqueKey::CliqueKey; + + bool IsSubsetOf(const CliqueKey& other) const final; + std::string ToString() const final; + + friend bool operator==(const CpuCliqueKey& a, const CpuCliqueKey& b); + friend bool operator<(const CpuCliqueKey& a, const CpuCliqueKey& b); + friend bool operator>(const CpuCliqueKey& a, const CpuCliqueKey& b); + + private: + void HashValue(absl::HashState state) const final; +}; + +} // namespace xla::cpu + +#endif // XLA_BACKENDS_CPU_COLLECTIVES_CPU_CLIQUE_KEY_H_ diff --git a/xla/backends/cpu/collectives/cpu_cliques.cc b/xla/backends/cpu/collectives/cpu_cliques.cc new file mode 100644 index 00000000000000..6e6c437256ad12 --- /dev/null +++ b/xla/backends/cpu/collectives/cpu_cliques.cc @@ -0,0 +1,122 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/backends/cpu/collectives/cpu_cliques.h" + +#include +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/container/node_hash_map.h" +#include "absl/status/statusor.h" +#include "absl/synchronization/mutex.h" +#include "xla/backends/cpu/collectives/cpu_clique.h" +#include "xla/backends/cpu/collectives/cpu_clique_key.h" +#include "xla/backends/cpu/collectives/cpu_collectives.h" +#include "xla/core/collectives/communicator.h" +#include "xla/core/collectives/rank_id.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/util.h" + +namespace xla::cpu { + +//===----------------------------------------------------------------------===// +// ProcessCpuCliques +//===----------------------------------------------------------------------===// + +namespace { + +// CpuClique is not thread-safe, so we wrap it in a thread-safe container as we +// create new communicators lazily and potentially from multiple threads. +struct ThreadSafeClique { + explicit ThreadSafeClique(CpuCliqueKey key) : clique(key) {} + + absl::Mutex mu; + CpuClique clique ABSL_GUARDED_BY(mu); +}; + +// Container for initialized and ready to use CPU cliques. In contrast to GPU +// cliques, CPU cliques are not lockable, and we create communicators lazily +// when needed. +struct ProcessCpuCliques { + absl::Mutex mu; + absl::node_hash_map map ABSL_GUARDED_BY(mu); +}; +} // namespace + +// Returns process-local CPU cliques. +static ProcessCpuCliques& GetProcessCpuCliques() { + static auto* cliques = new ProcessCpuCliques; + return *cliques; +} + +//===----------------------------------------------------------------------===// + +// TODO(b/380457503): Consider switching to a lockable CPU clique model similar +// to GPU cliques, and creating all communicators upfront. +absl::StatusOr AcquireCommunicator( + CpuCollectives* collectives, const CpuCliqueKey& clique_key, RankId rank) { + VLOG(3) << "Acquire communicator for clique key " << clique_key.ToString() + << " and rank " << rank; + + ProcessCpuCliques& cliques = GetProcessCpuCliques(); + + // Synchronize access to the process cliques. + ThreadSafeClique& thread_safe_clique = [&]() -> ThreadSafeClique& { + absl::MutexLock lock(&cliques.mu); + auto [it, emplaced] = cliques.map.try_emplace(clique_key, clique_key); + return it->second; + }(); + + // Check if we already have a communicator for this rank. + std::optional comm = [&]() -> std::optional { + absl::MutexLock lock(&thread_safe_clique.mu); + return thread_safe_clique.clique.comm(rank); + }(); + + if (comm.has_value()) return *comm; + + VLOG(3) << "Create a new communicator for clique key " + << clique_key.ToString() << " and rank " << rank; + + // Create a new communicator and add it to the clique. + CpuCollectives::DeviceRank device_rank(/*device=*/nullptr, rank); + CpuCollectives::Config config; + + TF_ASSIGN_OR_RETURN( + std::vector> communicators, + collectives->CreateCommunicators(clique_key.num_devices(), clique_key, + std::nullopt, {device_rank}, config)); + + // We expect to create communicators lazily on at a time. + if (communicators.size() != 1) { + return Internal( + "Expected to create a single communicator for a clique key %s and rank " + "%d, but got %d", + clique_key.ToString(), rank.value(), communicators.size()); + } + + absl::MutexLock lock(&thread_safe_clique.mu); + TF_RETURN_IF_ERROR(thread_safe_clique.clique.AddComm( + rank, std::move(communicators.front()))); + + return *thread_safe_clique.clique.comm(rank); +} + +} // namespace xla::cpu diff --git a/xla/backends/cpu/collectives/cpu_cliques.h b/xla/backends/cpu/collectives/cpu_cliques.h new file mode 100644 index 00000000000000..b42774619fe4b2 --- /dev/null +++ b/xla/backends/cpu/collectives/cpu_cliques.h @@ -0,0 +1,33 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_BACKENDS_CPU_COLLECTIVES_CPU_CLIQUES_H_ +#define XLA_BACKENDS_CPU_COLLECTIVES_CPU_CLIQUES_H_ + +#include "absl/status/statusor.h" +#include "xla/backends/cpu/collectives/cpu_clique_key.h" +#include "xla/backends/cpu/collectives/cpu_collectives.h" +#include "xla/core/collectives/communicator.h" +#include "xla/core/collectives/rank_id.h" + +namespace xla::cpu { + +// Returns a communicator for a given clique key and rank. +absl::StatusOr AcquireCommunicator( + CpuCollectives* collectives, const CpuCliqueKey& clique_key, RankId rank); + +} // namespace xla::cpu + +#endif // XLA_BACKENDS_CPU_COLLECTIVES_CPU_CLIQUES_H_ diff --git a/xla/backends/cpu/collectives/cpu_collectives.h b/xla/backends/cpu/collectives/cpu_collectives.h index a728e7cd3a399d..330b35f52146d1 100644 --- a/xla/backends/cpu/collectives/cpu_collectives.h +++ b/xla/backends/cpu/collectives/cpu_collectives.h @@ -16,11 +16,19 @@ limitations under the License. #ifndef XLA_BACKENDS_CPU_COLLECTIVES_CPU_COLLECTIVES_H_ #define XLA_BACKENDS_CPU_COLLECTIVES_CPU_COLLECTIVES_H_ +#include +#include +#include + #include "absl/status/statusor.h" #include "absl/time/time.h" +#include "absl/types/span.h" +#include "xla/core/collectives/clique_id.h" #include "xla/core/collectives/collectives.h" #include "xla/core/collectives/communicator.h" +#include "xla/core/collectives/rank_id.h" #include "xla/service/collective_ops_utils.h" +#include "xla/util.h" #include "xla/xla_data.pb.h" namespace xla::cpu { @@ -50,6 +58,17 @@ class CpuCollectives : public Collectives { absl::Duration timeout_; }; + absl::StatusOr CreateUniqueCliqueId() const final { + return Unimplemented("CPU collectives do not support clique ids"); + } + + absl::StatusOr>> SplitCommunicators( + absl::Span comms, int32_t color, + absl::Span keys, const Config& config) final { + return Unimplemented( + "CPU collectives do not support communicator splitting"); + } + // Tries to cast a Collectives::Device to a CpuCollectives::Device. static absl::StatusOr TryCast( const Collectives::Device* device); diff --git a/xla/backends/cpu/runtime/BUILD b/xla/backends/cpu/runtime/BUILD index a83a5e51dca28d..cd1e7b89e9c1ae 100644 --- a/xla/backends/cpu/runtime/BUILD +++ b/xla/backends/cpu/runtime/BUILD @@ -145,6 +145,8 @@ cc_library( ":resource_use", "//xla:executable_run_options", "//xla:util", + "//xla/backends/cpu/collectives:cpu_collectives", + "//xla/core/collectives", "//xla/ffi:execution_context", "//xla/runtime:buffer_use", "//xla/service:global_device_id", @@ -155,11 +157,12 @@ cc_library( "//xla/stream_executor:stream", "//xla/stream_executor:stream_executor_h", "//xla/tsl/concurrency:async_value", + "//xla/tsl/platform:logging", + "//xla/tsl/platform:statusor", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/status:statusor", - "@tsl//tsl/platform:logging", - "@tsl//tsl/platform:statusor", + "@com_google_absl//absl/strings:string_view", "@tsl//tsl/profiler/lib:traceme", "@tsl//tsl/profiler/lib:traceme_encode", ], @@ -593,6 +596,11 @@ cc_library( "//xla:status_macros", "//xla:util", "//xla:xla_data_proto_cc", + "//xla/backends/cpu/collectives:cpu_clique_key", + "//xla/backends/cpu/collectives:cpu_cliques", + "//xla/backends/cpu/collectives:cpu_collectives", + "//xla/core/collectives:communicator", + "//xla/core/collectives:rank_id", "//xla/runtime:buffer_use", "//xla/service:buffer_assignment", "//xla/service:collective_ops_utils", @@ -601,6 +609,9 @@ cc_library( "//xla/service/cpu:collectives_interface", "//xla/stream_executor:device_memory", "//xla/tsl/concurrency:async_value", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:logging", + "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/functional:any_invocable", @@ -610,9 +621,6 @@ cc_library( "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", - "@tsl//tsl/platform:errors", - "@tsl//tsl/platform:logging", - "@tsl//tsl/platform:statusor", ], ) diff --git a/xla/backends/cpu/runtime/collective_thunk.cc b/xla/backends/cpu/runtime/collective_thunk.cc index f838fb0e49acd1..35a6f72fb9671d 100644 --- a/xla/backends/cpu/runtime/collective_thunk.cc +++ b/xla/backends/cpu/runtime/collective_thunk.cc @@ -32,23 +32,27 @@ limitations under the License. #include "absl/strings/str_join.h" #include "absl/time/time.h" #include "absl/types/span.h" +#include "xla/backends/cpu/collectives/cpu_clique_key.h" +#include "xla/backends/cpu/collectives/cpu_cliques.h" +#include "xla/backends/cpu/collectives/cpu_collectives.h" #include "xla/backends/cpu/runtime/resource_use.h" #include "xla/backends/cpu/runtime/thunk.h" +#include "xla/core/collectives/communicator.h" +#include "xla/core/collectives/rank_id.h" #include "xla/runtime/buffer_use.h" #include "xla/service/buffer_assignment.h" #include "xla/service/collective_ops_utils.h" #include "xla/service/computation_placer.h" -#include "xla/service/cpu/collectives_interface.h" #include "xla/service/global_device_id.h" #include "xla/shape.h" #include "xla/status_macros.h" #include "xla/stream_executor/device_memory.h" #include "xla/tsl/concurrency/async_value_ref.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/statusor.h" #include "xla/util.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/statusor.h" namespace xla::cpu { @@ -172,7 +176,7 @@ CollectiveThunk::ExecuteWithCommunicator( TF_RET_CHECK(params) << "Collective parameters are not set for collective operation"; - CollectivesInterface* collectives = params->collectives; + CpuCollectives* collectives = params->collectives; TF_RET_CHECK(collectives) << "Collectives interface is not set for collective operation"; @@ -183,8 +187,10 @@ CollectiveThunk::ExecuteWithCommunicator( VLOG(3) << absl::StreamFormat(" rank=%d, key=%s", rank, key.ToString()); - TF_ASSIGN_OR_RETURN(std::shared_ptr communicator, - collectives->GetCommunicator(key.global_devices, rank)); + CpuCliqueKey clique_key(key.global_devices); + TF_ASSIGN_OR_RETURN( + Communicator * communicator, + AcquireCommunicator(collectives, clique_key, RankId(rank))); TF_RETURN_IF_ERROR(callback(key, *communicator)); diff --git a/xla/backends/cpu/runtime/thunk.cc b/xla/backends/cpu/runtime/thunk.cc index 8dab085b47fb6b..a17de11724bda3 100644 --- a/xla/backends/cpu/runtime/thunk.cc +++ b/xla/backends/cpu/runtime/thunk.cc @@ -22,6 +22,8 @@ limitations under the License. #include #include +#include "absl/strings/string_view.h" +#include "xla/backends/cpu/collectives/cpu_collectives.h" #include "xla/executable_run_options.h" #include "xla/service/cpu/collectives_interface.h" #include "xla/service/cpu/cpu_executable_run_options.h" @@ -30,7 +32,7 @@ limitations under the License. #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" #include "xla/tsl/concurrency/async_value_ref.h" -#include "tsl/platform/logging.h" +#include "xla/tsl/platform/logging.h" #include "tsl/profiler/lib/traceme.h" #include "tsl/profiler/lib/traceme_encode.h" @@ -121,8 +123,7 @@ Thunk::CollectiveExecuteParams::Create( Thunk::CollectiveExecuteParams::CollectiveExecuteParams( RunId run_id, int64_t local_device_ordinal, GlobalDeviceId global_device_id, - const DeviceAssignment* device_assignment, - CollectivesInterface* collectives) + const DeviceAssignment* device_assignment, CpuCollectives* collectives) : run_id(run_id), local_device_ordinal(local_device_ordinal), global_device_id(global_device_id), diff --git a/xla/backends/cpu/runtime/thunk.h b/xla/backends/cpu/runtime/thunk.h index 38d3f41d6a75b3..2c86db92517745 100644 --- a/xla/backends/cpu/runtime/thunk.h +++ b/xla/backends/cpu/runtime/thunk.h @@ -28,21 +28,20 @@ limitations under the License. #include "absl/container/inlined_vector.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/backends/cpu/collectives/cpu_collectives.h" #include "xla/backends/cpu/runtime/buffer_allocations.h" #include "xla/backends/cpu/runtime/function_library.h" -#include "xla/backends/cpu/runtime/kernel_c_api.h" #include "xla/backends/cpu/runtime/resource_use.h" #include "xla/executable_run_options.h" #include "xla/ffi/execution_context.h" #include "xla/runtime/buffer_use.h" -#include "xla/service/cpu/collectives_interface.h" #include "xla/service/cpu/xfeed_manager.h" #include "xla/service/global_device_id.h" #include "xla/tsl/concurrency/async_value_ref.h" #include "xla/tsl/concurrency/chain.h" -#include "xla/util.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/statusor.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/statusor.h" namespace Eigen { struct ThreadPoolDevice; @@ -164,13 +163,13 @@ class Thunk { GlobalDeviceId global_device_id; const DeviceAssignment* device_assignment = nullptr; - CollectivesInterface* collectives = nullptr; + CpuCollectives* collectives = nullptr; private: CollectiveExecuteParams(RunId run_id, int64_t local_device_ordinal, GlobalDeviceId global_device_id, const DeviceAssignment* device_assignment, - CollectivesInterface* collectives); + CpuCollectives* collectives); }; //===--------------------------------------------------------------------===// diff --git a/xla/core/collectives/clique.cc b/xla/core/collectives/clique.cc index 6eb73c1ea91cba..1a0a5d659aecba 100644 --- a/xla/core/collectives/clique.cc +++ b/xla/core/collectives/clique.cc @@ -21,8 +21,10 @@ limitations under the License. #include "absl/container/btree_map.h" #include "absl/functional/function_ref.h" +#include "absl/status/status.h" #include "xla/core/collectives/communicator.h" #include "xla/core/collectives/rank_id.h" +#include "xla/util.h" namespace xla { @@ -44,4 +46,13 @@ void Clique::ForEachComm( } } +absl::Status Clique::AddComm(RankId rank, + std::unique_ptr communicator) { + auto emplaced = communicators_.emplace(rank, std::move(communicator)); + if (!emplaced.second) { + return InvalidArgument("Rank %d already exists in clique", rank.value()); + } + return absl::OkStatus(); +} + } // namespace xla diff --git a/xla/core/collectives/clique.h b/xla/core/collectives/clique.h index 69705ccfa524c5..24f80a3f1682c9 100644 --- a/xla/core/collectives/clique.h +++ b/xla/core/collectives/clique.h @@ -49,6 +49,9 @@ class Clique { // Returns a communicator for a given rank if it's in a clique. std::optional comm(RankId rank) const; + // Adds a communicator to the clique. + absl::Status AddComm(RankId rank, std::unique_ptr communicator); + // Calls `fn` for each communicator in the clique. void ForEachComm(absl::FunctionRef fn) const; @@ -61,8 +64,8 @@ class Clique { size_t num_communicators() const { return communicators_.size(); } private: - // We keep communicators in a sorted order by rank to guarantee deterministic - // traversal order in `ForEachComm`. + // We keep communicators in a sorted order by rank to guarantee + // deterministic traversal order in `ForEachComm`. absl::btree_map> communicators_; }; diff --git a/xla/core/collectives/clique_key.cc b/xla/core/collectives/clique_key.cc index 2da8d6651c3548..92749633bb91ad 100644 --- a/xla/core/collectives/clique_key.cc +++ b/xla/core/collectives/clique_key.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/core/collectives/clique_key.h" +#include #include #include #include @@ -31,6 +32,8 @@ CliqueKey::CliqueKey(std::vector devices) absl::Span CliqueKey::devices() const { return devices_; } +size_t CliqueKey::num_devices() const { return devices_.size(); } + std::optional CliqueKey::rank(GlobalDeviceId id) const { if (auto it = absl::c_find(devices_, id); it != devices_.end()) { return RankId(it - devices_.begin()); diff --git a/xla/core/collectives/clique_key.h b/xla/core/collectives/clique_key.h index 05411773431507..37e16d5fb774ae 100644 --- a/xla/core/collectives/clique_key.h +++ b/xla/core/collectives/clique_key.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef XLA_CORE_COLLECTIVES_CLIQUE_KEY_H_ #define XLA_CORE_COLLECTIVES_CLIQUE_KEY_H_ +#include #include #include #include @@ -52,6 +53,7 @@ class CliqueKey { std::optional rank(GlobalDeviceId id) const; absl::Span devices() const; + size_t num_devices() const; // Returns true if this clique is a subset of `other`. virtual bool IsSubsetOf(const CliqueKey& other) const = 0; diff --git a/xla/service/cpu/BUILD b/xla/service/cpu/BUILD index 4ea228a0c63300..012112662640a2 100644 --- a/xla/service/cpu/BUILD +++ b/xla/service/cpu/BUILD @@ -1961,12 +1961,17 @@ cc_library( name = "collectives_interface", hdrs = ["collectives_interface.h"], deps = [ + "//xla:util", "//xla:xla_data_proto_cc", + "//xla/backends/cpu/collectives:cpu_collectives", + "//xla/core/collectives:clique_id", + "//xla/core/collectives:clique_key", "//xla/core/collectives:communicator", "//xla/core/collectives:rank_id", "//xla/service:collective_ops_utils", "//xla/service:global_device_id", "//xla/stream_executor:device_memory", + "//xla/tsl/platform:statusor", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/time", diff --git a/xla/service/cpu/collectives_interface.h b/xla/service/cpu/collectives_interface.h index cfa3b11f36513a..77e159e1535bc4 100644 --- a/xla/service/cpu/collectives_interface.h +++ b/xla/service/cpu/collectives_interface.h @@ -17,22 +17,108 @@ limitations under the License. #define XLA_SERVICE_CPU_COLLECTIVES_INTERFACE_H_ #include +#include #include #include +#include +#include +#include #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "absl/time/time.h" #include "absl/types/span.h" +#include "xla/backends/cpu/collectives/cpu_collectives.h" +#include "xla/core/collectives/clique_id.h" +#include "xla/core/collectives/clique_key.h" #include "xla/core/collectives/communicator.h" +#include "xla/core/collectives/rank_id.h" #include "xla/service/collective_ops_utils.h" #include "xla/service/global_device_id.h" #include "xla/stream_executor/device_memory.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/util.h" #include "xla/xla_data.pb.h" namespace xla::cpu { -class CollectivesInterface { +namespace internal { + +// An adapter from a shared_ptr to a Communicator. +class CommunicatorWrapper final : public Communicator { + public: + explicit CommunicatorWrapper(std::shared_ptr comm) + : comm_(std::move(comm)) {} + + absl::Status AllReduce(stream_executor::DeviceMemoryBase send_buffer, + stream_executor::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, + ReductionKind reduction_kind, + const Executor& executor) final { + return comm_->AllReduce(send_buffer, recv_buffer, dtype, count, + reduction_kind, executor); + } + + absl::Status Broadcast(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, PrimitiveType dtype, + size_t count, RankId root, + const Executor& executor) final { + return comm_->Broadcast(send_buffer, recv_buffer, dtype, count, root, + executor); + } + + absl::Status ReduceScatter(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, + ReductionKind reduction_kind, + const Executor& executor) final { + return comm_->ReduceScatter(send_buffer, recv_buffer, dtype, count, + reduction_kind, executor); + } + + absl::Status AllGather(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, PrimitiveType dtype, + size_t count, const Executor& executor) final { + return comm_->AllGather(send_buffer, recv_buffer, dtype, count, executor); + } + + absl::Status CollectivePermute(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, + std::optional source_rank, + absl::Span target_ranks, + const Executor& executor) final { + return comm_->CollectivePermute(send_buffer, recv_buffer, dtype, count, + source_rank, target_ranks, executor); + } + + absl::Status AllToAll(absl::Span send_buffers, + absl::Span recv_buffers, + PrimitiveType dtype, size_t count, + const Executor& executor) final { + return comm_->AllToAll(send_buffers, recv_buffers, dtype, count, executor); + } + + absl::Status Send(se::DeviceMemoryBase send_buffer, PrimitiveType dtype, + size_t count, RankId peer, const Executor& executor) final { + return comm_->Send(send_buffer, dtype, count, peer, executor); + } + + absl::Status Recv(se::DeviceMemoryBase recv_buffer, PrimitiveType dtype, + size_t count, RankId peer, const Executor& executor) final { + return comm_->Recv(recv_buffer, dtype, count, peer, executor); + } + + absl::StatusOr NumRanks() const final { return comm_->NumRanks(); } + + std::string ToString() const final { return comm_->ToString(); } + + private: + std::shared_ptr comm_; +}; + +} // namespace internal + +class CollectivesInterface : public CpuCollectives { public: virtual ~CollectivesInterface() = default; @@ -42,6 +128,25 @@ class CollectivesInterface { // rank: the rank of this process. virtual absl::StatusOr> GetCommunicator( absl::Span devices, int rank) = 0; + + absl::StatusOr>> + CreateCommunicators(int32_t nranks, const CliqueKey& clique_key, + const std::optional& clique_id, + absl::Span ranks, + const Config& config) final { + // We expect to create CPU communicators lazily one at a time. + if (ranks.size() != 1) { + return InvalidArgument("Expected 1 rank, got %d", ranks.size()); + } + + TF_ASSIGN_OR_RETURN(auto comm, GetCommunicator(clique_key.devices(), + ranks[0].rank.value())); + + std::vector> comms; + comms.reserve(1); + comms.push_back(std::make_unique(comm)); + return comms; + } }; } // namespace xla::cpu