Skip to content

Commit

Permalink
[pjrt] NFC: Rename HostBufferSemantics::kZeroCopy to kImmutableZeroCopy
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 621900021
  • Loading branch information
ezhulenev authored and copybara-github committed Apr 4, 2024
1 parent 3be7909 commit 570a4d8
Show file tree
Hide file tree
Showing 13 changed files with 56 additions and 48 deletions.
2 changes: 2 additions & 0 deletions xla/pjrt/c/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

## 0.47
* Added ``PJRT_Extension_Type::PJRT_Extension_Type_Custom_Partitioner``.
* Renamed host buffer semantics enum from ``PJRT_HostBufferSemantics_kZeroCopy``
to ``PJRT_HostBufferSemantics_kImmutableZeroCopy``.

## 0.46 (Feb 29, 2024)
* Update outdated struct sizes from previous changes to
Expand Down
12 changes: 7 additions & 5 deletions xla/pjrt/c/pjrt_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -652,11 +652,13 @@ typedef enum {
PJRT_HostBufferSemantics_kImmutableUntilTransferCompletes,

// The PjRtBuffer may alias `data` internally and the runtime may use the
// `data` contents as long as the buffer is alive. The caller promises to
// keep `data` alive and not to mutate its contents as long as the buffer is
// alive; to notify the caller that the buffer may be freed, the runtime
// will call `done_with_host_buffer` when the PjRtBuffer is freed.
PJRT_HostBufferSemantics_kZeroCopy,
// `data` contents as long as the buffer is alive. The runtime promises not
// to mutate contents of the buffer (i.e. it will not use it for aliased
// output buffers). The caller promises to keep `data` alive and not to mutate
// its contents as long as the buffer is alive; to notify the caller that the
// buffer may be freed, the runtime will call `done_with_host_buffer` when the
// PjRtBuffer is freed.
PJRT_HostBufferSemantics_kImmutableZeroCopy,
} PJRT_HostBufferSemantics;

