Skip to content

Commit

Permalink
Add GetDefaultLayout to PjRtTopologyDescription. This is needed to su…
Browse files Browse the repository at this point in the history
…pport py_compile_only_client.

PiperOrigin-RevId: 622042765
  • Loading branch information
pschuh authored and copybara-github committed Apr 5, 2024
1 parent b0c6c26 commit c0e91f8
Show file tree
Hide file tree
Showing 8 changed files with 50 additions and 0 deletions.
6 changes: 6 additions & 0 deletions xla/pjrt/cpu/cpu_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,12 @@ absl::string_view TfrtCpuDeviceDescription::ToString() const {
machine_attributes);
}

absl::StatusOr<Layout> TfrtCpuTopologyDescription::GetDefaultLayout(
PrimitiveType element_type, absl::Span<const int64_t> dims) const {
Shape shape = ShapeUtil::MakeShape(element_type, dims);
return LayoutUtil::GetWithDefaultLayout(shape).layout();
}

absl::StatusOr<std::string> TfrtCpuTopologyDescription::Serialize() const {
std::string result;
if (!tsl::SerializeToStringDeterministic(cpu_topology_.ToProto(), &result)) {
Expand Down
4 changes: 4 additions & 0 deletions xla/pjrt/cpu/cpu_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,10 @@ class TfrtCpuTopologyDescription : public PjRtTopologyDescription {
return attributes_;
}

StatusOr<Layout> GetDefaultLayout(
PrimitiveType element_type,
absl::Span<const int64_t> dims) const override;

private:
const PjRtPlatformId platform_id_;
const std::string platform_name_;
Expand Down
6 changes: 6 additions & 0 deletions xla/pjrt/gpu/se_gpu_pjrt_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1130,6 +1130,12 @@ absl::StatusOr<std::string> StreamExecutorGpuTopologyDescription::Serialize()
return result;
}

absl::StatusOr<Layout> StreamExecutorGpuTopologyDescription::GetDefaultLayout(
PrimitiveType element_type, absl::Span<const int64_t> dims) const {
Shape shape = ShapeUtil::MakeShape(element_type, dims);
return LayoutUtil::GetWithDefaultLayout(shape).layout();
}

std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> BuildLocalDevices(
std::map<int, std::unique_ptr<LocalDeviceState>> local_device_states,
int node_id) {
Expand Down
4 changes: 4 additions & 0 deletions xla/pjrt/gpu/se_gpu_pjrt_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,10 @@ class StreamExecutorGpuTopologyDescription : public PjRtTopologyDescription {
return attributes_;
}

StatusOr<Layout> GetDefaultLayout(
PrimitiveType element_type,
absl::Span<const int64_t> dims) const override;

private:
const PjRtPlatformId platform_id_;
const std::string platform_name_;
Expand Down
6 changes: 6 additions & 0 deletions xla/pjrt/pjrt_c_api_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,12 @@ class PjRtCApiTopologyDescription : public PjRtTopologyDescription {
return attributes_;
}

StatusOr<Layout> GetDefaultLayout(
PrimitiveType element_type,
absl::Span<const int64_t> dims) const override {
return Unimplemented("PJRT C API does not support GetDefaultLayout");
}

private:
std::unique_ptr<PjRtCApiCompiler> compiler_;
const PJRT_Api* c_api_;
Expand Down
9 changes: 9 additions & 0 deletions xla/pjrt/pjrt_compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,15 @@ class PjRtTopologyDescription {
// Returns vendor specific attributes about the topology.
virtual const absl::flat_hash_map<std::string, PjRtDeviceAttribute>&
Attributes() const = 0;

// Returns the default device layout for a buffer with `element_type` and
// `dims`. The default layout is a platform-specific layout used when no other
// layout is specified, e.g. for host-to-device transfers. When compiling, the
// default layout is used for program arguments and outputs unless
// user-specified or compiler-chosen layouts are requested via the
// "mhlo.layout_mode" attribute.
virtual StatusOr<Layout> GetDefaultLayout(
PrimitiveType element_type, absl::Span<const int64_t> dims) const = 0;
};

// Abstract interface that all registered compilers must implement.
Expand Down
10 changes: 10 additions & 0 deletions xla/pjrt/pjrt_compiler_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ class PjRtTestTopology : public PjRtTopologyDescription {
const override {
LOG(FATAL) << "Unused";
}
StatusOr<Layout> GetDefaultLayout(
PrimitiveType element_type,
absl::Span<const int64_t> dims) const override {
return Unimplemented("TestTopology does not support GetDefaultLayout");
}
};

TEST(PjRtCompilerTest, CompilerNotRegistered) {
Expand Down Expand Up @@ -85,6 +90,11 @@ TEST(PjRtCompilerTest, CompilerRegistered) {
const override {
LOG(FATAL) << "Unused";
}
StatusOr<Layout> GetDefaultLayout(
PrimitiveType element_type,
absl::Span<const int64_t> dims) const override {
return Unimplemented("TestTopology does not support GetDefaultLayout");
}
};
PjRtTestTopology topology;

Expand Down
5 changes: 5 additions & 0 deletions xla/python/py_compile_only_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,11 @@ class CompileOnlyIfRtClient final
return topology_;
}

StatusOr<Layout> GetDefaultLayout(PrimitiveType element_type,
absl::Span<const int64_t> dims) override {
return topology_->GetDefaultLayout(element_type, dims);
}

private:
InvalidIfrtCompiler default_compiler_;
std::shared_ptr<PjRtTopologyDescription> topology_;
Expand Down

0 comments on commit c0e91f8

Please sign in to comment.