Skip to content

Commit

Permalink
In progress experimention for supporting JAX Arrays with variable-wid…
Browse files Browse the repository at this point in the history
…th strings (i.e., with dtype = StringDType).

PiperOrigin-RevId: 707153208
  • Loading branch information
Google-ML-Automation committed Dec 17, 2024
1 parent 09bc536 commit 03c4c28
Show file tree
Hide file tree
Showing 5 changed files with 198 additions and 3 deletions.
2 changes: 2 additions & 0 deletions xla/python/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,7 @@ cc_library(
# placeholder for index annotation deps
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base",
"@com_google_absl//absl/cleanup",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:inlined_vector",
Expand All @@ -373,6 +374,7 @@ cc_library(
"@llvm-project//mlir:Pass",
"@nanobind",
"@shardy//shardy/dialect/sdy/ir:dialect",
"//third_party/py/numpy:headers",
"@local_config_python//:python_headers", # buildcleaner: keep
"//xla:comparison_util",
"//xla:literal",
Expand Down
83 changes: 83 additions & 0 deletions xla/python/py_array.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ limitations under the License.
#include "absl/base/casts.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/log/log.h"
#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
Expand Down Expand Up @@ -88,6 +89,7 @@ limitations under the License.
#include "xla/shape_util.h"
#include "xla/status_macros.h"
#include "xla/tsl/concurrency/ref_count.h"
#include "xla/tsl/python/lib/core/numpy.h" // IWYU pragma: keep
#include "xla/util.h"
#include "xla/xla_data.pb.h"
#include "tsl/platform/errors.h"
Expand Down Expand Up @@ -1563,15 +1565,96 @@ absl::StatusOr<nb::object> PyHostValue::AsNumPyArray(
} else {
TF_RETURN_IF_ERROR(ready_.Await());
}

TF_RETURN_IF_ERROR(ConvertStringArrayContentsToNumpyArray(ifrt_array));
return value_;
}

absl::Status PyHostValue::ConvertStringArrayContentsToNumpyArray(
ifrt::Array* ifrt_array) {
if (string_array_contents_ == nullptr) {
return absl::OkStatus();
}

#ifndef NPY_2_0_API_VERSION
return absl::FailedPreconditionError(
"String arrays are not supported in this NumPy version.");
#else

auto numpy_dtype = nb::steal<nb_dtype>(
reinterpret_cast<PyObject*>(PyArray_DescrFromType(NPY_VSTRING)));
value_ = nb_numpy_ndarray(numpy_dtype, ifrt_array->shape().dims(),
/*strides=*/std::nullopt);

auto dst_py_array_obj = reinterpret_cast<::PyArrayObject*>(value_.ptr());
auto iter =
nb::steal(PyArray_IterNew(reinterpret_cast<PyObject*>(dst_py_array_obj)));
for (auto& cord : *string_array_contents_) {
absl::string_view input_str_view = cord.Flatten();
auto py_unicode = nb::steal(PyUnicode_FromStringAndSize(
input_str_view.data(), input_str_view.size()));
if (py_unicode.ptr() == nullptr) {
return absl::InternalError("PyUnicode_FromStringAndSize failed");
}
if (PyArray_SETITEM(dst_py_array_obj,
static_cast<char*>(PyArray_ITER_DATA(iter.ptr())),
py_unicode.ptr()) != 0) {
return absl::InternalError("PyArray_SETITEM failed");
}
PyArray_ITER_NEXT(iter.ptr());
}

value_.attr("flags").attr("writeable") = nb::bool_(false);

string_array_contents_.reset();

return absl::OkStatus();

#endif
}

absl::Status PyHostValue::CopyStringArrayToHostAsync(
std::optional<Shape>& dynamic_shape_holder, ifrt::Array* ifrt_array) {
auto transfer_guard_formatter = [ifrt_array] {
return absl::StrCat(
"shape=(", absl::StrJoin(ifrt_array->shape().dims(), ","),
"), dtype=", ifrt_array->dtype().DebugString(), ", device=",
ifrt_array->sharding().devices()->devices().front()->DebugString());
};
TF_RETURN_IF_ERROR(
jax::ApplyTransferGuardToDeviceToHost(transfer_guard_formatter));

TF_ASSIGN_OR_RETURN(nb_dtype dtype, IfrtDtypeToNbDtype(ifrt_array->dtype()));
auto shape = ifrt_array->shape();

// Allocate a vector of cords to hold the contents of the array until
// they are until they are ultimately converted to a numpy array as part
// of the `AsNumPyArray` call.
string_array_contents_ =
std::make_shared<std::vector<absl::Cord>>(shape.num_elements());
ready_ = ifrt_array->CopyToHostBuffer(string_array_contents_->data(),
/*byte_strides=*/std::nullopt,
ifrt::ArrayCopySemantics::kAlwaysCopy);

ready_.OnReady(
[string_array_contents = string_array_contents_](absl::Status) {
}); // Keeps the cords alive until the copy is done.

return absl::OkStatus();
}

absl::Status PyHostValue::CopyToHostAsync(
std::optional<Shape>& dynamic_shape_holder, ifrt::Array* ifrt_array) {
if (ready_.IsValid()) {
// The array value has been populated, so CopyToHostAsync has been called.
return absl::OkStatus();
}

// Copying in Arrays of type kString requires some special handling
if (ifrt_array->dtype().kind() == ifrt::DType::kString) {
return CopyStringArrayToHostAsync(dynamic_shape_holder, ifrt_array);
}

auto* arr = llvm::dyn_cast_or_null<ifrt::PjRtCompatibleArray>(ifrt_array);
if (arr != nullptr && !arr->pjrt_buffers().front()->IsTuple() &&
IsZeroCopyableCpuBuffer(arr->pjrt_buffers().front().get())) {
Expand Down
12 changes: 12 additions & 0 deletions xla/python/py_array.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ limitations under the License.
#include "absl/log/check.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/cord.h"
#include "absl/types/span.h"
#include "llvm/Support/Casting.h"
#include "nanobind/nanobind.h"
Expand Down Expand Up @@ -69,8 +70,19 @@ class PyHostValue {
std::optional<Shape>& dynamic_shape_holder, ifrt::Array* ifrt_array);

private:
absl::Status CopyStringArrayToHostAsync(
std::optional<Shape>& dynamic_shape_holder, ifrt::Array* ifrt_array);

absl::Status ConvertStringArrayContentsToNumpyArray(ifrt::Array* ifrt_array);

ifrt::Future<> ready_;
nb_numpy_ndarray value_;

// Optional field, only used for arrays of type kString. This vector of cords
// serves as input buffer for the CopyToHostBuffer call. It holds these
// contents until it is lazily converted it to a numpy array when the user
// calls `AsNumPyArray`.
std::shared_ptr<std::vector<absl::Cord>> string_array_contents_;
};

// Private to PyArray, but you cannot forward declare member classes.
Expand Down
96 changes: 96 additions & 0 deletions xla/python/py_values.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,20 @@ limitations under the License.
#include <exception>
#include <functional>
#include <memory>
#include <optional>
#include <string>
#include <type_traits>
#include <utility>
#include <variant>
#include <vector>

#include "absl/cleanup/cleanup.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/inlined_vector.h"
#include "absl/log/log.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/cord.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "absl/strings/string_view.h"
Expand Down Expand Up @@ -66,6 +72,53 @@ namespace xla {

namespace {

absl::StatusOr<std::vector<absl::Cord>> StringDTypeArrayToCords(
PyArrayObject* py_array_obj) {
if (PyArray_SIZE(py_array_obj) == 0) {
return absl::InvalidArgumentError("empty numpy array");
}

NpyIter* iter =
NpyIter_New(py_array_obj,
NPY_ITER_READONLY | NPY_ITER_EXTERNAL_LOOP | NPY_ITER_REFS_OK,
NPY_CORDER, NPY_NO_CASTING, nullptr);
if (iter == nullptr) {
return absl::InternalError("failed to make numpy iterator");
}
absl::Cleanup cleanup = [iter] { NpyIter_Deallocate(iter); };

NpyIter_IterNextFunc* iternext = NpyIter_GetIterNext(iter, nullptr);
if (iternext == nullptr) {
NpyIter_Deallocate(iter);
return absl::InternalError("failed to get the next iterator-function");
}

// Pointers to the data, stride and inner_size in the iterator. Updated when
// the iterator is advanced.
char** dataptr = NpyIter_GetDataPtrArray(iter);
npy_intp* strideptr = NpyIter_GetInnerStrideArray(iter);
npy_intp* innersizeptr = NpyIter_GetInnerLoopSizePtr(iter);

std::vector<absl::Cord> cords;
cords.reserve(PyArray_SIZE(py_array_obj));

do {
char* data = *dataptr;
npy_intp stride = *strideptr;
npy_intp count = *innersizeptr;

while (count--) {
auto py_unicode = PyArray_GETITEM(py_array_obj, data);
Py_ssize_t len;
auto str = PyUnicode_AsUTF8AndSize(py_unicode, &len);
cords.push_back(absl::Cord(absl::string_view(str, len)));
data += stride;
}
} while (iternext(iter));

return cords;
}

using DevicePutFunc = std::function<absl::StatusOr<DevicePutResultFn>(
nb::handle, ifrt::Client*, ifrt::Device*, const DevicePutOptions& options,
ifrt::MemoryKind to_memory_kind)>;
Expand Down Expand Up @@ -246,10 +299,53 @@ absl::StatusOr<DevicePutResultFn> HandleNumpyScalar(
};
}

absl::StatusOr<DevicePutResultFn> HandleStringNumpyArray(
nb::handle h, ifrt::Client* client, ifrt::Device* to_device,
const DevicePutOptions& options, ifrt::MemoryKind to_memory_kind) {
xla::nb_numpy_ndarray array = nb::cast<xla::nb_numpy_ndarray>(h);
auto py_array_obj = reinterpret_cast<PyArrayObject*>(array.ptr());
TF_ASSIGN_OR_RETURN(auto cords, StringDTypeArrayToCords(py_array_obj));
for (const auto& cord : cords) {
LOG(INFO) << "2D cord: " << cord;
}

// Assemble all the parameters of MakeArrayFromHostBuffer
void* data = cords.data();
ifrt::Shape shape(
absl::MakeSpan(static_cast<const int64_t*>(array.shape()), array.ndim()));
std::shared_ptr<xla::ifrt::Sharding> sharding =
xla::ifrt::SingleDeviceSharding::Create(to_device, to_memory_kind);

auto on_done_with_host_buffer = [cords = std::move(cords)] {};

return [client, data = data, shape = std::move(shape),
sharding = std::move(sharding),
on_done_with_host_buffer =
std::move(on_done_with_host_buffer)]() mutable
-> absl::StatusOr<DevicePutResult> {
TF_ASSIGN_OR_RETURN(
auto ifrt_array,
client->MakeArrayFromHostBuffer(
data, ifrt::DType(ifrt::DType::kString), std::move(shape),
/*byte_strides=*/std::nullopt, std::move(sharding),
ifrt::Client::HostBufferSemantics::kImmutableUntilTransferCompletes,
std::move(on_done_with_host_buffer)));

return DevicePutResult(std::move(ifrt_array), /*weak_type=*/false);
};
}

absl::StatusOr<DevicePutResultFn> HandleNumpyArray(
nb::handle h, ifrt::Client* client, ifrt::Device* to_device,
const DevicePutOptions& options, ifrt::MemoryKind to_memory_kind) {
xla::nb_numpy_ndarray array = nb::cast<xla::nb_numpy_ndarray>(h);

// String numpy arrays require substantially different processing.
if (array.dtype().char_() == 'T') {
return HandleStringNumpyArray(h, client, to_device, options,
to_memory_kind);
}

TF_ASSIGN_OR_RETURN(PrimitiveType type, DtypeToPrimitiveType(array.dtype()));

PrimitiveType squashed_type;
Expand Down
8 changes: 5 additions & 3 deletions xla/tsl/python/lib/core/numpy.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,11 @@ limitations under the License.
#include <Python.h>
// clang-format on

#include "numpy/arrayobject.h" // IWYU pragma: export
#include "numpy/npy_common.h" // IWYU pragma: export
#include "numpy/ufuncobject.h" // IWYU pragma: export
#include "numpy/arrayobject.h" // IWYU pragma: export
#include "numpy/ndarraytypes.h" // IWYU pragma: export
#include "numpy/npy_common.h" // IWYU pragma: export
#include "numpy/numpyconfig.h" // IWYU pragma: export
#include "numpy/ufuncobject.h" // IWYU pragma: export

namespace tsl {

Expand Down

0 comments on commit 03c4c28

Please sign in to comment.