typedef enum {
Expand Down
13 changes: 7 additions & 6 deletions xla/pjrt/c/pjrt_c_api_helpers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -349,8 +349,8 @@ const char* HostBufferSemanticsToString(
switch (h) {
case xla::PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall:
return "xla::PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall";
case xla::PjRtClient::HostBufferSemantics::kZeroCopy:
return "xla::PjRtClient::HostBufferSemantics::kZeroCopy";
case xla::PjRtClient::HostBufferSemantics::kImmutableZeroCopy:
return "xla::PjRtClient::HostBufferSemantics::kImmutableZeroCopy";
case xla::PjRtClient::HostBufferSemantics::kImmutableUntilTransferCompletes:
return "xla::PjRtClient::HostBufferSemantics::"
"kImmutableUntilTransferCompletes";
Expand All @@ -366,8 +366,9 @@ PJRT_HostBufferSemantics ConvertToPjRtHostBufferSemantics(
case xla::PjRtClient::HostBufferSemantics::kImmutableUntilTransferCompletes:
return PJRT_HostBufferSemantics::
PJRT_HostBufferSemantics_kImmutableUntilTransferCompletes;
case xla::PjRtClient::HostBufferSemantics::kZeroCopy:
return PJRT_HostBufferSemantics::PJRT_HostBufferSemantics_kZeroCopy;
case xla::PjRtClient::HostBufferSemantics::kImmutableZeroCopy:
return PJRT_HostBufferSemantics::
PJRT_HostBufferSemantics_kImmutableZeroCopy;
default:
CHECK(false)
<< "Input host buffer semantics is not supported in C API layer: "
Expand All @@ -385,8 +386,8 @@ xla::PjRtClient::HostBufferSemantics ConvertFromPjRtHostBufferSemantics(
PJRT_HostBufferSemantics_kImmutableUntilTransferCompletes:
return xla::PjRtClient::HostBufferSemantics::
kImmutableUntilTransferCompletes;
case PJRT_HostBufferSemantics::PJRT_HostBufferSemantics_kZeroCopy:
return xla::PjRtClient::HostBufferSemantics::kZeroCopy;
case PJRT_HostBufferSemantics::PJRT_HostBufferSemantics_kImmutableZeroCopy:
return xla::PjRtClient::HostBufferSemantics::kImmutableZeroCopy;
}
}

Expand Down
5 changes: 3 additions & 2 deletions xla/pjrt/cpu/abstract_tfrt_cpu_buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,7 @@ AbstractTfrtCpuBuffer::CopyToDeviceAcrossClients(PjRtDevice* dst_device) {
return dst_device->client()->BufferFromHostBuffer(
literal_pointer->untyped_data(), literal_pointer->shape().element_type(),
literal_pointer->shape().dimensions(), byte_strides,
PjRtClient::HostBufferSemantics::kZeroCopy,
PjRtClient::HostBufferSemantics::kImmutableZeroCopy,
[literal{std::move(literal)}]() { /* frees literal */ }, dst_device);
}

Expand Down Expand Up @@ -695,7 +695,8 @@ AbstractTfrtCpuBuffer::BufferFromHostBufferHelper(
// code which requires it.
bool can_use_zero_copy =
has_default_layout && !is_int4 &&
host_buffer_semantics == PjRtClient::HostBufferSemantics::kZeroCopy &&
host_buffer_semantics ==
PjRtClient::HostBufferSemantics::kImmutableZeroCopy &&
((absl::bit_cast<std::uintptr_t>(data) &
(cpu_function_runtime::MinAlign() - 1)) == 0);
absl::InlinedVector<std::shared_ptr<MaybeOwningCpuMemory>, 4> buffers;
Expand Down
8 changes: 4 additions & 4 deletions xla/pjrt/pjrt_c_api_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -484,13 +484,13 @@ PjRtCApiClient::BufferFromHostBufferInternalImpl(
std::variant<PjRtDevice*, PjRtMemorySpace*> device_or_memory,
const Layout* device_layout) {
if (host_buffer_semantics != HostBufferSemantics::kImmutableOnlyDuringCall &&
host_buffer_semantics != HostBufferSemantics::kZeroCopy &&
host_buffer_semantics != HostBufferSemantics::kImmutableZeroCopy &&
host_buffer_semantics !=
HostBufferSemantics::kImmutableUntilTransferCompletes) {
return Unimplemented(
"PJRT C API does not support HostBufferSemantics other than "
"HostBufferSemantics::kImmutableOnlyDuringCall, "
"HostBufferSemantics::kZeroCopy and "
"HostBufferSemantics::kImmutableZeroCopy and "
"HostBufferSemantics::kImmutableUntilTransferCompletes.");
}

Expand Down Expand Up @@ -1943,7 +1943,7 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtCApiBuffer::CopyToDevice(
literal_pointer->untyped_data(),
literal_pointer->shape().element_type(),
literal_pointer->shape().dimensions(), byte_strides,
PjRtClient::HostBufferSemantics::kZeroCopy,
PjRtClient::HostBufferSemantics::kImmutableZeroCopy,
[literal{std::move(literal)}]() { /* frees literal */ }, dst_device);
}
}
Expand Down Expand Up @@ -1975,7 +1975,7 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtCApiBuffer::CopyToMemorySpace(
literal_pointer->untyped_data(),
literal_pointer->shape().element_type(),
literal_pointer->shape().dimensions(), byte_strides,
PjRtClient::HostBufferSemantics::kZeroCopy,
PjRtClient::HostBufferSemantics::kImmutableZeroCopy,
[literal{std::move(literal)}]() { /* frees literal */ }, dst_memory,
/*device_layout=*/nullptr);
}
Expand Down
15 changes: 8 additions & 7 deletions xla/pjrt/pjrt_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -813,13 +813,14 @@ class PjRtClient {
kImmutableUntilTransferCompletes,

// The PjRtBuffer may alias `data` internally and the runtime may use the
// `data` contents as long as the buffer is alive. The caller promises to
// keep `data` alive and not to mutate its contents as long as the buffer is
// alive; to notify the caller that the buffer may be freed, the runtime
// will call `on_done_with_host_buffer` when the PjRtBuffer is freed. On
// non-CPU platforms this acts identically to
// kImmutableUntilTransferCompletes.
kZeroCopy,
// `data` contents as long as the buffer is alive. The runtime promises not
// to mutate contents of the buffer (i.e. it will not use it for aliased
// output buffers). The caller promises to keep `data` alive and also not to
// mutate its contents as long as the buffer is alive; to notify the caller
// that the buffer may be freed, the runtime will call
// `on_done_with_host_buffer` when the PjRtBuffer is freed. On non-CPU
// platforms this acts identically to kImmutableUntilTransferCompletes.
kImmutableZeroCopy,
};

// on_done_with_host_buffer is optional and may be null.
Expand Down
14 changes: 7 additions & 7 deletions xla/pjrt/pjrt_client_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ TEST_P(PjRtClientTest, ExecuteWithTupleZeroCopy) {
/*byte_strides=*/std::nullopt,
// Use kZeroCopy to test the correctness of
// `on_done_with_host_buffer`.
PjRtClient::HostBufferSemantics::kZeroCopy,
PjRtClient::HostBufferSemantics::kImmutableZeroCopy,
/*on_done_with_host_buffer=*/
[&data]() {
// Deliberately modifying the content of `data`. A
Expand Down Expand Up @@ -216,8 +216,8 @@ TEST_P(PjRtClientTest, ExecuteWithDonation) {
auto buffer, client->BufferFromHostBuffer(
data.data(), shape.element_type(), shape.dimensions(),
/*byte_strides=*/std::nullopt,
PjRtClient::HostBufferSemantics::kZeroCopy, nullptr,
client->addressable_devices()[0]));
PjRtClient::HostBufferSemantics::kImmutableZeroCopy,
nullptr, client->addressable_devices()[0]));

ExecuteOptions options;
options.execution_mode = GetParam();
Expand Down Expand Up @@ -249,8 +249,8 @@ TEST_P(PjRtClientTest, ExecuteWithDonationAbort) {
auto buffer, client->BufferFromHostBuffer(
data.data(), shape.element_type(), shape.dimensions(),
/*byte_strides=*/std::nullopt,
PjRtClient::HostBufferSemantics::kZeroCopy, nullptr,
client->addressable_devices()[0]));
PjRtClient::HostBufferSemantics::kImmutableZeroCopy,
nullptr, client->addressable_devices()[0]));

