Skip to content

Commit

Permalink
#sdy support JAX callbacks through the Shardy XLA round-trip pipeline.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 705875941
  • Loading branch information
bartchr808 authored and Google-ML-Automation committed Dec 16, 2024
1 parent 0954ab3 commit 7b44115
Show file tree
Hide file tree
Showing 17 changed files with 327 additions and 17 deletions.
2 changes: 2 additions & 0 deletions xla/service/spmd/shardy/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ cc_library(
"@llvm-project//mlir:Support",
"@shardy//shardy/dialect/sdy/ir:dialect",
"@shardy//shardy/dialect/sdy/ir:register",
"@stablehlo//:stablehlo_ops",
],
)

Expand Down Expand Up @@ -131,6 +132,7 @@ xla_cc_binary(
"//xla/service/spmd/shardy/round_trip_common:open_while_free_vars_sharding",
"//xla/service/spmd/shardy/sdy_round_trip:export_ops",
"//xla/service/spmd/shardy/sdy_round_trip:export_shardy_attrs",
"//xla/service/spmd/shardy/sdy_round_trip:import_callback_custom_calls",
"//xla/service/spmd/shardy/sdy_round_trip:import_shardy_attrs",
"//xla/service/spmd/shardy/sdy_round_trip:pipelines",
"//xla/service/spmd/shardy/sdy_round_trip:remove_size_one_axes",
Expand Down
1 change: 1 addition & 0 deletions xla/service/spmd/shardy/mhlo_round_trip/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ cc_library(
"//xla/hlo/translate/mhlo_to_hlo:type_to_shape",
"//xla/mlir_hlo",
"//xla/service/spmd/shardy:constants",
"//xla/service/spmd/shardy:utils",
"@com_google_absl//absl/algorithm:container",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:FuncDialect",
Expand Down
41 changes: 41 additions & 0 deletions xla/service/spmd/shardy/mhlo_round_trip/export_shardings.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ limitations under the License.
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/LogicalResult.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
Expand All @@ -44,7 +45,9 @@ limitations under the License.
#include "mlir/IR/DialectRegistry.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/IR/TypeRange.h"
#include "mlir/IR/Value.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
Expand All @@ -62,6 +65,7 @@ limitations under the License.
#include "xla/hlo/translate/mhlo_to_hlo/type_to_shape.h"
#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
#include "xla/service/spmd/shardy/constants.h"
#include "xla/service/spmd/shardy/utils.h"
#include "xla/shape.h"
#include "xla/shape_util.h"

Expand All @@ -86,6 +90,8 @@ using ::mlir::success;
using ::mlir::SymbolTable;
using ::mlir::func::FuncOp;

using ::mlir::stablehlo::CustomCallOp;

using ::mlir::sdy::AxisRefAttr;
using ::mlir::sdy::DimensionShardingAttr;
using ::mlir::sdy::kShardingAttr;
Expand Down Expand Up @@ -190,13 +196,46 @@ LogicalResult exportFunc(FuncOp funcOp, const SymbolTable& symbolTable,
return success();
}

// The rest of the XLA pipeline expects host callback custom calls to either be
// a tuple with a get_tuple_element or no results (which we changed due to
// shardy shardings expecting at least one result, and needing to attach a
// maximal sharding to the callbacks).
void rewriteHostCallback(CustomCallOp op) {
StringRef targetName = op.getCallTargetName();
if (targetName != "xla_python_cpu_callback" &&
targetName != "xla_python_gpu_callback") {
return;
}
mlir::IRRewriter rewriter(op);
if (!op->use_empty()) {
if (op.getNumResults() == 1 &&
!isa<mlir::TupleType>(op->getResultTypes().front())) {
CustomCallOp tupleCustomCall = buildCustomCall(
op, mlir::TupleType::get(op->getContext(), {op->getResultTypes()}),
rewriter);
auto getTupleElement =
rewriter.create<mlir::stablehlo::GetTupleElementOp>(
op.getLoc(), op->getResultTypes().front(),
tupleCustomCall.getResult(0), rewriter.getI32IntegerAttr(0));
getTupleElement->setAttr(kXlaShardingAttr, op->getAttr(kXlaShardingAttr));
rewriter.replaceOp(op, getTupleElement);
}
return;
}
CustomCallOp newCustomCall =
buildCustomCall(op, SmallVector<mlir::Type>(), rewriter);
newCustomCall.setResultLayoutsAttr(rewriter.getArrayAttr({}));
rewriter.eraseOp(op);
}

class ExportMhloShardingsPass
: public PassWrapper<ExportMhloShardingsPass, OperationPass<ModuleOp>> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ExportMhloShardingsPass)

void runOnOperation() final {
ModuleOp moduleOp = getOperation();

mlir::SymbolTableCollection symbolTableCollection;
SymbolTable& symbolTable = symbolTableCollection.getSymbolTable(moduleOp);

Expand Down Expand Up @@ -236,6 +275,8 @@ class ExportMhloShardingsPass
llvm::make_early_inc_range(moduleOp.getOps<MeshOp>())) {
symbolTable.erase(meshOp);
}

moduleOp.walk([](CustomCallOp op) { rewriteHostCallback(op); });
}

