Skip to content

Commit

Permalink
Use absl::string_view instead of std::string_view as some environment…
Browse files Browse the repository at this point in the history
…s (e.g. Android) don't provide std::string_view.

PiperOrigin-RevId: 707210600
  • Loading branch information
klucke authored and Google-ML-Automation committed Dec 17, 2024
1 parent 9ba7f35 commit 2225cf1
Show file tree
Hide file tree
Showing 35 changed files with 160 additions and 175 deletions.
9 changes: 9 additions & 0 deletions xla/python/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,7 @@ cc_library(
"@com_google_absl//absl/base",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/hash",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@nanobind",
Expand Down Expand Up @@ -502,12 +503,15 @@ cc_library(
"//xla:comparison_util",
"//xla/pjrt:exceptions",
"//xla/pjrt:host_callback",
"//xla/pjrt:transpose",
"//xla/service:custom_call_status",
"//xla/service:custom_call_target_registry",
"//xla/service:platform_util",
"@com_google_absl//absl/base",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
"@nanobind",
"@tsl//tsl/platform:errors",
] + if_rocm(
Expand Down Expand Up @@ -589,6 +593,7 @@ cc_library(
"@nanobind",
"@local_config_python//:python_headers", # build_cleaner: keep
"//xla/pjrt:pjrt_client",
"//xla/pjrt:pjrt_layout",
"//xla/pjrt:status_casters",
"@tsl//tsl/platform:logging",
"@tsl//tsl/profiler/lib:traceme",
Expand Down Expand Up @@ -631,6 +636,9 @@ cc_library(
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@tsl//tsl/platform:errors",
"@tsl//tsl/platform:status",
"@tsl//tsl/platform:statusor",
],
)

Expand Down Expand Up @@ -1422,6 +1430,7 @@ cc_library(
copts = ["-fexceptions"],
features = ["-use_header_modules"],
deps = [
"@com_google_absl//absl/strings:string_view",
"@com_google_absl//absl/types:span",
"@nanobind",
# copybara:uncomment "//third_party/py/numpy:multiarray",
Expand Down
7 changes: 3 additions & 4 deletions xla/python/callback.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ limitations under the License.
#include <memory>
#include <optional>
#include <string>
#include <string_view>
#include <utility>
#include <vector>

Expand All @@ -32,10 +31,10 @@ limitations under the License.
#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "nanobind/nanobind.h"
#include "nanobind/stl/string_view.h" // IWYU pragma: keep
#include "xla/pjrt/host_callback.h"
#include "xla/pjrt/transpose.h"
#include "xla/primitive_util.h"
#include "xla/python/nb_numpy.h"
Expand Down Expand Up @@ -127,7 +126,7 @@ absl::StatusOr<nb::tuple> CpuCallback::Call(nb::tuple args) {
if (!PyTuple_Check(result_object.ptr())) {
return absl::InternalError(
absl::StrFormat("CPU callback expected a tuple result, got %s",
nb::cast<std::string_view>(nb::repr(result_object))));
nb::cast<absl::string_view>(nb::repr(result_object))));
}
if (PyTuple_Size(result_object.ptr()) != results_.size()) {
return absl::InternalError(
Expand All @@ -142,7 +141,7 @@ absl::StatusOr<nb::tuple> CpuCallback::Call(nb::tuple args) {
if (!output.is_none()) {
return absl::InternalError(absl::StrFormat(
"Token output from Python callback should be None, got %s",
nb::cast<std::string_view>(nb::repr(output))));
nb::cast<absl::string_view>(nb::repr(output))));
}
continue;
}
Expand Down
11 changes: 5 additions & 6 deletions xla/python/custom_call_sharding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ limitations under the License.
#include <memory>
#include <optional>
#include <string>
#include <string_view>
#include <tuple>
#include <utility>
#include <vector>
Expand Down Expand Up @@ -93,7 +92,7 @@ class PyCustomCallPartitionerCallbacks {
xla::Shape result_shape = std::move(std::get<2>(args_tuple));
std::optional<xla::HloSharding> result_sharding =
std::move(std::get<3>(args_tuple));
std::string_view backend_config = std::move(std::get<4>(args_tuple));
absl::string_view backend_config = std::move(std::get<4>(args_tuple));

{
nb::gil_scoped_acquire gil;
Expand All @@ -118,7 +117,7 @@ class PyCustomCallPartitionerCallbacks {
return xla::Internal(
"Shardings returned from partitioning: expected "
"Tuple[bytes, List[HloSharding], HloSharding] got: %s",
nb::cast<std::string_view>(nb::repr(py_result)));
nb::cast<absl::string_view>(nb::repr(py_result)));
}
} catch (const nb::python_error& e) {
return xla::Internal("custom_partitioner: %s", e.what());
Expand All @@ -136,7 +135,7 @@ class PyCustomCallPartitionerCallbacks {
std::vector<std::optional<xla::HloSharding>> arg_shardings =
std::move(std::get<1>(args_tuple));
xla::Shape result_shape = std::move(std::get<2>(args_tuple));
std::string_view backend_config = std::move(std::get<3>(args_tuple));
absl::string_view backend_config = std::move(std::get<3>(args_tuple));

std::optional<HloSharding> result;
nb::gil_scoped_acquire gil;
Expand All @@ -161,7 +160,7 @@ class PyCustomCallPartitionerCallbacks {
TF_ASSIGN_OR_RETURN(auto args_tuple, jax::ReadArgs(args));
xla::HloSharding result_sharding = std::move(std::get<0>(args_tuple));
xla::Shape result_shape = std::move(std::get<1>(args_tuple));
std::string_view backend_config = std::move(std::get<2>(args_tuple));
absl::string_view backend_config = std::move(std::get<2>(args_tuple));

nb::gil_scoped_acquire gil;
try {
Expand Down Expand Up @@ -229,7 +228,7 @@ void BuildCustomCallShardingPybindAPI(nb::module_& m) {
return;
}

if (std::string_view(c_api->name()) != "pjrt_c_api") {
if (absl::string_view(c_api->name()) != "pjrt_c_api") {
throw absl::InvalidArgumentError(
"Argument to register_custom_call_partitioner was not a "
"pjrt_c_api capsule.");
Expand Down
27 changes: 15 additions & 12 deletions xla/python/custom_partition_callback.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ limitations under the License.
#include <memory>
#include <optional>
#include <string>
#include <string_view>
#include <tuple>
#include <utility>
#include <vector>
Expand All @@ -31,6 +30,7 @@ limitations under the License.
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "xla/debug_options_flags.h"
#include "xla/hlo/builder/xla_computation.h"
#include "xla/hlo/ir/hlo_casting_utils.h"
Expand All @@ -46,8 +46,11 @@ limitations under the License.
#include "xla/pjrt/mlir_to_hlo.h"
#include "xla/service/call_inliner.h"
#include "xla/service/custom_call_sharding_helper.h"
#include "xla/service/spmd/spmd_partitioner_util.h"
#include "xla/service/spmd/spmd_partitioner.h"
#include "xla/util.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/status.h"
#include "tsl/platform/statusor.h"

namespace xla {

Expand Down Expand Up @@ -202,8 +205,8 @@ void SetCAPIString(JAX_CustomCallPartitioner_string& out, std::string result,
out.size = scratch.back().size();
}

std::string_view ToStringView(JAX_CustomCallPartitioner_string data) {
return std::string_view(data.data, data.size);
absl::string_view ToStringView(JAX_CustomCallPartitioner_string data) {
return absl::string_view(data.data, data.size);
}

void SetCAPIAval(JAX_CustomCallPartitioner_aval& result,
Expand Down Expand Up @@ -343,7 +346,7 @@ PartitionScratch PopulateArgs(JAX_CustomCallPartitioner_Partition_Args* args,

absl::StatusOr<std::tuple<
std::vector<xla::Shape>, std::vector<std::optional<xla::HloSharding>>,
xla::Shape, std::optional<xla::HloSharding>, std::string_view>>
xla::Shape, std::optional<xla::HloSharding>, absl::string_view>>
ReadArgs(JAX_CustomCallPartitioner_Partition_Args* args) {
std::vector<xla::Shape> shapes;
std::vector<std::optional<xla::HloSharding>> shardings;
Expand All @@ -369,14 +372,14 @@ ReadArgs(JAX_CustomCallPartitioner_Partition_Args* args) {
}
return std::tuple<std::vector<xla::Shape>,
std::vector<std::optional<xla::HloSharding>>, xla::Shape,
std::optional<xla::HloSharding>, std::string_view>(
std::optional<xla::HloSharding>, absl::string_view>(
std::move(shapes), std::move(shardings), std::move(result_shape),
std::move(result_sharding), ToStringView(args->backend_config));
}

absl::StatusOr<std::tuple<std::vector<xla::Shape>,
std::vector<std::optional<xla::HloSharding>>,
xla::Shape, std::string_view>>
xla::Shape, absl::string_view>>
ReadArgs(JAX_CustomCallPartitioner_InferShardingFromOperands_Args* args) {
std::vector<xla::Shape> shapes;
std::vector<std::optional<xla::HloSharding>> shardings;
Expand All @@ -397,9 +400,9 @@ ReadArgs(JAX_CustomCallPartitioner_InferShardingFromOperands_Args* args) {
TF_ASSIGN_OR_RETURN(auto result_shape, ReadHloShape(args->result_shape));
return std::tuple<std::vector<xla::Shape>,
std::vector<std::optional<xla::HloSharding>>, xla::Shape,
std::string_view>(std::move(shapes), std::move(shardings),
std::move(result_shape),
ToStringView(args->backend_config));
absl::string_view>(std::move(shapes), std::move(shardings),
std::move(result_shape),
ToStringView(args->backend_config));
}

PartitionScratch PopulateArgs(
Expand Down Expand Up @@ -455,11 +458,11 @@ absl::StatusOr<std::optional<xla::HloSharding>> ConsumeResults(
return ReadHloSharding(args->result_sharding);
}

absl::StatusOr<std::tuple<xla::HloSharding, xla::Shape, std::string_view>>
absl::StatusOr<std::tuple<xla::HloSharding, xla::Shape, absl::string_view>>
ReadArgs(JAX_CustomCallPartitioner_PropagateUserSharding_Args* args) {
TF_ASSIGN_OR_RETURN(auto shape, ReadHloShape(args->result_shape));
TF_ASSIGN_OR_RETURN(auto sharding, ReadHloSharding(args->result_sharding));
return std::tuple<xla::HloSharding, xla::Shape, std::string_view>(
return std::tuple<xla::HloSharding, xla::Shape, absl::string_view>(
std::move(sharding), std::move(shape),
ToStringView(args->backend_config));
}
Expand Down
7 changes: 3 additions & 4 deletions xla/python/custom_partition_callback.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ limitations under the License.
#include <memory>
#include <optional>
#include <string>
#include <string_view>
#include <tuple>

#include "xla/hlo/ir/hlo_instruction.h"
Expand All @@ -37,7 +36,7 @@ PartitionScratch PopulateArgs(JAX_CustomCallPartitioner_Partition_Args* args,
const xla::HloInstruction* instruction);
absl::StatusOr<std::tuple<
std::vector<xla::Shape>, std::vector<std::optional<xla::HloSharding>>,
xla::Shape, std::optional<xla::HloSharding>, std::string_view>>
xla::Shape, std::optional<xla::HloSharding>, absl::string_view>>
ReadArgs(JAX_CustomCallPartitioner_Partition_Args* args);
void PopulateResults(
absl::StatusOr<std::tuple<std::string, std::vector<xla::HloSharding>,
Expand All @@ -50,7 +49,7 @@ ConsumeResults(JAX_CustomCallPartitioner_Partition_Args* args);

absl::StatusOr<std::tuple<std::vector<xla::Shape>,
std::vector<std::optional<xla::HloSharding>>,
xla::Shape, std::string_view>>
xla::Shape, absl::string_view>>
ReadArgs(JAX_CustomCallPartitioner_InferShardingFromOperands_Args* args);
PartitionScratch PopulateArgs(
JAX_CustomCallPartitioner_InferShardingFromOperands_Args* args,
Expand All @@ -61,7 +60,7 @@ void PopulateResults(
absl::StatusOr<std::optional<xla::HloSharding>> ConsumeResults(
JAX_CustomCallPartitioner_InferShardingFromOperands_Args* args);

absl::StatusOr<std::tuple<xla::HloSharding, xla::Shape, std::string_view>>
absl::StatusOr<std::tuple<xla::HloSharding, xla::Shape, absl::string_view>>
ReadArgs(JAX_CustomCallPartitioner_PropagateUserSharding_Args* args);
PartitionScratch PopulateArgs(
JAX_CustomCallPartitioner_PropagateUserSharding_Args* args,
Expand Down
9 changes: 4 additions & 5 deletions xla/python/dlpack.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ limitations under the License.
#include <memory>
#include <numeric>
#include <optional>
#include <string_view>
#include <utility>
#include <vector>

Expand Down Expand Up @@ -458,11 +457,11 @@ absl::StatusOr<nb::object> DLPackManagedTensorToBuffer(
auto* cpu_pjrt_client = cpu_client ? (*cpu_client)->pjrt_client() : nullptr;
auto* gpu_pjrt_client = gpu_client ? (*gpu_client)->pjrt_client() : nullptr;

if (std::string_view(tensor.name()) != kDlTensorCapsuleName) {
if (absl::string_view(tensor.name()) != kDlTensorCapsuleName) {
return InvalidArgument(
"DLPack tensor must be a capsule with name \"dltensor\", got \"%s\". "
"Note that a DLPack tensor may be consumed at most once.",
std::string_view(tensor.name()));
absl::string_view(tensor.name()));
}
DLManagedTensor* dlmt = static_cast<DLManagedTensor*>(tensor.data());
if (dlmt->dl_tensor.ndim < 0) {
Expand Down Expand Up @@ -552,11 +551,11 @@ absl::StatusOr<nb::object> DLPackManagedTensorToBuffer(
"DLPack is only supported for devices addressable by the current "
"process.");
}
if (std::string_view(tensor.name()) != kDlTensorCapsuleName) {
if (absl::string_view(tensor.name()) != kDlTensorCapsuleName) {
return InvalidArgument(
"DLPack tensor must be a capsule with name \"dltensor\", got \"%s\". "
"Note that a DLPack tensor may be consumed at most once.",
std::string_view(tensor.name()));
absl::string_view(tensor.name()));
}
DLManagedTensor* dlmt = static_cast<DLManagedTensor*>(tensor.data());
if (dlmt->dl_tensor.ndim < 0) {
Expand Down
11 changes: 6 additions & 5 deletions xla/python/jax_jit.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ limitations under the License.
#include <optional>
#include <stdexcept>
#include <string>
#include <string_view>
#include <utility>
#include <vector>

Expand All @@ -45,6 +44,7 @@ limitations under the License.
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "nanobind/nanobind.h"
#include "nanobind/stl/optional.h" // IWYU pragma: keep
Expand All @@ -53,6 +53,7 @@ limitations under the License.
#include "nanobind/stl/string_view.h" // IWYU pragma: keep
#include "nanobind/stl/vector.h" // IWYU pragma: keep
#include "xla/pjrt/pjrt_client.h"
#include "xla/pjrt/pjrt_layout.h"
#include "xla/pjrt/status_casters.h"
#include "xla/python/nb_absl_inlined_vector.h" // IWYU pragma: keep
#include "xla/python/nb_absl_span.h" // IWYU pragma: keep
Expand Down Expand Up @@ -147,7 +148,7 @@ bool FetchMemoriesFlag() {

std::string ArgumentSignature::DebugString() const {
auto py_object_formatter = [](std::string* out, const nb::object& o) {
out->append(nb::cast<std::string_view>(nb::str(o)));
out->append(nb::cast<absl::string_view>(nb::str(o)));
};
auto treedef_formatter = [](std::string* out, const xla::PyTreeDef& d) {
out->append(d.ToString());
Expand Down Expand Up @@ -188,16 +189,16 @@ bool ArgumentSignature::operator==(const ArgumentSignature& other) const {
"static arguments should be comparable using __eq__."
"The following error was raised when comparing two objects of "
"types ",
nb::cast<std::string_view>(nb::str(a.type())), " and ",
nb::cast<std::string_view>(nb::str(b.type())),
nb::cast<absl::string_view>(nb::str(a.type())), " and ",
nb::cast<absl::string_view>(nb::str(b.type())),
". The error was:\n", e.what()));
}
});
}

std::string CallSignature::DebugString() const {
auto py_object_formatter = [](std::string* out, const nb::object& o) {
out->append(nb::cast<std::string_view>(nb::str(o)));
out->append(nb::cast<absl::string_view>(nb::str(o)));
};
auto signature_formatter = [](std::string* out,
const xla::PyArgSignature& s) {
Expand Down
7 changes: 3 additions & 4 deletions xla/python/jax_jit.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ limitations under the License.
#include <optional>
#include <stdexcept>
#include <string>
#include <string_view>
#include <utility>
#include <vector>

Expand Down Expand Up @@ -140,8 +139,8 @@ H AbslHashValue(H h, const ArgumentSignature& s) {
throw std::invalid_argument(absl::StrCat(
"Non-hashable static arguments are not supported. An error occurred "
"while trying to hash an object of type ",
nanobind::cast<std::string_view>(nanobind::str(static_arg.type())),
", ", nanobind::cast<std::string_view>(nanobind::str(static_arg)),
nanobind::cast<absl::string_view>(nanobind::str(static_arg.type())),
", ", nanobind::cast<absl::string_view>(nanobind::str(static_arg)),
". The error was:\n", e.what(), "\n"));
}
h = H::combine(std::move(h), hash);
Expand Down Expand Up @@ -185,7 +184,7 @@ absl::Status ParseArguments(
// (a) equality (delegated to Python) of the static arguments.
struct CallSignature {
// Not part of the signature, but we need it for error messages.
std::string_view function_name;
absl::string_view function_name;

ArgumentSignature arg_signature;

Expand Down
Loading

0 comments on commit 2225cf1

Please sign in to comment.