From c0e91f85a140ee791b22793853a0d077f2a3ba84 Mon Sep 17 00:00:00 2001 From: Parker Schuh Date: Thu, 4 Apr 2024 19:35:20 -0700 Subject: [PATCH] Add GetDefaultLayout to PjRtTopologyDescription. This is needed to support py_compile_only_client. PiperOrigin-RevId: 622042765 --- xla/pjrt/cpu/cpu_client.cc | 6 ++++++ xla/pjrt/cpu/cpu_client.h | 4 ++++ xla/pjrt/gpu/se_gpu_pjrt_client.cc | 6 ++++++ xla/pjrt/gpu/se_gpu_pjrt_client.h | 4 ++++ xla/pjrt/pjrt_c_api_client.h | 6 ++++++ xla/pjrt/pjrt_compiler.h | 9 +++++++++ xla/pjrt/pjrt_compiler_test.cc | 10 ++++++++++ xla/python/py_compile_only_client.cc | 5 +++++ 8 files changed, 50 insertions(+) diff --git a/xla/pjrt/cpu/cpu_client.cc b/xla/pjrt/cpu/cpu_client.cc index 53ac1d31755257..559a3fa77f927b 100644 --- a/xla/pjrt/cpu/cpu_client.cc +++ b/xla/pjrt/cpu/cpu_client.cc @@ -285,6 +285,12 @@ absl::string_view TfrtCpuDeviceDescription::ToString() const { machine_attributes); } +absl::StatusOr TfrtCpuTopologyDescription::GetDefaultLayout( + PrimitiveType element_type, absl::Span dims) const { + Shape shape = ShapeUtil::MakeShape(element_type, dims); + return LayoutUtil::GetWithDefaultLayout(shape).layout(); +} + absl::StatusOr TfrtCpuTopologyDescription::Serialize() const { std::string result; if (!tsl::SerializeToStringDeterministic(cpu_topology_.ToProto(), &result)) { diff --git a/xla/pjrt/cpu/cpu_client.h b/xla/pjrt/cpu/cpu_client.h index 34a63690b4b0d9..b302243ecfecda 100644 --- a/xla/pjrt/cpu/cpu_client.h +++ b/xla/pjrt/cpu/cpu_client.h @@ -183,6 +183,10 @@ class TfrtCpuTopologyDescription : public PjRtTopologyDescription { return attributes_; } + StatusOr GetDefaultLayout( + PrimitiveType element_type, + absl::Span dims) const override; + private: const PjRtPlatformId platform_id_; const std::string platform_name_; diff --git a/xla/pjrt/gpu/se_gpu_pjrt_client.cc b/xla/pjrt/gpu/se_gpu_pjrt_client.cc index 9dec918f993274..c6b1865a6915f2 100644 --- a/xla/pjrt/gpu/se_gpu_pjrt_client.cc +++ b/xla/pjrt/gpu/se_gpu_pjrt_client.cc @@ -1130,6 +1130,12 @@ absl::StatusOr StreamExecutorGpuTopologyDescription::Serialize() return result; } +absl::StatusOr StreamExecutorGpuTopologyDescription::GetDefaultLayout( + PrimitiveType element_type, absl::Span dims) const { + Shape shape = ShapeUtil::MakeShape(element_type, dims); + return LayoutUtil::GetWithDefaultLayout(shape).layout(); +} + std::vector> BuildLocalDevices( std::map> local_device_states, int node_id) { diff --git a/xla/pjrt/gpu/se_gpu_pjrt_client.h b/xla/pjrt/gpu/se_gpu_pjrt_client.h index 529258e2a90cc0..39c9327683d2e0 100644 --- a/xla/pjrt/gpu/se_gpu_pjrt_client.h +++ b/xla/pjrt/gpu/se_gpu_pjrt_client.h @@ -132,6 +132,10 @@ class StreamExecutorGpuTopologyDescription : public PjRtTopologyDescription { return attributes_; } + StatusOr GetDefaultLayout( + PrimitiveType element_type, + absl::Span dims) const override; + private: const PjRtPlatformId platform_id_; const std::string platform_name_; diff --git a/xla/pjrt/pjrt_c_api_client.h b/xla/pjrt/pjrt_c_api_client.h index d1d50261f7f6da..92c1c3044cc931 100644 --- a/xla/pjrt/pjrt_c_api_client.h +++ b/xla/pjrt/pjrt_c_api_client.h @@ -218,6 +218,12 @@ class PjRtCApiTopologyDescription : public PjRtTopologyDescription { return attributes_; } + StatusOr GetDefaultLayout( + PrimitiveType element_type, + absl::Span dims) const override { + return Unimplemented("PJRT C API does not support GetDefaultLayout"); + } + private: std::unique_ptr compiler_; const PJRT_Api* c_api_; diff --git a/xla/pjrt/pjrt_compiler.h b/xla/pjrt/pjrt_compiler.h index d624fd0cf99cd0..46c363ace13619 100644 --- a/xla/pjrt/pjrt_compiler.h +++ b/xla/pjrt/pjrt_compiler.h @@ -139,6 +139,15 @@ class PjRtTopologyDescription { // Returns vendor specific attributes about the topology. virtual const absl::flat_hash_map& Attributes() const = 0; + + // Returns the default device layout for a buffer with `element_type` and + // `dims`. The default layout is a platform-specific layout used when no other + // layout is specified, e.g. for host-to-device transfers. When compiling, the + // default layout is used for program arguments and outputs unless + // user-specified or compiler-chosen layouts are requested via the + // "mhlo.layout_mode" attribute. + virtual StatusOr GetDefaultLayout( + PrimitiveType element_type, absl::Span dims) const = 0; }; // Abstract interface that all registered compilers must implement. diff --git a/xla/pjrt/pjrt_compiler_test.cc b/xla/pjrt/pjrt_compiler_test.cc index 182e3ba9f7b85d..98a2b8e8d5e16b 100644 --- a/xla/pjrt/pjrt_compiler_test.cc +++ b/xla/pjrt/pjrt_compiler_test.cc @@ -56,6 +56,11 @@ class PjRtTestTopology : public PjRtTopologyDescription { const override { LOG(FATAL) << "Unused"; } + StatusOr GetDefaultLayout( + PrimitiveType element_type, + absl::Span dims) const override { + return Unimplemented("TestTopology does not support GetDefaultLayout"); + } }; TEST(PjRtCompilerTest, CompilerNotRegistered) { @@ -85,6 +90,11 @@ TEST(PjRtCompilerTest, CompilerRegistered) { const override { LOG(FATAL) << "Unused"; } + StatusOr GetDefaultLayout( + PrimitiveType element_type, + absl::Span dims) const override { + return Unimplemented("TestTopology does not support GetDefaultLayout"); + } }; PjRtTestTopology topology; diff --git a/xla/python/py_compile_only_client.cc b/xla/python/py_compile_only_client.cc index 258285886bfa1f..0c92ebe8a9c861 100644 --- a/xla/python/py_compile_only_client.cc +++ b/xla/python/py_compile_only_client.cc @@ -228,6 +228,11 @@ class CompileOnlyIfRtClient final return topology_; } + StatusOr GetDefaultLayout(PrimitiveType element_type, + absl::Span dims) override { + return topology_->GetDefaultLayout(element_type, dims); + } + private: InvalidIfrtCompiler default_compiler_; std::shared_ptr topology_;