StringRef getArgument() const override {
Expand Down
2 changes: 1 addition & 1 deletion xla/service/spmd/shardy/mhlo_round_trip/mhlo_import.cc
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ TensorShardingAttr convertToSdySharding(
// device.
if (hloSharding.HasUniqueDevice()) {
return TensorShardingAttr::getFullyClosed(
ctx, rank,
ctx, /*rank=*/0,
deviceIdToMaximalMeshName.lookup(hloSharding.GetUniqueDevice()));
}
CHECK(!hloSharding.IsTuple());
Expand Down
2 changes: 2 additions & 0 deletions xla/service/spmd/shardy/sdy_opt_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ limitations under the License.
#include "xla/service/spmd/shardy/round_trip_common/open_while_free_vars_sharding.h"
#include "xla/service/spmd/shardy/sdy_round_trip/export_ops.h"
#include "xla/service/spmd/shardy/sdy_round_trip/export_shardy_attrs.h"
#include "xla/service/spmd/shardy/sdy_round_trip/import_callback_custom_calls.h"
#include "xla/service/spmd/shardy/sdy_round_trip/import_shardy_attrs.h"
#include "xla/service/spmd/shardy/sdy_round_trip/pipelines.h"
#include "xla/service/spmd/shardy/sdy_round_trip/remove_size_one_axes.h"
Expand Down Expand Up @@ -72,6 +73,7 @@ int main(int argc, char** argv) {

xla::sdy::registerSdyRoundTripMhloToHloToMhloPass();
xla::sdy::registerSdyRoundTripExportShardyAttrsPass();
xla::sdy::registerSdyRoundTripImportCallbackCustomCallsPass();
xla::sdy::registerSdyRoundTripImportShardyAttrsPass();
xla::sdy::registerSdyRoundTripRemoveSizeOneAxesPass();
xla::sdy::registerSdyRoundTripExportOpsPass();
Expand Down
19 changes: 18 additions & 1 deletion xla/service/spmd/shardy/sdy_round_trip/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -126,13 +126,31 @@ cc_library(
],
)

cc_library(
name = "import_callback_custom_calls",
srcs = ["import_callback_custom_calls.cc"],
hdrs = ["import_callback_custom_calls.h"],
deps = [
"//xla/service/spmd/shardy:constants",
"//xla/service/spmd/shardy:utils",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TransformUtils",
"@shardy//shardy/dialect/sdy/ir:dialect",
"@stablehlo//:stablehlo_ops",
],
)

cc_library(
name = "pipelines",
srcs = ["pipelines.cc"],
hdrs = ["pipelines.h"],
deps = [
":export_ops",
":export_shardy_attrs",
":import_callback_custom_calls",
":import_shardy_attrs",
":remove_size_one_axes",
":shard_map_export",
Expand All @@ -143,6 +161,5 @@ cc_library(
"//xla/service/spmd/shardy/round_trip_common:pipeline_passes",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:Transforms",
],
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
/* Copyright 2024 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "xla/service/spmd/shardy/sdy_round_trip/import_callback_custom_calls.h"

#include <memory>

#include "llvm/ADT/StringRef.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/DialectRegistry.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/TypeID.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "xla/service/spmd/shardy/utils.h"

namespace xla {
namespace sdy {

namespace {

using ::mlir::ModuleOp;
using ::mlir::StringRef;
using ::mlir::stablehlo::CustomCallOp;

class SdyRoundTripImportCallbackCustomCallsPass
: public mlir::PassWrapper<SdyRoundTripImportCallbackCustomCallsPass,
mlir::OperationPass<ModuleOp>> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
SdyRoundTripImportCallbackCustomCallsPass)

void runOnOperation() final {
getOperation().walk([&](CustomCallOp op) {
StringRef targetName = op.getCallTargetName();
if (op->getNumResults() != 0 ||
(targetName != "xla_python_cpu_callback" &&
targetName != "xla_python_gpu_callback")) {
return;
}
mlir::IRRewriter rewriter(op);
CustomCallOp newCustomCall =
buildCustomCall(op, op->getOperand(0).getType(), rewriter);
newCustomCall.setResultLayoutsAttr(rewriter.getArrayAttr(
{op.getOperandLayoutsAttr().getValue().front()}));
rewriter.eraseOp(op);
});
}

StringRef getArgument() const override {
return "xla-sdy-round-trip-import-callback-custom-calls";
}

StringRef getDescription() const override {
return "Modifies the return types of XLA host callback custom calls to be "
"compatible with SDY";
}

void getDependentDialects(mlir::DialectRegistry& registry) const final {
registry.insert<mlir::stablehlo::StablehloDialect>();
}
};

} // namespace

std::unique_ptr<mlir::Pass> createSdyRoundTripImportCallbackCustomCallsPass() {
return std::make_unique<SdyRoundTripImportCallbackCustomCallsPass>();
}

void registerSdyRoundTripImportCallbackCustomCallsPass() {
mlir::registerPass(createSdyRoundTripImportCallbackCustomCallsPass);
}

} // namespace sdy
} // namespace xla
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/* Copyright 2024 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef XLA_SERVICE_SPMD_SHARDY_SDY_ROUND_TRIP_IMPORT_CALLBACK_CUSTOM_CALLS_H_
#define XLA_SERVICE_SPMD_SHARDY_SDY_ROUND_TRIP_IMPORT_CALLBACK_CUSTOM_CALLS_H_

#include <memory>

#include "mlir/Pass/Pass.h"

namespace xla {
namespace sdy {

// Creates the pass to modify the return types of XLA host callback custom calls
// to be compatible with SDY.
//
// Shardy shardings require an op to have at least one result, and the return
// types of XLA host callback custom calls are not guaranteed to return a value.
// To allow the custom calls to have a maximal sharding, we change the return
// type to return a dummy value.
std::unique_ptr<mlir::Pass> createSdyRoundTripImportCallbackCustomCallsPass();

// Registers the xla-sdy-round-trip-import-callback-custom-calls pass.
void registerSdyRoundTripImportCallbackCustomCallsPass();

} // namespace sdy
} // namespace xla

#endif // XLA_SERVICE_SPMD_SHARDY_SDY_ROUND_TRIP_IMPORT_CALLBACK_CUSTOM_CALLS_H_
15 changes: 14 additions & 1 deletion xla/service/spmd/shardy/sdy_round_trip/import_shardy_attrs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,17 @@ void convertShardyAttrs(FuncOp funcOp, IRRewriter& rewriter) {
if (!dictAttr) {
return;
}
// `SendOp` and `RecvOp` can have a sharding when doing TPU callbacks
// through JAX.
if (auto sendOp = mlir::dyn_cast<mlir::stablehlo::SendOp>(op)) {
sendOp->setAttr(kShardingAttr,
parseStringAttr<TensorShardingPerValueAttr>(
dictAttr, kShardingRoundTripAttr));
} else if (auto recvOp = mlir::dyn_cast<mlir::stablehlo::RecvOp>(op)) {
recvOp->setAttr(kShardingAttr,
parseStringAttr<TensorShardingPerValueAttr>(
dictAttr, kShardingRoundTripAttr));
}
// NOTE: we are only setting the sharding on known custom-calls. For any
// other op that has a `kShardingRoundTripAttr` we discard it. XLA sometimes
// creates new instructions, copying over the operand's frontend attrs,
Expand Down Expand Up @@ -139,7 +150,9 @@ void convertShardyAttrs(FuncOp funcOp, IRRewriter& rewriter) {
}
if (targetName == kShardingCustomCallTargetName ||
targetName == kSPMDFullToShardShapeCallTargetName ||
targetName == kSPMDShardToFullShapeCallTargetName) {
targetName == kSPMDShardToFullShapeCallTargetName ||
targetName == "xla_python_cpu_callback" ||
targetName == "xla_python_gpu_callback") {
customCallOp->setAttr(kShardingAttr,
parseStringAttr<TensorShardingPerValueAttr>(
dictAttr, kShardingRoundTripAttr));
Expand Down
2 changes: 2 additions & 0 deletions xla/service/spmd/shardy/sdy_round_trip/pipelines.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ limitations under the License.
#include "xla/service/spmd/shardy/round_trip_common/pipeline_passes.h"
#include "xla/service/spmd/shardy/sdy_round_trip/export_ops.h"
#include "xla/service/spmd/shardy/sdy_round_trip/export_shardy_attrs.h"
#include "xla/service/spmd/shardy/sdy_round_trip/import_callback_custom_calls.h"
#include "xla/service/spmd/shardy/sdy_round_trip/import_shardy_attrs.h"
#include "xla/service/spmd/shardy/sdy_round_trip/remove_size_one_axes.h"
#include "xla/service/spmd/shardy/sdy_round_trip/shard_map_export.h"
Expand All @@ -49,6 +50,7 @@ void addSdyRoundTripExportPipeline(mlir::OpPassManager& pm) {

void addSdyRoundTripImportPipeline(mlir::OpPassManager& pm) {
addCommonPreImportPasses(pm);
pm.addPass(createSdyRoundTripImportCallbackCustomCallsPass());
pm.addPass(createSdyRoundTripImportShardyAttrsPass());
pm.addPass(createSdyRoundTripShardMapImportPass());
pm.addPass(createSdyRoundTripRemoveSizeOneAxesPass());
Expand Down
Loading

0 comments on commit 7b44115

Please sign in to comment.