diff --git a/xla/service/cost_modelling/BUILD b/xla/service/cost_modelling/BUILD index 38b304b5ce1a6..c12df2742859a 100644 --- a/xla/service/cost_modelling/BUILD +++ b/xla/service/cost_modelling/BUILD @@ -54,10 +54,10 @@ xla_cc_test( "//xla/service:platform_util", "//xla/tests:hlo_runner_agnostic_test_base", "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/platform:statusor", + "//xla/tsl/platform:test", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest_main", - "@tsl//tsl/platform:statusor", - "@tsl//tsl/platform:test", ], ) diff --git a/xla/service/cost_modelling/op_cost_test.cc b/xla/service/cost_modelling/op_cost_test.cc index 0ec01352397b9..1cfbcad1874c5 100644 --- a/xla/service/cost_modelling/op_cost_test.cc +++ b/xla/service/cost_modelling/op_cost_test.cc @@ -31,8 +31,8 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/tests/hlo_runner_agnostic_test_base.h" #include "xla/tsl/lib/core/status_test_util.h" -#include "tsl/platform/statusor.h" -#include "tsl/platform/test.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/platform/test.h" namespace xla { namespace { @@ -86,8 +86,6 @@ class OpCostTest : public HloRunnerAgnosticTestBase { protected: OpCostTest() : HloRunnerAgnosticTestBase( - std::make_unique( - PlatformUtil::GetDefaultPlatform().value()), std::make_unique( PlatformUtil::GetDefaultPlatform().value())) {} diff --git a/xla/service/gpu/BUILD b/xla/service/gpu/BUILD index deb84e2014b59..4a010cf3b6c03 100644 --- a/xla/service/gpu/BUILD +++ b/xla/service/gpu/BUILD @@ -2594,9 +2594,10 @@ xla_cc_test( "//xla/stream_executor:device_description", "//xla/tests:hlo_runner_agnostic_test_base", "//xla/tests:xla_internal_test_main", + "//xla/tsl/platform:statusor", + "//xla/tsl/platform:test", "@com_google_absl//absl/strings", "@com_google_googletest//:gtest_main", - "@tsl//tsl/platform:statusor", ], ) diff --git a/xla/service/gpu/gpu_fusible_test.cc b/xla/service/gpu/gpu_fusible_test.cc index de1b51168682e..5c3af4434a3ba 100644 --- a/xla/service/gpu/gpu_fusible_test.cc +++ b/xla/service/gpu/gpu_fusible_test.cc @@ -16,8 +16,8 @@ limitations under the License. #include "xla/service/gpu/gpu_fusible.h" #include +#include -#include #include "absl/strings/str_cat.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" @@ -27,7 +27,8 @@ limitations under the License. #include "xla/service/platform_util.h" #include "xla/stream_executor/device_description.h" #include "xla/tests/hlo_runner_agnostic_test_base.h" -#include "tsl/platform/statusor.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/platform/test.h" namespace xla { namespace gpu { @@ -46,8 +47,6 @@ class GpuFusibleTest : public HloRunnerAgnosticTestBase { public: GpuFusibleTest() : HloRunnerAgnosticTestBase( - std::make_unique( - PlatformUtil::GetDefaultPlatform().value()), std::make_unique( PlatformUtil::GetDefaultPlatform().value())), device_description_(MakeDeviceDescription()) {} diff --git a/xla/service/gpu/transforms/BUILD b/xla/service/gpu/transforms/BUILD index fc3a25a9cd56b..1f7752a8cd938 100644 --- a/xla/service/gpu/transforms/BUILD +++ b/xla/service/gpu/transforms/BUILD @@ -3483,10 +3483,9 @@ xla_cc_test( "//xla/tests:hlo_runner_agnostic_test_base", "//xla/tests:test_utils", "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/platform:statusor", + "//xla/tsl/platform:test", "@com_google_absl//absl/log", - "@com_google_googletest//:gtest", "@com_google_googletest//:gtest_main", - "@tsl//tsl/platform:statusor", - "@tsl//tsl/platform:test", ], ) diff --git a/xla/service/gpu/transforms/ragged_all_to_all_decomposer_test.cc b/xla/service/gpu/transforms/ragged_all_to_all_decomposer_test.cc index be1ddb782e3f6..9b89b871f681e 100644 --- a/xla/service/gpu/transforms/ragged_all_to_all_decomposer_test.cc +++ b/xla/service/gpu/transforms/ragged_all_to_all_decomposer_test.cc @@ -17,8 +17,6 @@ limitations under the License. #include -#include -#include #include "absl/log/log.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/testlib/filecheck.h" @@ -28,8 +26,8 @@ limitations under the License. #include "xla/tests/hlo_runner_agnostic_test_base.h" #include "xla/tests/test_utils.h" #include "xla/tsl/lib/core/status_test_util.h" -#include "tsl/platform/statusor.h" -#include "tsl/platform/test.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/platform/test.h" namespace xla { namespace gpu { @@ -39,8 +37,6 @@ class RaggedAllToAllDecomposerTest : public HloRunnerAgnosticTestBase { public: RaggedAllToAllDecomposerTest() : HloRunnerAgnosticTestBase( - std::make_unique( - PlatformUtil::GetDefaultPlatform().value()), std::make_unique( PlatformUtil::GetDefaultPlatform().value())) {} }; diff --git a/xla/tests/BUILD b/xla/tests/BUILD index f8c1ecd3beae7..b98fef95d0f51 100644 --- a/xla/tests/BUILD +++ b/xla/tests/BUILD @@ -170,6 +170,7 @@ cc_library( srcs = ["hlo_test_base.cc"], hdrs = ["hlo_test_base.h"], deps = [ + ":hlo_runner_agnostic_reference_mixin", ":hlo_runner_agnostic_test_base", ":pjrt_client_registry", "//xla:error_spec", @@ -177,9 +178,7 @@ cc_library( "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", "//xla/hlo/testlib:filecheck", - "//xla/hlo/testlib:hlo_hardware_independent_test_base", "//xla/pjrt:pjrt_client", "//xla/service:backend", "//xla/service:computation_placer_hdr", @@ -193,16 +192,14 @@ cc_library( "//xla/stream_executor:platform", "//xla/stream_executor:stream_executor_memory_allocator", "//xla/tsl/lib/core:status_test_util", - "@com_google_absl//absl/base:core_headers", + "//xla/tsl/platform:status", + "//xla/tsl/platform:statusor", + "//xla/tsl/platform:test", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", - "@tsl//tsl/platform:logging", - "@tsl//tsl/platform:status", - "@tsl//tsl/platform:statusor", - "@tsl//tsl/platform:test", ], ) @@ -216,7 +213,6 @@ cc_library( ":test_utils", "//xla:error_spec", "//xla:literal", - "//xla:shape_util", "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", @@ -290,18 +286,16 @@ cc_library( srcs = ["hlo_pjrt_test_base.cc"], hdrs = ["hlo_pjrt_test_base.h"], deps = [ + ":hlo_pjrt_interpreter_reference_mixin", ":hlo_runner_agnostic_test_base", ":pjrt_client_registry", "//xla:util", "//xla:xla_data_proto_cc", "//xla/pjrt:pjrt_client", - "//xla/pjrt/interpreter:interpreter_client", "//xla/service:hlo_runner_interface", "//xla/service:hlo_runner_pjrt", - "//xla/service:interpreter_plugin", # reference backend "@com_google_absl//absl/log:check", "@com_google_absl//absl/status:statusor", - "@tsl//tsl/platform:logging", ], ) diff --git a/xla/tests/hlo_pjrt_test_base.cc b/xla/tests/hlo_pjrt_test_base.cc index 8eb3002d26dbe..34647c6df37db 100644 --- a/xla/tests/hlo_pjrt_test_base.cc +++ b/xla/tests/hlo_pjrt_test_base.cc @@ -21,14 +21,12 @@ limitations under the License. #include "absl/log/check.h" #include "absl/status/statusor.h" -#include "xla/pjrt/interpreter/interpreter_client.h" #include "xla/pjrt/pjrt_client.h" #include "xla/service/hlo_runner_interface.h" #include "xla/service/hlo_runner_pjrt.h" +#include "xla/tests/hlo_pjrt_interpreter_reference_mixin.h" #include "xla/tests/hlo_runner_agnostic_test_base.h" #include "xla/tests/pjrt_client_registry.h" -#include "xla/util.h" -#include "tsl/platform/logging.h" namespace xla { namespace { @@ -52,21 +50,12 @@ std::unique_ptr GetHloRunnerForTest() { *std::move(client), device_shape_representation_fn, device_shape_size_fn); } -std::unique_ptr GetHloRunnerForReference() { - return std::make_unique( - std::make_unique(), - InterpreterClient::DeviceShapeRepresentation, - InterpreterClient::ShapeSizeBytes, - /*use_parameter_layout_on_device=*/true); -} - } // namespace HloPjRtTestBase::HloPjRtTestBase(HloPjRtTestBaseOptions options) - : HloRunnerAgnosticTestBase(GetHloRunnerForTest(), - GetHloRunnerForReference(), - options.verifier_layout_sensitive, - options.allow_mixed_precision_in_hlo_verifier, - options.instruction_can_change_layout_func) {} + : HloPjRtInterpreterReferenceMixin( + GetHloRunnerForTest(), options.verifier_layout_sensitive, + options.allow_mixed_precision_in_hlo_verifier, + options.instruction_can_change_layout_func) {} } // namespace xla diff --git a/xla/tests/hlo_pjrt_test_base.h b/xla/tests/hlo_pjrt_test_base.h index fe7b95dfba363..9c374b2070bb7 100644 --- a/xla/tests/hlo_pjrt_test_base.h +++ b/xla/tests/hlo_pjrt_test_base.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef XLA_TESTS_HLO_PJRT_TEST_BASE_H_ #define XLA_TESTS_HLO_PJRT_TEST_BASE_H_ +#include "xla/tests/hlo_pjrt_interpreter_reference_mixin.h" #include "xla/tests/hlo_runner_agnostic_test_base.h" #include "xla/util.h" #include "xla/xla_data.pb.h" @@ -28,9 +29,10 @@ struct HloPjRtTestBaseOptions { HloPredicate instruction_can_change_layout_func; }; -class HloPjRtTestBase : public HloRunnerAgnosticTestBase { +class HloPjRtTestBase + : public HloPjRtInterpreterReferenceMixin { protected: - // This uses the SE interpreter backend for the reference backend and + // This uses the PjRt interpreter backend for the reference backend and // automatically finds a PjRt backend for the test backend. explicit HloPjRtTestBase(HloPjRtTestBaseOptions options = {}); }; diff --git a/xla/tests/hlo_runner_agnostic_test_base.cc b/xla/tests/hlo_runner_agnostic_test_base.cc index b781a0eebd37d..341ab477e5599 100644 --- a/xla/tests/hlo_runner_agnostic_test_base.cc +++ b/xla/tests/hlo_runner_agnostic_test_base.cc @@ -44,7 +44,6 @@ limitations under the License. #include "xla/service/hlo_module_util.h" #include "xla/service/hlo_runner_interface.h" #include "xla/service/hlo_verifier.h" -#include "xla/shape.h" #include "xla/tests/literal_test_util.h" #include "xla/tests/test_utils.h" #include "xla/tsl/platform/errors.h" @@ -56,45 +55,15 @@ limitations under the License. namespace xla { -namespace { - -bool ProgramShapesEqual(const ProgramShape& lhs, const ProgramShape& rhs) { - if (lhs.parameters_size() != rhs.parameters_size()) { - return false; - } - for (int i = 0; i < lhs.parameters_size(); i++) { - if (!Shape::Equal().IgnoreElementSizeInLayout()(lhs.parameters(i), - rhs.parameters(i))) { - return false; - } - } - return Shape::Equal().IgnoreElementSizeInLayout()(lhs.result(), rhs.result()); -} - -ProgramShape GetProgramShapeWithLayout(const HloModule& module) { - ProgramShape program_shape; - const auto* entry = module.entry_computation(); - for (const auto* param : entry->parameter_instructions()) { - *program_shape.add_parameters() = param->shape(); - *program_shape.add_parameter_names() = param->name(); - } - *program_shape.mutable_result() = entry->root_instruction()->shape(); - return program_shape; -} - -} // namespace - HloRunnerAgnosticTestBase::HloRunnerAgnosticTestBase( absl::Nonnull> test_runner, - absl::Nonnull> reference_runner, const bool verifier_layout_sensitive, const bool allow_mixed_precision_in_hlo_verifier, const HloPredicate instruction_can_change_layout_func) : HloHardwareIndependentTestBase(verifier_layout_sensitive, allow_mixed_precision_in_hlo_verifier, instruction_can_change_layout_func), - test_runner_(std::move(test_runner)), - reference_runner_(std::move(reference_runner)) {} + test_runner_(std::move(test_runner)) {} std::unique_ptr HloRunnerAgnosticTestBase::CreateNewVerifiedModule( @@ -231,71 +200,6 @@ HloRunnerAgnosticTestBase::ExecuteReplicated( /*device_assignment=*/device_assignment); } -::testing::AssertionResult HloRunnerAgnosticTestBase::RunAndCompare( - std::unique_ptr module, absl::Span arguments, - const std::optional& error, - const std::function& reference_preprocessor, - const std::function& test_preprocessor) { - const absl::StatusOr<::testing::AssertionResult> result = - RunAndCompareInternal(std::move(module), arguments, error, - /*run_hlo_passes=*/true, reference_preprocessor, - test_preprocessor); - if (!result.ok()) { - return ::testing::AssertionFailure() << result.status(); - } - return *result; -} - -::testing::AssertionResult HloRunnerAgnosticTestBase::RunAndCompareNoHloPasses( - std::unique_ptr module, - const absl::Span arguments, - const std::optional& error, - const std::function& reference_preprocessor, - const std::function& test_preprocessor) { - const absl::StatusOr<::testing::AssertionResult> result = - RunAndCompareInternal(std::move(module), arguments, error, - /*run_hlo_passes=*/false, reference_preprocessor, - test_preprocessor); - if (!result.ok()) { - return ::testing::AssertionFailure() << result.status(); - } - return *result; -} - -::testing::AssertionResult HloRunnerAgnosticTestBase::RunAndCompare( - std::unique_ptr module, const std::optional& error, - const std::function& reference_preprocessor, - const std::function& test_preprocessor, - const std::optional args_max_bits_of_precision) { - const std::vector fake_arguments = - MakeFakeArguments(module.get(), /*pseudo_random=*/true, - /*use_large_range=*/false, - /*treat_gte_as_data_formatting=*/false, - args_max_bits_of_precision) - .value(); - std::vector fake_argument_ptrs; - absl::c_transform( - fake_arguments, std::back_inserter(fake_argument_ptrs), - [](const Literal& literal) { return const_cast(&literal); }); - - return RunAndCompare(std::move(module), fake_argument_ptrs, error, - reference_preprocessor, test_preprocessor); -} - -::testing::AssertionResult HloRunnerAgnosticTestBase::RunAndCompareNoHloPasses( - std::unique_ptr module, const std::optional& error, - const std::function& reference_preprocessor, - const std::function& test_preprocessor) { - const std::vector fake_arguments = - MakeFakeArguments(module.get()).value(); - std::vector fake_argument_ptrs; - absl::c_transform( - fake_arguments, std::back_inserter(fake_argument_ptrs), - [](const Literal& literal) { return const_cast(&literal); }); - return RunAndCompareNoHloPasses(std::move(module), fake_argument_ptrs, error, - reference_preprocessor, test_preprocessor); -} - ::testing::AssertionResult HloRunnerAgnosticTestBase::Run( std::unique_ptr module, const bool run_hlo_passes, const std::function& test_preprocessor) { @@ -320,22 +224,6 @@ ::testing::AssertionResult HloRunnerAgnosticTestBase::Run( : ::testing::AssertionFailure() << output.status().message(); } -::testing::AssertionResult HloRunnerAgnosticTestBase::RunAndCompare( - const absl::string_view hlo_string, const std::optional& error, - const std::function& reference_preprocessor, - const std::function& test_preprocessor, - const std::optional args_max_bits_of_precision) { - absl::StatusOr> module = - ParseAndReturnVerifiedModule(hlo_string); - if (!module.ok()) { - return ::testing::AssertionFailure() - << "Error while parsing HLO text format: " - << module.status().ToString(); - } - return RunAndCompare(*std::move(module), error, reference_preprocessor, - test_preprocessor, args_max_bits_of_precision); -} - ::testing::AssertionResult HloRunnerAgnosticTestBase::RunAndCompareTwoModulesReplicated( std::unique_ptr module_0, std::unique_ptr module_1, @@ -721,69 +609,6 @@ ::testing::AssertionResult HloRunnerAgnosticTestBase::RunMultipleTimes( return ::testing::AssertionSuccess(); } -::testing::AssertionResult HloRunnerAgnosticTestBase::RunAndCompareNoHloPasses( - const absl::string_view hlo_string, const std::optional& error, - const std::function& reference_preprocessor, - const std::function& test_preprocessor) { - absl::StatusOr> module = - ParseAndReturnVerifiedModule(hlo_string); - if (!module.ok()) { - return ::testing::AssertionFailure() - << "Error while parsing HLO text format: " - << module.status().ToString(); - } - return RunAndCompareNoHloPasses(*std::move(module), error, - reference_preprocessor, test_preprocessor); -} - -absl::StatusOr> -HloRunnerAgnosticTestBase::MakeReferenceModule( - const HloModule& test_module, - const std::function& reference_preprocessor) { - std::unique_ptr reference_module = test_module.Clone(); - const ProgramShape program_shape = GetProgramShapeWithLayout(test_module); - - if (reference_preprocessor != nullptr) { - reference_preprocessor(reference_module.get()); - if (!ProgramShapesEqual(program_shape, - GetProgramShapeWithLayout(*reference_module))) { - return InvalidArgument( - "reference preprocessor must not modify the program shape"); - } - } - TF_RETURN_IF_ERROR(verifier().Run(reference_module.get()).status()); - return std::move(reference_module); -} - -absl::StatusOr<::testing::AssertionResult> -HloRunnerAgnosticTestBase::RunAndCompareInternal( - std::unique_ptr module, - const absl::Span arguments, - const std::optional& error, const bool run_hlo_passes, - const std::function& reference_preprocessor, - const std::function& test_preprocessor) { - TF_RETURN_IF_ERROR(verifier().Run(module.get()).status()); - TF_ASSIGN_OR_RETURN(std::unique_ptr reference_module, - MakeReferenceModule(*module, reference_preprocessor)); - TF_RETURN_IF_ERROR(PreprocessModuleForTestRunner(module.get())); - if (test_preprocessor != nullptr) { - test_preprocessor(module.get()); - } - // Execute on two backends. - TF_ASSIGN_OR_RETURN( - const Literal test, - test_runner_->Execute(std::move(module), arguments, run_hlo_passes)); - TF_ASSIGN_OR_RETURN(const Literal reference, - reference_runner_->Execute(std::move(reference_module), - arguments, run_hlo_passes)); - if (reference.IsAll(0)) { - LOG(WARNING) << "Reference value is only zeros."; - } - - return LiteralTestUtil::NearOrEqual(/*expected=*/reference, /*actual=*/test, - error); -} - absl::StatusOr<::testing::AssertionResult> HloRunnerAgnosticTestBase::RunAndCompareTwoModulesInternalReplicated( std::unique_ptr module_0, std::unique_ptr module_1, diff --git a/xla/tests/hlo_runner_agnostic_test_base.h b/xla/tests/hlo_runner_agnostic_test_base.h index 3bb8c5b787a91..efbf871fa2820 100644 --- a/xla/tests/hlo_runner_agnostic_test_base.h +++ b/xla/tests/hlo_runner_agnostic_test_base.h @@ -85,7 +85,6 @@ class HloRunnerAgnosticTestBase : public HloHardwareIndependentTestBase { protected: explicit HloRunnerAgnosticTestBase( absl::Nonnull> test_runner, - absl::Nonnull> reference_runner, bool verifier_layout_sensitive = false, bool allow_mixed_precision_in_hlo_verifier = true, HloPredicate instruction_can_change_layout_func = {}); @@ -163,59 +162,15 @@ class HloRunnerAgnosticTestBase : public HloHardwareIndependentTestBase { std::vector> arguments, int64_t num_replicas, bool run_hlo_passes, DeviceAssignment* device_assignment = nullptr); - // Executes the given hlo module on two backends and compares results. - // - // 'arguments': the input of the hlo module. - // - // 'error': if has value, expects the results to be near (within the error - // bound). Otherwise, expects the results to be equal. - // - // 'reference_preprocessor': the module should be ready to run on the test - // backend, but it might need to be tailored so that it is able to run on the - // reference backend. Note that the program shape of the module must not be - // modified. - ::testing::AssertionResult RunAndCompare( - std::unique_ptr module, absl::Span arguments, - const std::optional& error, - const std::function& reference_preprocessor = nullptr, - const std::function& test_preprocessor = nullptr); - - // Same as above, except that the module will be executed without Hlo - // optimization. - ::testing::AssertionResult RunAndCompareNoHloPasses( - std::unique_ptr module, absl::Span arguments, - const std::optional& error, - const std::function& reference_preprocessor = nullptr, - const std::function& test_preprocessor = nullptr); - - // Executes an hlo module with fake inputs and compares the results. - ::testing::AssertionResult RunAndCompare( - std::unique_ptr module, const std::optional& error, - const std::function& reference_preprocessor = nullptr, - const std::function& test_preprocessor = nullptr, - std::optional args_max_bits_of_precision = std::nullopt); - - // Same as above, except that the module will be executed without Hlo - // optimization. - ::testing::AssertionResult RunAndCompareNoHloPasses( - std::unique_ptr module, const std::optional& error, - const std::function& reference_preprocessor = nullptr, - const std::function& test_preprocessor = nullptr); - // Executes an hlo module with fake inputs and checks that the execution is // successful. ::testing::AssertionResult Run( std::unique_ptr module, bool run_hlo_passes, const std::function& test_preprocessor = nullptr); - // Convenient wrappers for executing and comparing an hlo module with fake + // Convenient wrapper for executing and comparing an hlo module with fake // input. Module can be passed in directly, or parsed from an hlo_string, // or loaded from a file. - ::testing::AssertionResult RunAndCompare( - absl::string_view hlo_string, const std::optional& error, - const std::function& reference_preprocessor = nullptr, - const std::function& test_preprocessor = nullptr, - std::optional args_max_bits_of_precision = std::nullopt); ::testing::AssertionResult Run( absl::string_view hlo_string, bool run_hlo_passes = true, ExecutionProfile* profile = nullptr, @@ -296,10 +251,6 @@ class HloRunnerAgnosticTestBase : public HloHardwareIndependentTestBase { std::vector* profiles, const tsl::protobuf::Message* backend_config = nullptr, bool assert_determinism = false); - ::testing::AssertionResult RunAndCompareNoHloPasses( - absl::string_view hlo_string, const std::optional& error, - const std::function& reference_preprocessor = nullptr, - const std::function& test_preprocessor = nullptr); // Override this method to add a default preprocessing step that is applied to // the test module in all Run* methods. The intended usecase for this is to @@ -313,25 +264,8 @@ class HloRunnerAgnosticTestBase : public HloHardwareIndependentTestBase { } HloRunnerInterface& test_runner() const { return *test_runner_; } - HloRunnerInterface& reference_runner() const { return *reference_runner_; } private: - // Given the test module, makes a reference module that is ready to run on the - // reference platform. This assumes that the given module is ready to run on - // the test platform. - absl::StatusOr> MakeReferenceModule( - const HloModule& test_module, - const std::function& reference_preprocessor); - - // Runs the module on two platforms with or without running hlo passes and - // compares the results. Returns whether the results are near or equal. If any - // error happens before the results are computed, returns the error status. - absl::StatusOr<::testing::AssertionResult> RunAndCompareInternal( - std::unique_ptr module, absl::Span arguments, - const std::optional& error, bool run_hlo_passes, - const std::function& reference_preprocessor, - const std::function& test_preprocessor = nullptr); - // Runs the two module with or without running hlo passes and compares // the results. Returns whether the results are near or equal. If any // error happens before the results are computed, returns the error status. @@ -350,7 +284,6 @@ class HloRunnerAgnosticTestBase : public HloHardwareIndependentTestBase { const std::optional& error, bool run_hlo_passes); std::unique_ptr test_runner_; - std::unique_ptr reference_runner_; }; } // namespace xla diff --git a/xla/tests/hlo_test_base.cc b/xla/tests/hlo_test_base.cc index 896de2cb58aa1..527f6c5b1880d 100644 --- a/xla/tests/hlo_test_base.cc +++ b/xla/tests/hlo_test_base.cc @@ -37,14 +37,14 @@ limitations under the License. #include "xla/stream_executor/device_memory_allocator.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream_executor_memory_allocator.h" +#include "xla/tests/hlo_runner_agnostic_reference_mixin.h" #include "xla/tests/hlo_runner_agnostic_test_base.h" #include "xla/tests/pjrt_client_registry.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/platform/status.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/platform/test.h" #include "xla/util.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/status.h" -#include "tsl/platform/statusor.h" -#include "tsl/platform/test.h" namespace xla { namespace { @@ -94,10 +94,10 @@ HloTestBase::HloTestBase(se::Platform* test_platform, bool verifier_layout_sensitive, bool allow_mixed_precision_in_hlo_verifier, HloPredicate instruction_can_change_layout_func) - : HloRunnerAgnosticTestBase( + : HloRunnerAgnosticReferenceMixin( + /*reference_runner=*/GetHloRunnerForReference(reference_platform) + .value(), /*test_runner=*/GetHloRunnerForTest(test_platform).value(), - /*reference_runner=*/ - GetHloRunnerForReference(reference_platform).value(), verifier_layout_sensitive, allow_mixed_precision_in_hlo_verifier), test_platform_(test_platform) {} diff --git a/xla/tests/hlo_test_base.h b/xla/tests/hlo_test_base.h index 00e755a010e2c..16db75d2bb4e8 100644 --- a/xla/tests/hlo_test_base.h +++ b/xla/tests/hlo_test_base.h @@ -18,22 +18,16 @@ limitations under the License. #include #include -#include #include #include #include -#include #include -#include "absl/base/attributes.h" #include "absl/log/log.h" -#include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "xla/error_spec.h" #include "xla/hlo/ir/hlo_module.h" -#include "xla/hlo/pass/hlo_pass_interface.h" -#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/literal.h" #include "xla/service/backend.h" #include "xla/service/computation_placer.h" @@ -42,10 +36,11 @@ limitations under the License. #include "xla/service/hlo_runner_interface.h" #include "xla/stream_executor/device_memory_allocator.h" #include "xla/stream_executor/platform.h" +#include "xla/tests/hlo_runner_agnostic_reference_mixin.h" #include "xla/tests/hlo_runner_agnostic_test_base.h" +#include "xla/tsl/platform/test.h" #include "xla/util.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/test.h" namespace xla { @@ -83,7 +78,8 @@ namespace xla { // class. HloTestBase remains as a shim on tests during this migration process. // While we would prefer if you can avoid introducing new tests that use this // class, we are still working on documenting the exact migration procedure. -class HloTestBase : public HloRunnerAgnosticTestBase { +class HloTestBase + : public HloRunnerAgnosticReferenceMixin { public: // Compiles the given `hlo` with optimizations, and verifies that optimized // HLO matches the given FileCheck pattern.