Skip to content

Commit

Permalink
Remove RunAndCompare functionality from HloRunnerAgnosticTestBase.
Browse files Browse the repository at this point in the history
This functionality is now fully contained in `HloRunnerAgnosticReferenceMixin`
and therefore is no longer needed in `HloRunnerAgnosticTestBase`.

This change temporarily adds the mixin to `HloPjRtTestBase`. Next, we'll go
through all tests that extend these base classes and will move the uses of the
mixins to the leaves.

PiperOrigin-RevId: 714996600
  • Loading branch information
nvgrw authored and Google-ML-Automation committed Jan 13, 2025
1 parent 135717d commit e27a06f
Show file tree
Hide file tree
Showing 13 changed files with 40 additions and 308 deletions.
4 changes: 2 additions & 2 deletions xla/service/cost_modelling/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,10 @@ xla_cc_test(
"//xla/service:platform_util",
"//xla/tests:hlo_runner_agnostic_test_base",
"//xla/tsl/lib/core:status_test_util",
"//xla/tsl/platform:statusor",
"//xla/tsl/platform:test",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings:string_view",
"@com_google_googletest//:gtest_main",
"@tsl//tsl/platform:statusor",
"@tsl//tsl/platform:test",
],
)
6 changes: 2 additions & 4 deletions xla/service/cost_modelling/op_cost_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ limitations under the License.
#include "xla/shape_util.h"
#include "xla/tests/hlo_runner_agnostic_test_base.h"
#include "xla/tsl/lib/core/status_test_util.h"
#include "tsl/platform/statusor.h"
#include "tsl/platform/test.h"
#include "xla/tsl/platform/statusor.h"
#include "xla/tsl/platform/test.h"

