Skip to content

Commit

Permalink
[xla] Add LiteralPool and LiteralCanonicalizer to share constant lite…
Browse files Browse the repository at this point in the history
…rals between HLO modules

This change saves a lot of host memory from duplicate constant literals in instantiated HLO modules.

PiperOrigin-RevId: 706775155
  • Loading branch information
ezhulenev authored and Google-ML-Automation committed Dec 17, 2024
1 parent 33f1796 commit f3dd9ec
Show file tree
Hide file tree
Showing 13 changed files with 497 additions and 0 deletions.
27 changes: 27 additions & 0 deletions xla/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -629,6 +629,33 @@ xla_cc_test(
],
)

cc_library(
name = "literal_pool",
srcs = ["literal_pool.cc"],
hdrs = ["literal_pool.h"],
visibility = ["//visibility:public"],
deps = [
":literal",
":shape_util",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/synchronization",
"@tsl//tsl/platform:logging",
],
)

xla_cc_test(
name = "literal_pool_test",
srcs = ["literal_pool_test.cc"],
deps = [
":literal",
":literal_pool",
":literal_util",
"@tsl//tsl/platform:test",
"@tsl//tsl/platform:test_main",
],
)

cc_library(
name = "literal_util",
srcs = ["literal_util.cc"],
Expand Down
1 change: 1 addition & 0 deletions xla/hlo/ir/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ cc_library(
"//xla:array",
"//xla:comparison_util",
"//xla:literal",
"//xla:literal_pool",
"//xla:literal_util",
"//xla:printer",
"//xla:protobuf_util",
Expand Down
1 change: 1 addition & 0 deletions xla/hlo/ir/dfs_hlo_visitor_with_default.h
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,7 @@ class DfsHloRewriteVisitor : public DfsHloVisitorWithDefault {

// Mark the computation as having changed.
void MarkAsChanged() { changed_ = true; }
void MarkAsMaybeChanged(bool changed) { changed_ |= changed; }

private:
bool changed_ = false;
Expand Down
13 changes: 13 additions & 0 deletions xla/hlo/ir/hlo_instructions.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ limitations under the License.
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/layout.h"
#include "xla/literal.h"
#include "xla/literal_pool.h"
#include "xla/printer.h"
#include "xla/service/hlo.pb.h"
#include "xla/shape.h"
Expand Down Expand Up @@ -1343,6 +1344,18 @@ class HloConstantInstruction : public HloInstruction {
return hlo->opcode() == HloOpcode::kConstant;
}

// Canonicalize constant literal using the given literal pool.
bool Canonicalize(LiteralPool* literal_pool) {
if (literal_pool && literal_) {
auto canonical = literal_pool->GetCanonicalLiteral(literal_);
if (canonical != literal_) {
literal_ = std::move(canonical);
return true;
}
}
return false;
}

private:
bool IsElementwiseImpl(
const std::optional<int64_t>& operand_idx) const override;
Expand Down
33 changes: 33 additions & 0 deletions xla/hlo/transforms/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2026,6 +2026,39 @@ cc_library(
],
)

cc_library(
name = "literal_canonicalizer",
srcs = ["literal_canonicalizer.cc"],
hdrs = ["literal_canonicalizer.h"],
deps = [
"//xla:literal_pool",
"//xla/hlo/ir:hlo",
"//xla/hlo/pass:hlo_pass",
"//xla/hlo/pass:hlo_pass_pipeline",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:string_view",
"@tsl//tsl/platform:errors",
],
)

xla_cc_test(
name = "literal_canonicalizer_test",
srcs = ["literal_canonicalizer_test.cc"],
deps = [
":literal_canonicalizer",
"//xla:literal_pool",
"//xla/hlo/ir:hlo",
"//xla/hlo/parser:hlo_parser",
"//xla/hlo/testlib:hlo_hardware_independent_test_base",
"@com_google_absl//absl/strings:string_view",
"@tsl//tsl/platform:statusor",
"@tsl//tsl/platform:test",
"@tsl//tsl/platform:test_main",
],
)

cc_library(
name = "optimize_input_output_buffer_alias",
srcs = ["simplifiers/optimize_input_output_buffer_alias.cc"],
Expand Down
75 changes: 75 additions & 0 deletions xla/hlo/transforms/literal_canonicalizer.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/* Copyright 2024 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "xla/hlo/transforms/literal_canonicalizer.h"

#include <cstddef>

#include "absl/container/flat_hash_set.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "xla/hlo/ir/dfs_hlo_visitor.h"
#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
#include "xla/hlo/ir/hlo_casting_utils.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_instructions.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/literal_pool.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/logging.h"

namespace xla {
namespace {

class LiteralCanonicalizerVisitor : public DfsHloRewriteVisitor {
public:
LiteralCanonicalizerVisitor(LiteralPool* literal_pool, size_t min_size_bytes)
: literal_pool_(literal_pool), min_size_bytes_(min_size_bytes) {}

absl::Status HandleConstant(HloInstruction* hlo) final {
auto* constant = Cast<HloConstantInstruction>(hlo);
if (constant->HasLiteral() &&
constant->literal().size_bytes() >= min_size_bytes_) {
MarkAsMaybeChanged(constant->Canonicalize(literal_pool_));
}
return absl::OkStatus();
}

private:
LiteralPool* literal_pool_;
size_t min_size_bytes_;
};

} // namespace

LiteralCanonicalizer::LiteralCanonicalizer(LiteralPool* literal_pool,
size_t min_size_bytes)
: literal_pool_(literal_pool), min_size_bytes_(min_size_bytes) {}

absl::StatusOr<bool> LiteralCanonicalizer::Run(
HloModule* module,
const absl::flat_hash_set<absl::string_view>& execution_threads) {
// Every time we canonicalize literals in a module, we garbage collect expired
// literals from the pool.
size_t num_erased = literal_pool_->GarbageCollect();
VLOG(3) << "Garbage collected " << num_erased << " expired literals";

LiteralCanonicalizerVisitor visitor(literal_pool_, min_size_bytes_);
TF_RETURN_IF_ERROR(module->entry_computation()->Accept(&visitor));
return visitor.changed();
}

} // namespace xla
50 changes: 50 additions & 0 deletions xla/hlo/transforms/literal_canonicalizer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/* Copyright 2024 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef XLA_HLO_TRANSFORMS_LITERAL_CANONICALIZER_H_
#define XLA_HLO_TRANSFORMS_LITERAL_CANONICALIZER_H_

#include <cstddef>

#include "absl/container/flat_hash_set.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/hlo/pass/hlo_pass_interface.h"
#include "xla/literal_pool.h"

namespace xla {

// Canonicalizes literals larger than 'min_size_bytes' in the HLO module using
// the given literal pool.
class LiteralCanonicalizer : public HloModulePass {
public:
LiteralCanonicalizer(LiteralPool* literal_pool, size_t min_size_bytes);

using HloPassInterface::Run;
absl::StatusOr<bool> Run(
HloModule* module,
const absl::flat_hash_set<absl::string_view>& execution_threads) override;

absl::string_view name() const override { return "literal-canonicalizer"; }

protected:
LiteralPool* literal_pool_;
size_t min_size_bytes_;
};

} // namespace xla

#endif // XLA_HLO_TRANSFORMS_LITERAL_CANONICALIZER_H_
61 changes: 61 additions & 0 deletions xla/hlo/transforms/literal_canonicalizer_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/* Copyright 2024 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "xla/hlo/transforms/literal_canonicalizer.h"

#include "absl/strings/string_view.h"
#include "xla/hlo/ir/hlo_casting_utils.h"
#include "xla/hlo/ir/hlo_instructions.h"
#include "xla/hlo/parser/hlo_parser.h"
#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h"
#include "xla/literal_pool.h"
#include "tsl/platform/statusor.h"
#include "tsl/platform/test.h"

namespace xla {
namespace {

class LiteralCanonicalizerTest : public HloHardwareIndependentTestBase {};

TEST_F(LiteralCanonicalizerTest, CanonicalizeConstants) {
absl::string_view hlo_string = R"(
HloModule m
ENTRY %entry {
ROOT %c0 = f32[4] constant({1.0, 2.0, 3.0, 4.0})
}
)";

TF_ASSERT_OK_AND_ASSIGN(auto module0,
ParseAndReturnVerifiedModule(hlo_string));
TF_ASSERT_OK_AND_ASSIGN(auto module1,
ParseAndReturnVerifiedModule(hlo_string));

LiteralPool literal_pool;
LiteralCanonicalizer literal_canonicalizer(&literal_pool, 0);

EXPECT_FALSE(literal_canonicalizer.Run(module0.get()).value());
EXPECT_TRUE(literal_canonicalizer.Run(module1.get()).value());

auto* c0 = Cast<HloConstantInstruction>(
module0->entry_computation()->root_instruction());
auto* c1 = Cast<HloConstantInstruction>(
module1->entry_computation()->root_instruction());

EXPECT_EQ(c0->literal(), c1->literal());
}

} // namespace
} // namespace xla
Loading

0 comments on commit f3dd9ec

Please sign in to comment.