From e4576065b418e3d63439214c4cb4021802ec9d32 Mon Sep 17 00:00:00 2001 From: Niklas Vangerow Date: Mon, 6 Jan 2025 14:21:05 -0800 Subject: [PATCH] Fix `HloRunnerAgnosticTestBase` includes. Many of the tests that extend `HloTestBase` rely on symbols included transitively. The main ones are: - `PlatformUtil` - `LiteralUtil` - `LiteralTestUtil` This patch adds includes for these explicitly. PiperOrigin-RevId: 712656979 --- xla/service/BUILD | 41 ++++++++++++++----- xla/service/cpu/BUILD | 2 + xla/service/cpu/conv_canonicalization_test.cc | 1 + .../cpu/cpu_instruction_fusion_test.cc | 1 + xla/service/hlo_creation_utils_test.cc | 12 +++++- xla/service/hlo_module_test.cc | 18 ++++++-- xla/service/hlo_schedule_test.cc | 7 ++-- xla/service/triangular_solve_expander_test.cc | 11 +++-- xla/tests/BUILD | 29 ++++++------- xla/tests/dot_operation_test.cc | 6 +-- xla/tests/hlo_runner_agnostic_test_base.cc | 15 +++---- xla/tests/hlo_runner_agnostic_test_base.h | 37 +++++------------ 12 files changed, 105 insertions(+), 75 deletions(-) diff --git a/xla/service/BUILD b/xla/service/BUILD index 8cd9cac1da809d..62197c2c8f6a90 100644 --- a/xla/service/BUILD +++ b/xla/service/BUILD @@ -1837,21 +1837,21 @@ xla_cc_test( name = "hlo_schedule_test", srcs = ["hlo_schedule_test.cc"], deps = [ + ":buffer_value", + "//xla:literal_util", "//xla:shape_util", "//xla:test_helpers", - "//xla:types", "//xla:xla_data_proto_cc", - "//xla/hlo/analysis:hlo_ordering", "//xla/hlo/ir:hlo", "//xla/hlo/transforms/simplifiers:hlo_dce", "//xla/hlo/transforms/simplifiers:hlo_memory_scheduler", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/log", "@com_google_googletest//:gtest", - "@tsl//tsl/platform:statusor", ], ) @@ -2024,14 +2024,22 @@ xla_cc_test( ":hlo_creation_utils", ":pattern_matcher", ":pattern_matcher_gmock", + "//xla:array2d", + "//xla:literal", + "//xla:literal_util", "//xla:shape_util", "//xla:test", "//xla:xla_data_proto_cc", "//xla/hlo/evaluator:hlo_evaluator", "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:verified_hlo_module", "//xla/tests:hlo_test_base", + "//xla/tests:literal_test_util", "//xla/tests:xla_internal_test_main", - "@tsl//tsl/platform:test", + "//xla/tsl/platform:statusor", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest", ], ) @@ -2230,13 +2238,16 @@ xla_cc_test( shard_count = 12, deps = [ ":triangular_solve_expander", + "//xla:array2d", + "//xla:error_spec", "//xla:literal", + "//xla:literal_util", "//xla:reference_util", - "//xla:test", - "//xla:types", "//xla/tests:hlo_test_base", + "//xla/tests:literal_test_util", "//xla/tests:xla_internal_test_main", - "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/platform:statusor", + "//xla/tsl/platform:test", ], ) @@ -3493,25 +3504,35 @@ xla_cc_test( name = "hlo_module_test", srcs = ["hlo_module_test.cc"], deps = [ + ":buffer_value", ":computation_placer_hdr", + ":hlo_module_config", ":test_compilation_environment_proto_cc", - "//xla:literal", + "//xla:comparison_util", + "//xla:debug_options_flags", + "//xla:literal_util", "//xla:shape_util", "//xla:test", "//xla:xla_data_proto_cc", "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:verified_hlo_module", "//xla/hlo/transforms/simplifiers:hlo_memory_scheduler", "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "//xla/tsl/lib/core:status_test_util", "//xla/tsl/lib/strings:proto_serialization", + "//xla/tsl/platform:statusor", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", - "@tsl//tsl/platform:errors", - "@tsl//tsl/platform:statusor", + "@tsl//tsl/platform:casts", + "@tsl//tsl/platform:protobuf", ], ) diff --git a/xla/service/cpu/BUILD b/xla/service/cpu/BUILD index 5dbad693e78471..430def53c196bb 100644 --- a/xla/service/cpu/BUILD +++ b/xla/service/cpu/BUILD @@ -1386,6 +1386,7 @@ xla_cc_test( tags = ["not_run:arm"], deps = [ ":cpu_instruction_fusion", + "//xla:literal_util", "//xla:shape_util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", @@ -1539,6 +1540,7 @@ xla_cc_test( deps = [ ":conv_canonicalization", ":target_machine_features_stub", + "//xla:literal_util", "//xla:test", "//xla:test_helpers", "//xla:util", diff --git a/xla/service/cpu/conv_canonicalization_test.cc b/xla/service/cpu/conv_canonicalization_test.cc index 00c9ee256452c9..6f6ebd96fb64c2 100644 --- a/xla/service/cpu/conv_canonicalization_test.cc +++ b/xla/service/cpu/conv_canonicalization_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/literal_util.h" #include "xla/service/cpu/target_machine_features_stub.h" #include "xla/test.h" #include "xla/test_helpers.h" diff --git a/xla/service/cpu/cpu_instruction_fusion_test.cc b/xla/service/cpu/cpu_instruction_fusion_test.cc index 6b4de145d8e809..787c4d138b3448 100644 --- a/xla/service/cpu/cpu_instruction_fusion_test.cc +++ b/xla/service/cpu/cpu_instruction_fusion_test.cc @@ -29,6 +29,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/utils/hlo_matchers.h" +#include "xla/literal_util.h" #include "xla/service/transpose_folding.h" #include "xla/shape.h" #include "xla/tests/hlo_test_base.h" diff --git a/xla/service/hlo_creation_utils_test.cc b/xla/service/hlo_creation_utils_test.cc index 252345fbbbc5ff..debabe09c3c51e 100644 --- a/xla/service/hlo_creation_utils_test.cc +++ b/xla/service/hlo_creation_utils_test.cc @@ -15,19 +15,29 @@ limitations under the License. #include "xla/service/hlo_creation_utils.h" +#include #include +#include +#include "absl/log/check.h" +#include "absl/types/span.h" +#include "xla/array2d.h" #include "xla/hlo/evaluator/hlo_evaluator.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/testlib/verified_hlo_module.h" +#include "xla/literal.h" +#include "xla/literal_util.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/test.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tests/literal_test_util.h" +#include "xla/tsl/platform/statusor.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/test.h" namespace xla { namespace { diff --git a/xla/service/hlo_module_test.cc b/xla/service/hlo_module_test.cc index 339feeb8fd2d4e..960f107c9117b9 100644 --- a/xla/service/hlo_module_test.cc +++ b/xla/service/hlo_module_test.cc @@ -24,25 +24,37 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "absl/types/span.h" +#include "xla/comparison_util.h" +#include "xla/debug_options_flags.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_original_value.h" +#include "xla/hlo/testlib/verified_hlo_module.h" #include "xla/hlo/transforms/simplifiers/hlo_memory_scheduler.h" #include "xla/hlo/utils/hlo_matchers.h" -#include "xla/literal.h" +#include "xla/literal_util.h" +#include "xla/service/buffer_value.h" #include "xla/service/computation_placer.h" +#include "xla/service/hlo_module_config.h" #include "xla/service/test_compilation_environment.pb.h" +#include "xla/shape.h" #include "xla/shape_util.h" #include "xla/test.h" #include "xla/tests/hlo_test_base.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/tsl/lib/strings/proto_serialization.h" +#include "xla/tsl/platform/statusor.h" #include "xla/xla.pb.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/statusor.h" +#include "tsl/platform/casts.h" +#include "tsl/platform/protobuf.h" namespace xla { diff --git a/xla/service/hlo_schedule_test.cc b/xla/service/hlo_schedule_test.cc index d18c8527893c81..fd89bcc5b23fc5 100644 --- a/xla/service/hlo_schedule_test.cc +++ b/xla/service/hlo_schedule_test.cc @@ -22,19 +22,20 @@ limitations under the License. #include #include "absl/algorithm/container.h" #include "absl/log/log.h" -#include "xla/hlo/analysis/hlo_ordering.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/transforms/simplifiers/hlo_dce.h" #include "xla/hlo/transforms/simplifiers/hlo_memory_scheduler.h" +#include "xla/literal_util.h" +#include "xla/service/buffer_value.h" +#include "xla/shape.h" #include "xla/shape_util.h" #include "xla/test_helpers.h" #include "xla/tests/hlo_test_base.h" #include "xla/tsl/lib/core/status_test_util.h" -#include "xla/types.h" +#include "xla/tsl/platform/statusor.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/statusor.h" namespace xla { namespace { diff --git a/xla/service/triangular_solve_expander_test.cc b/xla/service/triangular_solve_expander_test.cc index fa382b24d0d9db..1a2ba8c71ece6e 100644 --- a/xla/service/triangular_solve_expander_test.cc +++ b/xla/service/triangular_solve_expander_test.cc @@ -15,15 +15,20 @@ limitations under the License. #include "xla/service/triangular_solve_expander.h" +#include #include +#include #include +#include "xla/array2d.h" +#include "xla/error_spec.h" #include "xla/literal.h" +#include "xla/literal_util.h" #include "xla/reference_util.h" -#include "xla/test.h" #include "xla/tests/hlo_test_base.h" -#include "xla/tsl/lib/core/status_test_util.h" -#include "xla/types.h" +#include "xla/tests/literal_test_util.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/platform/test.h" namespace xla { namespace { diff --git a/xla/tests/BUILD b/xla/tests/BUILD index e8ab69dffb4dce..a4e570f8939d04 100644 --- a/xla/tests/BUILD +++ b/xla/tests/BUILD @@ -214,35 +214,27 @@ cc_library( deps = [ ":literal_test_util", ":test_utils", - "//xla:debug_options_flags", "//xla:error_spec", "//xla:literal", - "//xla:literal_util", - "//xla:shape_layout", "//xla:shape_util", "//xla:test_helpers", "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", - "//xla/hlo/ir:hlo_module_group", - "//xla/hlo/pass:hlo_pass", "//xla/hlo/testlib:hlo_hardware_independent_test_base", "//xla/hlo/testlib:verified_hlo_module", - "//xla/hlo/utils:hlo_query", - "//xla/service:backend", - "//xla/service:computation_layout", "//xla/service:computation_placer_hdr", "//xla/service:executable", "//xla/service:hlo_module_config", "//xla/service:hlo_module_util", - "//xla/service:hlo_runner", "//xla/service:hlo_runner_interface", "//xla/service:hlo_verifier", "//xla/service:interpreter_plugin", # reference backend - "//xla/service:platform_util", - "//xla/stream_executor:device_memory_allocator", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:logging", + "//xla/tsl/platform:statusor", + "//xla/tsl/platform:test", "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", @@ -252,10 +244,7 @@ cc_library( "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", - "@tsl//tsl/platform:errors", - "@tsl//tsl/platform:logging", - "@tsl//tsl/platform:statusor", - "@tsl//tsl/platform:test", + "@tsl//tsl/platform:protobuf", ], ) @@ -979,6 +968,8 @@ xla_test( "//xla/stream_executor:device_description", "//xla/stream_executor:platform", "//xla/stream_executor:stream_executor_memory_allocator", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_benchmark", "@com_google_absl//absl/log", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", @@ -1070,7 +1061,10 @@ xla_test( "//xla/hlo/builder/lib:arithmetic", "//xla/hlo/builder/lib:matrix", "//xla/hlo/parser:hlo_parser", + "//xla/service:platform_util", "//xla/stream_executor:stream_executor_memory_allocator", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_benchmark", "@com_google_absl//absl/strings", "@tsl//tsl/platform:ml_dtypes", "@tsl//tsl/platform:test", @@ -1158,7 +1152,10 @@ xla_test( "//xla/hlo/builder/lib:arithmetic", "//xla/hlo/builder/lib:matrix", "//xla/hlo/parser:hlo_parser", + "//xla/service:platform_util", "//xla/stream_executor:stream_executor_memory_allocator", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_benchmark", "@com_google_absl//absl/strings", "@tsl//tsl/platform:ml_dtypes", "@tsl//tsl/platform:test", diff --git a/xla/tests/dot_operation_test.cc b/xla/tests/dot_operation_test.cc index 674ada04d96c30..2acc860804d0d6 100644 --- a/xla/tests/dot_operation_test.cc +++ b/xla/tests/dot_operation_test.cc @@ -22,21 +22,21 @@ limitations under the License. #include "xla/array3d.h" #include "xla/client/local_client.h" #include "xla/error_spec.h" -#include "xla/hlo/builder/lib/arithmetic.h" #include "xla/hlo/builder/lib/matrix.h" #include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/parser/hlo_parser.h" #include "xla/literal_util.h" #include "xla/primitive_util.h" #include "xla/reference_util.h" +#include "xla/service/platform_util.h" #include "xla/shape_util.h" #include "xla/stream_executor/stream_executor_memory_allocator.h" #include "xla/tests/client_library_test_base.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/test_macros.h" +#include "xla/tsl/platform/test.h" +#include "xla/tsl/platform/test_benchmark.h" #include "tsl/platform/ml_dtypes.h" -#include "tsl/platform/test.h" -#include "tsl/platform/test_benchmark.h" #if TENSORFLOW_USE_ROCM #include "rocm/rocm_config.h" diff --git a/xla/tests/hlo_runner_agnostic_test_base.cc b/xla/tests/hlo_runner_agnostic_test_base.cc index 402159a1858530..b781a0eebd37d0 100644 --- a/xla/tests/hlo_runner_agnostic_test_base.cc +++ b/xla/tests/hlo_runner_agnostic_test_base.cc @@ -30,19 +30,13 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_join.h" -#include "absl/strings/str_replace.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "llvm/ADT/STLExtras.h" -#include "xla/debug_options_flags.h" #include "xla/error_spec.h" #include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_module_group.h" -#include "xla/hlo/ir/hlo_opcode.h" -#include "xla/hlo/pass/hlo_pass_interface.h" #include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/hlo/testlib/verified_hlo_module.h" -#include "xla/hlo/utils/hlo_query.h" #include "xla/literal.h" #include "xla/service/computation_placer.h" #include "xla/service/executable.h" @@ -53,11 +47,12 @@ limitations under the License. #include "xla/shape.h" #include "xla/tests/literal_test_util.h" #include "xla/tests/test_utils.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/platform/test.h" #include "xla/util.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/statusor.h" -#include "tsl/platform/test.h" +#include "tsl/platform/protobuf.h" namespace xla { diff --git a/xla/tests/hlo_runner_agnostic_test_base.h b/xla/tests/hlo_runner_agnostic_test_base.h index e43ddec3e28926..9b8ae26f615f45 100644 --- a/xla/tests/hlo_runner_agnostic_test_base.h +++ b/xla/tests/hlo_runner_agnostic_test_base.h @@ -24,7 +24,6 @@ limitations under the License. #include #include -#include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/log/log.h" #include "absl/status/status.h" @@ -35,31 +34,17 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" -#include "xla/hlo/ir/hlo_module_group.h" -#include "xla/hlo/ir/hlo_opcode.h" -#include "xla/hlo/pass/hlo_pass_interface.h" #include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/hlo/testlib/verified_hlo_module.h" -#include "xla/layout.h" #include "xla/literal.h" -#include "xla/literal_util.h" -#include "xla/service/backend.h" -#include "xla/service/computation_layout.h" #include "xla/service/computation_placer.h" #include "xla/service/executable.h" #include "xla/service/hlo_module_config.h" -#include "xla/service/hlo_runner.h" #include "xla/service/hlo_runner_interface.h" -#include "xla/service/hlo_verifier.h" -#include "xla/service/platform_util.h" -#include "xla/shape_layout.h" -#include "xla/shape_util.h" -#include "xla/stream_executor/device_memory_allocator.h" #include "xla/test_helpers.h" -#include "xla/tests/literal_test_util.h" +#include "xla/tsl/platform/test.h" #include "xla/util.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/test.h" namespace xla { @@ -189,7 +174,7 @@ class HloRunnerAgnosticTestBase : public HloHardwareIndependentTestBase { // backend, but it might need to be tailored so that it is able to run on the // reference backend. Note that the program shape of the module must not be // modified. - [[nodiscard]] ::testing::AssertionResult RunAndCompare( + ::testing::AssertionResult RunAndCompare( std::unique_ptr module, absl::Span arguments, const std::optional& error, const std::function& reference_preprocessor = nullptr, @@ -197,14 +182,14 @@ class HloRunnerAgnosticTestBase : public HloHardwareIndependentTestBase { // Same as above, except that the module will be executed without Hlo // optimization. - [[nodiscard]] ::testing::AssertionResult RunAndCompareNoHloPasses( + ::testing::AssertionResult RunAndCompareNoHloPasses( std::unique_ptr module, absl::Span arguments, const std::optional& error, const std::function& reference_preprocessor = nullptr, const std::function& test_preprocessor = nullptr); // Executes an hlo module with fake inputs and compares the results. - [[nodiscard]] ::testing::AssertionResult RunAndCompare( + ::testing::AssertionResult RunAndCompare( std::unique_ptr module, const std::optional& error, const std::function& reference_preprocessor = nullptr, const std::function& test_preprocessor = nullptr, @@ -212,26 +197,26 @@ class HloRunnerAgnosticTestBase : public HloHardwareIndependentTestBase { // Same as above, except that the module will be executed without Hlo // optimization. - [[nodiscard]] ::testing::AssertionResult RunAndCompareNoHloPasses( + ::testing::AssertionResult RunAndCompareNoHloPasses( std::unique_ptr module, const std::optional& error, const std::function& reference_preprocessor = nullptr, const std::function& test_preprocessor = nullptr); // Executes an hlo module with fake inputs and checks that the execution is // successful. - [[nodiscard]] ::testing::AssertionResult Run( + ::testing::AssertionResult Run( std::unique_ptr module, bool run_hlo_passes, const std::function& test_preprocessor = nullptr); // Convenient wrappers for executing and comparing an hlo module with fake // input. Module can be passed in directly, or parsed from an hlo_string, // or loaded from a file. - [[nodiscard]] ::testing::AssertionResult RunAndCompare( + ::testing::AssertionResult RunAndCompare( absl::string_view hlo_string, const std::optional& error, const std::function& reference_preprocessor = nullptr, const std::function& test_preprocessor = nullptr, std::optional args_max_bits_of_precision = std::nullopt); - [[nodiscard]] ::testing::AssertionResult Run( + ::testing::AssertionResult Run( absl::string_view hlo_string, bool run_hlo_passes = true, ExecutionProfile* profile = nullptr, const tsl::protobuf::Message* backend_config = nullptr, @@ -299,19 +284,19 @@ class HloRunnerAgnosticTestBase : public HloHardwareIndependentTestBase { const std::optional& error, bool run_hlo_passes = true); // Executes an hlo module with fake inputs on multiple replicas. - [[nodiscard]] ::testing::AssertionResult RunReplicated( + ::testing::AssertionResult RunReplicated( absl::string_view hlo_string, bool run_hlo_passes = true, int64_t num_replicas = 1, const tsl::protobuf::Message* backend_config = nullptr); // If assert_determinism is true, the assertion will fail unless all runs // produce exactly the same output. - [[nodiscard]] ::testing::AssertionResult RunMultipleTimes( + ::testing::AssertionResult RunMultipleTimes( absl::string_view hlo_string, bool run_hlo_passes, std::vector* profiles, const tsl::protobuf::Message* backend_config = nullptr, bool assert_determinism = false); - [[nodiscard]] ::testing::AssertionResult RunAndCompareNoHloPasses( + ::testing::AssertionResult RunAndCompareNoHloPasses( absl::string_view hlo_string, const std::optional& error, const std::function& reference_preprocessor = nullptr, const std::function& test_preprocessor = nullptr);