diff --git a/xla/python/BUILD b/xla/python/BUILD index a087f61f04c41..2aef58884209c 100644 --- a/xla/python/BUILD +++ b/xla/python/BUILD @@ -263,6 +263,7 @@ cc_library( "@com_google_absl//absl/base", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/hash", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@nanobind", @@ -502,12 +503,15 @@ cc_library( "//xla:comparison_util", "//xla/pjrt:exceptions", "//xla/pjrt:host_callback", + "//xla/pjrt:transpose", "//xla/service:custom_call_status", "//xla/service:custom_call_target_registry", "//xla/service:platform_util", "@com_google_absl//absl/base", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", "@nanobind", "@tsl//tsl/platform:errors", ] + if_rocm( @@ -589,6 +593,7 @@ cc_library( "@nanobind", "@local_config_python//:python_headers", # build_cleaner: keep "//xla/pjrt:pjrt_client", + "//xla/pjrt:pjrt_layout", "//xla/pjrt:status_casters", "@tsl//tsl/platform:logging", "@tsl//tsl/profiler/lib:traceme", @@ -631,6 +636,9 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:status", + "@tsl//tsl/platform:statusor", ], ) @@ -1422,6 +1430,7 @@ cc_library( copts = ["-fexceptions"], features = ["-use_header_modules"], deps = [ + "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", "@nanobind", # copybara:uncomment "//third_party/py/numpy:multiarray", diff --git a/xla/python/callback.cc b/xla/python/callback.cc index 9d0f707b71d2e..5f4675df6ccb2 100644 --- a/xla/python/callback.cc +++ b/xla/python/callback.cc @@ -23,7 +23,6 @@ limitations under the License. #include #include #include -#include #include #include @@ -32,10 +31,10 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "nanobind/nanobind.h" #include "nanobind/stl/string_view.h" // IWYU pragma: keep -#include "xla/pjrt/host_callback.h" #include "xla/pjrt/transpose.h" #include "xla/primitive_util.h" #include "xla/python/nb_numpy.h" @@ -127,7 +126,7 @@ absl::StatusOr CpuCallback::Call(nb::tuple args) { if (!PyTuple_Check(result_object.ptr())) { return absl::InternalError( absl::StrFormat("CPU callback expected a tuple result, got %s", - nb::cast(nb::repr(result_object)))); + nb::cast(nb::repr(result_object)))); } if (PyTuple_Size(result_object.ptr()) != results_.size()) { return absl::InternalError( @@ -142,7 +141,7 @@ absl::StatusOr CpuCallback::Call(nb::tuple args) { if (!output.is_none()) { return absl::InternalError(absl::StrFormat( "Token output from Python callback should be None, got %s", - nb::cast(nb::repr(output)))); + nb::cast(nb::repr(output)))); } continue; } diff --git a/xla/python/custom_call_sharding.cc b/xla/python/custom_call_sharding.cc index e25fdf835955e..0bc424c9c13be 100644 --- a/xla/python/custom_call_sharding.cc +++ b/xla/python/custom_call_sharding.cc @@ -19,7 +19,6 @@ limitations under the License. #include #include #include -#include #include #include #include @@ -93,7 +92,7 @@ class PyCustomCallPartitionerCallbacks { xla::Shape result_shape = std::move(std::get<2>(args_tuple)); std::optional result_sharding = std::move(std::get<3>(args_tuple)); - std::string_view backend_config = std::move(std::get<4>(args_tuple)); + absl::string_view backend_config = std::move(std::get<4>(args_tuple)); { nb::gil_scoped_acquire gil; @@ -118,7 +117,7 @@ class PyCustomCallPartitionerCallbacks { return xla::Internal( "Shardings returned from partitioning: expected " "Tuple[bytes, List[HloSharding], HloSharding] got: %s", - nb::cast(nb::repr(py_result))); + nb::cast(nb::repr(py_result))); } } catch (const nb::python_error& e) { return xla::Internal("custom_partitioner: %s", e.what()); @@ -136,7 +135,7 @@ class PyCustomCallPartitionerCallbacks { std::vector> arg_shardings = std::move(std::get<1>(args_tuple)); xla::Shape result_shape = std::move(std::get<2>(args_tuple)); - std::string_view backend_config = std::move(std::get<3>(args_tuple)); + absl::string_view backend_config = std::move(std::get<3>(args_tuple)); std::optional result; nb::gil_scoped_acquire gil; @@ -161,7 +160,7 @@ class PyCustomCallPartitionerCallbacks { TF_ASSIGN_OR_RETURN(auto args_tuple, jax::ReadArgs(args)); xla::HloSharding result_sharding = std::move(std::get<0>(args_tuple)); xla::Shape result_shape = std::move(std::get<1>(args_tuple)); - std::string_view backend_config = std::move(std::get<2>(args_tuple)); + absl::string_view backend_config = std::move(std::get<2>(args_tuple)); nb::gil_scoped_acquire gil; try { @@ -229,7 +228,7 @@ void BuildCustomCallShardingPybindAPI(nb::module_& m) { return; } - if (std::string_view(c_api->name()) != "pjrt_c_api") { + if (absl::string_view(c_api->name()) != "pjrt_c_api") { throw absl::InvalidArgumentError( "Argument to register_custom_call_partitioner was not a " "pjrt_c_api capsule."); diff --git a/xla/python/custom_partition_callback.cc b/xla/python/custom_partition_callback.cc index df49dfc1e37bc..3349385ffa43e 100644 --- a/xla/python/custom_partition_callback.cc +++ b/xla/python/custom_partition_callback.cc @@ -21,7 +21,6 @@ limitations under the License. #include #include #include -#include #include #include #include @@ -31,6 +30,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "xla/debug_options_flags.h" #include "xla/hlo/builder/xla_computation.h" #include "xla/hlo/ir/hlo_casting_utils.h" @@ -46,8 +46,11 @@ limitations under the License. #include "xla/pjrt/mlir_to_hlo.h" #include "xla/service/call_inliner.h" #include "xla/service/custom_call_sharding_helper.h" -#include "xla/service/spmd/spmd_partitioner_util.h" +#include "xla/service/spmd/spmd_partitioner.h" #include "xla/util.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/status.h" +#include "tsl/platform/statusor.h" namespace xla { @@ -202,8 +205,8 @@ void SetCAPIString(JAX_CustomCallPartitioner_string& out, std::string result, out.size = scratch.back().size(); } -std::string_view ToStringView(JAX_CustomCallPartitioner_string data) { - return std::string_view(data.data, data.size); +absl::string_view ToStringView(JAX_CustomCallPartitioner_string data) { + return absl::string_view(data.data, data.size); } void SetCAPIAval(JAX_CustomCallPartitioner_aval& result, @@ -343,7 +346,7 @@ PartitionScratch PopulateArgs(JAX_CustomCallPartitioner_Partition_Args* args, absl::StatusOr, std::vector>, - xla::Shape, std::optional, std::string_view>> + xla::Shape, std::optional, absl::string_view>> ReadArgs(JAX_CustomCallPartitioner_Partition_Args* args) { std::vector shapes; std::vector> shardings; @@ -369,14 +372,14 @@ ReadArgs(JAX_CustomCallPartitioner_Partition_Args* args) { } return std::tuple, std::vector>, xla::Shape, - std::optional, std::string_view>( + std::optional, absl::string_view>( std::move(shapes), std::move(shardings), std::move(result_shape), std::move(result_sharding), ToStringView(args->backend_config)); } absl::StatusOr, std::vector>, - xla::Shape, std::string_view>> + xla::Shape, absl::string_view>> ReadArgs(JAX_CustomCallPartitioner_InferShardingFromOperands_Args* args) { std::vector shapes; std::vector> shardings; @@ -397,9 +400,9 @@ ReadArgs(JAX_CustomCallPartitioner_InferShardingFromOperands_Args* args) { TF_ASSIGN_OR_RETURN(auto result_shape, ReadHloShape(args->result_shape)); return std::tuple, std::vector>, xla::Shape, - std::string_view>(std::move(shapes), std::move(shardings), - std::move(result_shape), - ToStringView(args->backend_config)); + absl::string_view>(std::move(shapes), std::move(shardings), + std::move(result_shape), + ToStringView(args->backend_config)); } PartitionScratch PopulateArgs( @@ -455,11 +458,11 @@ absl::StatusOr> ConsumeResults( return ReadHloSharding(args->result_sharding); } -absl::StatusOr> +absl::StatusOr> ReadArgs(JAX_CustomCallPartitioner_PropagateUserSharding_Args* args) { TF_ASSIGN_OR_RETURN(auto shape, ReadHloShape(args->result_shape)); TF_ASSIGN_OR_RETURN(auto sharding, ReadHloSharding(args->result_sharding)); - return std::tuple( + return std::tuple( std::move(sharding), std::move(shape), ToStringView(args->backend_config)); } diff --git a/xla/python/custom_partition_callback.h b/xla/python/custom_partition_callback.h index 33cc31e75fc9b..6ba1789a038da 100644 --- a/xla/python/custom_partition_callback.h +++ b/xla/python/custom_partition_callback.h @@ -18,7 +18,6 @@ limitations under the License. #include #include #include -#include #include #include "xla/hlo/ir/hlo_instruction.h" @@ -37,7 +36,7 @@ PartitionScratch PopulateArgs(JAX_CustomCallPartitioner_Partition_Args* args, const xla::HloInstruction* instruction); absl::StatusOr, std::vector>, - xla::Shape, std::optional, std::string_view>> + xla::Shape, std::optional, absl::string_view>> ReadArgs(JAX_CustomCallPartitioner_Partition_Args* args); void PopulateResults( absl::StatusOr, @@ -50,7 +49,7 @@ ConsumeResults(JAX_CustomCallPartitioner_Partition_Args* args); absl::StatusOr, std::vector>, - xla::Shape, std::string_view>> + xla::Shape, absl::string_view>> ReadArgs(JAX_CustomCallPartitioner_InferShardingFromOperands_Args* args); PartitionScratch PopulateArgs( JAX_CustomCallPartitioner_InferShardingFromOperands_Args* args, @@ -61,7 +60,7 @@ void PopulateResults( absl::StatusOr> ConsumeResults( JAX_CustomCallPartitioner_InferShardingFromOperands_Args* args); -absl::StatusOr> +absl::StatusOr> ReadArgs(JAX_CustomCallPartitioner_PropagateUserSharding_Args* args); PartitionScratch PopulateArgs( JAX_CustomCallPartitioner_PropagateUserSharding_Args* args, diff --git a/xla/python/dlpack.cc b/xla/python/dlpack.cc index 2848fc20827b1..d3bf32ff46fef 100644 --- a/xla/python/dlpack.cc +++ b/xla/python/dlpack.cc @@ -22,7 +22,6 @@ limitations under the License. #include #include #include -#include #include #include @@ -458,11 +457,11 @@ absl::StatusOr DLPackManagedTensorToBuffer( auto* cpu_pjrt_client = cpu_client ? (*cpu_client)->pjrt_client() : nullptr; auto* gpu_pjrt_client = gpu_client ? (*gpu_client)->pjrt_client() : nullptr; - if (std::string_view(tensor.name()) != kDlTensorCapsuleName) { + if (absl::string_view(tensor.name()) != kDlTensorCapsuleName) { return InvalidArgument( "DLPack tensor must be a capsule with name \"dltensor\", got \"%s\". " "Note that a DLPack tensor may be consumed at most once.", - std::string_view(tensor.name())); + absl::string_view(tensor.name())); } DLManagedTensor* dlmt = static_cast(tensor.data()); if (dlmt->dl_tensor.ndim < 0) { @@ -552,11 +551,11 @@ absl::StatusOr DLPackManagedTensorToBuffer( "DLPack is only supported for devices addressable by the current " "process."); } - if (std::string_view(tensor.name()) != kDlTensorCapsuleName) { + if (absl::string_view(tensor.name()) != kDlTensorCapsuleName) { return InvalidArgument( "DLPack tensor must be a capsule with name \"dltensor\", got \"%s\". " "Note that a DLPack tensor may be consumed at most once.", - std::string_view(tensor.name())); + absl::string_view(tensor.name())); } DLManagedTensor* dlmt = static_cast(tensor.data()); if (dlmt->dl_tensor.ndim < 0) { diff --git a/xla/python/jax_jit.cc b/xla/python/jax_jit.cc index 78c909caa39d2..1ecbce58fc5b0 100644 --- a/xla/python/jax_jit.cc +++ b/xla/python/jax_jit.cc @@ -33,7 +33,6 @@ limitations under the License. #include #include #include -#include #include #include @@ -45,6 +44,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "nanobind/nanobind.h" #include "nanobind/stl/optional.h" // IWYU pragma: keep @@ -53,6 +53,7 @@ limitations under the License. #include "nanobind/stl/string_view.h" // IWYU pragma: keep #include "nanobind/stl/vector.h" // IWYU pragma: keep #include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_layout.h" #include "xla/pjrt/status_casters.h" #include "xla/python/nb_absl_inlined_vector.h" // IWYU pragma: keep #include "xla/python/nb_absl_span.h" // IWYU pragma: keep @@ -147,7 +148,7 @@ bool FetchMemoriesFlag() { std::string ArgumentSignature::DebugString() const { auto py_object_formatter = [](std::string* out, const nb::object& o) { - out->append(nb::cast(nb::str(o))); + out->append(nb::cast(nb::str(o))); }; auto treedef_formatter = [](std::string* out, const xla::PyTreeDef& d) { out->append(d.ToString()); @@ -188,8 +189,8 @@ bool ArgumentSignature::operator==(const ArgumentSignature& other) const { "static arguments should be comparable using __eq__." "The following error was raised when comparing two objects of " "types ", - nb::cast(nb::str(a.type())), " and ", - nb::cast(nb::str(b.type())), + nb::cast(nb::str(a.type())), " and ", + nb::cast(nb::str(b.type())), ". The error was:\n", e.what())); } }); @@ -197,7 +198,7 @@ bool ArgumentSignature::operator==(const ArgumentSignature& other) const { std::string CallSignature::DebugString() const { auto py_object_formatter = [](std::string* out, const nb::object& o) { - out->append(nb::cast(nb::str(o))); + out->append(nb::cast(nb::str(o))); }; auto signature_formatter = [](std::string* out, const xla::PyArgSignature& s) { diff --git a/xla/python/jax_jit.h b/xla/python/jax_jit.h index 8f77a7b7a8369..f732ddd483410 100644 --- a/xla/python/jax_jit.h +++ b/xla/python/jax_jit.h @@ -22,7 +22,6 @@ limitations under the License. #include #include #include -#include #include #include @@ -140,8 +139,8 @@ H AbslHashValue(H h, const ArgumentSignature& s) { throw std::invalid_argument(absl::StrCat( "Non-hashable static arguments are not supported. An error occurred " "while trying to hash an object of type ", - nanobind::cast(nanobind::str(static_arg.type())), - ", ", nanobind::cast(nanobind::str(static_arg)), + nanobind::cast(nanobind::str(static_arg.type())), + ", ", nanobind::cast(nanobind::str(static_arg)), ". The error was:\n", e.what(), "\n")); } h = H::combine(std::move(h), hash); @@ -185,7 +184,7 @@ absl::Status ParseArguments( // (a) equality (delegated to Python) of the static arguments. struct CallSignature { // Not part of the signature, but we need it for error messages. - std::string_view function_name; + absl::string_view function_name; ArgumentSignature arg_signature; diff --git a/xla/python/mlir.cc b/xla/python/mlir.cc index 36e19d2e7f94a..2083367b87d42 100644 --- a/xla/python/mlir.cc +++ b/xla/python/mlir.cc @@ -14,7 +14,6 @@ limitations under the License. ==============================================================================*/ #include -#include #include "mhlo/transforms/passes.h" #include "absl/status/status.h" @@ -36,10 +35,8 @@ limitations under the License. #include "nanobind/stl/string.h" // IWYU pragma: keep #include "nanobind/stl/string_view.h" // IWYU pragma: keep #include "stablehlo/dialect/Serialization.h" -#include "stablehlo/dialect/StablehloOps.h" #include "xla/hlo/builder/xla_computation.h" #include "xla/hlo/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h" -#include "xla/mlir/utils/error_util.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/mlir_hlo/mhlo/transforms/passes.h" #include "xla/pjrt/mlir_to_hlo.h" @@ -110,7 +107,7 @@ absl::StatusOr PyXlaComputationToMlirModule( } absl::StatusOr PyMlirModuleToXlaComputation( - std::string_view mlir_module, bool use_tuple_args, bool return_tuple) { + absl::string_view mlir_module, bool use_tuple_args, bool return_tuple) { mlir::MLIRContext context; TF_ASSIGN_OR_RETURN(mlir::OwningOpRef module, ParseMlirModuleString(mlir_module, context)); @@ -123,7 +120,7 @@ absl::StatusOr PyMlirModuleToXlaComputation( return computation; } -absl::StatusOr PyMhloToStablehlo(std::string_view mlir_module) { +absl::StatusOr PyMhloToStablehlo(absl::string_view mlir_module) { mlir::MLIRContext context; if (VLOG_IS_ON(3)) context.disableMultithreading(); // JAX can be customized in a way that involves operations from custom @@ -156,7 +153,7 @@ absl::StatusOr PyStablehloToMhlo(const nb::bytes& mlir_module) { TF_ASSIGN_OR_RETURN( mlir::OwningOpRef module, ParseMlirModuleString( - std::string_view(mlir_module.c_str(), mlir_module.size()), context)); + absl::string_view(mlir_module.c_str(), mlir_module.size()), context)); mlir::PassManager pm(&context); if (VLOG_IS_ON(3)) EnablePrintBeforeAndAfter(pm); pm.addPass(mlir::mhlo::createStablehloLegalizeToHloPass()); @@ -171,7 +168,7 @@ absl::StatusOr PyStablehloToMhlo(const nb::bytes& mlir_module) { } absl::StatusOr PySerializePortableArtifact( - std::string_view mlir_module, std::string_view target) { + absl::string_view mlir_module, absl::string_view target) { mlir::MLIRContext context; if (VLOG_IS_ON(3)) context.disableMultithreading(); TF_ASSIGN_OR_RETURN(mlir::OwningOpRef module, @@ -189,7 +186,7 @@ absl::StatusOr PyDeserializePortableArtifact( mlir::MLIRContext context; mlir::OwningOpRef module = mlir::stablehlo::deserializePortableArtifact( - std::string_view(bytecode_str.c_str(), bytecode_str.size()), + absl::string_view(bytecode_str.c_str(), bytecode_str.size()), &context); if (!module) return tsl::errors::InvalidArgument("Failed to deserialize StableHLO"); @@ -208,8 +205,8 @@ void BuildMlirSubmodule(nb::module_& m) { "mlir_module_to_xla_computation", [](const nb::bytes& bytecode, bool use_tuple_args, bool return_tuple) { return xla::ValueOrThrow(PyMlirModuleToXlaComputation( - std::string_view(bytecode.c_str(), bytecode.size()), use_tuple_args, - return_tuple)); + absl::string_view(bytecode.c_str(), bytecode.size()), + use_tuple_args, return_tuple)); }, nb::arg("mlir_module"), nb::arg("use_tuple_args") = false, nb::arg("return_tuple") = false); @@ -221,7 +218,7 @@ void BuildMlirSubmodule(nb::module_& m) { "mhlo_to_stablehlo", [](const nb::bytes& bytecode) { return xla::ValueOrThrow(PyMhloToStablehlo( - std::string_view(bytecode.c_str(), bytecode.size()))); + absl::string_view(bytecode.c_str(), bytecode.size()))); }, nb::arg("mlir_module")); mlir_module.def("mhlo_to_stablehlo", @@ -232,9 +229,9 @@ void BuildMlirSubmodule(nb::module_& m) { nb::arg("mlir_module")); mlir_module.def( "serialize_portable_artifact", - [](const nb::bytes& bytecode, std::string_view target) { + [](const nb::bytes& bytecode, absl::string_view target) { return xla::ValueOrThrow(PySerializePortableArtifact( - std::string_view(bytecode.c_str(), bytecode.size()), target)); + absl::string_view(bytecode.c_str(), bytecode.size()), target)); }, nb::arg("mlir_module"), nb::arg("target")); mlir_module.def("serialize_portable_artifact", @@ -250,7 +247,7 @@ void BuildMlirSubmodule(nb::module_& m) { std::string buffer; llvm::raw_string_ostream os(buffer); xla::ThrowIfError(RefinePolymorphicShapes( - std::string_view(bytecode.c_str(), bytecode.size()), os, + absl::string_view(bytecode.c_str(), bytecode.size()), os, enable_shape_assertions, validate_static_shapes)); return nb::bytes(buffer.data(), buffer.size()); }, diff --git a/xla/python/nb_numpy.h b/xla/python/nb_numpy.h index b4ed1c9cc92c0..94820d464b302 100644 --- a/xla/python/nb_numpy.h +++ b/xla/python/nb_numpy.h @@ -26,8 +26,8 @@ limitations under the License. #include #include -#include +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "nanobind/nanobind.h" #include "xla/tsl/python/lib/core/numpy.h" @@ -46,7 +46,7 @@ class nb_dtype : public nanobind::object { explicit nb_dtype(const nanobind::str& format) : nb_dtype(from_args(format)) {} - explicit nb_dtype(std::string_view format) + explicit nb_dtype(absl::string_view format) : nb_dtype(from_args(nanobind::str(format.data(), format.size()))) {} static nb_dtype from_args(const nanobind::object& args); diff --git a/xla/python/pjit.cc b/xla/python/pjit.cc index 9a1ef9e1a621e..2cdfb929221be 100644 --- a/xla/python/pjit.cc +++ b/xla/python/pjit.cc @@ -26,7 +26,6 @@ limitations under the License. #include #include #include -#include #include // NOLINT #include #include @@ -1069,8 +1068,8 @@ static PyGetSetDef PjitFunction_tp_getset[] = { PyObject* PjitFunction_tp_repr(PyObject* self) { try { const std::string& repr = absl::StrFormat( - "", - nb::cast(nb::repr(nb::getattr(self, "__wrapped__")))); + "", nb::cast(nb::repr( + nb::getattr(self, "__wrapped__")))); return PyUnicode_FromString(repr.c_str()); } catch (...) { // Ignore all errors when accessing a repr. diff --git a/xla/python/pprof_profile_builder.cc b/xla/python/pprof_profile_builder.cc index e3bf8104eab9a..21d8d3cca881b 100644 --- a/xla/python/pprof_profile_builder.cc +++ b/xla/python/pprof_profile_builder.cc @@ -18,7 +18,6 @@ limitations under the License. #include // IWYU pragma: keep #include -#include #include #include "absl/status/statusor.h" @@ -34,7 +33,7 @@ namespace nb = nanobind; PprofProfileBuilder::PprofProfileBuilder() { CHECK_EQ(0, StringId("")); } -int PprofProfileBuilder::StringId(std::string_view s) { +int PprofProfileBuilder::StringId(absl::string_view s) { auto ret = strings_.emplace(s, profile_.string_table_size()); if (ret.second) { profile_.add_string_table(s.data(), s.size()); @@ -48,11 +47,11 @@ int PprofProfileBuilder::FunctionId(PyCodeObject* code) { if (ret.second) { auto* function = profile_.add_function(); function->set_id(ret.first->second); - int name = StringId(nb::cast(nb::str(code->co_name))); + int name = StringId(nb::cast(nb::str(code->co_name))); function->set_name(name); function->set_system_name(name); function->set_filename( - StringId(nb::cast(nb::str(code->co_filename)))); + StringId(nb::cast(nb::str(code->co_filename)))); function->set_start_line(code->co_firstlineno); } return ret.first->second; diff --git a/xla/python/pprof_profile_builder.h b/xla/python/pprof_profile_builder.h index ca0e6f04e57f9..8c1ee9afb784a 100644 --- a/xla/python/pprof_profile_builder.h +++ b/xla/python/pprof_profile_builder.h @@ -19,7 +19,6 @@ limitations under the License. #include #include -#include #include #include "absl/container/flat_hash_map.h" @@ -36,7 +35,7 @@ class PprofProfileBuilder { tensorflow::tfprof::pprof::Profile& profile() { return profile_; } // Adds or returns the ID of `s` in the table. - int StringId(std::string_view s); + int StringId(absl::string_view s); // Adds or returns the ID of a function. int FunctionId(PyCodeObject* code); diff --git a/xla/python/profiler.cc b/xla/python/profiler.cc index 9afe7d695ff7c..20b75b4e500a8 100644 --- a/xla/python/profiler.cc +++ b/xla/python/profiler.cc @@ -15,14 +15,13 @@ limitations under the License. #include "xla/python/profiler.h" -#include #include #include -#include #include #include #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "nanobind/nanobind.h" #include "nanobind/stl/pair.h" // IWYU pragma: keep @@ -30,10 +29,7 @@ limitations under the License. #include "nanobind/stl/string_view.h" // IWYU pragma: keep #include "nanobind/stl/unique_ptr.h" // IWYU pragma: keep #include "nanobind/stl/vector.h" // IWYU pragma: keep -#include "xla/backends/profiler/plugin/plugin_tracer.h" -#include "xla/backends/profiler/plugin/profiler_c_api.h" #include "xla/pjrt/c/pjrt_c_api.h" -#include "xla/pjrt/c/pjrt_c_api_profiler_extension.h" #include "xla/pjrt/exceptions.h" #include "xla/pjrt/status_casters.h" #include "xla/python/aggregate_profile.h" @@ -44,8 +40,6 @@ limitations under the License. #include "xla/tsl/profiler/rpc/profiler_server.h" #include "tsl/platform/macros.h" #include "tsl/platform/protobuf.h" // IWYU pragma: keep -#include "tsl/profiler/lib/profiler_factory.h" -#include "tsl/profiler/lib/profiler_interface.h" #include "tsl/profiler/lib/profiler_session.h" #include "tsl/profiler/lib/traceme.h" @@ -93,7 +87,7 @@ class TraceMeWrapper { static void AppendMetadata(std::string* name, const nb::kwargs& kwargs) { name->push_back('#'); for (const auto& kv : kwargs) { - absl::StrAppend(name, nb::cast(kv.first), "=", + absl::StrAppend(name, nb::cast(kv.first), "=", EncodePyObject(kv.second), ","); } name->back() = '#'; @@ -131,7 +125,7 @@ struct ProfilerSessionWrapper { static std::string GetFdoProfile(const std::string& xspace, bool as_textproto = false) { tensorflow::profiler::XSpace xspace_proto; - // TODO(phawkins): change to std::string_view when protobuf is + // TODO(phawkins): change to absl::string_view when protobuf is // updated in XLA. xspace_proto.ParseFromString(std::string(xspace.c_str(), xspace.size())); tensorflow::profiler::ProfiledInstructionsProto fdo_profile; @@ -161,7 +155,7 @@ void BuildProfilerSubmodule(nb::module_& m) { }, nb::arg("port")); profiler.def("register_plugin_profiler", [](nb::capsule c_api) -> void { - if (std::string_view(c_api.name()) != "pjrt_c_api") { + if (absl::string_view(c_api.name()) != "pjrt_c_api") { throw xla::XlaRuntimeError( "Argument to register_plugin_profiler was not a pjrt_c_api capsule."); } @@ -211,7 +205,7 @@ void BuildProfilerSubmodule(nb::module_& m) { [](ProfilerSessionWrapper* sess, nb::bytes xspace, const std::string& tensorboard_dir) -> void { tensorflow::profiler::XSpace xspace_proto; - // TODO(phawkins): change to std::string_view when protobuf is + // TODO(phawkins): change to absl::string_view when protobuf is // updated in XLA. xspace_proto.ParseFromString( std::string(xspace.c_str(), xspace.size())); diff --git a/xla/python/profiler/internal/python_hooks.cc b/xla/python/profiler/internal/python_hooks.cc index 4f6a9a4942803..4f691c08b0d15 100644 --- a/xla/python/profiler/internal/python_hooks.cc +++ b/xla/python/profiler/internal/python_hooks.cc @@ -61,7 +61,7 @@ std::string GetEventName(PyObject* co_filename, PyObject* co_name, " ", function); } -std::string GetEventName(std::string_view method_name, PyObject* module) { +std::string GetEventName(absl::string_view method_name, PyObject* module) { // Python stack does not have a filename/line_no for native calls. // Use module name and function/method name instead. std::string filename; diff --git a/xla/python/py_array.cc b/xla/python/py_array.cc index ef1655b1ad97b..5dd40c177a42f 100644 --- a/xla/python/py_array.cc +++ b/xla/python/py_array.cc @@ -27,7 +27,6 @@ limitations under the License. #include #include #include -#include #include #include #include @@ -645,7 +644,7 @@ absl::Status PyArray::set_arrays(nb::object obj) { if (!nb::isinstance(obj)) { return InvalidArgument("Unsupported arg when setting Array._arrays: %s", - nb::cast(nb::str(obj.type()))); + nb::cast(nb::str(obj.type()))); } nb::list list(obj); @@ -676,7 +675,7 @@ absl::Status PyArray::set_arrays(nb::object obj) { shapes.push_back(ifrt_arrays.back()->shape()); } else { return InvalidArgument("Unsupported arg when setting Array._arrays: %s", - nb::cast(nb::str(obj.type()))); + nb::cast(nb::str(obj.type()))); } } const ifrt::MemoryKind first_memory_kind = @@ -786,7 +785,7 @@ absl::Status PyArray::CopySingleDeviceArrayToHostAsync() { arr.GetStorage().dynamic_shape, arr.ifrt_array()); } -absl::StatusOr PyArray::AssertUnsharded(std::string_view api) { +absl::StatusOr PyArray::AssertUnsharded(absl::string_view api) { if (ifrt_array() == nullptr) { return InvalidArgument("%s( called on deleted or donated buffer", api); } @@ -1119,11 +1118,11 @@ absl::StatusOr> PyArray::BatchedCopyToDeviceWithSharding( auto transfer_guard_formatter = [&py_array, &dst_sharding] { return absl::StrCat( - "aval=", nb::cast(nb::repr(py_array.aval())), + "aval=", nb::cast(nb::repr(py_array.aval())), ", sharding=", - nb::cast(nb::repr(py_array.sharding())), + nb::cast(nb::repr(py_array.sharding())), ", dst_sharding=", - nb::cast(nb::repr(dst_sharding))); + nb::cast(nb::repr(dst_sharding))); }; TF_RETURN_IF_ERROR( jax::ApplyTransferGuardToDeviceToDevice(transfer_guard_formatter)); @@ -1187,8 +1186,8 @@ absl::StatusOr PyArray::BatchedDevicePut( } auto transfer_guard_formatter = [&aval, &sharding] { return absl::StrCat( - "aval=", nb::cast(nb::repr(aval)), - ", dst_sharding=", nb::cast(nb::repr(sharding))); + "aval=", nb::cast(nb::repr(aval)), + ", dst_sharding=", nb::cast(nb::repr(sharding))); }; GlobalPyRefManager()->CollectGarbage(); @@ -1702,7 +1701,7 @@ absl::Status PyArray::RegisterTypes(nb::module_& m) { throw nb::type_error( absl::StrCat( "Unsupported type for elements in `arrays`: ", - nb::cast(nb::str(arrays[0].type()))) + nb::cast(nb::str(arrays[0].type()))) .c_str()); } }, diff --git a/xla/python/py_array.h b/xla/python/py_array.h index 39731a9b6200e..46c2279224b81 100644 --- a/xla/python/py_array.h +++ b/xla/python/py_array.h @@ -22,7 +22,6 @@ limitations under the License. #include #include #include -#include #include #include @@ -295,7 +294,7 @@ class PyArray : public nanobind::object { std::vector objs); private: - absl::StatusOr AssertUnsharded(std::string_view api); + absl::StatusOr AssertUnsharded(absl::string_view api); void CheckAndRearrange(); diff --git a/xla/python/py_client.cc b/xla/python/py_client.cc index e9819ba4bb68d..2adae5fe40a26 100644 --- a/xla/python/py_client.cc +++ b/xla/python/py_client.cc @@ -23,7 +23,6 @@ limitations under the License. #include #include #include -#include #include #include #include @@ -36,7 +35,6 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "llvm/Support/Casting.h" -#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/OwningOpRef.h" @@ -91,7 +89,6 @@ limitations under the License. #include "xla/status_macros.h" #include "xla/tsl/concurrency/ref_count.h" #include "xla/util.h" -#include "tsl/platform/casts.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" #include "tsl/platform/status.h" @@ -489,7 +486,7 @@ PyClient::DeserializeExecutable(nb_class_ptr client, TF_ASSIGN_OR_RETURN( ifrt_loaded_executable, client->ifrt_client_->GetDefaultCompiler()->DeserializeLoadedExecutable( - std::string_view(serialized.c_str(), serialized.size()), + absl::string_view(serialized.c_str(), serialized.size()), std::move(ifrt_deserialize_options))); } TF_ASSIGN_OR_RETURN(fingerprint, ifrt_loaded_executable->Fingerprint()); @@ -785,7 +782,7 @@ PyType_Slot PyClient::slots_[] = { }, nb::arg("dtype"), nb::arg("shard_shape"), nb::arg("device")) .def("__getattr__", - [](PyClient& client, std::string_view name) -> nb::object { + [](PyClient& client, absl::string_view name) -> nb::object { const auto& attrs = client.Attributes().map(); auto it = attrs.find(name); if (it != attrs.end()) { diff --git a/xla/python/py_client.h b/xla/python/py_client.h index 32b15a22b80b6..351d72eb42438 100644 --- a/xla/python/py_client.h +++ b/xla/python/py_client.h @@ -22,7 +22,6 @@ limitations under the License. #include #include #include -#include #include #include @@ -95,7 +94,7 @@ class PyClient { return shared_ptr_pjrt_client(); } - std::string_view platform_name() const { + absl::string_view platform_name() const { // TODO(phawkins): this is a temporary backwards compatibility shim. We // changed the name PJRT reports for GPU platforms to "cuda" or "rocm", but // we haven't yet updated JAX clients that expect "gpu". Migrate users and @@ -107,14 +106,16 @@ class PyClient { return ifrt_client_->platform_name(); } } - std::string_view raw_platform_name() const { + absl::string_view raw_platform_name() const { // TODO(parkers): Once platform_name() is the same, remove this. return ifrt_client_->platform_name(); } - std::string_view platform_version() const { + absl::string_view platform_version() const { return ifrt_client_->platform_version(); } - std::string_view runtime_type() const { return ifrt_client_->runtime_type(); } + absl::string_view runtime_type() const { + return ifrt_client_->runtime_type(); + } // Returns implementation-specific attributes about this client, e.g. the PJRT // C API version if applicable. diff --git a/xla/python/py_client_gpu.cc b/xla/python/py_client_gpu.cc index d1c01a62d16a7..73d2e8edafaa9 100644 --- a/xla/python/py_client_gpu.cc +++ b/xla/python/py_client_gpu.cc @@ -13,30 +13,35 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include +#include +#include +#include +#include #include #include "absl/base/casts.h" +#include "absl/log/check.h" #include "absl/status/statusor.h" #include "absl/strings/ascii.h" #include "absl/strings/numbers.h" +#include "absl/types/span.h" #include "xla/service/custom_call_status.h" -#include "tsl/platform/errors.h" #if TENSORFLOW_USE_ROCM #include "rocm/include/hip/hip_runtime.h" #else #include "third_party/gpus/cuda/include/cuda.h" #include "third_party/gpus/cuda/include/cuda_runtime_api.h" +#include "third_party/gpus/cuda/include/driver_types.h" #endif #include "nanobind/nanobind.h" #include "xla/pjrt/exceptions.h" #include "xla/pjrt/host_callback.h" +#include "xla/pjrt/transpose.h" #include "xla/primitive_util.h" #include "xla/python/callback.h" #include "xla/python/nb_numpy.h" #include "xla/service/custom_call_target_registry.h" #include "xla/service/platform_util.h" - #if TENSORFLOW_USE_ROCM #define gpuSuccess hipSuccess #define gpuStreamHandle hipStream_t @@ -109,7 +114,7 @@ void XlaPythonGpuCallback(gpuStreamHandle stream, void** buffers, callback->Call(host_input_arrays); LeaveHostCallback(); if (!maybe_result_tuple.ok()) { - std::string_view msg = maybe_result_tuple.status().message(); + absl::string_view msg = maybe_result_tuple.status().message(); XlaCustomCallStatusSetFailure(status, msg.data(), msg.length()); return; } diff --git a/xla/python/py_compile_only_client.cc b/xla/python/py_compile_only_client.cc index 9dde801ff5a7f..d366ef93c096b 100644 --- a/xla/python/py_compile_only_client.cc +++ b/xla/python/py_compile_only_client.cc @@ -20,7 +20,6 @@ limitations under the License. #include #include #include -#include #include #include @@ -63,7 +62,6 @@ limitations under the License. #include "xla/python/ifrt/tuple.h" #include "xla/python/ifrt/value.h" #include "xla/python/nb_class_ptr.h" -#include "xla/python/pjrt_ifrt/pjrt_array.h" #include "xla/python/pjrt_ifrt/pjrt_attribute_map_util.h" #include "xla/python/pjrt_ifrt/pjrt_dtype.h" #include "xla/python/pjrt_ifrt/pjrt_executable.h" @@ -372,7 +370,7 @@ class CompileOnlyPyClient : public PyClient { } absl::StatusOr> CompileUnloaded( - std::string_view mlir_module, CompileOptions options, + absl::string_view mlir_module, CompileOptions options, std::vector host_callbacks) { if (!host_callbacks.empty()) { return Unimplemented( @@ -422,7 +420,7 @@ void RegisterCompileOnlyClient(nb::module_& m) { [](CompileOnlyPyClient& self, nb::bytes mlir_module, CompileOptions options, std::vector host_callbacks) { return ValueOrThrow(self.CompileUnloaded( - std::string_view(mlir_module.c_str(), mlir_module.size()), + absl::string_view(mlir_module.c_str(), mlir_module.size()), std::move(options), std::move(host_callbacks))); }, nb::arg("computation"), nb::arg("compile_options") = CompileOptions(), diff --git a/xla/python/py_device.cc b/xla/python/py_device.cc index 9139454bc36cd..6a9f4ef781b84 100644 --- a/xla/python/py_device.cc +++ b/xla/python/py_device.cc @@ -22,13 +22,13 @@ limitations under the License. #include #include #include -#include #include #include #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" #include "llvm/Support/Casting.h" #include "nanobind/nanobind.h" #include "nanobind/stl/optional.h" // IWYU pragma: keep @@ -66,7 +66,7 @@ int PyDevice::id() const { return device_->Id().value(); } int PyDevice::process_index() const { return device_->ProcessIndex(); } -std::string_view PyDevice::platform() const { +absl::string_view PyDevice::platform() const { // TODO(phawkins): this is a temporary backwards // compatibility shim. We changed the name PJRT // reports for GPU platforms to "cuda" or "rocm", @@ -75,13 +75,13 @@ std::string_view PyDevice::platform() const { // code. if (client_->platform_name() == "cuda" || client_->platform_name() == "rocm") { - return std::string_view("gpu"); + return absl::string_view("gpu"); } else { return client_->platform_name(); } } -std::string_view PyDevice::device_kind() const { return device_->Kind(); } +absl::string_view PyDevice::device_kind() const { return device_->Kind(); } std::optional PyDevice::local_hardware_id() const { // TODO(phawkins): consider supporting this for non-PJRT devices. @@ -96,9 +96,9 @@ std::optional PyDevice::local_hardware_id() const { return local_hardware_id; } -std::string_view PyDevice::Str() const { return device_->DebugString(); } +absl::string_view PyDevice::Str() const { return device_->DebugString(); } -std::string_view PyDevice::Repr() const { return device_->ToString(); } +absl::string_view PyDevice::Repr() const { return device_->ToString(); } absl::Status PyDevice::TransferToInfeed(LiteralSlice literal) { GlobalPyRefManager()->CollectGarbage(); @@ -136,7 +136,7 @@ absl::StatusOr PyDevice::TransferFromOutfeed(Shape shape) { } absl::StatusOr> PyDevice::Memory( - std::string_view kind) const { + absl::string_view kind) const { ifrt::Memory* result_memory_space = nullptr; for (auto* memory_space : device_->Memories()) { if (memory_space->Kind().memory_kind() == kind) { @@ -321,7 +321,7 @@ PyType_Slot PyDevice::slots_[] = { } try { auto device = nb::cast(nb::handle(self)); - auto name = nb::cast(nb::handle(key)); + auto name = nb::cast(nb::handle(key)); const auto& attrs = device->device_->Attributes().map(); auto it = attrs.find(name); if (it != attrs.end()) { diff --git a/xla/python/py_device.h b/xla/python/py_device.h index 7151fccb114a6..6acd35b1da990 100644 --- a/xla/python/py_device.h +++ b/xla/python/py_device.h @@ -20,7 +20,6 @@ limitations under the License. #include #include -#include #include "absl/status/status.h" #include "absl/status/statusor.h" @@ -49,18 +48,18 @@ class PyDevice { int id() const; int process_index() const; - std::string_view platform() const; - std::string_view device_kind() const; + absl::string_view platform() const; + absl::string_view device_kind() const; std::optional local_hardware_id() const; - std::string_view Str() const; - std::string_view Repr() const; + absl::string_view Str() const; + absl::string_view Repr() const; absl::Status TransferToInfeed(LiteralSlice literal); absl::StatusOr TransferFromOutfeed(Shape shape); absl::StatusOr> Memory( - std::string_view kind) const; + absl::string_view kind) const; absl::StatusOr> DefaultMemory() const; nanobind::list AddressableMemories() const; absl::StatusOr> MemoryStats() const; diff --git a/xla/python/py_executable.cc b/xla/python/py_executable.cc index 0bdff1204ac2f..face6782350fb 100644 --- a/xla/python/py_executable.cc +++ b/xla/python/py_executable.cc @@ -21,7 +21,6 @@ limitations under the License. #include #include #include -#include #include #include #include @@ -30,10 +29,10 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "nanobind/nanobind.h" #include "xla/hlo/ir/hlo_module.h" -#include "xla/pjrt/pjrt_executable.h" #include "xla/pjrt/pjrt_future.h" #include "xla/pjrt/pjrt_layout.h" #include "xla/python/ifrt/array.h" @@ -408,7 +407,7 @@ PyLoadedExecutable::HloModules() const { return ifrt_loaded_executable_->GetHloModules(); } -absl::StatusOr>> +absl::StatusOr>> PyLoadedExecutable::GetOutputMemoryKinds() const { nb::gil_scoped_release gil_release; return ifrt_loaded_executable_->GetOutputMemoryKinds(); diff --git a/xla/python/py_executable.h b/xla/python/py_executable.h index e032ee7b4acdd..9af7a4a783970 100644 --- a/xla/python/py_executable.h +++ b/xla/python/py_executable.h @@ -21,7 +21,6 @@ limitations under the License. #include #include #include -#include #include #include #include @@ -187,7 +186,7 @@ class PyLoadedExecutable { absl::StatusOr>> HloModules() const; - absl::StatusOr>> + absl::StatusOr>> GetOutputMemoryKinds() const; absl::StatusOr>> GetParameterLayouts() diff --git a/xla/python/py_memory_space.cc b/xla/python/py_memory_space.cc index c55f0d0438396..990b1ba6ec5f8 100644 --- a/xla/python/py_memory_space.cc +++ b/xla/python/py_memory_space.cc @@ -17,12 +17,11 @@ limitations under the License. #include -#include #include +#include "absl/strings/string_view.h" #include "nanobind/nanobind.h" #include "nanobind/stl/string_view.h" // IWYU pragma: keep -#include "xla/pjrt/pjrt_client.h" #include "xla/python/ifrt/device.h" #include "xla/python/nb_class_ptr.h" #include "xla/python/py_client.h" @@ -37,7 +36,7 @@ PyMemorySpace::PyMemorySpace(nb_class_ptr client, int PyMemorySpace::process_index() const { return client_->process_index(); } -std::string_view PyMemorySpace::platform() const { +absl::string_view PyMemorySpace::platform() const { // TODO(phawkins): this is a temporary backwards // compatibility shim. We changed the name PJRT // reports for GPU platforms to "cuda" or "rocm", @@ -46,19 +45,19 @@ std::string_view PyMemorySpace::platform() const { // code. if (client_->platform_name() == "cuda" || client_->platform_name() == "rocm") { - return std::string_view("gpu"); + return absl::string_view("gpu"); } else { return client_->platform_name(); } } -std::string_view PyMemorySpace::kind() const { +absl::string_view PyMemorySpace::kind() const { return *memory_->Kind().memory_kind(); } -std::string_view PyMemorySpace::Str() const { return memory_->DebugString(); } +absl::string_view PyMemorySpace::Str() const { return memory_->DebugString(); } -std::string_view PyMemorySpace::Repr() const { return memory_->ToString(); } +absl::string_view PyMemorySpace::Repr() const { return memory_->ToString(); } nb::list PyMemorySpace::AddressableByDevices() const { nb::list devices; diff --git a/xla/python/py_memory_space.h b/xla/python/py_memory_space.h index 9b5507b55422e..bc0773ed43667 100644 --- a/xla/python/py_memory_space.h +++ b/xla/python/py_memory_space.h @@ -18,8 +18,6 @@ limitations under the License. #include -#include - #include "nanobind/nanobind.h" #include "xla/python/ifrt/memory.h" #include "xla/python/nb_class_ptr.h" @@ -42,11 +40,11 @@ class PyMemorySpace { ifrt::Memory* memory_space() const { return memory_; } int process_index() const; - std::string_view platform() const; - std::string_view kind() const; + absl::string_view platform() const; + absl::string_view kind() const; - std::string_view Str() const; - std::string_view Repr() const; + absl::string_view Str() const; + absl::string_view Repr() const; nanobind::list AddressableByDevices() const; diff --git a/xla/python/py_values.cc b/xla/python/py_values.cc index 7c3e18c873ac4..631b0bcb9b956 100644 --- a/xla/python/py_values.cc +++ b/xla/python/py_values.cc @@ -22,7 +22,6 @@ limitations under the License. #include #include #include -#include #include #include #include @@ -32,6 +31,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "nanobind/nanobind.h" #include "nanobind/stl/complex.h" // IWYU pragma: keep @@ -44,7 +44,6 @@ limitations under the License. #include "xla/python/ifrt/memory.h" #include "xla/python/ifrt/shape.h" #include "xla/python/ifrt/sharding.h" -#include "xla/python/nb_helpers.h" #include "xla/python/nb_numpy.h" #include "xla/python/pjrt_ifrt/pjrt_dtype.h" #include "xla/python/py_array.h" @@ -83,7 +82,7 @@ absl::StatusOr HandlePythonScalar( "Unable to convert Python scalar to %s. This most likely means the " "value (%s) overflows the range of the type.", PrimitiveType_Name(primitive_util::NativeToPrimitiveType()), - nb::cast(nb::repr(obj))); + nb::cast(nb::repr(obj))); } std::variant data; @@ -130,7 +129,7 @@ absl::StatusOr HandlePythonInt( "Unable to convert Python scalar to %s. This most likely means the " "value (%s) overflows the range of the type.", PrimitiveType_Name(primitive_util::NativeToPrimitiveType()), - nb::cast(nb::repr(obj))); + nb::cast(nb::repr(obj))); } type = S32; } else { @@ -141,7 +140,7 @@ absl::StatusOr HandlePythonInt( "Unable to convert Python scalar to %s. This most likely means the " "value (%s) overflows the range of the type.", PrimitiveType_Name(primitive_util::NativeToPrimitiveType()), - nb::cast(nb::repr(obj))); + nb::cast(nb::repr(obj))); } type = S64; } @@ -451,7 +450,7 @@ absl::StatusOr DevicePut(nb::handle arg, "Not supported: The C++ jax jit execution path, only accepts " "DeviceArray, Numpy arrays scalars of supported types " "(see implementation), or Python scalars. Got type ", - nb::cast(nb::str(arg.type())))); + nb::cast(nb::str(arg.type())))); } return res->second(arg, client, to_device, options, to_memory_kind); } @@ -641,7 +640,7 @@ absl::StatusOr PyArgSignatureOfValue(nb::handle arg, "Buffer/DeviceArray, Numpy " "arrays scalars of supported types " "(see implementation), or Python scalars. Got type ", - nb::cast(nb::str(arg.type())))); + nb::cast(nb::str(arg.type())))); } return res->second(arg, jax_enable_x64); } diff --git a/xla/python/pytree.cc b/xla/python/pytree.cc index 138316c722d56..a374c2df6bff9 100644 --- a/xla/python/pytree.cc +++ b/xla/python/pytree.cc @@ -29,7 +29,6 @@ limitations under the License. #include #include #include -#include #include #include diff --git a/xla/python/pytree.h b/xla/python/pytree.h index 55ddf041232d5..fc16fdd40136c 100644 --- a/xla/python/pytree.h +++ b/xla/python/pytree.h @@ -25,7 +25,6 @@ limitations under the License. #include #include #include -#include #include #include diff --git a/xla/python/sharding.cc b/xla/python/sharding.cc index c1bae6a50a58a..06e5d7870c187 100644 --- a/xla/python/sharding.cc +++ b/xla/python/sharding.cc @@ -20,22 +20,20 @@ limitations under the License. #include #include #include -#include #include #include "absl/hash/hash.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" #include "nanobind/nanobind.h" #include "nanobind/stl/string.h" // IWYU pragma: keep #include "nanobind/stl/string_view.h" // IWYU pragma: keep #include "xla/hlo/ir/hlo_sharding.h" #include "xla/pjrt/status_casters.h" -#include "xla/python/ifrt/device.h" #include "xla/python/ifrt/device_list.h" #include "xla/python/nb_class_ptr.h" -#include "xla/python/nb_helpers.h" #include "xla/python/nb_numpy.h" #include "xla/python/py_client.h" #include "xla/python/py_device_list.h" @@ -83,9 +81,10 @@ nb::object CheckAndCanonicalizeMemoryKind( } nb::object device_kind = addressable_device_list->GetItem(0).attr("device_kind"); - std::string_view device_kind_str = nb::cast(device_kind); + absl::string_view device_kind_str = + nb::cast(device_kind); auto py_str_formatter = [](std::string* out, nb::handle h) { - *out += nb::cast(nb::str(h)); + *out += nb::cast(nb::str(h)); }; throw nb::value_error( absl::StrCat( @@ -93,7 +92,7 @@ nb::object CheckAndCanonicalizeMemoryKind( ". Device ", device_kind_str, " can address the following memory kinds: ", absl::StrJoin(*supported_memory_kinds, ", ", py_str_formatter), - ". Got memory kind: ", nb::cast(memory_kind)) + ". Got memory kind: ", nb::cast(memory_kind)) .c_str()); } // If memory kind is None, canonicalize to default memory. diff --git a/xla/python/traceback.cc b/xla/python/traceback.cc index 19e4f94d4f8d9..a9d35e4d04d74 100644 --- a/xla/python/traceback.cc +++ b/xla/python/traceback.cc @@ -20,14 +20,15 @@ limitations under the License. #include #include #include -#include #include #include #include "absl/base/casts.h" #include "absl/hash/hash.h" +#include "absl/log/check.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" #include "nanobind/nanobind.h" #include "nanobind/stl/optional.h" // IWYU pragma: keep #include "nanobind/stl/string.h" // IWYU pragma: keep @@ -108,8 +109,8 @@ Traceback::Traceback(Traceback&& other) noexcept } std::string Traceback::Frame::ToString() const { - return absl::StrFormat("%s:%d (%s)", nb::cast(file_name), - line_num, nb::cast(function_name)); + return absl::StrFormat("%s:%d (%s)", nb::cast(file_name), + line_num, nb::cast(function_name)); } std::string Traceback::ToString() const { @@ -230,8 +231,8 @@ void BuildTracebackSubmodule(nb::module_& m) { .def_ro("line_num", &Traceback::Frame::line_num) .def("__repr__", [](const Traceback::Frame& frame) { return absl::StrFormat( - "%s;%s:%d", nb::cast(frame.function_name), - nb::cast(frame.file_name), frame.line_num); + "%s;%s:%d", nb::cast(frame.function_name), + nb::cast(frame.file_name), frame.line_num); }); nb::class_ traceback(m, "Traceback", diff --git a/xla/python/types.cc b/xla/python/types.cc index 125f96a75fdf2..50366be350bc0 100644 --- a/xla/python/types.cc +++ b/xla/python/types.cc @@ -20,7 +20,6 @@ limitations under the License. #include #include #include -#include #include #include #include @@ -29,6 +28,7 @@ limitations under the License. #include "absl/container/inlined_vector.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "nanobind/nanobind.h" #include "nanobind/ndarray.h" // IWYU pragma: keep @@ -39,7 +39,6 @@ limitations under the License. #include "xla/literal.h" #include "xla/pjrt/exceptions.h" #include "xla/python/ifrt/dtype.h" -#include "xla/python/nb_helpers.h" #include "xla/python/nb_numpy.h" #include "xla/python/pjrt_ifrt/pjrt_dtype.h" #include "xla/shape.h" @@ -175,7 +174,7 @@ absl::StatusOr DtypeToPrimitiveType(const nb_dtype& np_type) { return custom_it->second; } return InvalidArgument("Unknown NumPy dtype %s char %c kind %c itemsize %d", - nb::cast(nb::repr(np_type)), + nb::cast(nb::repr(np_type)), np_type.char_(), np_type.kind(), np_type.itemsize()); } diff --git a/xla/python/xla.cc b/xla/python/xla.cc index 1f9f76ed3c469..51c96229493e4 100644 --- a/xla/python/xla.cc +++ b/xla/python/xla.cc @@ -257,7 +257,7 @@ NB_MODULE(xla_extension, m) { // like ClientAndPtr). nb::bytes serialized = nb::cast(t[0]); absl::StatusOr layout = PjRtXlaLayout::Deserialize( - std::string_view(serialized.c_str(), serialized.size())); + absl::string_view(serialized.c_str(), serialized.size())); ThrowIfError(layout.status()); new (self) PjRtXlaLayout(std::move(*layout)); }); @@ -691,8 +691,8 @@ NB_MODULE(xla_extension, m) { // `blocking_key_value_get_bytes()`. .def( "key_value_set", - [](DistributedRuntimeClient& client, std::string_view key, - std::string_view value, bool allow_overwrite) { + [](DistributedRuntimeClient& client, absl::string_view key, + absl::string_view value, bool allow_overwrite) { nb::gil_scoped_release gil_release; xla::ThrowIfError(client.KeyValueSet(key, value, allow_overwrite)); }, @@ -702,18 +702,18 @@ NB_MODULE(xla_extension, m) { // Use `key_value_set_bytes()` and `blocking_key_value_get_bytes()`. .def( "key_value_set_bytes", - [](DistributedRuntimeClient& client, std::string_view key, + [](DistributedRuntimeClient& client, absl::string_view key, nb::bytes value, bool allow_overwrite) { nb::gil_scoped_release gil_release; xla::ThrowIfError(client.KeyValueSet( - key, std::string_view(value.c_str(), value.size()), + key, absl::string_view(value.c_str(), value.size()), allow_overwrite)); }, nb::arg("key"), nb::arg("value"), nb::arg("allow_overwrite") = false) // Assumes that all values in the directory are Python strings. .def( "key_value_dir_get", - [](DistributedRuntimeClient& client, std::string_view key) { + [](DistributedRuntimeClient& client, absl::string_view key) { nb::gil_scoped_release gil_release; return xla::ValueOrThrow(client.KeyValueDirGet(key)); }, @@ -723,7 +723,7 @@ NB_MODULE(xla_extension, m) { // explicitly. .def( "key_value_dir_get_bytes", - [](DistributedRuntimeClient& client, std::string_view key) + [](DistributedRuntimeClient& client, absl::string_view key) -> std::vector> { nb::gil_scoped_release gil_release; std::vector> result = @@ -740,7 +740,7 @@ NB_MODULE(xla_extension, m) { nb::arg("key")) .def( "key_value_delete", - [](DistributedRuntimeClient& client, std::string_view key) { + [](DistributedRuntimeClient& client, absl::string_view key) { nb::gil_scoped_release gil_release; return xla::ThrowIfError(client.KeyValueDelete(key)); }, @@ -861,7 +861,7 @@ NB_MODULE(xla_extension, m) { return nb::bytes(serialized.data(), serialized.size()); }) .def("__getattr__", - [](ifrt::Topology& topology, std::string_view name) -> nb::object { + [](ifrt::Topology& topology, absl::string_view name) -> nb::object { const auto& attrs = topology.Attributes().map(); auto it = attrs.find(name); if (it != attrs.end()) { diff --git a/xla/python/xla_compiler.cc b/xla/python/xla_compiler.cc index 66496043ab2a7..13d3de2e50f1a 100644 --- a/xla/python/xla_compiler.cc +++ b/xla/python/xla_compiler.cc @@ -20,7 +20,6 @@ limitations under the License. #include #include #include -#include #include #include @@ -376,7 +375,7 @@ absl::Status PyRegisterCustomCallTarget(const std::string& fn_name, api_version)); } -absl::Status PyRegisterCustomTypeId(std::string_view type_name, +absl::Status PyRegisterCustomTypeId(absl::string_view type_name, nb::object type_id) { nb::capsule capsule; if (!nb::try_cast(type_id, capsule)) { @@ -1156,7 +1155,8 @@ void BuildXlaCompilerSubmodule(nb::module_& m) { for (const auto& [name, registration] : *ffi_handlers) { nb::dict bundle; - auto export_handler = [&](std::string_view name, XLA_FFI_Handler* h) { + auto export_handler = [&](absl::string_view name, + XLA_FFI_Handler* h) { if (h != nullptr) { bundle[nb::str(name.data(), name.size())] = nb::capsule(reinterpret_cast(h)); @@ -1178,7 +1178,7 @@ void BuildXlaCompilerSubmodule(nb::module_& m) { m.def( "register_custom_type_id", - [](std::string_view type_name, nb::object type_id) { + [](absl::string_view type_name, nb::object type_id) { xla::ThrowIfError(PyRegisterCustomTypeId(type_name, type_id)); }, nb::arg("type_name"), nb::arg("type_id"));