Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

In progress experimention for supporting JAX Arrays with variable-width strings (i.e., with dtype = StringDType). #20642

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions xla/python/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,7 @@ cc_library(
# placeholder for index annotation deps
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base",
"@com_google_absl//absl/cleanup",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:inlined_vector",
Expand All @@ -373,6 +374,7 @@ cc_library(
"@llvm-project//mlir:Pass",
"@nanobind",
"@shardy//shardy/dialect/sdy/ir:dialect",
"//third_party/py/numpy:headers",
"@local_config_python//:python_headers", # buildcleaner: keep
"//xla:comparison_util",
"//xla:literal",
Expand Down
83 changes: 83 additions & 0 deletions xla/python/py_array.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ limitations under the License.
#include "absl/base/casts.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/log/log.h"
#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
Expand Down Expand Up @@ -88,6 +89,7 @@ limitations under the License.
#include "xla/shape_util.h"
#include "xla/status_macros.h"
#include "xla/tsl/concurrency/ref_count.h"
#include "xla/tsl/python/lib/core/numpy.h" // IWYU pragma: keep
#include "xla/util.h"
#include "xla/xla_data.pb.h"
#include "tsl/platform/errors.h"
Expand Down Expand Up @@ -1563,15 +1565,96 @@ absl::StatusOr<nb::object> PyHostValue::AsNumPyArray(
} else {
TF_RETURN_IF_ERROR(ready_.Await());
}

TF_RETURN_IF_ERROR(ConvertStringArrayContentsToNumpyArray(ifrt_array));
return value_;
}

absl::Status PyHostValue::ConvertStringArrayContentsToNumpyArray(
ifrt::Array* ifrt_array) {
if (string_array_contents_ == nullptr) {
return absl::OkStatus();
}

#ifndef NPY_2_0_API_VERSION
return absl::FailedPreconditionError(
"String arrays are not supported in this NumPy version.");
#else

auto numpy_dtype = nb::steal<nb_dtype>(
reinterpret_cast<PyObject*>(PyArray_DescrFromType(NPY_VSTRING)));
value_ = nb_numpy_ndarray(numpy_dtype, ifrt_array->shape().dims(),
/*strides=*/std::nullopt);

auto dst_py_array_obj = reinterpret_cast<::PyArrayObject*>(value_.ptr());
auto iter =
nb::steal(PyArray_IterNew(reinterpret_cast<PyObject*>(dst_py_array_obj)));
for (auto& cord : *string_array_contents_) {
absl::string_view input_str_view = cord.Flatten();
auto py_unicode = nb::steal(PyUnicode_FromStringAndSize(
input_str_view.data(), input_str_view.size()));
if (py_unicode.ptr() == nullptr) {
return absl::InternalError("PyUnicode_FromStringAndSize failed");
}
if (PyArray_SETITEM(dst_py_array_obj,
static_cast<char*>(PyArray_ITER_DATA(iter.ptr())),
py_unicode.ptr()) != 0) {
return absl::InternalError("PyArray_SETITEM failed");
}
PyArray_ITER_NEXT(iter.ptr());
}

value_.attr("flags").attr("writeable") = nb::bool_(false);

string_array_contents_.reset();

return absl::OkStatus();

#endif
}

absl::Status PyHostValue::CopyStringArrayToHostAsync(
std::optional<Shape>& dynamic_shape_holder, ifrt::Array* ifrt_array) {
auto transfer_guard_formatter = [ifrt_array] {
return absl::StrCat(
"shape=(", absl::StrJoin(ifrt_array->shape().dims(), ","),
"), dtype=", ifrt_array->dtype().DebugString(), ", device=",
ifrt_array->sharding().devices()->devices().front()->DebugString());
};
TF_RETURN_IF_ERROR(
jax::ApplyTransferGuardToDeviceToHost(transfer_guard_formatter));

TF_ASSIGN_OR_RETURN(nb_dtype dtype, IfrtDtypeToNbDtype(ifrt_array->dtype()));
auto shape = ifrt_array->shape();

// Allocate a vector of cords to hold the contents of the array until
// they are until they are ultimately converted to a numpy array as part
// of the `AsNumPyArray` call.
string_array_contents_ =
std::make_shared<std::vector<absl::Cord>>(shape.num_elements());
ready_ = ifrt_array->CopyToHostBuffer(string_array_contents_->data(),
/*byte_strides=*/std::nullopt,
ifrt::ArrayCopySemantics::kAlwaysCopy);

ready_.OnReady(
[string_array_contents = string_array_contents_](absl::Status) {
}); // Keeps the cords alive until the copy is done.

return absl::OkStatus();
}

absl::Status PyHostValue::CopyToHostAsync(
std::optional<Shape>& dynamic_shape_holder, ifrt::Array* ifrt_array) {
if (ready_.IsValid()) {
// The array value has been populated, so CopyToHostAsync has been called.
return absl::OkStatus();
}

// Copying in Arrays of type kString requires some special handling
if (ifrt_array->dtype().kind() == ifrt::DType::kString) {
return CopyStringArrayToHostAsync(dynamic_shape_holder, ifrt_array);
}

auto* arr = llvm::dyn_cast_or_null<ifrt::PjRtCompatibleArray>(ifrt_array);
if (arr != nullptr && !arr->pjrt_buffers().front()->IsTuple() &&
IsZeroCopyableCpuBuffer(arr->pjrt_buffers().front().get())) {
Expand Down
12 changes: 12 additions & 0 deletions xla/python/py_array.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ limitations under the License.
#include "absl/log/check.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/cord.h"
#include "absl/types/span.h"
#include "llvm/Support/Casting.h"
#include "nanobind/nanobind.h"
Expand Down Expand Up @@ -69,8 +70,19 @@ class PyHostValue {
std::optional<Shape>& dynamic_shape_holder, ifrt::Array* ifrt_array);

private:
absl::Status CopyStringArrayToHostAsync(
std::optional<Shape>& dynamic_shape_holder, ifrt::Array* ifrt_array);

absl::Status ConvertStringArrayContentsToNumpyArray(ifrt::Array* ifrt_array);

ifrt::Future<> ready_;
nb_numpy_ndarray value_;

// Optional field, only used for arrays of type kString. This vector of cords
// serves as input buffer for the CopyToHostBuffer call. It holds these
// contents until it is lazily converted it to a numpy array when the user
// calls `AsNumPyArray`.
std::shared_ptr<std::vector<absl::Cord>> string_array_contents_;
};

// Private to PyArray, but you cannot forward declare member classes.
Expand Down
96 changes: 96 additions & 0 deletions xla/python/py_values.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,20 @@ limitations under the License.
#include <exception>
#include <functional>
#include <memory>
#include <optional>
#include <string>
#include <type_traits>
#include <utility>
#include <variant>
#include <vector>

#include "absl/cleanup/cleanup.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/inlined_vector.h"
#include "absl/log/log.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/cord.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "absl/strings/string_view.h"
Expand Down Expand Up @@ -66,6 +72,53 @@ namespace xla {

namespace {

absl::StatusOr<std::vector<absl::Cord>> StringDTypeArrayToCords(
PyArrayObject* py_array_obj) {
if (PyArray_SIZE(py_array_obj) == 0) {
return absl::InvalidArgumentError("empty numpy array");
}

NpyIter* iter =
NpyIter_New(py_array_obj,
NPY_ITER_READONLY | NPY_ITER_EXTERNAL_LOOP | NPY_ITER_REFS_OK,
NPY_CORDER, NPY_NO_CASTING, nullptr);
if (iter == nullptr) {
return absl::InternalError("failed to make numpy iterator");
}
absl::Cleanup cleanup = [iter] { NpyIter_Deallocate(iter); };

NpyIter_IterNextFunc* iternext = NpyIter_GetIterNext(iter, nullptr);
if (iternext == nullptr) {
NpyIter_Deallocate(iter);
return absl::InternalError("failed to get the next iterator-function");
}

// Pointers to the data, stride and inner_size in the iterator. Updated when
// the iterator is advanced.
char** dataptr = NpyIter_GetDataPtrArray(iter);
npy_intp* strideptr = NpyIter_GetInnerStrideArray(iter);
npy_intp* innersizeptr = NpyIter_GetInnerLoopSizePtr(iter);

std::vector<absl::Cord> cords;
cords.reserve(PyArray_SIZE(py_array_obj));

do {
char* data = *dataptr;
npy_intp stride = *strideptr;
npy_intp count = *innersizeptr;

while (count--) {
auto py_unicode = PyArray_GETITEM(py_array_obj, data);
Py_ssize_t len;
auto str = PyUnicode_AsUTF8AndSize(py_unicode, &len);
cords.push_back(absl::Cord(absl::string_view(str, len)));
data += stride;
}
} while (iternext(iter));

return cords;
}

using DevicePutFunc = std::function<absl::StatusOr<DevicePutResultFn>(
nb::handle, ifrt::Client*, ifrt::Device*, const DevicePutOptions& options,
ifrt::MemoryKind to_memory_kind)>;
Expand Down Expand Up @@ -246,10 +299,53 @@ absl::StatusOr<DevicePutResultFn> HandleNumpyScalar(
};
}

absl::StatusOr<DevicePutResultFn> HandleStringNumpyArray(
nb::handle h, ifrt::Client* client, ifrt::Device* to_device,
const DevicePutOptions& options, ifrt::MemoryKind to_memory_kind) {
xla::nb_numpy_ndarray array = nb::cast<xla::nb_numpy_ndarray>(h);
auto py_array_obj = reinterpret_cast<PyArrayObject*>(array.ptr());
TF_ASSIGN_OR_RETURN(auto cords, StringDTypeArrayToCords(py_array_obj));
for (const auto& cord : cords) {
LOG(INFO) << "2D cord: " << cord;
}

// Assemble all the parameters of MakeArrayFromHostBuffer
void* data = cords.data();
ifrt::Shape shape(
absl::MakeSpan(static_cast<const int64_t*>(array.shape()), array.ndim()));
std::shared_ptr<xla::ifrt::Sharding> sharding =
xla::ifrt::SingleDeviceSharding::Create(to_device, to_memory_kind);

auto on_done_with_host_buffer = [cords = std::move(cords)] {};

return [client, data = data, shape = std::move(shape),
sharding = std::move(sharding),
on_done_with_host_buffer =
std::move(on_done_with_host_buffer)]() mutable
-> absl::StatusOr<DevicePutResult> {
TF_ASSIGN_OR_RETURN(
auto ifrt_array,
client->MakeArrayFromHostBuffer(
data, ifrt::DType(ifrt::DType::kString), std::move(shape),
/*byte_strides=*/std::nullopt, std::move(sharding),
ifrt::Client::HostBufferSemantics::kImmutableUntilTransferCompletes,
std::move(on_done_with_host_buffer)));

return DevicePutResult(std::move(ifrt_array), /*weak_type=*/false);
};
}

absl::StatusOr<DevicePutResultFn> HandleNumpyArray(
nb::handle h, ifrt::Client* client, ifrt::Device* to_device,
const DevicePutOptions& options, ifrt::MemoryKind to_memory_kind) {
xla::nb_numpy_ndarray array = nb::cast<xla::nb_numpy_ndarray>(h);

// String numpy arrays require substantially different processing.
if (array.dtype().char_() == 'T') {
return HandleStringNumpyArray(h, client, to_device, options,
to_memory_kind);
}

TF_ASSIGN_OR_RETURN(PrimitiveType type, DtypeToPrimitiveType(array.dtype()));

PrimitiveType squashed_type;
Expand Down
8 changes: 5 additions & 3 deletions xla/tsl/python/lib/core/numpy.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,11 @@ limitations under the License.
#include <Python.h>
// clang-format on

#include "numpy/arrayobject.h" // IWYU pragma: export
#include "numpy/npy_common.h" // IWYU pragma: export
#include "numpy/ufuncobject.h" // IWYU pragma: export
#include "numpy/arrayobject.h" // IWYU pragma: export
#include "numpy/ndarraytypes.h" // IWYU pragma: export
#include "numpy/npy_common.h" // IWYU pragma: export
#include "numpy/numpyconfig.h" // IWYU pragma: export
#include "numpy/ufuncobject.h" // IWYU pragma: export

namespace tsl {

Expand Down