Skip to content

Commit

Permalink
Add kCpu property tag.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 719007295
  • Loading branch information
nvgrw authored and Google-ML-Automation committed Jan 23, 2025
1 parent b3f7ab6 commit 53c999d
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 0 deletions.
3 changes: 3 additions & 0 deletions xla/service/hlo_runner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -718,6 +718,9 @@ bool HloRunner::HasProperty(const HloRunnerPropertyTag::Type tag) const {
return std::holds_alternative<stream_executor::RocmComputeCapability>(
device_description.gpu_compute_capability());
}
if (tag == HloRunnerPropertyTag::kCpu) {
return backend().platform()->Name() == "Host";
}
return false;
}

Expand Down
2 changes: 2 additions & 0 deletions xla/service/hlo_runner_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ class HloRunnerPropertyTag final {
static constexpr Type kDefault = 0;
// Indicates that the runner is using ROCm.
static constexpr Type kUsingGpuRocm = 1;
// Indicates that this runner is a CPU runner.
static constexpr Type kCpu = 2;

private:
HloRunnerPropertyTag() = default;
Expand Down
3 changes: 3 additions & 0 deletions xla/service/hlo_runner_pjrt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -662,6 +662,9 @@ bool HloRunnerPjRt::HasProperty(const HloRunnerPropertyTag::Type tag) const {
if (tag == HloRunnerPropertyTag::kUsingGpuRocm) {
return pjrt_client_->platform_name() == xla::RocmName();
}
if (tag == HloRunnerPropertyTag::kCpu) {
return pjrt_client_->platform_name() == xla::CpuName();
}
return false;
}

Expand Down

0 comments on commit 53c999d

Please sign in to comment.