diff --git a/xla/service/gpu/gpu_executable.cc b/xla/service/gpu/gpu_executable.cc index dde57439a5fe8d..3f0b8b3928a81e 100644 --- a/xla/service/gpu/gpu_executable.cc +++ b/xla/service/gpu/gpu_executable.cc @@ -541,8 +541,18 @@ absl::Status ExecuteThunks( TF_RETURN_IF_ERROR(thunk_sequence.ExecuteOnStream(execute_params)); - return MaybeSyncAndProfile(run_options, execution_timer.get(), - block_host_until_done ? main_stream : nullptr); + auto status = + MaybeSyncAndProfile(run_options, execution_timer.get(), + block_host_until_done ? main_stream : nullptr); + + Thunk::CleanupParams cleanup_params{ + executor, + &collective_params, + &collective_cliques, + }; + TF_RETURN_IF_ERROR(thunk_sequence.Cleanup(cleanup_params)); + + return status; } namespace { diff --git a/xla/service/gpu/runtime/nccl_collective_permute_thunk.cc b/xla/service/gpu/runtime/nccl_collective_permute_thunk.cc index b2a046321efe33..8c213386471121 100644 --- a/xla/service/gpu/runtime/nccl_collective_permute_thunk.cc +++ b/xla/service/gpu/runtime/nccl_collective_permute_thunk.cc @@ -161,13 +161,49 @@ absl::Status NcclCollectivePermuteStartThunk::Initialize( if (p2p_memcpy_enabled_) { TF_ASSIGN_OR_RETURN(const int64_t current_id, GetCurrentId(params.collective_params, config_)); + absl::MutexLock lock(&barrier_mutex_); + if (barrier_flags_.find(current_id) == barrier_flags_.end()) { + if (!params.stream->parent()->HostMemoryRegister( + &barrier_flags_[current_id], sizeof(uint8_t))) { + LOG(ERROR) << "Registering barrier flag failed."; + } + } + + TF_ASSIGN_OR_RETURN( + std::vector device_buffers, + ConvertToDeviceBuffers(params.buffer_allocations, {buffer_}, + config_.config.operand_element_type)); + TF_RET_CHECK(device_buffers.size() == 1) << "Expected one buffer pair."; + DeviceBufferPair& buffer = device_buffers[0]; + const NcclP2PConfig::SourceTargetMapEntry source_target = + NcclP2PConfig::GetSourceTarget(config_.id_to_source_target, current_id); + + const std::optional source_id = source_target.source; + se::DeviceMemoryBase dest_addr = buffer.destination_buffer; TF_RETURN_IF_ERROR(recv_ptr_map_.InitializeId(current_id)); + + if (source_id) { + TF_RETURN_IF_ERROR( + recv_ptr_map_.PutRecvPtr(current_id, dest_addr.opaque())); + } } return absl::OkStatus(); } +absl::Status NcclCollectivePermuteStartThunk::Cleanup( + const CleanupParams& params) { + TF_ASSIGN_OR_RETURN(const int64_t current_id, + GetCurrentId(params.collective_params, config_)); + + absl::MutexLock lock(&barrier_mutex_); + if (!params.executor->HostMemoryUnregister(&barrier_flags_[current_id])) { + LOG(ERROR) << "Unregistering barrier flag failed."; + } + return absl::OkStatus(); +} + absl::Status NcclCollectivePermuteStartThunk::RunNcclCollective( const ExecuteParams& params, se::Stream& stream, CommunicatorHandle comm_handle) { @@ -190,6 +226,14 @@ absl::Status NcclCollectivePermuteStartThunk::RunNcclCollective( p2p_memcpy_enabled_; TF_ASSIGN_OR_RETURN(GpuCollectives * collectives, GetGpuCollectives(params)); + if (use_memcpy) { + se::DeviceMemoryBase sync_var_address = + se::DeviceMemoryBase((void*)(&barrier_flags_[current_id])); + TF_RETURN_IF_ERROR(comm_handle.comm->AllReduce( + sync_var_address, sync_var_address, PrimitiveType::U8, 1, + ReductionKind::MIN, GpuCollectives::On(stream))); + } + return ::xla::gpu::RunCollectivePermute( collectives, source_target, device_buffers[0], stream, comm_handle.comm, device_string, current_id, use_memcpy, recv_ptr_map_); @@ -241,16 +285,7 @@ absl::Status RunCollectivePermute( device_string, current_id, source_id.value_or(-1), target_id.value_or(-1)); - // If all peers are local, only get/send device pointer values and invoke - // memcpy. - if (use_memcpy) { - // If sending to another peer, get the pointer value of the src addr. - // Only change the pointer value when it's different from stored one. - if (source_id) { - TF_RETURN_IF_ERROR( - recv_ptr_map.PutRecvPtr(current_id, dest_addr.opaque())); - } - } else { + if (!use_memcpy) { // GroupStart/End API is needed only if we will issue both send & recv // calls. const bool is_nccl_group_needed = (target_id && source_id); @@ -284,10 +319,6 @@ absl::Status RunCollectivePermute( } if (use_memcpy && target_id) { TF_ASSIGN_OR_RETURN(auto recv_ptr, recv_ptr_map.GetRecvPtr(*target_id)); - if (recv_ptr.IsUnavailable()) { - // TODO make BlockUntilReady support AsyncValueRef directly. - BlockUntilReady(recv_ptr.GetAsyncValue()); - } VLOG(3) << "Using memcpy, received target pointer: " << recv_ptr.get() << " current_id " << current_id << " target_id: " << *target_id; diff --git a/xla/service/gpu/runtime/nccl_collective_permute_thunk.h b/xla/service/gpu/runtime/nccl_collective_permute_thunk.h index bcc124b3dafcd1..8753df53eb6562 100644 --- a/xla/service/gpu/runtime/nccl_collective_permute_thunk.h +++ b/xla/service/gpu/runtime/nccl_collective_permute_thunk.h @@ -52,9 +52,7 @@ class NcclCollectivePermuteStartThunk : public NcclCollectiveThunk { absl::Status InitializeId(int64_t current_id) { absl::MutexLock lock(&mutex_); - if (recv_ptrs_.find(current_id) == recv_ptrs_.end()) { - recv_ptrs_[current_id] = tsl::MakeUnconstructedAsyncValueRef(); - } + recv_ptrs_[current_id] = tsl::MakeUnconstructedAsyncValueRef(); return absl::OkStatus(); } @@ -102,6 +100,7 @@ class NcclCollectivePermuteStartThunk : public NcclCollectiveThunk { int64_t partition_count, const Buffer& buffer, bool p2p_memcpy_enabled); absl::Status Initialize(const InitializeParams& params) override; + absl::Status Cleanup(const CleanupParams& params) override; static const char* GetHloOpName() { return "collective-permute-start"; } @@ -115,6 +114,8 @@ class NcclCollectivePermuteStartThunk : public NcclCollectiveThunk { const NcclP2PConfig config_; const Buffer buffer_; RecvPtrMap recv_ptr_map_; + absl::Mutex barrier_mutex_; + std::unordered_map barrier_flags_; bool p2p_memcpy_enabled_ = false; int64_t device_count_; }; diff --git a/xla/tests/collective_ops_e2e_test.cc b/xla/tests/collective_ops_e2e_test.cc index 97a7afa7f1d137..2eb8f2ed2d4b91 100644 --- a/xla/tests/collective_ops_e2e_test.cc +++ b/xla/tests/collective_ops_e2e_test.cc @@ -1801,5 +1801,123 @@ XLA_TEST_P(RaggedAllToAllTest, RaggedAllToAll_8GPUs) { INSTANTIATE_TEST_SUITE_P(RaggedAllToAllTest, RaggedAllToAllTest, ::testing::Bool()); +TEST_F(CollectiveOpsTestE2E, MemcpyP2pWhileLoopCorrectness) { + absl::string_view hlo_string = R"( +HloModule MemcpyP2pWhileLoopCorrectness, entry_computation_layout={(bf16[128,96]{1,0})->(bf16[32,384]{1,0}, bf16[32,384]{1,0})}, allow_spmd_sharding_propagation_to_output={true,true}, num_partitions=4 + +None.4 { + Arg_1.6 = bf16[32,96]{1,0} parameter(1) + Arg_0.5 = bf16[32,96]{1,0} parameter(0) + collective-permute.9 = bf16[32,96]{1,0} collective-permute(Arg_0.5), channel_id=1, source_target_pairs={{0,1},{1,2},{2,3},{3,0}} + constant.7 = bf16[] constant(2) + broadcast.8 = bf16[32,96]{1,0} broadcast(constant.7), dimensions={} + multiply.10 = bf16[32,96]{1,0} multiply(Arg_0.5, broadcast.8) + ROOT tuple.11 = (bf16[32,96]{1,0}, bf16[32,96]{1,0}) tuple(collective-permute.9, multiply.10) +} // None.4 + +region_0.12 { + arg_tuple.13 = (s32[], bf16[32,96]{1,0}, bf16[32,96]{1,0}) parameter(0) + get-tuple-element.14 = s32[] get-tuple-element(arg_tuple.13), index=0 + constant.17 = s32[] constant(1) + add.21 = s32[] add(get-tuple-element.14, constant.17) + get-tuple-element.15 = bf16[32,96]{1,0} get-tuple-element(arg_tuple.13), index=1 + get-tuple-element.16 = bf16[32,96]{1,0} get-tuple-element(arg_tuple.13), index=2 + call.18 = (bf16[32,96]{1,0}, bf16[32,96]{1,0}) call(get-tuple-element.15, get-tuple-element.16), to_apply=None.4 + get-tuple-element.19 = bf16[32,96]{1,0} get-tuple-element(call.18), index=0 + get-tuple-element.20 = bf16[32,96]{1,0} get-tuple-element(call.18), index=1 + ROOT tuple.22 = (s32[], bf16[32,96]{1,0}, bf16[32,96]{1,0}) tuple(add.21, get-tuple-element.19, get-tuple-element.20) +} // region_0.12 + +region_1.23 { + arg_tuple.24 = (s32[], bf16[32,96]{1,0}, bf16[32,96]{1,0}) parameter(0) + get-tuple-element.26 = bf16[32,96]{1,0} get-tuple-element(arg_tuple.24), index=1 + get-tuple-element.27 = bf16[32,96]{1,0} get-tuple-element(arg_tuple.24), index=2 + get-tuple-element.25 = s32[] get-tuple-element(arg_tuple.24), index=0 + constant.28 = s32[] constant(3) + ROOT compare.29 = pred[] compare(get-tuple-element.25, constant.28), direction=LT +} // region_1.23 + +shmap_body.30 { + constant.32 = s32[] constant(0) + Arg_0.31 = bf16[32,96]{1,0} parameter(0) + constant.33 = bf16[] constant(0) + broadcast.34 = bf16[32,96]{1,0} broadcast(constant.33), dimensions={} + tuple.35 = (s32[], bf16[32,96]{1,0}, bf16[32,96]{1,0}) tuple(constant.32, Arg_0.31, broadcast.34) + while.36 = (s32[], bf16[32,96]{1,0}, bf16[32,96]{1,0}) while(tuple.35), condition=region_1.23, body=region_0.12 + get-tuple-element.37 = s32[] get-tuple-element(while.36), index=0 + get-tuple-element.38 = bf16[32,96]{1,0} get-tuple-element(while.36), index=1 + get-tuple-element.39 = bf16[32,96]{1,0} get-tuple-element(while.36), index=2 + ROOT tuple.40 = (bf16[32,96]{1,0}, bf16[32,96]{1,0}) tuple(get-tuple-element.38, get-tuple-element.39) +} // shmap_body.30 + +ENTRY main.49 { + Arg_0.1 = bf16[128,96]{1,0} parameter(0), sharding={devices=[4,1]<=[4]} + custom-call.2 = bf16[128,96]{1,0} custom-call(Arg_0.1), custom_call_target="Sharding", sharding={devices=[4,1]<=[4]} + custom-call.3 = bf16[32,96]{1,0} custom-call(custom-call.2), custom_call_target="SPMDFullToShardShape", sharding={manual} + call.41 = (bf16[32,96]{1,0}, bf16[32,96]{1,0}) call(custom-call.3), to_apply=shmap_body.30 + get-tuple-element.42 = bf16[32,96]{1,0} get-tuple-element(call.41), index=0 + custom-call.44 = bf16[32,96]{1,0} custom-call(get-tuple-element.42), custom_call_target="Sharding", sharding={manual} + custom-call.45 = bf16[32,384]{1,0} custom-call(custom-call.44), custom_call_target="SPMDShardToFullShape", sharding={devices=[1,4]<=[4]} + get-tuple-element.43 = bf16[32,96]{1,0} get-tuple-element(call.41), index=1 + custom-call.46 = bf16[32,96]{1,0} custom-call(get-tuple-element.43), custom_call_target="Sharding", sharding={manual} + custom-call.47 = bf16[32,384]{1,0} custom-call(custom-call.46), custom_call_target="SPMDShardToFullShape", sharding={devices=[1,4]<=[4]} + ROOT tuple.48 = (bf16[32,384]{1,0}, bf16[32,384]{1,0}) tuple(custom-call.45, custom-call.47) +} // main.49 +)"; + + const int64_t kNumReplicas = 1; + const int64_t kNumPartitions = 4; + SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas * kNumPartitions); + + HloModuleConfig config = GetModuleConfigForTest(kNumReplicas, kNumPartitions); + auto opts = GetDebugOptionsForTest(); + opts.set_xla_gpu_use_memcpy_local_p2p(true); + config.set_debug_options(opts); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string, config)); + auto fake_arguments = xla::MakeFakeArguments(module.get()).value(); + std::vector fake_ptrs(fake_arguments.size()); + for (int i = 0; i < fake_arguments.size(); ++i) { + fake_ptrs[i] = &fake_arguments[i]; + } + + DeviceAssignment assn(/*replica_count=*/kNumReplicas, + /*computation_count=*/kNumPartitions); + for (int64_t i = 0; i < kNumPartitions; ++i) { + assn(0, i) = i; + } + + TF_ASSERT_OK_AND_ASSIGN( + std::vector results, + HloTestBase::ExecuteReplicated( + std::move(module), fake_ptrs, kNumPartitions, &assn, + /*run_hlo_passes=*/true, /*use-threads=*/true)); + ASSERT_EQ(results.size(), kNumPartitions); + + HloModuleConfig ref_config = + GetModuleConfigForTest(kNumReplicas, kNumPartitions); + auto ref_opts = GetDebugOptionsForTest(); + ref_opts.set_xla_gpu_use_memcpy_local_p2p(false); + ref_config.set_debug_options(ref_opts); + TF_ASSERT_OK_AND_ASSIGN(auto ref_module, + ParseAndReturnVerifiedModule(hlo_string, ref_config)); + auto fake_ref_arguments = xla::MakeFakeArguments(ref_module.get()).value(); + std::vector ref_fake_ptrs(fake_ref_arguments.size()); + for (int i = 0; i < fake_ref_arguments.size(); ++i) { + ref_fake_ptrs[i] = &fake_ref_arguments[i]; + } + + TF_ASSERT_OK_AND_ASSIGN( + std::vector ref_results, + HloTestBase::ExecuteReplicated( + std::move(ref_module), ref_fake_ptrs, kNumPartitions, &assn, + /*run_hlo_passes=*/true, /*use-threads=*/true)); + ASSERT_EQ(ref_results.size(), kNumPartitions); + ErrorSpec error_spec{1e-5, 1e-5}; + // Expect same results with and without pipelining of collectives. + for (int i = 0; i < kNumPartitions; ++i) { + EXPECT_TRUE(LiteralTestUtil::Near(ref_results[i], results[i], error_spec)); + } +} } // namespace } // namespace xla