namespace xla {
namespace {
Expand Down Expand Up @@ -86,8 +86,6 @@ class OpCostTest : public HloRunnerAgnosticTestBase {
protected:
OpCostTest()
: HloRunnerAgnosticTestBase(
std::make_unique<HloRunner>(
PlatformUtil::GetDefaultPlatform().value()),
std::make_unique<HloRunner>(
PlatformUtil::GetDefaultPlatform().value())) {}

Expand Down
3 changes: 2 additions & 1 deletion xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2594,9 +2594,10 @@ xla_cc_test(
"//xla/stream_executor:device_description",
"//xla/tests:hlo_runner_agnostic_test_base",
"//xla/tests:xla_internal_test_main",
"//xla/tsl/platform:statusor",
"//xla/tsl/platform:test",
"@com_google_absl//absl/strings",
"@com_google_googletest//:gtest_main",
"@tsl//tsl/platform:statusor",
],
)

Expand Down
7 changes: 3 additions & 4 deletions xla/service/gpu/gpu_fusible_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ limitations under the License.
#include "xla/service/gpu/gpu_fusible.h"

#include <memory>
#include <vector>

#include <gtest/gtest.h>
#include "absl/strings/str_cat.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_opcode.h"
Expand All @@ -27,7 +27,8 @@ limitations under the License.
#include "xla/service/platform_util.h"
#include "xla/stream_executor/device_description.h"
#include "xla/tests/hlo_runner_agnostic_test_base.h"
#include "tsl/platform/statusor.h"
#include "xla/tsl/platform/statusor.h"
#include "xla/tsl/platform/test.h"

namespace xla {
namespace gpu {
Expand All @@ -46,8 +47,6 @@ class GpuFusibleTest : public HloRunnerAgnosticTestBase {
public:
GpuFusibleTest()
: HloRunnerAgnosticTestBase(
std::make_unique<HloRunner>(
PlatformUtil::GetDefaultPlatform().value()),
std::make_unique<HloRunner>(
PlatformUtil::GetDefaultPlatform().value())),
device_description_(MakeDeviceDescription()) {}
Expand Down
5 changes: 2 additions & 3 deletions xla/service/gpu/transforms/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -3483,10 +3483,9 @@ xla_cc_test(
"//xla/tests:hlo_runner_agnostic_test_base",
"//xla/tests:test_utils",
"//xla/tsl/lib/core:status_test_util",
"//xla/tsl/platform:statusor",
"//xla/tsl/platform:test",
"@com_google_absl//absl/log",
"@com_google_googletest//:gtest",
"@com_google_googletest//:gtest_main",
"@tsl//tsl/platform:statusor",
"@tsl//tsl/platform:test",
],
)
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@ limitations under the License.

#include <memory>

#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "absl/log/log.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/testlib/filecheck.h"
Expand All @@ -28,8 +26,8 @@ limitations under the License.
#include "xla/tests/hlo_runner_agnostic_test_base.h"
#include "xla/tests/test_utils.h"
#include "xla/tsl/lib/core/status_test_util.h"
#include "tsl/platform/statusor.h"
#include "tsl/platform/test.h"
#include "xla/tsl/platform/statusor.h"
#include "xla/tsl/platform/test.h"

namespace xla {
namespace gpu {
Expand All @@ -39,8 +37,6 @@ class RaggedAllToAllDecomposerTest : public HloRunnerAgnosticTestBase {
public:
RaggedAllToAllDecomposerTest()
: HloRunnerAgnosticTestBase(
std::make_unique<HloRunner>(
PlatformUtil::GetDefaultPlatform().value()),
std::make_unique<HloRunner>(
PlatformUtil::GetDefaultPlatform().value())) {}
};
Expand Down
16 changes: 5 additions & 11 deletions xla/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -170,16 +170,15 @@ cc_library(
srcs = ["hlo_test_base.cc"],
hdrs = ["hlo_test_base.h"],
deps = [
":hlo_runner_agnostic_reference_mixin",
":hlo_runner_agnostic_test_base",
":pjrt_client_registry",
"//xla:error_spec",
"//xla:literal",
"//xla:util",
"//xla:xla_data_proto_cc",
"//xla/hlo/ir:hlo",
"//xla/hlo/pass:hlo_pass",
"//xla/hlo/testlib:filecheck",
"//xla/hlo/testlib:hlo_hardware_independent_test_base",
"//xla/pjrt:pjrt_client",
"//xla/service:backend",
"//xla/service:computation_placer_hdr",
Expand All @@ -193,16 +192,14 @@ cc_library(
"//xla/stream_executor:platform",
"//xla/stream_executor:stream_executor_memory_allocator",
"//xla/tsl/lib/core:status_test_util",
"@com_google_absl//absl/base:core_headers",
"//xla/tsl/platform:status",
"//xla/tsl/platform:statusor",
"//xla/tsl/platform:test",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:string_view",
"@tsl//tsl/platform:logging",
"@tsl//tsl/platform:status",
"@tsl//tsl/platform:statusor",
"@tsl//tsl/platform:test",
],
)

Expand All @@ -216,7 +213,6 @@ cc_library(
":test_utils",
"//xla:error_spec",
"//xla:literal",
"//xla:shape_util",
"//xla:util",
"//xla:xla_data_proto_cc",
"//xla/hlo/ir:hlo",
Expand Down Expand Up @@ -290,18 +286,16 @@ cc_library(
srcs = ["hlo_pjrt_test_base.cc"],
hdrs = ["hlo_pjrt_test_base.h"],
deps = [
":hlo_pjrt_interpreter_reference_mixin",
":hlo_runner_agnostic_test_base",
":pjrt_client_registry",
"//xla:util",
"//xla:xla_data_proto_cc",
"//xla/pjrt:pjrt_client",
"//xla/pjrt/interpreter:interpreter_client",
"//xla/service:hlo_runner_interface",
"//xla/service:hlo_runner_pjrt",
"//xla/service:interpreter_plugin", # reference backend
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status:statusor",
"@tsl//tsl/platform:logging",
],
)

Expand Down
21 changes: 5 additions & 16 deletions xla/tests/hlo_pjrt_test_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,12 @@ limitations under the License.

#include "absl/log/check.h"
#include "absl/status/statusor.h"
#include "xla/pjrt/interpreter/interpreter_client.h"
#include "xla/pjrt/pjrt_client.h"
#include "xla/service/hlo_runner_interface.h"
#include "xla/service/hlo_runner_pjrt.h"
#include "xla/tests/hlo_pjrt_interpreter_reference_mixin.h"
#include "xla/tests/hlo_runner_agnostic_test_base.h"
#include "xla/tests/pjrt_client_registry.h"
#include "xla/util.h"
#include "tsl/platform/logging.h"

namespace xla {
namespace {
Expand All @@ -52,21 +50,12 @@ std::unique_ptr<HloRunnerInterface> GetHloRunnerForTest() {
*std::move(client), device_shape_representation_fn, device_shape_size_fn);
}

std::unique_ptr<HloRunnerInterface> GetHloRunnerForReference() {
return std::make_unique<HloRunnerPjRt>(
std::make_unique<InterpreterClient>(),
InterpreterClient::DeviceShapeRepresentation,
InterpreterClient::ShapeSizeBytes,
/*use_parameter_layout_on_device=*/true);
}

} // namespace

HloPjRtTestBase::HloPjRtTestBase(HloPjRtTestBaseOptions options)
: HloRunnerAgnosticTestBase(GetHloRunnerForTest(),
GetHloRunnerForReference(),
options.verifier_layout_sensitive,
options.allow_mixed_precision_in_hlo_verifier,
options.instruction_can_change_layout_func) {}
: HloPjRtInterpreterReferenceMixin<HloRunnerAgnosticTestBase>(
GetHloRunnerForTest(), options.verifier_layout_sensitive,
options.allow_mixed_precision_in_hlo_verifier,
options.instruction_can_change_layout_func) {}

} // namespace xla
6 changes: 4 additions & 2 deletions xla/tests/hlo_pjrt_test_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License.
#ifndef XLA_TESTS_HLO_PJRT_TEST_BASE_H_
#define XLA_TESTS_HLO_PJRT_TEST_BASE_H_

#include "xla/tests/hlo_pjrt_interpreter_reference_mixin.h"
#include "xla/tests/hlo_runner_agnostic_test_base.h"
#include "xla/util.h"
#include "xla/xla_data.pb.h"
Expand All @@ -28,9 +29,10 @@ struct HloPjRtTestBaseOptions {
HloPredicate instruction_can_change_layout_func;
};

class HloPjRtTestBase : public HloRunnerAgnosticTestBase {
class HloPjRtTestBase
: public HloPjRtInterpreterReferenceMixin<HloRunnerAgnosticTestBase> {
protected:
// This uses the SE interpreter backend for the reference backend and
// This uses the PjRt interpreter backend for the reference backend and
// automatically finds a PjRt backend for the test backend.
explicit HloPjRtTestBase(HloPjRtTestBaseOptions options = {});
};
Expand Down
Loading

0 comments on commit e27a06f

Please sign in to comment.