Skip to content

Commit

Permalink
#sdy support JAX callbacks through the Shardy XLA round-trip pipeline.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 705875941
  • Loading branch information
bartchr808 authored and Google-ML-Automation committed Jan 8, 2025
1 parent d83a315 commit e076af4
Show file tree
Hide file tree
Showing 17 changed files with 485 additions and 9 deletions.
3 changes: 3 additions & 0 deletions xla/service/spmd/shardy/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ cc_library(
"@llvm-project//mlir:Support",
"@shardy//shardy/dialect/sdy/ir:dialect",
"@shardy//shardy/dialect/sdy/ir:register",
"@stablehlo//:stablehlo_ops",
],
)

Expand Down Expand Up @@ -119,6 +120,7 @@ xla_cc_binary(
deps = [
"//xla/mlir_hlo",
"//xla/mlir_hlo:mhlo_passes",
"//xla/service/spmd/shardy/mhlo_round_trip:export_callback_custom_calls",
"//xla/service/spmd/shardy/mhlo_round_trip:export_ops",
"//xla/service/spmd/shardy/mhlo_round_trip:export_shardings",
"//xla/service/spmd/shardy/mhlo_round_trip:mhlo_export",
Expand All @@ -132,6 +134,7 @@ xla_cc_binary(
"//xla/service/spmd/shardy/round_trip_common:open_while_free_vars_sharding",
"//xla/service/spmd/shardy/sdy_round_trip:export_ops",
"//xla/service/spmd/shardy/sdy_round_trip:export_shardy_attrs",
"//xla/service/spmd/shardy/sdy_round_trip:import_callback_custom_calls",
"//xla/service/spmd/shardy/sdy_round_trip:import_shardy_attrs",
"//xla/service/spmd/shardy/sdy_round_trip:pipelines",
"//xla/service/spmd/shardy/sdy_round_trip:remove_size_one_axes",
Expand Down
8 changes: 8 additions & 0 deletions xla/service/spmd/shardy/constants.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,14 @@ inline constexpr llvm::StringRef kSPMDFullToShardShapeCallTargetName =
inline constexpr llvm::StringRef kSPMDShardToFullShapeCallTargetName =
"SPMDShardToFullShape";

// The target name of the Python CPU callback custom call.
inline constexpr llvm::StringRef kPythonCpuCallbackCustomCallTargetName =
"xla_python_cpu_callback";

// The target name of the Python GPU callback custom call.
inline constexpr llvm::StringRef kPythonGpuCallbackCustomCallTargetName =
"xla_python_gpu_callback";

// The attribute name for backend config.
inline constexpr llvm::StringRef kXlaBackendConfigAttr = "backend_config";

Expand Down
17 changes: 17 additions & 0 deletions xla/service/spmd/shardy/mhlo_round_trip/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,28 @@ cc_library(
],
)

cc_library(
name = "export_callback_custom_calls",
srcs = ["export_callback_custom_calls.cc"],
hdrs = ["export_callback_custom_calls.h"],
deps = [
"//xla/service/spmd/shardy:constants",
"//xla/service/spmd/shardy:utils",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
"@shardy//shardy/dialect/sdy/ir:dialect",
"@stablehlo//:stablehlo_ops",
],
)

cc_library(
name = "mhlo_export",
srcs = ["mhlo_export.cc"],
hdrs = ["mhlo_export.h"],
deps = [
":export_callback_custom_calls",
":export_ops",
":export_shardings",
":shard_map_export",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
/* Copyright 2024 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "xla/service/spmd/shardy/mhlo_round_trip/export_callback_custom_calls.h"

#include <memory>

#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Casting.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectRegistry.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Support/TypeID.h"
#include "shardy/dialect/sdy/ir/dialect.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "xla/service/spmd/shardy/constants.h"
#include "xla/service/spmd/shardy/utils.h"

namespace xla {
namespace sdy {

namespace {

using ::mlir::ModuleOp;
using ::mlir::OperationPass;
using ::mlir::PassWrapper;
using ::mlir::StringRef;

using ::mlir::stablehlo::CustomCallOp;

// Attempts to replace the `CustomCallOp` with a tuple version of it, and a
// `GetTupleElementOp` that gets the first element of the tuple.
//
// This only happens if the op has a single result and the result type is not
// a tuple.
void replaceCallbackWithTupleVersion(CustomCallOp customCall,
mlir::IRRewriter& rewriter) {
if (customCall.getNumResults() != 1 ||
isa<mlir::TupleType>(customCall->getResultTypes().front())) {
return;
}
CustomCallOp tupleCustomCall = cloneCustomCallWithNewResultTypes(
customCall,
mlir::TupleType::get(customCall->getContext(),
{customCall->getResultTypes()}),
rewriter);
auto getTupleElement = rewriter.create<mlir::stablehlo::GetTupleElementOp>(
customCall.getLoc(), customCall->getResultTypes().front(),
tupleCustomCall.getResult(0), rewriter.getI32IntegerAttr(0));
getTupleElement->setAttr(kXlaShardingAttr,
customCall->getAttr(kXlaShardingAttr));
rewriter.replaceOp(customCall, getTupleElement);
}

class MhloRoundTripExportCallbackCustomCallsPass
: public PassWrapper<MhloRoundTripExportCallbackCustomCallsPass,
OperationPass<ModuleOp>> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
MhloRoundTripExportCallbackCustomCallsPass)

void runOnOperation() final {
getOperation().walk([&](CustomCallOp customCall) {
if (!isPythonCallbackCustomCall(customCall)) {
return;
}
mlir::IRRewriter rewriter(customCall);
if (!customCall->use_empty()) {
replaceCallbackWithTupleVersion(customCall, rewriter);
return;
}
CustomCallOp newCustomCall = cloneCustomCallWithNewResultTypes(
customCall, mlir::TypeRange(), rewriter);
newCustomCall.setResultLayoutsAttr(rewriter.getArrayAttr({}));
rewriter.eraseOp(customCall);
return;
});
}

StringRef getArgument() const override {
return "xla-sdy-mhlo-round-trip-export-callback-custom-calls";
}

StringRef getDescription() const override {
return "Converts the `CustomCallOp`s for host callbacks in XLA into the "
"pattern that the XLA compiler recognizes.";
}

void getDependentDialects(mlir::DialectRegistry& registry) const final {
registry.insert<mlir::sdy::SdyDialect>();
}
};

} // namespace

std::unique_ptr<mlir::Pass> createMhloRoundTripExportCallbackCustomCallsPass() {
return std::make_unique<MhloRoundTripExportCallbackCustomCallsPass>();
}

void registerMhloRoundTripExportCallbackCustomCallsPass() {
mlir::registerPass(createMhloRoundTripExportCallbackCustomCallsPass);
}

} // namespace sdy
} // namespace xla
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/* Copyright 2024 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef XLA_SERVICE_SPMD_SHARDY_MHLO_ROUND_TRIP_EXPORT_CALLBACK_CUSTOM_CALLS_H_
#define XLA_SERVICE_SPMD_SHARDY_MHLO_ROUND_TRIP_EXPORT_CALLBACK_CUSTOM_CALLS_H_

#include <memory>

#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassRegistry.h"

namespace xla {
namespace sdy {

// Creates a pass that converts the `CustomCallOp`s for host callbacks in XLA
// into the pattern that the XLA compiler recognizes.
//
// The rest of the XLA pipeline expects host callback custom calls to either be
// a tuple with a get_tuple_element or no results (which we changed due to
// shardy shardings expecting at least one result, and needing to attach a
// maximal sharding to the callbacks).
std::unique_ptr<mlir::Pass> createMhloRoundTripExportCallbackCustomCallsPass();

// Registers the xla-sdy-mhlo-round-trip-export-callback-custom-calls pass.
void registerMhloRoundTripExportCallbackCustomCallsPass();

} // namespace sdy
} // namespace xla

#endif // XLA_SERVICE_SPMD_SHARDY_MHLO_ROUND_TRIP_EXPORT_CALLBACK_CUSTOM_CALLS_H_
11 changes: 7 additions & 4 deletions xla/service/spmd/shardy/mhlo_round_trip/export_shardings.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ using ::mlir::success;
using ::mlir::SymbolTable;
using ::mlir::func::FuncOp;

using ::mlir::stablehlo::CustomCallOp;

using ::mlir::sdy::AxisRefAttr;
using ::mlir::sdy::DimensionShardingAttr;
using ::mlir::sdy::kShardingAttr;
Expand Down Expand Up @@ -197,6 +199,7 @@ class ExportMhloShardingsPass

void runOnOperation() final {
ModuleOp moduleOp = getOperation();

mlir::SymbolTableCollection symbolTableCollection;
SymbolTable& symbolTable = symbolTableCollection.getSymbolTable(moduleOp);

Expand All @@ -208,10 +211,10 @@ class ExportMhloShardingsPass
}
}

// StableHLO doesn't have an equivalent of `erf` and `topk` ops.
// If they have a sharding annotation, we need to move it into
// `mhlo.attributes`, which StableHLO->MHLO conversion would lift back up.
moduleOp.walk([&](mlir::stablehlo::CustomCallOp customCall) {
moduleOp.walk([&](CustomCallOp customCall) {
// StableHLO doesn't have an equivalent of `erf` and `topk` ops.
// If they have a sharding annotation, we need to move it into
// `mhlo.attributes`, which StableHLO->MHLO conversion would lift back up.
StringRef callTargetName = customCall.getCallTargetName();
if (callTargetName != "mhlo.erf" && callTargetName != "mhlo.topk") {
return;
Expand Down
2 changes: 2 additions & 0 deletions xla/service/spmd/shardy/mhlo_round_trip/mhlo_export.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ limitations under the License.
#include "mlir/Pass/PassManager.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Support/LLVM.h"
#include "xla/service/spmd/shardy/mhlo_round_trip/export_callback_custom_calls.h"
#include "xla/service/spmd/shardy/mhlo_round_trip/export_ops.h"
#include "xla/service/spmd/shardy/mhlo_round_trip/export_shardings.h"
#include "xla/service/spmd/shardy/mhlo_round_trip/shard_map_export.h"
Expand All @@ -36,6 +37,7 @@ void addMhloExportPipeline(mlir::OpPassManager& pm) {
pm.addPass(createMhloRoundTripShardMapExportPass());
pm.addPass(createExportNamedComputationsPass());
pm.addPass(createExportMhloShardingsPass());
pm.addPass(createMhloRoundTripExportCallbackCustomCallsPass());
}

void registerMhloExportPipeline() {
Expand Down
4 changes: 4 additions & 0 deletions xla/service/spmd/shardy/sdy_opt_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ limitations under the License.
#include "stablehlo/dialect/StablehloOps.h"
#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
#include "xla/mlir_hlo/mhlo/transforms/passes.h"
#include "xla/service/spmd/shardy/mhlo_round_trip/export_callback_custom_calls.h"
#include "xla/service/spmd/shardy/mhlo_round_trip/export_ops.h"
#include "xla/service/spmd/shardy/mhlo_round_trip/export_shardings.h"
#include "xla/service/spmd/shardy/mhlo_round_trip/mhlo_export.h"
Expand All @@ -36,6 +37,7 @@ limitations under the License.
#include "xla/service/spmd/shardy/round_trip_common/open_while_free_vars_sharding.h"
#include "xla/service/spmd/shardy/sdy_round_trip/export_ops.h"
#include "xla/service/spmd/shardy/sdy_round_trip/export_shardy_attrs.h"
#include "xla/service/spmd/shardy/sdy_round_trip/import_callback_custom_calls.h"
#include "xla/service/spmd/shardy/sdy_round_trip/import_shardy_attrs.h"
#include "xla/service/spmd/shardy/sdy_round_trip/pipelines.h"
#include "xla/service/spmd/shardy/sdy_round_trip/remove_size_one_axes.h"
Expand Down Expand Up @@ -66,12 +68,14 @@ int main(int argc, char** argv) {

xla::sdy::registerMhloExportPipeline();
xla::sdy::registerMhloExportShardingsPass();
xla::sdy::registerMhloRoundTripExportCallbackCustomCallsPass();
xla::sdy::registerMhloRoundTripShardMapExportPass();
xla::sdy::registerExportNamedComputationsPass();
xla::sdy::registerExportOpsPass();

xla::sdy::registerSdyRoundTripMhloToHloToMhloPass();
xla::sdy::registerSdyRoundTripExportShardyAttrsPass();
xla::sdy::registerSdyRoundTripImportCallbackCustomCallsPass();
xla::sdy::registerSdyRoundTripImportShardyAttrsPass();
xla::sdy::registerSdyRoundTripRemoveSizeOneAxesPass();
xla::sdy::registerSdyRoundTripExportOpsPass();
Expand Down
18 changes: 17 additions & 1 deletion xla/service/spmd/shardy/sdy_round_trip/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -126,13 +126,30 @@ cc_library(
],
)

cc_library(
name = "import_callback_custom_calls",
srcs = ["import_callback_custom_calls.cc"],
hdrs = ["import_callback_custom_calls.h"],
deps = [
"//xla/service/spmd/shardy:utils",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TransformUtils",
"@shardy//shardy/dialect/sdy/ir:dialect",
"@stablehlo//:stablehlo_ops",
],
)

cc_library(
name = "pipelines",
srcs = ["pipelines.cc"],
hdrs = ["pipelines.h"],
deps = [
":export_ops",
":export_shardy_attrs",
":import_callback_custom_calls",
":import_shardy_attrs",
":remove_size_one_axes",
":shard_map_export",
Expand All @@ -143,6 +160,5 @@ cc_library(
"//xla/service/spmd/shardy/round_trip_common:pipeline_passes",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:Transforms",
],
)
Loading

0 comments on commit e076af4

Please sign in to comment.