auto external_reference = buffer->AcquireExternalReference();

Expand Down Expand Up @@ -323,8 +323,8 @@ TEST_P(PjRtClientTest, ExecuteWithConcurrentUsageAndDonation) {
auto buffer, client->BufferFromHostBuffer(
data.data(), shape.element_type(), shape.dimensions(),
/*byte_strides=*/std::nullopt,
PjRtClient::HostBufferSemantics::kZeroCopy, nullptr,
client->addressable_devices()[0]));
PjRtClient::HostBufferSemantics::kImmutableZeroCopy,
nullptr, client->addressable_devices()[0]));

ExecuteOptions options;
options.execution_mode = GetParam();
Expand Down
2 changes: 1 addition & 1 deletion xla/pjrt/pjrt_stream_executor_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1805,7 +1805,7 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtStreamExecutorBuffer::CopyToDevice(
literal_pointer->untyped_data(),
literal_pointer->shape().element_type(),
literal_pointer->shape().dimensions(), byte_strides,
PjRtStreamExecutorClient::HostBufferSemantics::kZeroCopy,
PjRtStreamExecutorClient::HostBufferSemantics::kImmutableZeroCopy,
[literal{std::move(literal)}]() { /* frees literal */ }, dst_device);
}

Expand Down
16 changes: 8 additions & 8 deletions xla/python/ifrt/array_impl_test_lib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ TEST_P(ArrayImplWithHostBufferSemanticsTest,

// Regardless of the host buffer semantics chosen, the host buffer must not be
// used by the runtime once `on_done_with_host_buffer` has been called.
if (semantics == Client::HostBufferSemantics::kZeroCopy) {
if (semantics == Client::HostBufferSemantics::kImmutableZeroCopy) {
// `on_done_with_host_buffer` is called only when the `Array` is destroyed
// if the runtime implements `kZeroCopy`. A deadlock will occur if we keep
// the `Array` instance.
Expand All @@ -108,7 +108,7 @@ INSTANTIATE_TEST_CASE_P(
testing::Values(
Client::HostBufferSemantics::kImmutableOnlyDuringCall,
Client::HostBufferSemantics::kImmutableUntilTransferCompletes,
Client::HostBufferSemantics::kZeroCopy));
Client::HostBufferSemantics::kImmutableZeroCopy));

TEST(ArrayImplTest, MakeArrayFromHostBufferImmutableOnlyDuringCall) {
TF_ASSERT_OK_AND_ASSIGN(auto client, test_util::GetClient());
Expand Down Expand Up @@ -184,12 +184,12 @@ TEST(ArrayImplTest, MakeArrayFromHostBufferZeroCopy) {
std::shared_ptr<const Sharding> sharding =
SingleDeviceSharding::Create(device, MemoryKind());

TF_ASSERT_OK_AND_ASSIGN(
auto array,
client->MakeArrayFromHostBuffer(data->data(), dtype, shape,
/*byte_strides=*/std::nullopt, sharding,
Client::HostBufferSemantics::kZeroCopy,
/*on_done_with_host_buffer=*/nullptr));
TF_ASSERT_OK_AND_ASSIGN(auto array,
client->MakeArrayFromHostBuffer(
data->data(), dtype, shape,
/*byte_strides=*/std::nullopt, sharding,
Client::HostBufferSemantics::kImmutableZeroCopy,
/*on_done_with_host_buffer=*/nullptr));

// The `Array` may alias the host buffer, but once the transfer is done and
// the `Array` is destroyed, the host buffer is not accessed. This test would
Expand Down
4 changes: 2 additions & 2 deletions xla/python/py_array.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1083,8 +1083,8 @@ StatusOr<PyArray> PyArray::BatchedDevicePut(
DevicePutOptions options;
options.squash_64bit_types = !jax_enable_x64;
options.allow_zero_copy =
(!force_copy &&
(host_buffer_semantics == ifrt::Client::HostBufferSemantics::kZeroCopy));
(!force_copy && (host_buffer_semantics ==
ifrt::Client::HostBufferSemantics::kImmutableZeroCopy));

nb::list owning_pylist;
std::vector<tsl::RCReference<ifrt::Array>> ifrt_arrays;
Expand Down
8 changes: 4 additions & 4 deletions xla/python/py_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -316,8 +316,8 @@ absl::Status PyClient::Defragment() {
DevicePutOptions options;
options.squash_64bit_types = false;
options.allow_zero_copy =
(!force_copy &&
(host_buffer_semantics == ifrt::Client::HostBufferSemantics::kZeroCopy));
(!force_copy && (host_buffer_semantics ==
ifrt::Client::HostBufferSemantics::kImmutableZeroCopy));
// TODO(phawkins): remove .ptr() after nanobind transition is complete.
TF_ASSIGN_OR_RETURN(DevicePutResult put,
DevicePut(argument.ptr(), client->ifrt_client_.get(),
Expand Down Expand Up @@ -710,7 +710,7 @@ PyType_Slot PyClient::slots_[] = {
PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall)
.value("IMMUTABLE_UNTIL_TRANSFER_COMPLETES",
PjRtClient::HostBufferSemantics::kImmutableUntilTransferCompletes)
.value("ZERO_COPY", PjRtClient::HostBufferSemantics::kZeroCopy);
.value("ZERO_COPY", PjRtClient::HostBufferSemantics::kImmutableZeroCopy);

nb::class_<PyClient> py_local_client(m, "Client", nb::is_weak_referenceable(),
nb::type_slots(PyClient::slots_));
Expand Down Expand Up @@ -742,7 +742,7 @@ PyType_Slot PyClient::slots_[] = {
nb::arg("argument"), nb::arg("device").none() = nullptr,
nb::arg("force_copy") = false,
nb::arg("host_buffer_semantics") =
PjRtClient::HostBufferSemantics::kZeroCopy)
PjRtClient::HostBufferSemantics::kImmutableZeroCopy)
.def(
"make_cross_host_receive_buffers",
[](nb_class_ptr<PyClient> client, absl::Span<const Shape> shapes,
Expand Down
3 changes: 2 additions & 1 deletion xla/python/py_values.cc
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,8 @@ absl::StatusOr<DevicePutResult> HandleNumpyArray(
on_done_with_host_buffer =
[py_buffer_ref{
std::move(py_buffer_ref)}]() { /* keeps py_buffer_ref alive */ };
host_buffer_semantics = ifrt::Client::HostBufferSemantics::kZeroCopy;
host_buffer_semantics =
ifrt::Client::HostBufferSemantics::kImmutableZeroCopy;
}
// Must release the GIL before BufferFromHostBuffer because backends may
// decide to block/sleep for device buffer allocation.
Expand Down
2 changes: 1 addition & 1 deletion xla/python/xla.cc
Original file line number Diff line number Diff line change
Expand Up @@ -933,7 +933,7 @@ NB_MODULE(xla_extension, m_nb) {
nb::arg("aval"), nb::arg("sharding"), nb::arg("xs"), nb::arg("devices"),
nb::arg("committed") = true, nb::arg("force_copy") = false,
nb::arg("host_buffer_semantics") =
PjRtClient::HostBufferSemantics::kZeroCopy);
PjRtClient::HostBufferSemantics::kImmutableZeroCopy);

m_nb.def("batched_block_until_ready", [](std::vector<nb::object> xs) {
ThrowIfError(PyArray::BatchedBlockUntilReady(std::move(xs)));
Expand Down

0 comments on commit 570a4d8

Please sign in to comment.