Skip to content

Commit

Permalink
Add virtual method IsInlineableCallOp to CallInliner to allow sub…
Browse files Browse the repository at this point in the history
…classes to change which call instructions to inline. And clean up `#include`s in `call_inliner.cc`.

PiperOrigin-RevId: 621842248
  • Loading branch information
bartchr808 authored and copybara-github committed Apr 4, 2024
1 parent b879cfa commit bd5e32f
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 5 deletions.
9 changes: 9 additions & 0 deletions xla/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -969,10 +969,19 @@ cc_library(
":hlo_dce",
":hlo_domain_isolator",
":hlo_pass",
"//xla:status",
"//xla:status_macros",
"//xla:statusor",
"//xla:util",
"//xla/hlo/ir:hlo",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:string_view",
"@tsl//tsl/platform:errors",
"@tsl//tsl/platform:statusor",
],
)

Expand Down
24 changes: 20 additions & 4 deletions xla/service/call_inliner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,26 @@ limitations under the License.
#include "xla/service/call_inliner.h"

#include <memory>

#include <utility>
#include <vector>

#include "absl/container/flat_hash_set.h"
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/hlo/ir/hlo_sharding_metadata.h"
#include "xla/service/call_graph.h"
#include "xla/service/hlo_dce.h"
#include "xla/service/hlo_domain_isolator.h"
#include "xla/status.h"
#include "xla/status_macros.h"
#include "xla/util.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/statusor.h"

namespace xla {
namespace {
Expand Down Expand Up @@ -136,6 +148,11 @@ CallInliner::Inline(HloInstruction* call) {
return visitor.ConsumeInstructionMap();
}

bool CallInliner::IsInlineableCallOp(HloInstruction* instruction) const {
return instruction->opcode() == HloOpcode::kCall &&
!instruction->parent()->IsAsyncComputation();
}

absl::StatusOr<bool> CallInliner::Run(
HloModule* module,
const absl::flat_hash_set<absl::string_view>& execution_threads) {
Expand All @@ -156,8 +173,7 @@ absl::StatusOr<bool> CallInliner::Run(
// used for parallel device computation.
// TODO(b/229887502): update the inliner to ignore only parallel
// device type async call instead of all.
if (instruction->opcode() == HloOpcode::kCall &&
!instruction->parent()->IsAsyncComputation()) {
if (IsInlineableCallOp(instruction)) {
const auto& callees = instruction->called_computations();
TF_RET_CHECK(callees.size() == 1);
if (!single_call_site_ || call_graph->GetNode(instruction->to_apply())
Expand All @@ -182,7 +198,7 @@ absl::StatusOr<bool> CallInliner::Run(
// Run DCE to remove called computations which are now becoming unused.
// This can result then in problems if within the called computation, there
// were send/recv instructions, which the module group verifier will flag as
// error findingthe same channel ID used for multiple send/recv
// error finding the same channel ID used for multiple send/recv
// instructions.
TF_RETURN_IF_ERROR(HloDCE().Run(module, execution_threads).status());
}
Expand Down
10 changes: 9 additions & 1 deletion xla/service/call_inliner.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,12 @@ limitations under the License.
#define XLA_SERVICE_CALL_INLINER_H_

#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/service/hlo_pass_interface.h"
#include "xla/statusor.h"

namespace xla {

Expand Down Expand Up @@ -48,6 +52,10 @@ class CallInliner : public HloModulePass {
HloModule* module,
const absl::flat_hash_set<absl::string_view>& execution_threads) override;

// Returns true if the instruction is a kCall operation and is eligible for
// inlining.
virtual bool IsInlineableCallOp(HloInstruction* instruction) const;

private:
bool single_call_site_;
bool update_domain_;
Expand Down

0 comments on commit bd5e32f

Please sign in to comment.