From 570a4d842f94f3fdd4ea578046a95e2211bb50cd Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Thu, 4 Apr 2024 10:25:10 -0700 Subject: [PATCH] [pjrt] NFC: Rename HostBufferSemantics::kZeroCopy to kImmutableZeroCopy PiperOrigin-RevId: 621900021 --- xla/pjrt/c/CHANGELOG.md | 2 ++ xla/pjrt/c/pjrt_c_api.h | 12 +++++++----- xla/pjrt/c/pjrt_c_api_helpers.cc | 13 +++++++------ xla/pjrt/cpu/abstract_tfrt_cpu_buffer.cc | 5 +++-- xla/pjrt/pjrt_c_api_client.cc | 8 ++++---- xla/pjrt/pjrt_client.h | 15 ++++++++------- xla/pjrt/pjrt_client_test.cc | 14 +++++++------- xla/pjrt/pjrt_stream_executor_client.cc | 2 +- xla/python/ifrt/array_impl_test_lib.cc | 16 ++++++++-------- xla/python/py_array.cc | 4 ++-- xla/python/py_client.cc | 8 ++++---- xla/python/py_values.cc | 3 ++- xla/python/xla.cc | 2 +- 13 files changed, 56 insertions(+), 48 deletions(-) diff --git a/xla/pjrt/c/CHANGELOG.md b/xla/pjrt/c/CHANGELOG.md index cb9cb750d8194..ec1f828b3b98d 100644 --- a/xla/pjrt/c/CHANGELOG.md +++ b/xla/pjrt/c/CHANGELOG.md @@ -2,6 +2,8 @@ ## 0.47 * Added ``PJRT_Extension_Type::PJRT_Extension_Type_Custom_Partitioner``. +* Renamed host buffer semantics enum from ``PJRT_HostBufferSemantics_kZeroCopy`` + to ``PJRT_HostBufferSemantics_kImmutableZeroCopy``. ## 0.46 (Feb 29, 2024) * Update outdated struct sizes from previous changes to diff --git a/xla/pjrt/c/pjrt_c_api.h b/xla/pjrt/c/pjrt_c_api.h index da1934e64f2e6..670501205c35c 100644 --- a/xla/pjrt/c/pjrt_c_api.h +++ b/xla/pjrt/c/pjrt_c_api.h @@ -652,11 +652,13 @@ typedef enum { PJRT_HostBufferSemantics_kImmutableUntilTransferCompletes, // The PjRtBuffer may alias `data` internally and the runtime may use the - // `data` contents as long as the buffer is alive. The caller promises to - // keep `data` alive and not to mutate its contents as long as the buffer is - // alive; to notify the caller that the buffer may be freed, the runtime - // will call `done_with_host_buffer` when the PjRtBuffer is freed. - PJRT_HostBufferSemantics_kZeroCopy, + // `data` contents as long as the buffer is alive. The runtime promises not + // to mutate contents of the buffer (i.e. it will not use it for aliased + // output buffers). The caller promises to keep `data` alive and not to mutate + // its contents as long as the buffer is alive; to notify the caller that the + // buffer may be freed, the runtime will call `done_with_host_buffer` when the + // PjRtBuffer is freed. + PJRT_HostBufferSemantics_kImmutableZeroCopy, } PJRT_HostBufferSemantics; typedef enum { diff --git a/xla/pjrt/c/pjrt_c_api_helpers.cc b/xla/pjrt/c/pjrt_c_api_helpers.cc index 8cabd1e477d2a..aade90c1b534e 100644 --- a/xla/pjrt/c/pjrt_c_api_helpers.cc +++ b/xla/pjrt/c/pjrt_c_api_helpers.cc @@ -349,8 +349,8 @@ const char* HostBufferSemanticsToString( switch (h) { case xla::PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall: return "xla::PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall"; - case xla::PjRtClient::HostBufferSemantics::kZeroCopy: - return "xla::PjRtClient::HostBufferSemantics::kZeroCopy"; + case xla::PjRtClient::HostBufferSemantics::kImmutableZeroCopy: + return "xla::PjRtClient::HostBufferSemantics::kImmutableZeroCopy"; case xla::PjRtClient::HostBufferSemantics::kImmutableUntilTransferCompletes: return "xla::PjRtClient::HostBufferSemantics::" "kImmutableUntilTransferCompletes"; @@ -366,8 +366,9 @@ PJRT_HostBufferSemantics ConvertToPjRtHostBufferSemantics( case xla::PjRtClient::HostBufferSemantics::kImmutableUntilTransferCompletes: return PJRT_HostBufferSemantics:: PJRT_HostBufferSemantics_kImmutableUntilTransferCompletes; - case xla::PjRtClient::HostBufferSemantics::kZeroCopy: - return PJRT_HostBufferSemantics::PJRT_HostBufferSemantics_kZeroCopy; + case xla::PjRtClient::HostBufferSemantics::kImmutableZeroCopy: + return PJRT_HostBufferSemantics:: + PJRT_HostBufferSemantics_kImmutableZeroCopy; default: CHECK(false) << "Input host buffer semantics is not supported in C API layer: " @@ -385,8 +386,8 @@ xla::PjRtClient::HostBufferSemantics ConvertFromPjRtHostBufferSemantics( PJRT_HostBufferSemantics_kImmutableUntilTransferCompletes: return xla::PjRtClient::HostBufferSemantics:: kImmutableUntilTransferCompletes; - case PJRT_HostBufferSemantics::PJRT_HostBufferSemantics_kZeroCopy: - return xla::PjRtClient::HostBufferSemantics::kZeroCopy; + case PJRT_HostBufferSemantics::PJRT_HostBufferSemantics_kImmutableZeroCopy: + return xla::PjRtClient::HostBufferSemantics::kImmutableZeroCopy; } } diff --git a/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.cc b/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.cc index 31ee420427707..26ec0703cd4db 100644 --- a/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.cc +++ b/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.cc @@ -474,7 +474,7 @@ AbstractTfrtCpuBuffer::CopyToDeviceAcrossClients(PjRtDevice* dst_device) { return dst_device->client()->BufferFromHostBuffer( literal_pointer->untyped_data(), literal_pointer->shape().element_type(), literal_pointer->shape().dimensions(), byte_strides, - PjRtClient::HostBufferSemantics::kZeroCopy, + PjRtClient::HostBufferSemantics::kImmutableZeroCopy, [literal{std::move(literal)}]() { /* frees literal */ }, dst_device); } @@ -695,7 +695,8 @@ AbstractTfrtCpuBuffer::BufferFromHostBufferHelper( // code which requires it. bool can_use_zero_copy = has_default_layout && !is_int4 && - host_buffer_semantics == PjRtClient::HostBufferSemantics::kZeroCopy && + host_buffer_semantics == + PjRtClient::HostBufferSemantics::kImmutableZeroCopy && ((absl::bit_cast(data) & (cpu_function_runtime::MinAlign() - 1)) == 0); absl::InlinedVector, 4> buffers; diff --git a/xla/pjrt/pjrt_c_api_client.cc b/xla/pjrt/pjrt_c_api_client.cc index 503bea10614f8..11affaff461e6 100644 --- a/xla/pjrt/pjrt_c_api_client.cc +++ b/xla/pjrt/pjrt_c_api_client.cc @@ -484,13 +484,13 @@ PjRtCApiClient::BufferFromHostBufferInternalImpl( std::variant device_or_memory, const Layout* device_layout) { if (host_buffer_semantics != HostBufferSemantics::kImmutableOnlyDuringCall && - host_buffer_semantics != HostBufferSemantics::kZeroCopy && + host_buffer_semantics != HostBufferSemantics::kImmutableZeroCopy && host_buffer_semantics != HostBufferSemantics::kImmutableUntilTransferCompletes) { return Unimplemented( "PJRT C API does not support HostBufferSemantics other than " "HostBufferSemantics::kImmutableOnlyDuringCall, " - "HostBufferSemantics::kZeroCopy and " + "HostBufferSemantics::kImmutableZeroCopy and " "HostBufferSemantics::kImmutableUntilTransferCompletes."); } @@ -1943,7 +1943,7 @@ StatusOr> PjRtCApiBuffer::CopyToDevice( literal_pointer->untyped_data(), literal_pointer->shape().element_type(), literal_pointer->shape().dimensions(), byte_strides, - PjRtClient::HostBufferSemantics::kZeroCopy, + PjRtClient::HostBufferSemantics::kImmutableZeroCopy, [literal{std::move(literal)}]() { /* frees literal */ }, dst_device); } } @@ -1975,7 +1975,7 @@ StatusOr> PjRtCApiBuffer::CopyToMemorySpace( literal_pointer->untyped_data(), literal_pointer->shape().element_type(), literal_pointer->shape().dimensions(), byte_strides, - PjRtClient::HostBufferSemantics::kZeroCopy, + PjRtClient::HostBufferSemantics::kImmutableZeroCopy, [literal{std::move(literal)}]() { /* frees literal */ }, dst_memory, /*device_layout=*/nullptr); } diff --git a/xla/pjrt/pjrt_client.h b/xla/pjrt/pjrt_client.h index 0299ca8c9fada..3f167e0fb265d 100644 --- a/xla/pjrt/pjrt_client.h +++ b/xla/pjrt/pjrt_client.h @@ -813,13 +813,14 @@ class PjRtClient { kImmutableUntilTransferCompletes, // The PjRtBuffer may alias `data` internally and the runtime may use the - // `data` contents as long as the buffer is alive. The caller promises to - // keep `data` alive and not to mutate its contents as long as the buffer is - // alive; to notify the caller that the buffer may be freed, the runtime - // will call `on_done_with_host_buffer` when the PjRtBuffer is freed. On - // non-CPU platforms this acts identically to - // kImmutableUntilTransferCompletes. - kZeroCopy, + // `data` contents as long as the buffer is alive. The runtime promises not + // to mutate contents of the buffer (i.e. it will not use it for aliased + // output buffers). The caller promises to keep `data` alive and also not to + // mutate its contents as long as the buffer is alive; to notify the caller + // that the buffer may be freed, the runtime will call + // `on_done_with_host_buffer` when the PjRtBuffer is freed. On non-CPU + // platforms this acts identically to kImmutableUntilTransferCompletes. + kImmutableZeroCopy, }; // on_done_with_host_buffer is optional and may be null. diff --git a/xla/pjrt/pjrt_client_test.cc b/xla/pjrt/pjrt_client_test.cc index fbb2b18e7f679..1bc5d6704abbf 100644 --- a/xla/pjrt/pjrt_client_test.cc +++ b/xla/pjrt/pjrt_client_test.cc @@ -176,7 +176,7 @@ TEST_P(PjRtClientTest, ExecuteWithTupleZeroCopy) { /*byte_strides=*/std::nullopt, // Use kZeroCopy to test the correctness of // `on_done_with_host_buffer`. - PjRtClient::HostBufferSemantics::kZeroCopy, + PjRtClient::HostBufferSemantics::kImmutableZeroCopy, /*on_done_with_host_buffer=*/ [&data]() { // Deliberately modifying the content of `data`. A @@ -216,8 +216,8 @@ TEST_P(PjRtClientTest, ExecuteWithDonation) { auto buffer, client->BufferFromHostBuffer( data.data(), shape.element_type(), shape.dimensions(), /*byte_strides=*/std::nullopt, - PjRtClient::HostBufferSemantics::kZeroCopy, nullptr, - client->addressable_devices()[0])); + PjRtClient::HostBufferSemantics::kImmutableZeroCopy, + nullptr, client->addressable_devices()[0])); ExecuteOptions options; options.execution_mode = GetParam(); @@ -249,8 +249,8 @@ TEST_P(PjRtClientTest, ExecuteWithDonationAbort) { auto buffer, client->BufferFromHostBuffer( data.data(), shape.element_type(), shape.dimensions(), /*byte_strides=*/std::nullopt, - PjRtClient::HostBufferSemantics::kZeroCopy, nullptr, - client->addressable_devices()[0])); + PjRtClient::HostBufferSemantics::kImmutableZeroCopy, + nullptr, client->addressable_devices()[0])); auto external_reference = buffer->AcquireExternalReference(); @@ -323,8 +323,8 @@ TEST_P(PjRtClientTest, ExecuteWithConcurrentUsageAndDonation) { auto buffer, client->BufferFromHostBuffer( data.data(), shape.element_type(), shape.dimensions(), /*byte_strides=*/std::nullopt, - PjRtClient::HostBufferSemantics::kZeroCopy, nullptr, - client->addressable_devices()[0])); + PjRtClient::HostBufferSemantics::kImmutableZeroCopy, + nullptr, client->addressable_devices()[0])); ExecuteOptions options; options.execution_mode = GetParam(); diff --git a/xla/pjrt/pjrt_stream_executor_client.cc b/xla/pjrt/pjrt_stream_executor_client.cc index 83261bcd7f98a..4ce9b9f2e57c0 100644 --- a/xla/pjrt/pjrt_stream_executor_client.cc +++ b/xla/pjrt/pjrt_stream_executor_client.cc @@ -1805,7 +1805,7 @@ StatusOr> PjRtStreamExecutorBuffer::CopyToDevice( literal_pointer->untyped_data(), literal_pointer->shape().element_type(), literal_pointer->shape().dimensions(), byte_strides, - PjRtStreamExecutorClient::HostBufferSemantics::kZeroCopy, + PjRtStreamExecutorClient::HostBufferSemantics::kImmutableZeroCopy, [literal{std::move(literal)}]() { /* frees literal */ }, dst_device); } diff --git a/xla/python/ifrt/array_impl_test_lib.cc b/xla/python/ifrt/array_impl_test_lib.cc index 022ad3d777bdb..ec1f81424274f 100644 --- a/xla/python/ifrt/array_impl_test_lib.cc +++ b/xla/python/ifrt/array_impl_test_lib.cc @@ -83,7 +83,7 @@ TEST_P(ArrayImplWithHostBufferSemanticsTest, // Regardless of the host buffer semantics chosen, the host buffer must not be // used by the runtime once `on_done_with_host_buffer` has been called. - if (semantics == Client::HostBufferSemantics::kZeroCopy) { + if (semantics == Client::HostBufferSemantics::kImmutableZeroCopy) { // `on_done_with_host_buffer` is called only when the `Array` is destroyed // if the runtime implements `kZeroCopy`. A deadlock will occur if we keep // the `Array` instance. @@ -108,7 +108,7 @@ INSTANTIATE_TEST_CASE_P( testing::Values( Client::HostBufferSemantics::kImmutableOnlyDuringCall, Client::HostBufferSemantics::kImmutableUntilTransferCompletes, - Client::HostBufferSemantics::kZeroCopy)); + Client::HostBufferSemantics::kImmutableZeroCopy)); TEST(ArrayImplTest, MakeArrayFromHostBufferImmutableOnlyDuringCall) { TF_ASSERT_OK_AND_ASSIGN(auto client, test_util::GetClient()); @@ -184,12 +184,12 @@ TEST(ArrayImplTest, MakeArrayFromHostBufferZeroCopy) { std::shared_ptr sharding = SingleDeviceSharding::Create(device, MemoryKind()); - TF_ASSERT_OK_AND_ASSIGN( - auto array, - client->MakeArrayFromHostBuffer(data->data(), dtype, shape, - /*byte_strides=*/std::nullopt, sharding, - Client::HostBufferSemantics::kZeroCopy, - /*on_done_with_host_buffer=*/nullptr)); + TF_ASSERT_OK_AND_ASSIGN(auto array, + client->MakeArrayFromHostBuffer( + data->data(), dtype, shape, + /*byte_strides=*/std::nullopt, sharding, + Client::HostBufferSemantics::kImmutableZeroCopy, + /*on_done_with_host_buffer=*/nullptr)); // The `Array` may alias the host buffer, but once the transfer is done and // the `Array` is destroyed, the host buffer is not accessed. This test would diff --git a/xla/python/py_array.cc b/xla/python/py_array.cc index 5be4df4674ea4..62674088729bd 100644 --- a/xla/python/py_array.cc +++ b/xla/python/py_array.cc @@ -1083,8 +1083,8 @@ StatusOr PyArray::BatchedDevicePut( DevicePutOptions options; options.squash_64bit_types = !jax_enable_x64; options.allow_zero_copy = - (!force_copy && - (host_buffer_semantics == ifrt::Client::HostBufferSemantics::kZeroCopy)); + (!force_copy && (host_buffer_semantics == + ifrt::Client::HostBufferSemantics::kImmutableZeroCopy)); nb::list owning_pylist; std::vector> ifrt_arrays; diff --git a/xla/python/py_client.cc b/xla/python/py_client.cc index 0afd053313a66..efce0ebc1e589 100644 --- a/xla/python/py_client.cc +++ b/xla/python/py_client.cc @@ -316,8 +316,8 @@ absl::Status PyClient::Defragment() { DevicePutOptions options; options.squash_64bit_types = false; options.allow_zero_copy = - (!force_copy && - (host_buffer_semantics == ifrt::Client::HostBufferSemantics::kZeroCopy)); + (!force_copy && (host_buffer_semantics == + ifrt::Client::HostBufferSemantics::kImmutableZeroCopy)); // TODO(phawkins): remove .ptr() after nanobind transition is complete. TF_ASSIGN_OR_RETURN(DevicePutResult put, DevicePut(argument.ptr(), client->ifrt_client_.get(), @@ -710,7 +710,7 @@ PyType_Slot PyClient::slots_[] = { PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall) .value("IMMUTABLE_UNTIL_TRANSFER_COMPLETES", PjRtClient::HostBufferSemantics::kImmutableUntilTransferCompletes) - .value("ZERO_COPY", PjRtClient::HostBufferSemantics::kZeroCopy); + .value("ZERO_COPY", PjRtClient::HostBufferSemantics::kImmutableZeroCopy); nb::class_ py_local_client(m, "Client", nb::is_weak_referenceable(), nb::type_slots(PyClient::slots_)); @@ -742,7 +742,7 @@ PyType_Slot PyClient::slots_[] = { nb::arg("argument"), nb::arg("device").none() = nullptr, nb::arg("force_copy") = false, nb::arg("host_buffer_semantics") = - PjRtClient::HostBufferSemantics::kZeroCopy) + PjRtClient::HostBufferSemantics::kImmutableZeroCopy) .def( "make_cross_host_receive_buffers", [](nb_class_ptr client, absl::Span shapes, diff --git a/xla/python/py_values.cc b/xla/python/py_values.cc index c308adbb943d5..ae20a409a29e1 100644 --- a/xla/python/py_values.cc +++ b/xla/python/py_values.cc @@ -257,7 +257,8 @@ absl::StatusOr HandleNumpyArray( on_done_with_host_buffer = [py_buffer_ref{ std::move(py_buffer_ref)}]() { /* keeps py_buffer_ref alive */ }; - host_buffer_semantics = ifrt::Client::HostBufferSemantics::kZeroCopy; + host_buffer_semantics = + ifrt::Client::HostBufferSemantics::kImmutableZeroCopy; } // Must release the GIL before BufferFromHostBuffer because backends may // decide to block/sleep for device buffer allocation. diff --git a/xla/python/xla.cc b/xla/python/xla.cc index a9f45f5753996..85acf36fe17ae 100644 --- a/xla/python/xla.cc +++ b/xla/python/xla.cc @@ -933,7 +933,7 @@ NB_MODULE(xla_extension, m_nb) { nb::arg("aval"), nb::arg("sharding"), nb::arg("xs"), nb::arg("devices"), nb::arg("committed") = true, nb::arg("force_copy") = false, nb::arg("host_buffer_semantics") = - PjRtClient::HostBufferSemantics::kZeroCopy); + PjRtClient::HostBufferSemantics::kImmutableZeroCopy); m_nb.def("batched_block_until_ready", [](std::vector xs) { ThrowIfError(PyArray::BatchedBlockUntilReady(std::move(xs)));