diff --git a/xla/BUILD b/xla/BUILD index 9e9cf1343781ef..2d8a0fe99a1ab7 100644 --- a/xla/BUILD +++ b/xla/BUILD @@ -629,6 +629,33 @@ xla_cc_test( ], ) +cc_library( + name = "literal_pool", + srcs = ["literal_pool.cc"], + hdrs = ["literal_pool.h"], + visibility = ["//visibility:public"], + deps = [ + ":literal", + ":shape_util", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/synchronization", + "@tsl//tsl/platform:logging", + ], +) + +xla_cc_test( + name = "literal_pool_test", + srcs = ["literal_pool_test.cc"], + deps = [ + ":literal", + ":literal_pool", + ":literal_util", + "@tsl//tsl/platform:test", + "@tsl//tsl/platform:test_main", + ], +) + cc_library( name = "literal_util", srcs = ["literal_util.cc"], diff --git a/xla/hlo/ir/BUILD b/xla/hlo/ir/BUILD index 616f686e946ffa..2b40f783d01abe 100644 --- a/xla/hlo/ir/BUILD +++ b/xla/hlo/ir/BUILD @@ -65,6 +65,7 @@ cc_library( "//xla:array", "//xla:comparison_util", "//xla:literal", + "//xla:literal_pool", "//xla:literal_util", "//xla:printer", "//xla:protobuf_util", diff --git a/xla/hlo/ir/dfs_hlo_visitor_with_default.h b/xla/hlo/ir/dfs_hlo_visitor_with_default.h index c9ba49231955ab..56846cac1d9647 100644 --- a/xla/hlo/ir/dfs_hlo_visitor_with_default.h +++ b/xla/hlo/ir/dfs_hlo_visitor_with_default.h @@ -380,6 +380,7 @@ class DfsHloRewriteVisitor : public DfsHloVisitorWithDefault { // Mark the computation as having changed. void MarkAsChanged() { changed_ = true; } + void MarkAsMaybeChanged(bool changed) { changed_ |= changed; } private: bool changed_ = false; diff --git a/xla/hlo/ir/hlo_instructions.h b/xla/hlo/ir/hlo_instructions.h index 6830061d85036e..1ca2bfddd55592 100644 --- a/xla/hlo/ir/hlo_instructions.h +++ b/xla/hlo/ir/hlo_instructions.h @@ -40,6 +40,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/layout.h" #include "xla/literal.h" +#include "xla/literal_pool.h" #include "xla/printer.h" #include "xla/service/hlo.pb.h" #include "xla/shape.h" @@ -1343,6 +1344,18 @@ class HloConstantInstruction : public HloInstruction { return hlo->opcode() == HloOpcode::kConstant; } + // Canonicalize constant literal using the given literal pool. + bool Canonicalize(LiteralPool* literal_pool) { + if (literal_pool && literal_) { + auto canonical = literal_pool->GetCanonicalLiteral(literal_); + if (canonical != literal_) { + literal_ = std::move(canonical); + return true; + } + } + return false; + } + private: bool IsElementwiseImpl( const std::optional& operand_idx) const override; diff --git a/xla/hlo/transforms/BUILD b/xla/hlo/transforms/BUILD index d56e5e0ce76876..3b965e0a64cec9 100644 --- a/xla/hlo/transforms/BUILD +++ b/xla/hlo/transforms/BUILD @@ -2026,6 +2026,39 @@ cc_library( ], ) +cc_library( + name = "literal_canonicalizer", + srcs = ["literal_canonicalizer.cc"], + hdrs = ["literal_canonicalizer.h"], + deps = [ + "//xla:literal_pool", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "//xla/hlo/pass:hlo_pass_pipeline", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@tsl//tsl/platform:errors", + ], +) + +xla_cc_test( + name = "literal_canonicalizer_test", + srcs = ["literal_canonicalizer_test.cc"], + deps = [ + ":literal_canonicalizer", + "//xla:literal_pool", + "//xla/hlo/ir:hlo", + "//xla/hlo/parser:hlo_parser", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "@com_google_absl//absl/strings:string_view", + "@tsl//tsl/platform:statusor", + "@tsl//tsl/platform:test", + "@tsl//tsl/platform:test_main", + ], +) + cc_library( name = "optimize_input_output_buffer_alias", srcs = ["simplifiers/optimize_input_output_buffer_alias.cc"], diff --git a/xla/hlo/transforms/literal_canonicalizer.cc b/xla/hlo/transforms/literal_canonicalizer.cc new file mode 100644 index 00000000000000..3712881a4f7927 --- /dev/null +++ b/xla/hlo/transforms/literal_canonicalizer.cc @@ -0,0 +1,75 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/hlo/transforms/literal_canonicalizer.h" + +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/dfs_hlo_visitor.h" +#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" +#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/literal_pool.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/logging.h" + +namespace xla { +namespace { + +class LiteralCanonicalizerVisitor : public DfsHloRewriteVisitor { + public: + LiteralCanonicalizerVisitor(LiteralPool* literal_pool, size_t min_size_bytes) + : literal_pool_(literal_pool), min_size_bytes_(min_size_bytes) {} + + absl::Status HandleConstant(HloInstruction* hlo) final { + auto* constant = Cast(hlo); + if (constant->HasLiteral() && + constant->literal().size_bytes() >= min_size_bytes_) { + MarkAsMaybeChanged(constant->Canonicalize(literal_pool_)); + } + return absl::OkStatus(); + } + + private: + LiteralPool* literal_pool_; + size_t min_size_bytes_; +}; + +} // namespace + +LiteralCanonicalizer::LiteralCanonicalizer(LiteralPool* literal_pool, + size_t min_size_bytes) + : literal_pool_(literal_pool), min_size_bytes_(min_size_bytes) {} + +absl::StatusOr LiteralCanonicalizer::Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) { + // Every time we canonicalize literals in a module, we garbage collect expired + // literals from the pool. + size_t num_erased = literal_pool_->GarbageCollect(); + VLOG(3) << "Garbage collected " << num_erased << " expired literals"; + + LiteralCanonicalizerVisitor visitor(literal_pool_, min_size_bytes_); + TF_RETURN_IF_ERROR(module->entry_computation()->Accept(&visitor)); + return visitor.changed(); +} + +} // namespace xla diff --git a/xla/hlo/transforms/literal_canonicalizer.h b/xla/hlo/transforms/literal_canonicalizer.h new file mode 100644 index 00000000000000..26d1768f374a79 --- /dev/null +++ b/xla/hlo/transforms/literal_canonicalizer.h @@ -0,0 +1,50 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TRANSFORMS_LITERAL_CANONICALIZER_H_ +#define XLA_HLO_TRANSFORMS_LITERAL_CANONICALIZER_H_ + +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/pass/hlo_pass_interface.h" +#include "xla/literal_pool.h" + +namespace xla { + +// Canonicalizes literals larger than 'min_size_bytes' in the HLO module using +// the given literal pool. +class LiteralCanonicalizer : public HloModulePass { + public: + LiteralCanonicalizer(LiteralPool* literal_pool, size_t min_size_bytes); + + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; + + absl::string_view name() const override { return "literal-canonicalizer"; } + + protected: + LiteralPool* literal_pool_; + size_t min_size_bytes_; +}; + +} // namespace xla + +#endif // XLA_HLO_TRANSFORMS_LITERAL_CANONICALIZER_H_ diff --git a/xla/hlo/transforms/literal_canonicalizer_test.cc b/xla/hlo/transforms/literal_canonicalizer_test.cc new file mode 100644 index 00000000000000..95afd269d4b090 --- /dev/null +++ b/xla/hlo/transforms/literal_canonicalizer_test.cc @@ -0,0 +1,61 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/hlo/transforms/literal_canonicalizer.h" + +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/parser/hlo_parser.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" +#include "xla/literal_pool.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" + +namespace xla { +namespace { + +class LiteralCanonicalizerTest : public HloHardwareIndependentTestBase {}; + +TEST_F(LiteralCanonicalizerTest, CanonicalizeConstants) { + absl::string_view hlo_string = R"( + HloModule m + + ENTRY %entry { + ROOT %c0 = f32[4] constant({1.0, 2.0, 3.0, 4.0}) + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(auto module0, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(auto module1, + ParseAndReturnVerifiedModule(hlo_string)); + + LiteralPool literal_pool; + LiteralCanonicalizer literal_canonicalizer(&literal_pool, 0); + + EXPECT_FALSE(literal_canonicalizer.Run(module0.get()).value()); + EXPECT_TRUE(literal_canonicalizer.Run(module1.get()).value()); + + auto* c0 = Cast( + module0->entry_computation()->root_instruction()); + auto* c1 = Cast( + module1->entry_computation()->root_instruction()); + + EXPECT_EQ(c0->literal(), c1->literal()); +} + +} // namespace +} // namespace xla diff --git a/xla/literal_pool.cc b/xla/literal_pool.cc new file mode 100644 index 00000000000000..e3ce7269621f6b --- /dev/null +++ b/xla/literal_pool.cc @@ -0,0 +1,114 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/literal_pool.h" + +#include +#include +#include +#include +#include + +#include "absl/synchronization/mutex.h" +#include "xla/literal.h" +#include "xla/shape.h" +#include "tsl/platform/logging.h" + +namespace xla { + +LiteralPool* LiteralPool::Default() { + static auto* pool = new LiteralPool(); + return pool; +} + +// Erases expired weak pointers from the vector and returns the number of +// elements that were erased. +static size_t EraseExpiredLiterals( + std::vector>& literals) { + auto it = std::remove_if(literals.begin(), literals.end(), + [](auto& ptr) { return ptr.expired(); }); + size_t num_erased = std::distance(it, literals.end()); + + literals.erase(it, literals.end()); + return num_erased; +} + +size_t LiteralPool::GarbageCollect() { + absl::MutexLock lock(&mu_); + size_t num_erased = 0; + + for (auto& [shape, literals] : literals_) { + num_erased += EraseExpiredLiterals(literals); + } + + VLOG(3) << "Garbage collected " << num_erased << " literals"; + return num_erased; +} + +size_t LiteralPool::GarbageCollect(Shape shape) { + absl::MutexLock lock(&mu_); + size_t num_erased = 0; + + if (auto it = literals_.find(shape); it != literals_.end()) { + num_erased = EraseExpiredLiterals(it->second); + } + + VLOG(3) << "Garbage collected " << num_erased << " literals for shape " + << shape.ToString(); + return num_erased; +} + +// Tried to find a canonical literal in the pool. Return nullptr if not found. +static std::shared_ptr FindCanonicalLiteral( + std::vector>& literals, const Literal& literal) { + for (std::weak_ptr& ptr : literals) { + if (auto locked_ptr = ptr.lock()) { + if (locked_ptr->Equal(literal, /*layout_sensitive=*/true)) { + return locked_ptr; + } + } + } + + return nullptr; +} + +std::shared_ptr LiteralPool::GetCanonicalLiteral( + const Literal& literal) { + absl::MutexLock lock(&mu_); + + auto& literals = literals_[literal.shape()]; + if (auto ptr = FindCanonicalLiteral(literals, literal)) { + return ptr; + } + + std::shared_ptr new_literal = literal.CloneToUnique(); + literals.push_back(new_literal); + return new_literal; +} + +std::shared_ptr LiteralPool::GetCanonicalLiteral( + std::shared_ptr literal) { + absl::MutexLock lock(&mu_); + + auto& literals = literals_[literal->shape()]; + if (auto ptr = FindCanonicalLiteral(literals, *literal)) { + return ptr; + } + + literals.push_back(literal); + return literal; +} + +} // namespace xla diff --git a/xla/literal_pool.h b/xla/literal_pool.h new file mode 100644 index 00000000000000..4e53181b05e9a6 --- /dev/null +++ b/xla/literal_pool.h @@ -0,0 +1,67 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_LITERAL_POOL_H_ +#define XLA_LITERAL_POOL_H_ + +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" +#include "absl/synchronization/mutex.h" +#include "xla/literal.h" +#include "xla/shape.h" + +namespace xla { + +// Literal pool provides a mechanism to deduplicate identical literals and +// share them across multiple HLO modules. +class LiteralPool { + public: + // Returns a default literal pool that can be used across multiple HLO modules + // in a process. + static LiteralPool* Default(); + + // Returns a canonical literal from the pool. If the literal is not in the + // pool, it is added to the pool and returned back. + std::shared_ptr GetCanonicalLiteral(const Literal& literal); + + // Returns a canonical literal from the pool. If the literal is not in the + // pool, it is added to the pool and returned back. + std::shared_ptr GetCanonicalLiteral( + std::shared_ptr literal); + + // Runs garbage collection on all the literals in the pool. Returns the number + // of literals that were garbage collected. + size_t GarbageCollect(); + + // Runs garbage collection on literals with the given shape. Returns the + // number of literals that were garbage collected. + size_t GarbageCollect(Shape shape); + + private: + // We keep weak pointers to the literals in the pool to allow for garbage + // collection when owning HLO modules are destroyed. We run periodic garbage + // collection to clean up the literals that are no longer referenced. + absl::Mutex mu_; + absl::flat_hash_map>> literals_ + ABSL_GUARDED_BY(mu_); +}; + +} // namespace xla + +#endif // XLA_LITERAL_POOL_H_ diff --git a/xla/literal_pool_test.cc b/xla/literal_pool_test.cc new file mode 100644 index 00000000000000..b655c8c4661f77 --- /dev/null +++ b/xla/literal_pool_test.cc @@ -0,0 +1,45 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/literal_pool.h" + +#include "xla/literal_util.h" +#include "tsl/platform/test.h" + +namespace xla { +namespace { + +TEST(LiteralPoolTest, GetCanonicalLiteral) { + LiteralPool pool; + + auto l0 = LiteralUtil::CreateR2({{1., 2.}, {3., 4.}}); + auto l1 = LiteralUtil::CreateR2({{2., 1.}, {4., 3.}}); + + { // Use nested scope to allow garbage collection below. + auto cl0_0 = pool.GetCanonicalLiteral(l0); + auto cl0_1 = pool.GetCanonicalLiteral(l0); + ASSERT_EQ(cl0_0, cl0_1); + + auto cl1_0 = pool.GetCanonicalLiteral(l1); + auto cl1_1 = pool.GetCanonicalLiteral(l1); + ASSERT_NE(cl0_0, cl1_0); + ASSERT_EQ(cl1_0, cl1_1); + } + + ASSERT_EQ(pool.GarbageCollect(), 2); +} + +} // namespace +} // namespace xla diff --git a/xla/service/cpu/BUILD b/xla/service/cpu/BUILD index 2add6936cb6d16..0431b4dab3e290 100644 --- a/xla/service/cpu/BUILD +++ b/xla/service/cpu/BUILD @@ -235,6 +235,7 @@ cc_library( "//xla:cpu_function_runtime", "//xla:debug_options_flags", "//xla:literal", + "//xla:literal_pool", "//xla:protobuf_util", "//xla:shape_util", "//xla:status_macros", @@ -271,6 +272,7 @@ cc_library( "//xla/hlo/transforms:hlo_constant_folding", "//xla/hlo/transforms:hlo_dce", "//xla/hlo/transforms:hlo_memory_scheduler", + "//xla/hlo/transforms:literal_canonicalizer", "//xla/hlo/transforms:logistic_expander", "//xla/hlo/transforms:operand_upcaster", "//xla/hlo/transforms:optimization_barrier_expander", diff --git a/xla/service/cpu/cpu_compiler.cc b/xla/service/cpu/cpu_compiler.cc index 305cabc3c99e5c..70fe1f7403bcbf 100644 --- a/xla/service/cpu/cpu_compiler.cc +++ b/xla/service/cpu/cpu_compiler.cc @@ -111,6 +111,7 @@ limitations under the License. #include "xla/hlo/transforms/expanders/rng_bit_generator_expander.h" #include "xla/hlo/transforms/expanders/rng_expander.h" #include "xla/hlo/transforms/expanders/stochastic_convert_decomposer.h" +#include "xla/hlo/transforms/literal_canonicalizer.h" #include "xla/hlo/transforms/operand_upcaster.h" #include "xla/hlo/transforms/simplifiers/algebraic_simplifier.h" #include "xla/hlo/transforms/simplifiers/batch_dot_simplification.h" @@ -135,6 +136,7 @@ limitations under the License. #include "xla/hlo/transforms/while_loop_trip_count_annotator.h" #include "xla/hlo/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h" #include "xla/literal.h" +#include "xla/literal_pool.h" #include "xla/map_util.h" #include "xla/mlir_hlo/transforms/passes.h" #include "xla/primitive_util.h" @@ -751,6 +753,12 @@ absl::Status CpuCompiler::RunHloPassesThroughLayoutAssn( pipeline.AddPass( SubByteNormalization::SET_ELEMENT_SIZE); } + + // Finally canonicalize all literals larger than 1024 bytes in the module to + // reuse the same literal across multiple HLO modules. + pipeline.AddPass(LiteralPool::Default(), + /*min_size_bytes=*/1024); + return pipeline.Run(module).status(); }