From e076af48a6d8c906d25e45fa93a5fb8c37ecc817 Mon Sep 17 00:00:00 2001 From: Bart Chrzaszcz Date: Fri, 13 Dec 2024 07:21:42 -0800 Subject: [PATCH] #sdy support JAX callbacks through the Shardy XLA round-trip pipeline. PiperOrigin-RevId: 705875941 --- xla/service/spmd/shardy/BUILD | 3 + xla/service/spmd/shardy/constants.h | 8 ++ xla/service/spmd/shardy/mhlo_round_trip/BUILD | 17 +++ .../export_callback_custom_calls.cc | 120 ++++++++++++++++++ .../export_callback_custom_calls.h | 42 ++++++ .../mhlo_round_trip/export_shardings.cc | 11 +- .../shardy/mhlo_round_trip/mhlo_export.cc | 2 + xla/service/spmd/shardy/sdy_opt_main.cc | 4 + xla/service/spmd/shardy/sdy_round_trip/BUILD | 18 ++- .../import_callback_custom_calls.cc | 91 +++++++++++++ .../import_callback_custom_calls.h | 41 ++++++ .../sdy_round_trip/import_shardy_attrs.cc | 16 ++- .../spmd/shardy/sdy_round_trip/pipelines.cc | 2 + .../shardy/test/mhlo_export_pipeline.mlir | 68 ++++++++++ .../test/sdy_round_trip_import_pipeline.mlir | 15 +++ xla/service/spmd/shardy/utils.cc | 24 ++++ xla/service/spmd/shardy/utils.h | 12 ++ 17 files changed, 485 insertions(+), 9 deletions(-) create mode 100644 xla/service/spmd/shardy/mhlo_round_trip/export_callback_custom_calls.cc create mode 100644 xla/service/spmd/shardy/mhlo_round_trip/export_callback_custom_calls.h create mode 100644 xla/service/spmd/shardy/sdy_round_trip/import_callback_custom_calls.cc create mode 100644 xla/service/spmd/shardy/sdy_round_trip/import_callback_custom_calls.h diff --git a/xla/service/spmd/shardy/BUILD b/xla/service/spmd/shardy/BUILD index 9cfab1d6355eb7..a25a4347db9f8d 100644 --- a/xla/service/spmd/shardy/BUILD +++ b/xla/service/spmd/shardy/BUILD @@ -88,6 +88,7 @@ cc_library( "@llvm-project//mlir:Support", "@shardy//shardy/dialect/sdy/ir:dialect", "@shardy//shardy/dialect/sdy/ir:register", + "@stablehlo//:stablehlo_ops", ], ) @@ -119,6 +120,7 @@ xla_cc_binary( deps = [ "//xla/mlir_hlo", "//xla/mlir_hlo:mhlo_passes", + "//xla/service/spmd/shardy/mhlo_round_trip:export_callback_custom_calls", "//xla/service/spmd/shardy/mhlo_round_trip:export_ops", "//xla/service/spmd/shardy/mhlo_round_trip:export_shardings", "//xla/service/spmd/shardy/mhlo_round_trip:mhlo_export", @@ -132,6 +134,7 @@ xla_cc_binary( "//xla/service/spmd/shardy/round_trip_common:open_while_free_vars_sharding", "//xla/service/spmd/shardy/sdy_round_trip:export_ops", "//xla/service/spmd/shardy/sdy_round_trip:export_shardy_attrs", + "//xla/service/spmd/shardy/sdy_round_trip:import_callback_custom_calls", "//xla/service/spmd/shardy/sdy_round_trip:import_shardy_attrs", "//xla/service/spmd/shardy/sdy_round_trip:pipelines", "//xla/service/spmd/shardy/sdy_round_trip:remove_size_one_axes", diff --git a/xla/service/spmd/shardy/constants.h b/xla/service/spmd/shardy/constants.h index ac227366096c37..4ebd8d3690d066 100644 --- a/xla/service/spmd/shardy/constants.h +++ b/xla/service/spmd/shardy/constants.h @@ -38,6 +38,14 @@ inline constexpr llvm::StringRef kSPMDFullToShardShapeCallTargetName = inline constexpr llvm::StringRef kSPMDShardToFullShapeCallTargetName = "SPMDShardToFullShape"; +// The target name of the Python CPU callback custom call. +inline constexpr llvm::StringRef kPythonCpuCallbackCustomCallTargetName = + "xla_python_cpu_callback"; + +// The target name of the Python GPU callback custom call. +inline constexpr llvm::StringRef kPythonGpuCallbackCustomCallTargetName = + "xla_python_gpu_callback"; + // The attribute name for backend config. inline constexpr llvm::StringRef kXlaBackendConfigAttr = "backend_config"; diff --git a/xla/service/spmd/shardy/mhlo_round_trip/BUILD b/xla/service/spmd/shardy/mhlo_round_trip/BUILD index fb62d32e756f16..8bb924fd9fe31a 100644 --- a/xla/service/spmd/shardy/mhlo_round_trip/BUILD +++ b/xla/service/spmd/shardy/mhlo_round_trip/BUILD @@ -83,11 +83,28 @@ cc_library( ], ) +cc_library( + name = "export_callback_custom_calls", + srcs = ["export_callback_custom_calls.cc"], + hdrs = ["export_callback_custom_calls.h"], + deps = [ + "//xla/service/spmd/shardy:constants", + "//xla/service/spmd/shardy:utils", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@shardy//shardy/dialect/sdy/ir:dialect", + "@stablehlo//:stablehlo_ops", + ], +) + cc_library( name = "mhlo_export", srcs = ["mhlo_export.cc"], hdrs = ["mhlo_export.h"], deps = [ + ":export_callback_custom_calls", ":export_ops", ":export_shardings", ":shard_map_export", diff --git a/xla/service/spmd/shardy/mhlo_round_trip/export_callback_custom_calls.cc b/xla/service/spmd/shardy/mhlo_round_trip/export_callback_custom_calls.cc new file mode 100644 index 00000000000000..c8da99e34b5715 --- /dev/null +++ b/xla/service/spmd/shardy/mhlo_round_trip/export_callback_custom_calls.cc @@ -0,0 +1,120 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/spmd/shardy/mhlo_round_trip/export_callback_custom_calls.h" + +#include + +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/DialectRegistry.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassRegistry.h" +#include "mlir/Support/TypeID.h" +#include "shardy/dialect/sdy/ir/dialect.h" +#include "stablehlo/dialect/StablehloOps.h" +#include "xla/service/spmd/shardy/constants.h" +#include "xla/service/spmd/shardy/utils.h" + +namespace xla { +namespace sdy { + +namespace { + +using ::mlir::ModuleOp; +using ::mlir::OperationPass; +using ::mlir::PassWrapper; +using ::mlir::StringRef; + +using ::mlir::stablehlo::CustomCallOp; + +// Attempts to replace the `CustomCallOp` with a tuple version of it, and a +// `GetTupleElementOp` that gets the first element of the tuple. +// +// This only happens if the op has a single result and the result type is not +// a tuple. +void replaceCallbackWithTupleVersion(CustomCallOp customCall, + mlir::IRRewriter& rewriter) { + if (customCall.getNumResults() != 1 || + isa(customCall->getResultTypes().front())) { + return; + } + CustomCallOp tupleCustomCall = cloneCustomCallWithNewResultTypes( + customCall, + mlir::TupleType::get(customCall->getContext(), + {customCall->getResultTypes()}), + rewriter); + auto getTupleElement = rewriter.create( + customCall.getLoc(), customCall->getResultTypes().front(), + tupleCustomCall.getResult(0), rewriter.getI32IntegerAttr(0)); + getTupleElement->setAttr(kXlaShardingAttr, + customCall->getAttr(kXlaShardingAttr)); + rewriter.replaceOp(customCall, getTupleElement); +} + +class MhloRoundTripExportCallbackCustomCallsPass + : public PassWrapper> { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( + MhloRoundTripExportCallbackCustomCallsPass) + + void runOnOperation() final { + getOperation().walk([&](CustomCallOp customCall) { + if (!isPythonCallbackCustomCall(customCall)) { + return; + } + mlir::IRRewriter rewriter(customCall); + if (!customCall->use_empty()) { + replaceCallbackWithTupleVersion(customCall, rewriter); + return; + } + CustomCallOp newCustomCall = cloneCustomCallWithNewResultTypes( + customCall, mlir::TypeRange(), rewriter); + newCustomCall.setResultLayoutsAttr(rewriter.getArrayAttr({})); + rewriter.eraseOp(customCall); + return; + }); + } + + StringRef getArgument() const override { + return "xla-sdy-mhlo-round-trip-export-callback-custom-calls"; + } + + StringRef getDescription() const override { + return "Converts the `CustomCallOp`s for host callbacks in XLA into the " + "pattern that the XLA compiler recognizes."; + } + + void getDependentDialects(mlir::DialectRegistry& registry) const final { + registry.insert(); + } +}; + +} // namespace + +std::unique_ptr createMhloRoundTripExportCallbackCustomCallsPass() { + return std::make_unique(); +} + +void registerMhloRoundTripExportCallbackCustomCallsPass() { + mlir::registerPass(createMhloRoundTripExportCallbackCustomCallsPass); +} + +} // namespace sdy +} // namespace xla diff --git a/xla/service/spmd/shardy/mhlo_round_trip/export_callback_custom_calls.h b/xla/service/spmd/shardy/mhlo_round_trip/export_callback_custom_calls.h new file mode 100644 index 00000000000000..b67955f7a80212 --- /dev/null +++ b/xla/service/spmd/shardy/mhlo_round_trip/export_callback_custom_calls.h @@ -0,0 +1,42 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_SPMD_SHARDY_MHLO_ROUND_TRIP_EXPORT_CALLBACK_CUSTOM_CALLS_H_ +#define XLA_SERVICE_SPMD_SHARDY_MHLO_ROUND_TRIP_EXPORT_CALLBACK_CUSTOM_CALLS_H_ + +#include + +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassRegistry.h" + +namespace xla { +namespace sdy { + +// Creates a pass that converts the `CustomCallOp`s for host callbacks in XLA +// into the pattern that the XLA compiler recognizes. +// +// The rest of the XLA pipeline expects host callback custom calls to either be +// a tuple with a get_tuple_element or no results (which we changed due to +// shardy shardings expecting at least one result, and needing to attach a +// maximal sharding to the callbacks). +std::unique_ptr createMhloRoundTripExportCallbackCustomCallsPass(); + +// Registers the xla-sdy-mhlo-round-trip-export-callback-custom-calls pass. +void registerMhloRoundTripExportCallbackCustomCallsPass(); + +} // namespace sdy +} // namespace xla + +#endif // XLA_SERVICE_SPMD_SHARDY_MHLO_ROUND_TRIP_EXPORT_CALLBACK_CUSTOM_CALLS_H_ diff --git a/xla/service/spmd/shardy/mhlo_round_trip/export_shardings.cc b/xla/service/spmd/shardy/mhlo_round_trip/export_shardings.cc index 05be693ea09b12..bd5834c8249333 100644 --- a/xla/service/spmd/shardy/mhlo_round_trip/export_shardings.cc +++ b/xla/service/spmd/shardy/mhlo_round_trip/export_shardings.cc @@ -86,6 +86,8 @@ using ::mlir::success; using ::mlir::SymbolTable; using ::mlir::func::FuncOp; +using ::mlir::stablehlo::CustomCallOp; + using ::mlir::sdy::AxisRefAttr; using ::mlir::sdy::DimensionShardingAttr; using ::mlir::sdy::kShardingAttr; @@ -197,6 +199,7 @@ class ExportMhloShardingsPass void runOnOperation() final { ModuleOp moduleOp = getOperation(); + mlir::SymbolTableCollection symbolTableCollection; SymbolTable& symbolTable = symbolTableCollection.getSymbolTable(moduleOp); @@ -208,10 +211,10 @@ class ExportMhloShardingsPass } } - // StableHLO doesn't have an equivalent of `erf` and `topk` ops. - // If they have a sharding annotation, we need to move it into - // `mhlo.attributes`, which StableHLO->MHLO conversion would lift back up. - moduleOp.walk([&](mlir::stablehlo::CustomCallOp customCall) { + moduleOp.walk([&](CustomCallOp customCall) { + // StableHLO doesn't have an equivalent of `erf` and `topk` ops. + // If they have a sharding annotation, we need to move it into + // `mhlo.attributes`, which StableHLO->MHLO conversion would lift back up. StringRef callTargetName = customCall.getCallTargetName(); if (callTargetName != "mhlo.erf" && callTargetName != "mhlo.topk") { return; diff --git a/xla/service/spmd/shardy/mhlo_round_trip/mhlo_export.cc b/xla/service/spmd/shardy/mhlo_round_trip/mhlo_export.cc index 36aee9a64f266b..232e8c4d09da2c 100644 --- a/xla/service/spmd/shardy/mhlo_round_trip/mhlo_export.cc +++ b/xla/service/spmd/shardy/mhlo_round_trip/mhlo_export.cc @@ -20,6 +20,7 @@ limitations under the License. #include "mlir/Pass/PassManager.h" #include "mlir/Pass/PassRegistry.h" #include "mlir/Support/LLVM.h" +#include "xla/service/spmd/shardy/mhlo_round_trip/export_callback_custom_calls.h" #include "xla/service/spmd/shardy/mhlo_round_trip/export_ops.h" #include "xla/service/spmd/shardy/mhlo_round_trip/export_shardings.h" #include "xla/service/spmd/shardy/mhlo_round_trip/shard_map_export.h" @@ -36,6 +37,7 @@ void addMhloExportPipeline(mlir::OpPassManager& pm) { pm.addPass(createMhloRoundTripShardMapExportPass()); pm.addPass(createExportNamedComputationsPass()); pm.addPass(createExportMhloShardingsPass()); + pm.addPass(createMhloRoundTripExportCallbackCustomCallsPass()); } void registerMhloExportPipeline() { diff --git a/xla/service/spmd/shardy/sdy_opt_main.cc b/xla/service/spmd/shardy/sdy_opt_main.cc index 7f2dff488a7f00..1fd97e53d3d936 100644 --- a/xla/service/spmd/shardy/sdy_opt_main.cc +++ b/xla/service/spmd/shardy/sdy_opt_main.cc @@ -23,6 +23,7 @@ limitations under the License. #include "stablehlo/dialect/StablehloOps.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/mlir_hlo/mhlo/transforms/passes.h" +#include "xla/service/spmd/shardy/mhlo_round_trip/export_callback_custom_calls.h" #include "xla/service/spmd/shardy/mhlo_round_trip/export_ops.h" #include "xla/service/spmd/shardy/mhlo_round_trip/export_shardings.h" #include "xla/service/spmd/shardy/mhlo_round_trip/mhlo_export.h" @@ -36,6 +37,7 @@ limitations under the License. #include "xla/service/spmd/shardy/round_trip_common/open_while_free_vars_sharding.h" #include "xla/service/spmd/shardy/sdy_round_trip/export_ops.h" #include "xla/service/spmd/shardy/sdy_round_trip/export_shardy_attrs.h" +#include "xla/service/spmd/shardy/sdy_round_trip/import_callback_custom_calls.h" #include "xla/service/spmd/shardy/sdy_round_trip/import_shardy_attrs.h" #include "xla/service/spmd/shardy/sdy_round_trip/pipelines.h" #include "xla/service/spmd/shardy/sdy_round_trip/remove_size_one_axes.h" @@ -66,12 +68,14 @@ int main(int argc, char** argv) { xla::sdy::registerMhloExportPipeline(); xla::sdy::registerMhloExportShardingsPass(); + xla::sdy::registerMhloRoundTripExportCallbackCustomCallsPass(); xla::sdy::registerMhloRoundTripShardMapExportPass(); xla::sdy::registerExportNamedComputationsPass(); xla::sdy::registerExportOpsPass(); xla::sdy::registerSdyRoundTripMhloToHloToMhloPass(); xla::sdy::registerSdyRoundTripExportShardyAttrsPass(); + xla::sdy::registerSdyRoundTripImportCallbackCustomCallsPass(); xla::sdy::registerSdyRoundTripImportShardyAttrsPass(); xla::sdy::registerSdyRoundTripRemoveSizeOneAxesPass(); xla::sdy::registerSdyRoundTripExportOpsPass(); diff --git a/xla/service/spmd/shardy/sdy_round_trip/BUILD b/xla/service/spmd/shardy/sdy_round_trip/BUILD index 66dd2587a60d8e..3d5f950c31d92c 100644 --- a/xla/service/spmd/shardy/sdy_round_trip/BUILD +++ b/xla/service/spmd/shardy/sdy_round_trip/BUILD @@ -126,6 +126,22 @@ cc_library( ], ) +cc_library( + name = "import_callback_custom_calls", + srcs = ["import_callback_custom_calls.cc"], + hdrs = ["import_callback_custom_calls.h"], + deps = [ + "//xla/service/spmd/shardy:utils", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@shardy//shardy/dialect/sdy/ir:dialect", + "@stablehlo//:stablehlo_ops", + ], +) + cc_library( name = "pipelines", srcs = ["pipelines.cc"], @@ -133,6 +149,7 @@ cc_library( deps = [ ":export_ops", ":export_shardy_attrs", + ":import_callback_custom_calls", ":import_shardy_attrs", ":remove_size_one_axes", ":shard_map_export", @@ -143,6 +160,5 @@ cc_library( "//xla/service/spmd/shardy/round_trip_common:pipeline_passes", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", - "@llvm-project//mlir:Transforms", ], ) diff --git a/xla/service/spmd/shardy/sdy_round_trip/import_callback_custom_calls.cc b/xla/service/spmd/shardy/sdy_round_trip/import_callback_custom_calls.cc new file mode 100644 index 00000000000000..0fa3f44d8204af --- /dev/null +++ b/xla/service/spmd/shardy/sdy_round_trip/import_callback_custom_calls.cc @@ -0,0 +1,91 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/spmd/shardy/sdy_round_trip/import_callback_custom_calls.h" + +#include + +#include "llvm/ADT/StringRef.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/DialectRegistry.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassRegistry.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/TypeID.h" +#include "stablehlo/dialect/StablehloOps.h" +#include "xla/service/spmd/shardy/utils.h" + +namespace xla { +namespace sdy { + +namespace { + +using ::mlir::ModuleOp; +using ::mlir::StringRef; +using ::mlir::stablehlo::CustomCallOp; + +class SdyRoundTripImportCallbackCustomCallsPass + : public mlir::PassWrapper> { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( + SdyRoundTripImportCallbackCustomCallsPass) + + void runOnOperation() final { + getOperation().walk([&](CustomCallOp op) { + if (op->getNumResults() != 0 || !isPythonCallbackCustomCall(op)) { + return; + } + mlir::IRRewriter rewriter(op); + // Shardy needs at least one op result to have a sharding annotation. + // Since the callback has no results, and we need to say the callbacks + // have a maximal sharding, we add a dummy result and set the result + // layout to the 0th operand layout. + CustomCallOp newCustomCall = cloneCustomCallWithNewResultTypes( + op, op->getOperand(0).getType(), rewriter); + newCustomCall.setResultLayoutsAttr(rewriter.getArrayAttr( + {op.getOperandLayoutsAttr().getValue().front()})); + rewriter.eraseOp(op); + }); + } + + StringRef getArgument() const override { + return "xla-sdy-round-trip-import-callback-custom-calls"; + } + + StringRef getDescription() const override { + return "Modifies the return types of XLA host callback custom calls to be " + "compatible with SDY"; + } + + void getDependentDialects(mlir::DialectRegistry& registry) const final { + registry.insert(); + } +}; + +} // namespace + +std::unique_ptr createSdyRoundTripImportCallbackCustomCallsPass() { + return std::make_unique(); +} + +void registerSdyRoundTripImportCallbackCustomCallsPass() { + mlir::registerPass(createSdyRoundTripImportCallbackCustomCallsPass); +} + +} // namespace sdy +} // namespace xla diff --git a/xla/service/spmd/shardy/sdy_round_trip/import_callback_custom_calls.h b/xla/service/spmd/shardy/sdy_round_trip/import_callback_custom_calls.h new file mode 100644 index 00000000000000..ce81f5ead47191 --- /dev/null +++ b/xla/service/spmd/shardy/sdy_round_trip/import_callback_custom_calls.h @@ -0,0 +1,41 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_SPMD_SHARDY_SDY_ROUND_TRIP_IMPORT_CALLBACK_CUSTOM_CALLS_H_ +#define XLA_SERVICE_SPMD_SHARDY_SDY_ROUND_TRIP_IMPORT_CALLBACK_CUSTOM_CALLS_H_ + +#include + +#include "mlir/Pass/Pass.h" + +namespace xla { +namespace sdy { + +// Creates the pass to modify the return types of XLA host callback custom calls +// to be compatible with SDY. +// +// Shardy shardings require an op to have at least one result, and the XLA host +// callback custom calls are not guaranteed to return a value. +// To allow the custom calls to have a maximal sharding, we change the return +// type to return a dummy value. +std::unique_ptr createSdyRoundTripImportCallbackCustomCallsPass(); + +// Registers the xla-sdy-round-trip-import-callback-custom-calls pass. +void registerSdyRoundTripImportCallbackCustomCallsPass(); + +} // namespace sdy +} // namespace xla + +#endif // XLA_SERVICE_SPMD_SHARDY_SDY_ROUND_TRIP_IMPORT_CALLBACK_CUSTOM_CALLS_H_ diff --git a/xla/service/spmd/shardy/sdy_round_trip/import_shardy_attrs.cc b/xla/service/spmd/shardy/sdy_round_trip/import_shardy_attrs.cc index a9a7f3003fb562..99936f482f7fb2 100644 --- a/xla/service/spmd/shardy/sdy_round_trip/import_shardy_attrs.cc +++ b/xla/service/spmd/shardy/sdy_round_trip/import_shardy_attrs.cc @@ -22,6 +22,7 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" #include "mlir/AsmParser/AsmParser.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/Attributes.h" @@ -65,8 +66,6 @@ using ::mlir::StringRef; using ::mlir::SymbolTable; using ::mlir::func::FuncOp; -using ::mlir::stablehlo::CustomCallOp; - using ::mlir::sdy::kShardingAttr; using ::mlir::sdy::kShardingRuleAttr; using ::mlir::sdy::MeshAttr; @@ -74,6 +73,8 @@ using ::mlir::sdy::OpShardingRuleAttr; using ::mlir::sdy::TensorShardingAttr; using ::mlir::sdy::TensorShardingPerValueAttr; +namespace stablehlo = ::mlir::stablehlo; + // Builds the shardy attributes coming from Shardy previously. This means // the module was exported from Shardy and we are now round-tripping back. // This should happen after the meshes were created from the `ModuleOp` attrs @@ -108,13 +109,19 @@ void convertShardyAttrs(FuncOp funcOp, IRRewriter& rewriter) { if (!dictAttr) { return; } + // `SendOp` and `RecvOp` can have a sharding when doing TPU callbacks + // through JAX. + if (isa(op)) { + op->setAttr(kShardingAttr, parseStringAttr( + dictAttr, kShardingRoundTripAttr)); + } // NOTE: we are only setting the sharding on known custom-calls. For any // other op that has a `kShardingRoundTripAttr` we discard it. XLA sometimes // creates new instructions, copying over the operand's frontend attrs, // which may mean the shapes are wrong when the new instruction is a reshape // for example. This does mean we can't fully round-trip b/w HLO and MLIR // after SDY propagation. - if (auto customCallOp = mlir::dyn_cast(op)) { + if (auto customCallOp = mlir::dyn_cast(op)) { StringRef targetName = customCallOp.getCallTargetName(); if (targetName == kFuncResultShardingTargetName) { // This is a temporary CustomCallOp that holds the sharding from a @@ -139,7 +146,8 @@ void convertShardyAttrs(FuncOp funcOp, IRRewriter& rewriter) { } if (targetName == kShardingCustomCallTargetName || targetName == kSPMDFullToShardShapeCallTargetName || - targetName == kSPMDShardToFullShapeCallTargetName) { + targetName == kSPMDShardToFullShapeCallTargetName || + isPythonCallbackCustomCall(customCallOp)) { customCallOp->setAttr(kShardingAttr, parseStringAttr( dictAttr, kShardingRoundTripAttr)); diff --git a/xla/service/spmd/shardy/sdy_round_trip/pipelines.cc b/xla/service/spmd/shardy/sdy_round_trip/pipelines.cc index 32e15074c843a1..0f92d457152cf4 100644 --- a/xla/service/spmd/shardy/sdy_round_trip/pipelines.cc +++ b/xla/service/spmd/shardy/sdy_round_trip/pipelines.cc @@ -26,6 +26,7 @@ limitations under the License. #include "xla/service/spmd/shardy/round_trip_common/pipeline_passes.h" #include "xla/service/spmd/shardy/sdy_round_trip/export_ops.h" #include "xla/service/spmd/shardy/sdy_round_trip/export_shardy_attrs.h" +#include "xla/service/spmd/shardy/sdy_round_trip/import_callback_custom_calls.h" #include "xla/service/spmd/shardy/sdy_round_trip/import_shardy_attrs.h" #include "xla/service/spmd/shardy/sdy_round_trip/remove_size_one_axes.h" #include "xla/service/spmd/shardy/sdy_round_trip/shard_map_export.h" @@ -49,6 +50,7 @@ void addSdyRoundTripExportPipeline(mlir::OpPassManager& pm) { void addSdyRoundTripImportPipeline(mlir::OpPassManager& pm) { addCommonPreImportPasses(pm); + pm.addPass(createSdyRoundTripImportCallbackCustomCallsPass()); pm.addPass(createSdyRoundTripImportShardyAttrsPass()); pm.addPass(createSdyRoundTripShardMapImportPass()); pm.addPass(createSdyRoundTripRemoveSizeOneAxesPass()); diff --git a/xla/service/spmd/shardy/test/mhlo_export_pipeline.mlir b/xla/service/spmd/shardy/test/mhlo_export_pipeline.mlir index d327cd439f07b6..ca9d1d5d00647f 100644 --- a/xla/service/spmd/shardy/test/mhlo_export_pipeline.mlir +++ b/xla/service/spmd/shardy/test/mhlo_export_pipeline.mlir @@ -246,6 +246,74 @@ func.func @custom_call_erf_topk( return %1#0 : tensor<16x2xf32> } +// CHECK-LABEL: @callback_transform_to_tuple +func.func @callback_transform_to_tuple(%arg0: tensor<2xf64> {sdy.sharding = #sdy.sharding<@mesh_5, [{"i"}]>}) -> (tensor<2xf64> {sdy.sharding = #sdy.sharding<@mesh_5, [{"i"}]>}) { + // CHECK-NEXT: %[[C:.*]] = stablehlo.constant + // CHECK-NEXT: %[[CALLBACK:.*]] = stablehlo.custom_call @xla_python_cpu_callback(%[[C]], %arg0) {{{.*}} : (tensor, tensor<2xf64>) -> tuple> + // CHECK-NEXT: %[[GET_TUPLE:.*]] = stablehlo.get_tuple_element %[[CALLBACK]][0] {mhlo.sharding = "{replicated}"} : (tuple>) -> tensor<2xf64> + // CHECK-NEXT: return %[[GET_TUPLE]] : tensor<2xf64> + %1 = stablehlo.constant dense<56560393354880> : tensor + %2 = stablehlo.custom_call @xla_python_cpu_callback(%1, %arg0) {api_version = 2 : i32, backend_config = "56560393354880", operand_layouts = [dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>], result_layouts = [dense<0> : tensor<1xindex>], sdy.sharding = #sdy.sharding_per_value<[<@empty_mesh_0, [{}]>]>, xla_shape = "(f64[2]{0})"} : (tensor, tensor<2xf64>) -> tensor<2xf64> + return %2 : tensor<2xf64> +} + +// CHECK-LABEL: @callback_no_result +func.func private @callback_no_result(%arg0: tensor) { + // CHECK-NEXT: %[[C:.*]] = stablehlo.constant + // CHECK-NEXT: stablehlo.custom_call @xla_python_cpu_callback(%[[C]], %arg0) { + // CHECK-SAME: api_version = 2 : i32, backend_config = "56238273106176", + // CHECK-SAME: has_side_effect = true, mhlo.sharding = "{maximal device=0}", + // CHECK-SAME: operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>], + // CHECK-SAME: result_layouts = [] + // CHECK-SAME: } : (tensor, tensor) -> () + %c = stablehlo.constant dense<56238273106176> : tensor + %0 = stablehlo.custom_call @xla_python_cpu_callback(%c, %arg0) {api_version = 2 : i32, backend_config = "56238273106176", has_side_effect = true, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>], result_layouts = [], sdy.sharding = #sdy.sharding_per_value<[<@maximal_mesh_0, []>]>} : (tensor, tensor) -> tuple<> + return +} + +// CHECK-LABEL: @callback_result_unused +func.func private @callback_result_unused(%arg0: tensor) { + // CHECK-NEXT: %[[C:.*]] = stablehlo.constant + // CHECK-NEXT: stablehlo.custom_call @xla_python_cpu_callback(%[[C]], %arg0) { + // CHECK-SAME: api_version = 2 : i32, backend_config = "56238273106176", + // CHECK-SAME: has_side_effect = true, mhlo.sharding = "{maximal device=0}", + // CHECK-SAME: operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>], + // CHECK-SAME: result_layouts = [] + // CHECK-SAME: } : (tensor, tensor) -> () + %c = stablehlo.constant dense<56238273106176> : tensor + %0 = stablehlo.custom_call @xla_python_cpu_callback(%c, %arg0) {api_version = 2 : i32, backend_config = "56238273106176", has_side_effect = true, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>], result_layouts = [dense<> : tensor<0xindex>], sdy.sharding = #sdy.sharding_per_value<[<@maximal_mesh_0, []>]>} : (tensor, tensor) -> tensor + return +} + +// CHECK-LABEL: @callback_tuple_result_token_used +func.func public @callback_tuple_result_token_used(%arg0: !stablehlo.token, %arg1: tensor<2xi64>) -> !stablehlo.token { + %c = stablehlo.constant dense<56238119409280> : tensor + // CHECK-NEXT: %[[C:.*]] = stablehlo.constant + // CHECK-NEXT: %[[CALLBACK:.*]] = stablehlo.custom_call @xla_python_cpu_callback(%[[C]], %arg0, %arg1) { + // CHECK-SAME: api_version = 2 : i32, backend_config = "56238119409280", + // CHECK-SAME: has_side_effect = true, mhlo.sharding = "{maximal device=0}", + // CHECK-SAME: operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>], + // CHECK-SAME: result_layouts = [dense<> : tensor<0xindex>] + // CHECK-SAME: } : (tensor, !stablehlo.token, tensor<2xi64>) -> tuple + // CHECK-NEXT: %[[TOKEN:.*]] = stablehlo.get_tuple_element %[[CALLBACK]][0] : (tuple) -> !stablehlo.token + // CHECK-NEXT: return %[[TOKEN]] : !stablehlo.token + %0 = stablehlo.custom_call @xla_python_cpu_callback(%c, %arg0, %arg1) {api_version = 2 : i32, backend_config = "56238119409280", has_side_effect = true, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>], result_layouts = [dense<> : tensor<0xindex>], sdy.sharding = #sdy.sharding_per_value<[<@maximal_mesh_0, []>]>} : (tensor, !stablehlo.token, tensor<2xi64>) -> tuple + %1 = stablehlo.get_tuple_element %0[0] : (tuple) -> !stablehlo.token + return %1 : !stablehlo.token +} + +// CHECK-LABEL: @callback_no_tuple_result_used +func.func @callback_no_tuple_result_used(%arg0: tensor<2xf64>) -> tensor<2xf64> { + // CHECK-NEXT: %[[C:.*]] = stablehlo.constant + // CHECK-NEXT: %[[CALLBACK:.*]] = stablehlo.custom_call @xla_python_cpu_callback(%[[C]], %arg0) {{{.*}} : (tensor, tensor<2xf64>) -> tuple> + // CHECK-NEXT: %[[GET_TUPLE:.*]] = stablehlo.get_tuple_element %[[CALLBACK]][0] {mhlo.sharding = "{replicated}"} : (tuple>) -> tensor<2xf64> + // CHECK-NEXT: return %[[GET_TUPLE]] : tensor<2xf64> + %c = stablehlo.constant dense<18990036333952> : tensor + %0 = stablehlo.custom_call @xla_python_cpu_callback(%c, %arg0) {api_version = 2 : i32, backend_config = "18990036333952", operand_layouts = [dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>], result_layouts = [dense<0> : tensor<1xindex>], sdy.sharding = #sdy.sharding_per_value<[<@empty_mesh_0, [{?}]>]>, xla_shape = "(f64[2]{0})"} : (tensor, tensor<2xf64>) -> tensor<2xf64> + return %0 : tensor<2xf64> +} + + // CHECK-LABEL: func private @foo // CHECK-SAME: %arg0: tensor<4x2xi32> {mhlo.sharding = "{devices=[4,1,8]<=[8,4]T(1,0) last_tile_dim_replicate}"} // CHECK-SAME: -> (tensor<4x2xi32> {mhlo.sharding = "{devices=[4,1,8]<=[8,4]T(1,0) last_tile_dim_replicate}"}) { diff --git a/xla/service/spmd/shardy/test/sdy_round_trip_import_pipeline.mlir b/xla/service/spmd/shardy/test/sdy_round_trip_import_pipeline.mlir index 8529477a8a1d00..ac38fb676b5225 100644 --- a/xla/service/spmd/shardy/test/sdy_round_trip_import_pipeline.mlir +++ b/xla/service/spmd/shardy/test/sdy_round_trip_import_pipeline.mlir @@ -241,3 +241,18 @@ func.func @import_sharding_group(%arg0: tensor<8x8xf32>) -> tensor<8x8xf32> { stablehlo.custom_call @xla.sdy.ShardingGroup(%arg0) {has_side_effect = true, mhlo.frontend_attributes = {xla.sdy.sharding_group_id = "21 : i64"}} : (tensor<8x8xf32>) -> () return %arg0 : tensor<8x8xf32> } + +// ----- + +func.func @callback_no_result(%arg0: tensor) { + // CHECK: %[[C:.*]] = sdy.constant + // CHECK-NEXT: stablehlo.custom_call @xla_python_cpu_callback(%[[C]], %arg0) { + // CHECK-SAME: api_version = 2 : i32, backend_config = "56238273106176", + // CHECK-SAME: has_side_effect = true, + // CHECK-SAME: operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>], + // CHECK-SAME: result_layouts = [dense<> : tensor<0xindex>] + // CHECK-SAME: } : (tensor, tensor) -> tensor + %c = stablehlo.constant dense<56238273106176> : tensor + stablehlo.custom_call @xla_python_cpu_callback(%c, %arg0) {api_version = 2 : i32, backend_config = "56238273106176", has_side_effect = true, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>], result_layouts = []} : (tensor, tensor) -> () + return +} diff --git a/xla/service/spmd/shardy/utils.cc b/xla/service/spmd/shardy/utils.cc index 8bd04c8f6f1ab2..62eecad007b040 100644 --- a/xla/service/spmd/shardy/utils.cc +++ b/xla/service/spmd/shardy/utils.cc @@ -30,9 +30,12 @@ limitations under the License. #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/TypeRange.h" #include "mlir/Support/LLVM.h" #include "shardy/dialect/sdy/ir/register.h" #include "shardy/dialect/sdy/ir/utils.h" +#include "stablehlo/dialect/StablehloOps.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/service/spmd/shardy/constants.h" @@ -50,6 +53,7 @@ using ::mlir::StringRef; using xla::sdy::kFrontendAttributesAttr; using ::mlir::func::FuncOp; +using ::mlir::stablehlo::CustomCallOp; DictionaryAttr getFrontendAttrs(Operation* op) { return op->getAttrOfType(kFrontendAttributesAttr); @@ -185,5 +189,25 @@ void loadAllRequiredDialects(mlir::MLIRContext* context) { context->loadAllAvailableDialects(); } +CustomCallOp cloneCustomCallWithNewResultTypes(CustomCallOp op, + mlir::TypeRange resultTypes, + mlir::IRRewriter& rewriter) { + auto customCallOp = rewriter.create( + op.getLoc(), resultTypes, op.getOperands(), op.getCallTargetNameAttr(), + op.getHasSideEffectAttr(), op.getBackendConfigAttr(), + op.getApiVersionAttr(), op.getCalledComputations(), + op.getOperandLayoutsAttr(), op.getResultLayoutsAttr(), + op.getOutputOperandAliases()); + customCallOp->setDiscardableAttrs(mlir::DictionaryAttr::get( + op->getContext(), llvm::to_vector(op->getDiscardableAttrs()))); + return customCallOp; +}; + +bool isPythonCallbackCustomCall(mlir::stablehlo::CustomCallOp op) { + mlir::StringRef targetName = op.getCallTargetName(); + return targetName == kPythonCpuCallbackCustomCallTargetName || + targetName == kPythonGpuCallbackCustomCallTargetName; +} + } // namespace sdy } // namespace xla diff --git a/xla/service/spmd/shardy/utils.h b/xla/service/spmd/shardy/utils.h index fbdcbca4913c93..7975a55599d648 100644 --- a/xla/service/spmd/shardy/utils.h +++ b/xla/service/spmd/shardy/utils.h @@ -28,7 +28,10 @@ limitations under the License. #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/MLIRContext.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/TypeRange.h" #include "mlir/Support/LLVM.h" +#include "stablehlo/dialect/StablehloOps.h" namespace xla { namespace sdy { @@ -101,6 +104,15 @@ std::optional tryGetFrontendAttr(mlir::Operation* op, return std::nullopt; } +// Builds a new `stablehlo.custom_call` with the same operands and attributes +// as `op` but with new `resultTypes`. +mlir::stablehlo::CustomCallOp cloneCustomCallWithNewResultTypes( + mlir::stablehlo::CustomCallOp op, mlir::TypeRange resultTypes, + mlir::IRRewriter& rewriter); + +// Whether `op` is a Python callback custom call. +bool isPythonCallbackCustomCall(mlir::stablehlo::CustomCallOp op); + } // namespace sdy } // namespace xla