From 03c4c282cd0d3032fcaff1efec5234a797a64b76 Mon Sep 17 00:00:00 2001 From: xla authors Date: Tue, 17 Dec 2024 10:23:48 -0800 Subject: [PATCH] In progress experimention for supporting JAX Arrays with variable-width strings (i.e., with dtype = StringDType). PiperOrigin-RevId: 707153208 --- xla/python/BUILD | 2 + xla/python/py_array.cc | 83 ++++++++++++++++++++++++++++ xla/python/py_array.h | 12 +++++ xla/python/py_values.cc | 96 +++++++++++++++++++++++++++++++++ xla/tsl/python/lib/core/numpy.h | 8 +-- 5 files changed, 198 insertions(+), 3 deletions(-) diff --git a/xla/python/BUILD b/xla/python/BUILD index 2aef58884209c..05e8b0c623ac8 100644 --- a/xla/python/BUILD +++ b/xla/python/BUILD @@ -352,6 +352,7 @@ cc_library( # placeholder for index annotation deps "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base", + "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", @@ -373,6 +374,7 @@ cc_library( "@llvm-project//mlir:Pass", "@nanobind", "@shardy//shardy/dialect/sdy/ir:dialect", + "//third_party/py/numpy:headers", "@local_config_python//:python_headers", # buildcleaner: keep "//xla:comparison_util", "//xla:literal", diff --git a/xla/python/py_array.cc b/xla/python/py_array.cc index 5dd40c177a42f..d891d3e3f112c 100644 --- a/xla/python/py_array.cc +++ b/xla/python/py_array.cc @@ -36,6 +36,7 @@ limitations under the License. #include "absl/base/casts.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/log/log.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" @@ -88,6 +89,7 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/status_macros.h" #include "xla/tsl/concurrency/ref_count.h" +#include "xla/tsl/python/lib/core/numpy.h" // IWYU pragma: keep #include "xla/util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" @@ -1563,15 +1565,96 @@ absl::StatusOr PyHostValue::AsNumPyArray( } else { TF_RETURN_IF_ERROR(ready_.Await()); } + + TF_RETURN_IF_ERROR(ConvertStringArrayContentsToNumpyArray(ifrt_array)); return value_; } +absl::Status PyHostValue::ConvertStringArrayContentsToNumpyArray( + ifrt::Array* ifrt_array) { + if (string_array_contents_ == nullptr) { + return absl::OkStatus(); + } + +#ifndef NPY_2_0_API_VERSION + return absl::FailedPreconditionError( + "String arrays are not supported in this NumPy version."); +#else + + auto numpy_dtype = nb::steal( + reinterpret_cast(PyArray_DescrFromType(NPY_VSTRING))); + value_ = nb_numpy_ndarray(numpy_dtype, ifrt_array->shape().dims(), + /*strides=*/std::nullopt); + + auto dst_py_array_obj = reinterpret_cast<::PyArrayObject*>(value_.ptr()); + auto iter = + nb::steal(PyArray_IterNew(reinterpret_cast(dst_py_array_obj))); + for (auto& cord : *string_array_contents_) { + absl::string_view input_str_view = cord.Flatten(); + auto py_unicode = nb::steal(PyUnicode_FromStringAndSize( + input_str_view.data(), input_str_view.size())); + if (py_unicode.ptr() == nullptr) { + return absl::InternalError("PyUnicode_FromStringAndSize failed"); + } + if (PyArray_SETITEM(dst_py_array_obj, + static_cast(PyArray_ITER_DATA(iter.ptr())), + py_unicode.ptr()) != 0) { + return absl::InternalError("PyArray_SETITEM failed"); + } + PyArray_ITER_NEXT(iter.ptr()); + } + + value_.attr("flags").attr("writeable") = nb::bool_(false); + + string_array_contents_.reset(); + + return absl::OkStatus(); + +#endif +} + +absl::Status PyHostValue::CopyStringArrayToHostAsync( + std::optional& dynamic_shape_holder, ifrt::Array* ifrt_array) { + auto transfer_guard_formatter = [ifrt_array] { + return absl::StrCat( + "shape=(", absl::StrJoin(ifrt_array->shape().dims(), ","), + "), dtype=", ifrt_array->dtype().DebugString(), ", device=", + ifrt_array->sharding().devices()->devices().front()->DebugString()); + }; + TF_RETURN_IF_ERROR( + jax::ApplyTransferGuardToDeviceToHost(transfer_guard_formatter)); + + TF_ASSIGN_OR_RETURN(nb_dtype dtype, IfrtDtypeToNbDtype(ifrt_array->dtype())); + auto shape = ifrt_array->shape(); + + // Allocate a vector of cords to hold the contents of the array until + // they are until they are ultimately converted to a numpy array as part + // of the `AsNumPyArray` call. + string_array_contents_ = + std::make_shared>(shape.num_elements()); + ready_ = ifrt_array->CopyToHostBuffer(string_array_contents_->data(), + /*byte_strides=*/std::nullopt, + ifrt::ArrayCopySemantics::kAlwaysCopy); + + ready_.OnReady( + [string_array_contents = string_array_contents_](absl::Status) { + }); // Keeps the cords alive until the copy is done. + + return absl::OkStatus(); +} + absl::Status PyHostValue::CopyToHostAsync( std::optional& dynamic_shape_holder, ifrt::Array* ifrt_array) { if (ready_.IsValid()) { // The array value has been populated, so CopyToHostAsync has been called. return absl::OkStatus(); } + + // Copying in Arrays of type kString requires some special handling + if (ifrt_array->dtype().kind() == ifrt::DType::kString) { + return CopyStringArrayToHostAsync(dynamic_shape_holder, ifrt_array); + } + auto* arr = llvm::dyn_cast_or_null(ifrt_array); if (arr != nullptr && !arr->pjrt_buffers().front()->IsTuple() && IsZeroCopyableCpuBuffer(arr->pjrt_buffers().front().get())) { diff --git a/xla/python/py_array.h b/xla/python/py_array.h index 46c2279224b81..9233341417b0d 100644 --- a/xla/python/py_array.h +++ b/xla/python/py_array.h @@ -29,6 +29,7 @@ limitations under the License. #include "absl/log/check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/cord.h" #include "absl/types/span.h" #include "llvm/Support/Casting.h" #include "nanobind/nanobind.h" @@ -69,8 +70,19 @@ class PyHostValue { std::optional& dynamic_shape_holder, ifrt::Array* ifrt_array); private: + absl::Status CopyStringArrayToHostAsync( + std::optional& dynamic_shape_holder, ifrt::Array* ifrt_array); + + absl::Status ConvertStringArrayContentsToNumpyArray(ifrt::Array* ifrt_array); + ifrt::Future<> ready_; nb_numpy_ndarray value_; + + // Optional field, only used for arrays of type kString. This vector of cords + // serves as input buffer for the CopyToHostBuffer call. It holds these + // contents until it is lazily converted it to a numpy array when the user + // calls `AsNumPyArray`. + std::shared_ptr> string_array_contents_; }; // Private to PyArray, but you cannot forward declare member classes. diff --git a/xla/python/py_values.cc b/xla/python/py_values.cc index 631b0bcb9b956..391b80b47c944 100644 --- a/xla/python/py_values.cc +++ b/xla/python/py_values.cc @@ -21,14 +21,20 @@ limitations under the License. #include #include #include +#include #include #include #include #include +#include +#include "absl/cleanup/cleanup.h" #include "absl/container/flat_hash_map.h" #include "absl/container/inlined_vector.h" +#include "absl/log/log.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/cord.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" @@ -66,6 +72,53 @@ namespace xla { namespace { +absl::StatusOr> StringDTypeArrayToCords( + PyArrayObject* py_array_obj) { + if (PyArray_SIZE(py_array_obj) == 0) { + return absl::InvalidArgumentError("empty numpy array"); + } + + NpyIter* iter = + NpyIter_New(py_array_obj, + NPY_ITER_READONLY | NPY_ITER_EXTERNAL_LOOP | NPY_ITER_REFS_OK, + NPY_CORDER, NPY_NO_CASTING, nullptr); + if (iter == nullptr) { + return absl::InternalError("failed to make numpy iterator"); + } + absl::Cleanup cleanup = [iter] { NpyIter_Deallocate(iter); }; + + NpyIter_IterNextFunc* iternext = NpyIter_GetIterNext(iter, nullptr); + if (iternext == nullptr) { + NpyIter_Deallocate(iter); + return absl::InternalError("failed to get the next iterator-function"); + } + + // Pointers to the data, stride and inner_size in the iterator. Updated when + // the iterator is advanced. + char** dataptr = NpyIter_GetDataPtrArray(iter); + npy_intp* strideptr = NpyIter_GetInnerStrideArray(iter); + npy_intp* innersizeptr = NpyIter_GetInnerLoopSizePtr(iter); + + std::vector cords; + cords.reserve(PyArray_SIZE(py_array_obj)); + + do { + char* data = *dataptr; + npy_intp stride = *strideptr; + npy_intp count = *innersizeptr; + + while (count--) { + auto py_unicode = PyArray_GETITEM(py_array_obj, data); + Py_ssize_t len; + auto str = PyUnicode_AsUTF8AndSize(py_unicode, &len); + cords.push_back(absl::Cord(absl::string_view(str, len))); + data += stride; + } + } while (iternext(iter)); + + return cords; +} + using DevicePutFunc = std::function( nb::handle, ifrt::Client*, ifrt::Device*, const DevicePutOptions& options, ifrt::MemoryKind to_memory_kind)>; @@ -246,10 +299,53 @@ absl::StatusOr HandleNumpyScalar( }; } +absl::StatusOr HandleStringNumpyArray( + nb::handle h, ifrt::Client* client, ifrt::Device* to_device, + const DevicePutOptions& options, ifrt::MemoryKind to_memory_kind) { + xla::nb_numpy_ndarray array = nb::cast(h); + auto py_array_obj = reinterpret_cast(array.ptr()); + TF_ASSIGN_OR_RETURN(auto cords, StringDTypeArrayToCords(py_array_obj)); + for (const auto& cord : cords) { + LOG(INFO) << "2D cord: " << cord; + } + + // Assemble all the parameters of MakeArrayFromHostBuffer + void* data = cords.data(); + ifrt::Shape shape( + absl::MakeSpan(static_cast(array.shape()), array.ndim())); + std::shared_ptr sharding = + xla::ifrt::SingleDeviceSharding::Create(to_device, to_memory_kind); + + auto on_done_with_host_buffer = [cords = std::move(cords)] {}; + + return [client, data = data, shape = std::move(shape), + sharding = std::move(sharding), + on_done_with_host_buffer = + std::move(on_done_with_host_buffer)]() mutable + -> absl::StatusOr { + TF_ASSIGN_OR_RETURN( + auto ifrt_array, + client->MakeArrayFromHostBuffer( + data, ifrt::DType(ifrt::DType::kString), std::move(shape), + /*byte_strides=*/std::nullopt, std::move(sharding), + ifrt::Client::HostBufferSemantics::kImmutableUntilTransferCompletes, + std::move(on_done_with_host_buffer))); + + return DevicePutResult(std::move(ifrt_array), /*weak_type=*/false); + }; +} + absl::StatusOr HandleNumpyArray( nb::handle h, ifrt::Client* client, ifrt::Device* to_device, const DevicePutOptions& options, ifrt::MemoryKind to_memory_kind) { xla::nb_numpy_ndarray array = nb::cast(h); + + // String numpy arrays require substantially different processing. + if (array.dtype().char_() == 'T') { + return HandleStringNumpyArray(h, client, to_device, options, + to_memory_kind); + } + TF_ASSIGN_OR_RETURN(PrimitiveType type, DtypeToPrimitiveType(array.dtype())); PrimitiveType squashed_type; diff --git a/xla/tsl/python/lib/core/numpy.h b/xla/tsl/python/lib/core/numpy.h index ca57a0370548e..307c253d111fc 100644 --- a/xla/tsl/python/lib/core/numpy.h +++ b/xla/tsl/python/lib/core/numpy.h @@ -43,9 +43,11 @@ limitations under the License. #include // clang-format on -#include "numpy/arrayobject.h" // IWYU pragma: export -#include "numpy/npy_common.h" // IWYU pragma: export -#include "numpy/ufuncobject.h" // IWYU pragma: export +#include "numpy/arrayobject.h" // IWYU pragma: export +#include "numpy/ndarraytypes.h" // IWYU pragma: export +#include "numpy/npy_common.h" // IWYU pragma: export +#include "numpy/numpyconfig.h" // IWYU pragma: export +#include "numpy/ufuncobject.h" // IWYU pragma: export namespace tsl {