From f5c78de054154a3eed066bc61dd7b60b6d2872af Mon Sep 17 00:00:00 2001 From: anakinxc <103552181+anakinxc@users.noreply.github.com> Date: Fri, 2 Dec 2022 11:56:53 +0800 Subject: [PATCH] Weekly sync (#43) * repo-sync-2022-12-02T10:50:14+0800 * Fix merge * Fix merge * Fix yacl commit id. --- CHANGELOG.md | 4 +- examples/cpp/BUILD.bazel | 3 +- examples/cpp/simple_pphlo.cc | 37 +- examples/python/millionare.py | 2 +- examples/python/ml/flax_mlp.py | 2 +- examples/python/ml/jax_lr.py | 2 +- examples/python/ml/ss_lr.py | 2 +- examples/python/ml/ss_xgb.py | 2 +- examples/python/ml/stax_mnist_classifier.py | 2 +- examples/python/ml/stax_nn.py | 12 +- examples/python/ml/tf_experiment.py | 2 +- examples/python/stats/corr.py | 2 +- examples/python/stats/pvalue.py | 2 +- examples/python/stats/woe.py | 2 +- spu/BUILD.bazel | 2 +- spu/binding/BUILD.bazel | 3 +- spu/binding/_lib.cc | 10 +- spu/binding/util/frontend_test.py | 12 +- spu/compiler/core/core.cc | 3 + spu/compiler/passes/optimize_maxpool.cc | 102 +- spu/compiler/passes/optimize_select.cc | 35 +- spu/compiler/tests/ops.mlir | 27 + ...ol_optimize.mlir => optimize_maxpool.mlir} | 3 +- spu/compiler/tests/optimize_select.mlir | 30 + spu/core/trace.cc | 4 +- spu/core/trace.h | 12 +- spu/device/BUILD.bazel | 45 +- spu/device/api.cc | 319 +++++ spu/device/{pphlo/type_checker.h => api.h} | 23 +- spu/device/executor.cc | 207 +--- spu/device/executor.h | 97 +- spu/device/frame.cc | 60 - spu/device/frame.h | 60 - spu/device/pphlo/BUILD.bazel | 59 +- spu/device/pphlo/executor.cc | 106 -- spu/device/pphlo/executor.h | 44 - spu/device/pphlo/executor_debug_runner.cc | 2 +- spu/device/pphlo/pphlo_executor.cc | 1063 +++++++++++++++++ spu/device/pphlo/pphlo_executor.h | 40 + ...xecutor_test.cc => pphlo_executor_test.cc} | 121 +- spu/device/pphlo/region_executor.cc | 931 --------------- spu/device/pphlo/region_executor.h | 232 ---- spu/device/pphlo/type_checker.cc | 87 -- spu/device/pphlo/xla_verifier.h | 6 +- spu/device/profiler.h | 65 - spu/device/type_checker.cc | 21 - spu/device/type_checker.h | 31 - spu/dialect/pphlo_ops.cc | 126 +- spu/dialect/pphlo_ops.td | 57 +- spu/kernel/context.h | 10 - spu/kernel/hal/concat.cc | 5 + spu/kernel/hal/shape_ops.cc | 5 + spu/kernel/hlo/reduce.cc | 249 ++-- spu/kernel/hlo/reduce.h | 11 +- spu/kernel/hlo/select_and_scatter.cc | 67 +- spu/kernel/hlo/sort.cc | 215 +++- spu/kernel/hlo/sort.h | 9 +- spu/mpc/aby3/arithmetic.cc | 162 ++- spu/mpc/aby3/ot.cc | 1 + spu/mpc/aby3/protocol.cc | 2 +- 60 files changed, 2508 insertions(+), 2349 deletions(-) create mode 100644 spu/compiler/tests/ops.mlir rename spu/compiler/tests/{maxpool_optimize.mlir => optimize_maxpool.mlir} (85%) create mode 100644 spu/compiler/tests/optimize_select.mlir create mode 100644 spu/device/api.cc rename spu/device/{pphlo/type_checker.h => api.h} (54%) delete mode 100644 spu/device/frame.cc delete mode 100644 spu/device/frame.h delete mode 100644 spu/device/pphlo/executor.cc delete mode 100644 spu/device/pphlo/executor.h create mode 100644 spu/device/pphlo/pphlo_executor.cc create mode 100644 spu/device/pphlo/pphlo_executor.h rename spu/device/pphlo/{executor_test.cc => pphlo_executor_test.cc} (93%) delete mode 100644 spu/device/pphlo/region_executor.cc delete mode 100644 spu/device/pphlo/region_executor.h delete mode 100644 spu/device/pphlo/type_checker.cc delete mode 100644 spu/device/profiler.h delete mode 100644 spu/device/type_checker.cc delete mode 100644 spu/device/type_checker.h diff --git a/CHANGELOG.md b/CHANGELOG.md index f48b1d71..df167613 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,8 +8,10 @@ ## staging > please add your unreleased change here. - - [Feature] ECDH-PSI supports white box interconnection mode +- [Feature] Various performance improvements +- [3p] Build with Tensorflow 2.11.0 +- [bugfix] Fix various crashes ## 20221116 - [SPU] 0.3.0 release diff --git a/examples/cpp/BUILD.bazel b/examples/cpp/BUILD.bazel index 5e87c33c..d90234ff 100644 --- a/examples/cpp/BUILD.bazel +++ b/examples/cpp/BUILD.bazel @@ -49,8 +49,9 @@ spu_cc_binary( srcs = ["simple_pphlo.cc"], deps = [ ":utils", + "//spu/device:api", "//spu/device:io", - "//spu/device/pphlo:executor", + "//spu/device/pphlo:pphlo_executor", "@llvm-project//llvm:Support", ], ) diff --git a/examples/cpp/simple_pphlo.cc b/examples/cpp/simple_pphlo.cc index 0a6b86b0..3263bde8 100644 --- a/examples/cpp/simple_pphlo.cc +++ b/examples/cpp/simple_pphlo.cc @@ -21,36 +21,37 @@ #include "examples/cpp/utils.h" #include "spdlog/spdlog.h" +#include "spu/device/api.h" #include "spu/device/io.h" -#include "spu/device/pphlo/executor.h" +#include "spu/device/pphlo/pphlo_executor.h" // This example demostrates the basic compute functionality of spu vm. -void constant_add(spu::device::Executor* executor) { +void constant_add(spu::HalContext* hctx) { // Write the assembly, this code simple add two numbers. // - `%1` is a constant public integer, with dtype int32 and value 1. // - `%2` is a constant public integer, with dtype int32 and value 2. // - `%3` is the sum of two integers. // - `dbg_print` print the value of `%3` - constexpr auto code = R"PPHlo( -func @main() -> () { + constexpr auto code = R"( +func.func @main() -> () { %0 = "pphlo.constant"() {value = dense<1> : tensor} : () -> tensor> %1 = "pphlo.constant"() {value = dense<2> : tensor} : () -> tensor> %2 = "pphlo.add"(%0, %1) : (tensor>, tensor>) -> tensor> "pphlo.dbg_print"(%2) : (tensor>) -> () return -} -)PPHlo"; +})"; // Run it, with no input and output, (since the program does not contain IO) spu::device::SymbolTable env; - executor->runWithEnv(code, {}, {}, &env); + spu::device::pphlo::PPHloExecutor executor; + spu::device::execute(&executor, hctx, code, {}, {}, &env); } // This example demostrates how to pass parameters. -void parameters(spu::device::Executor* executor) { +void parameters(spu::HalContext* hctx) { // In this example, data owner also participates the computation progress, // which is called "colocated mode" in spu system. - spu::device::ColocatedIo cio(executor->getContext()); + spu::device::ColocatedIo cio(hctx); if (cio.getRank() == 0) { // rank-0, set a float variable 3.14 as 'x' to the device. @@ -72,29 +73,29 @@ void parameters(spu::device::Executor* executor) { // - `%3` is the product of two values, it will do auto type promotion. // - `dbg_print` print the value of `%3` constexpr auto code = R"PPHlo( -func @main(%arg0: tensor>, %arg1: tensor>) -> () { - %0 = "pphlo.multiply"(%arg0, %arg1) : (tensor>, tensor>) -> tensor> - "pphlo.dbg_print"(%0) : (tensor>) -> () +func.func @main(%arg0: tensor>, %arg1: tensor>) -> () { + %0 = "pphlo.multiply"(%arg0, %arg1) : (tensor>, tensor>) -> tensor> + "pphlo.dbg_print"(%0) : (tensor>) -> () return -} - )PPHlo"; +})PPHlo"; // run the assembly, with // - "x" binding to the first parameter (position 0). // - "y" binding to the second parameter (position 1). // - there is no output bindings. - executor->runWithEnv(code, {"x", "y"}, {}, &cio.deviceSymbols()); + spu::device::pphlo::PPHloExecutor executor; + spu::device::execute(&executor, hctx, code, {"x", "y"}, {}, + &cio.deviceSymbols()); } int main(int argc, char** argv) { llvm::cl::ParseCommandLineOptions(argc, argv); auto hctx = MakeHalContext(); - spu::device::pphlo::PPHloExecutor executor(hctx.get()); - parameters(&executor); + parameters(hctx.get()); - constant_add(&executor); + constant_add(hctx.get()); return 0; } diff --git a/examples/python/millionare.py b/examples/python/millionare.py index a36dcc0d..58553988 100644 --- a/examples/python/millionare.py +++ b/examples/python/millionare.py @@ -17,7 +17,7 @@ # > bazel run -c opt //examples/python/utils:nodectl -- up # # Run this example script. -# > bazel run //examples/python:millionare +# > bazel run -c opt //examples/python:millionare import argparse diff --git a/examples/python/ml/flax_mlp.py b/examples/python/ml/flax_mlp.py index 35141859..29065d0d 100644 --- a/examples/python/ml/flax_mlp.py +++ b/examples/python/ml/flax_mlp.py @@ -16,7 +16,7 @@ # > bazel run -c opt //examples/python/utils:nodectl -- up # # Run this example script. -# > bazel run //examples/python/ml:flax_mlp +# > bazel run -c opt //examples/python/ml:flax_mlp import argparse import json diff --git a/examples/python/ml/jax_lr.py b/examples/python/ml/jax_lr.py index d145acdb..f4ba16ce 100644 --- a/examples/python/ml/jax_lr.py +++ b/examples/python/ml/jax_lr.py @@ -16,7 +16,7 @@ # > bazel run -c opt //examples/python/utils:nodectl -- up # # Run this example script. -# > bazel run //examples/python/ml:jax_lr +# > bazel run -c opt //examples/python/ml:jax_lr import argparse diff --git a/examples/python/ml/ss_lr.py b/examples/python/ml/ss_lr.py index 8b72031b..22aca049 100644 --- a/examples/python/ml/ss_lr.py +++ b/examples/python/ml/ss_lr.py @@ -16,7 +16,7 @@ # > bazel run -c opt //examples/python/utils:nodectl -- up # # Run this example script. -# > bazel run //examples/python/ml:ss_lr +# > bazel run -c opt //examples/python/ml:ss_lr import argparse import json diff --git a/examples/python/ml/ss_xgb.py b/examples/python/ml/ss_xgb.py index 196cac70..85776129 100644 --- a/examples/python/ml/ss_xgb.py +++ b/examples/python/ml/ss_xgb.py @@ -16,7 +16,7 @@ # > bazel run -c opt //examples/python/utils:nodectl -- up # # Run this example script. -# > bazel run //examples/python/ml:ss_xgb +# > bazel run -c opt //examples/python/ml:ss_xgb import argparse import json diff --git a/examples/python/ml/stax_mnist_classifier.py b/examples/python/ml/stax_mnist_classifier.py index 9b0c5e5e..0ba38422 100644 --- a/examples/python/ml/stax_mnist_classifier.py +++ b/examples/python/ml/stax_mnist_classifier.py @@ -23,7 +23,7 @@ # > bazel run -c opt //examples/python/utils:nodectl -- up # # Run this example script. -# > bazel run //examples/python/ml:stax_mnist_classifier +# > bazel run -c opt //examples/python/ml:stax_mnist_classifier import time import itertools diff --git a/examples/python/ml/stax_nn.py b/examples/python/ml/stax_nn.py index 6bb3a12b..e8a8e6fa 100644 --- a/examples/python/ml/stax_nn.py +++ b/examples/python/ml/stax_nn.py @@ -50,12 +50,12 @@ import argparse parser = argparse.ArgumentParser(description='distributed driver.') -parser.add_argument("--model", default='network_a') -parser.add_argument("-c", "--config", default="examples/python/conf/3pc.json") -parser.add_argument("-l", "--learning_rate", default=0.01) -parser.add_argument("-e", "--epoch", default=5) -parser.add_argument("-b", "--batch_size", default=128) -parser.add_argument("-o", "--optimizer", default="SGD") +parser.add_argument("--model", default='network_a', type=str) +parser.add_argument("-c", "--config", default="examples/python/conf/3pc.json", type=str) +parser.add_argument("-l", "--learning_rate", default=0.01, type=float) +parser.add_argument("-e", "--epoch", default=5, type=int) +parser.add_argument("-b", "--batch_size", default=128, type=int) +parser.add_argument("-o", "--optimizer", default="SGD", type=str) args = parser.parse_args() # Follows https://arxiv.org/pdf/2107.00501.pdf Appendix C. diff --git a/examples/python/ml/tf_experiment.py b/examples/python/ml/tf_experiment.py index 2f4ea817..3fc73f58 100644 --- a/examples/python/ml/tf_experiment.py +++ b/examples/python/ml/tf_experiment.py @@ -29,7 +29,7 @@ # > bazel run -c opt //examples/python/utils:nodectl -- up # # Run this example script. -# > bazel run //examples/python/ml:tf_experiment +# > bazel run -c opt //examples/python/ml:tf_experiment # This example is tf counterpart to //examples/python/ml:jax_lr diff --git a/examples/python/stats/corr.py b/examples/python/stats/corr.py index 564f10de..8d0c997f 100644 --- a/examples/python/stats/corr.py +++ b/examples/python/stats/corr.py @@ -16,7 +16,7 @@ # > bazel run -c opt //examples/python/utils:nodectl -- up # # Run this example script. -# > bazel run //examples/python/stats:corr +# > bazel run -c opt //examples/python/stats:corr import argparse import json diff --git a/examples/python/stats/pvalue.py b/examples/python/stats/pvalue.py index ac649900..2b731f43 100644 --- a/examples/python/stats/pvalue.py +++ b/examples/python/stats/pvalue.py @@ -16,7 +16,7 @@ # > bazel run -c opt //examples/python/utils:nodectl -- up # # Run this example script. -# > bazel run //examples/python/stats:pvalue +# > bazel run -c opt //examples/python/stats:pvalue import argparse diff --git a/examples/python/stats/woe.py b/examples/python/stats/woe.py index d9a71936..bb9dba6a 100644 --- a/examples/python/stats/woe.py +++ b/examples/python/stats/woe.py @@ -16,7 +16,7 @@ # > bazel run -c opt //examples/python/utils:nodectl -- up # # Run this example script. -# > bazel run //examples/python/stats:woe +# > bazel run -c opt //examples/python/stats:woe import argparse diff --git a/spu/BUILD.bazel b/spu/BUILD.bazel index f9b96706..6844405d 100644 --- a/spu/BUILD.bazel +++ b/spu/BUILD.bazel @@ -20,7 +20,7 @@ load("//bazel:spu.bzl", "spu_version_file") package(default_visibility = ["//visibility:public"]) -SPU_VERSION = "0.3.1b0" +SPU_VERSION = "0.3.1b2" proto_library( name = "spu_proto", diff --git a/spu/binding/BUILD.bazel b/spu/binding/BUILD.bazel index 8b192fca..e3d3bac5 100644 --- a/spu/binding/BUILD.bazel +++ b/spu/binding/BUILD.bazel @@ -41,8 +41,9 @@ pybind_extension( ":version_script.lds", "//spu/compiler:compile", "//spu/compiler/common:compilation_context", + "//spu/device:api", "//spu/device:io", - "//spu/device/pphlo:executor", + "//spu/device/pphlo:pphlo_executor", "//spu/psi:bucket_psi", "//spu/psi:memory_psi", "@yacl//yacl/link", diff --git a/spu/binding/_lib.cc b/spu/binding/_lib.cc index 31f2964c..4bace790 100644 --- a/spu/binding/_lib.cc +++ b/spu/binding/_lib.cc @@ -22,8 +22,9 @@ #include "spu/compiler/common/compilation_context.h" #include "spu/compiler/compile.h" #include "spu/core/type_util.h" +#include "spu/device/api.h" #include "spu/device/io.h" -#include "spu/device/pphlo/executor.h" +#include "spu/device/pphlo/pphlo_executor.h" #include "spu/kernel/context.h" #include "spu/kernel/value.h" #include "spu/psi/bucket_psi.h" @@ -205,8 +206,7 @@ void BindLink(py::module& m) { }); } -// Wrap Processor, it's workaround for protobuf pybind11/protoc conflict. - +// Wrap Runtime, it's workaround for protobuf pybind11/protoc conflict. class RuntimeWrapper { std::unique_ptr hctx_; @@ -226,8 +226,8 @@ class RuntimeWrapper { spu::ExecutableProto exec; YACL_ENFORCE(exec.ParseFromString(exec_pb)); - spu::device::pphlo::PPHloExecutor executor(hctx_.get()); - executor.runWithEnv(exec, &env_); + spu::device::pphlo::PPHloExecutor executor; + spu::device::execute(&executor, hctx_.get(), exec, &env_); } void SetVar(const std::string& name, const py::bytes& value) { diff --git a/spu/binding/util/frontend_test.py b/spu/binding/util/frontend_test.py index 0153d14a..f2e78946 100644 --- a/spu/binding/util/frontend_test.py +++ b/spu/binding/util/frontend_test.py @@ -37,14 +37,12 @@ def test_jax_compile(self): self.assertEqual(executable.name, "add") self.assertEqual(executable.input_names, ["in1", "in2"]) self.assertEqual(executable.output_names, ["test-out0"]) - self.assertMultiLineEqual( - executable.code.decode(), - "module @xla_computation_add {\n" + self.assertTrue( " func.func @main(%arg0: tensor<2x!pphlo.pub>," " %arg1: tensor<2x!pphlo.pub>) -> tensor<2x!pphlo.pub> {\n" " %0 = \"pphlo.add\"(%arg0, %arg1) : (tensor<2x!pphlo.pub>," " tensor<2x!pphlo.pub>) -> tensor<2x!pphlo.pub>\n" - " return %0 : tensor<2x!pphlo.pub>\n }\n}\n", + " return %0 : tensor<2x!pphlo.pub>\n }" in executable.code.decode() ) self.assertEqual(output.shape, (2,)) self.assertEqual(output.dtype, np.dtype("int32")) @@ -62,14 +60,12 @@ def test_tf_compile(self): self.assertEqual(executable.name, "add") self.assertEqual(executable.input_names, ["in1", "in2"]) self.assertEqual(executable.output_names, ["test-out0"]) - self.assertMultiLineEqual( - executable.code.decode(), - "module @a_inference_add_9__.9 {\n" + self.assertTrue( " func.func @main(%arg0: tensor<2x!pphlo.pub>," " %arg1: tensor<2x!pphlo.pub>) -> tensor<2x!pphlo.pub> {\n" " %0 = \"pphlo.add\"(%arg0, %arg1) : (tensor<2x!pphlo.pub>," " tensor<2x!pphlo.pub>) -> tensor<2x!pphlo.pub>\n" - " return %0 : tensor<2x!pphlo.pub>\n }\n}\n", + " return %0 : tensor<2x!pphlo.pub>\n }" in executable.code.decode() ) self.assertEqual(output.shape, (2,)) self.assertEqual(output.dtype, np.dtype("int64")) diff --git a/spu/compiler/core/core.cc b/spu/compiler/core/core.cc index e29f3f91..357f0c5c 100644 --- a/spu/compiler/core/core.cc +++ b/spu/compiler/core/core.cc @@ -47,6 +47,9 @@ void Core::buildPipeline(mlir::PassManager *pm) { optPM.addPass(mlir::pphlo::createOptimizeMaxPoolingPass()); optPM.addPass(mlir::pphlo::createDecomposeComparisonPass()); optPM.addPass(mlir::pphlo::createDecomposeMinMaxPass()); + + optPM.addPass(mlir::createCSEPass()); + optPM.addPass(mlir::pphlo::createReduceTruncationPass()); optPM.addPass(mlir::pphlo::createLowerMixedTypeOpPass()); diff --git a/spu/compiler/passes/optimize_maxpool.cc b/spu/compiler/passes/optimize_maxpool.cc index 4a58e3bd..db0b3c7c 100644 --- a/spu/compiler/passes/optimize_maxpool.cc +++ b/spu/compiler/passes/optimize_maxpool.cc @@ -34,105 +34,39 @@ namespace { struct SelectAndScatterConverter : public OpRewritePattern { private: TypeTools tools_; - // Rewrite reduce body from a unary max reduce to a binary GE reduce, - // Which returns both max value and onehot for max location - void rewriteReduceBody(Region &r, PatternRewriter &rewriter) const { - auto comp = mlir::dyn_cast(r.front().front()); - YACL_ENFORCE(comp); - - auto builder = OpBuilder::atBlockBegin(&r.front()); - - auto comp_ret = comp->getResultTypes()[0]; - auto comp_vis = tools_.getTypeVisibility(comp_ret); - auto index_ret_t = tools_.getTypeWithVisibility( - RankedTensorType::get({}, rewriter.getI8Type()), comp_vis); - auto ge_ret_t = tools_.getTypeWithVisibility( - RankedTensorType::get({}, rewriter.getI1Type()), comp_vis); - auto ge = builder.create(comp->getLoc(), ge_ret_t, - comp.lhs(), comp.rhs()); - - auto select1 = builder.create( - comp->getLoc(), TypeRange{comp_ret, index_ret_t}, ge, - ValueRange{comp.lhs(), r.getArgument(1)}, - ValueRange{comp.rhs(), r.getArgument(3)}); - - auto *operation = r.front().getTerminator(); - rewriter.updateRootInPlace( - operation, [&]() { operation->setOperands(select1->getResults()); }); - } Value rewriteReduceWindow(ReduceWindowOp op, PatternRewriter &rewriter) const { - auto pub_mask_type = tools_.getTypeWithVisibility(rewriter.getI8Type(), - Visibility::VIS_PUBLIC); - - std::vector window_shape( - op.window_dimensions().getValues().begin(), - op.window_dimensions().getValues().end()); - auto window_size = std::accumulate(window_shape.begin(), window_shape.end(), - 1, std::multiplies()); - - OpBuilder builder(op); - builder.setInsertionPointAfter(op.getOperation()); - - auto init_value = builder.create( - op->getLoc(), - DenseElementsAttr::get(RankedTensorType::get({}, rewriter.getI8Type()), - rewriter.getI8IntegerAttr(-1))); - - // Build a window mask as eye(n), where n = window size - std::vector mask(window_size * window_size, - rewriter.getI8IntegerAttr(0)); - for (int64_t idx = 0; idx < window_size; ++idx) { - mask[idx * window_size + idx] = rewriter.getI8IntegerAttr(1); - } - auto mask_const = builder.create( - op->getLoc(), - DenseElementsAttr::get(RankedTensorType::get({window_size, window_size}, - rewriter.getI8Type()), - mask)); - - // Rewrite reduce window from max to argmax - llvm::SmallVector operands; - operands.emplace_back(op.inputs()[0]); - operands.emplace_back(mask_const); - operands.emplace_back(op.init_values()[0]); - operands.emplace_back(init_value); + auto window_size = + std::accumulate(op.window_dimensions().getValues().begin(), + op.window_dimensions().getValues().end(), 1, + std::multiplies()); auto current_ret_type = - op->getResultTypes()[0].dyn_cast(); - auto current_ret_vis = tools_.getTypeVisibility(current_ret_type); + op.getResult(0).getType().dyn_cast(); std::vector index_result_shape = current_ret_type.getShape(); index_result_shape.emplace_back(window_size); + auto current_ret_vis = tools_.getTypeVisibility(current_ret_type); + auto index_result_type = RankedTensorType::get( index_result_shape, - tools_.getTypeWithVisibility(rewriter.getI8Type(), current_ret_vis)); + tools_.getTypeWithVisibility(rewriter.getI1Type(), current_ret_vis)); - auto new_reduce_window = builder.create( + OpBuilder builder(op); + builder.setInsertionPoint(op.getOperation()); + auto argmax = builder.create( op->getLoc(), SmallVector{current_ret_type, index_result_type}, - operands, op->getAttrs()); - - new_reduce_window.last_operand_is_window_maskAttr( - BoolAttr::get(op->getContext(), true)); - new_reduce_window.ignore_init_valueAttr( - BoolAttr::get(op->getContext(), true)); - - rewriter.inlineRegionBefore(op.body(), new_reduce_window.body(), - new_reduce_window.body().end()); - - new_reduce_window.body().insertArgument( - 1, RankedTensorType::get({}, pub_mask_type), op->getLoc()); - new_reduce_window.body().addArgument( - RankedTensorType::get({}, pub_mask_type), op->getLoc()); - - op->getResult(0).replaceAllUsesWith(new_reduce_window->getResult(0)); + op.inputs()[0], op.window_dimensions(), + op.window_strides().value_or(nullptr), + op.base_dilations().value_or(nullptr), + op.window_dilations().value_or(nullptr), + op.padding().value_or(nullptr)); - // Rewrite body - rewriteReduceBody(new_reduce_window.body(), rewriter); + op->getResult(0).replaceAllUsesWith(argmax->getResult(0)); - return new_reduce_window->getResults()[1]; + return argmax->getResult(1); } bool isSingleRegion(Region &r) const { diff --git a/spu/compiler/passes/optimize_select.cc b/spu/compiler/passes/optimize_select.cc index 03f0a9ce..3b3b1bb6 100644 --- a/spu/compiler/passes/optimize_select.cc +++ b/spu/compiler/passes/optimize_select.cc @@ -40,16 +40,41 @@ struct SelectConversion : public OpRewritePattern { LogicalResult matchAndRewrite(SelectOp op, PatternRewriter &rewriter) const override { + auto pred = op.pred(); // Only do this for certain select... - if (op.pred().getDefiningOp() != nullptr) { + if (pred.getDefiningOp() != nullptr) { // This select pred has already been optimized, bailout here return failure(); } - auto pref_a = rewriter.create(op->getLoc(), op.pred().getType(), - op.pred()); - rewriter.replaceOpWithNewOp(op, op->getResultTypes(), pref_a, - op.on_true(), op.on_false()); + // If this pred has only one use...do not rewrite, with mula1b is faster + if (pred.hasOneUse()) { + return failure(); + } + + auto number_of_selects = 0; + for (auto &use : pred.getUses()) { + if (mlir::isa(use.getOwner())) { + ++number_of_selects; + } + } + + // Although this value is used by multiple operations, there is still a + // single select + if (number_of_selects == 1) { + return failure(); + } + + OpBuilder builder(op); + builder.setInsertionPoint(pred.getDefiningOp()->getNextNode()); + auto pref_a = builder.create(pred.getDefiningOp()->getLoc(), + pred.getType(), pred); + + // Only replace select usage + pred.replaceUsesWithIf(pref_a, [](OpOperand &use) { + return mlir::isa(use.getOwner()); + }); + return success(); } }; diff --git a/spu/compiler/tests/ops.mlir b/spu/compiler/tests/ops.mlir new file mode 100644 index 00000000..f9724aab --- /dev/null +++ b/spu/compiler/tests/ops.mlir @@ -0,0 +1,27 @@ +// RUN: mlir-pphlo-opt %s -verify-diagnostics -split-input-file + +// ----- + +func.func @invalid_concate_dim() -> tensor> { + %0 = "pphlo.constant"() {value = dense<1.3347515E+38> : tensor} : () -> tensor> + // expected-error @+1 {{rank-0 values cannot be concatenated}} + %1 = "pphlo.concatenate"(%0) {dimension = 27755 : i64} : (tensor>) -> tensor> + %2 = "pphlo.constant"() {value = dense<5> : tensor} : () -> tensor> + "pphlo.return"(%2) : (tensor>) -> () +} + +// ----- + +// ----- + +func.func @invalid_broadcast_dim() -> tensor> { + %2 = "pphlo.constant"() {value = dense<[0x41DA6E5887800000, 0x41C94E3940000000, 0x41C4BD2007000000, 0x41DC95133AC00000, 0x41D1650CEC000000, 0x41C9DF42E7800000, 0x41D46C43B6800000, 0x41C467EE0E800000, 0x41DC705F14400000]> : tensor<9xf64>} : () -> tensor<9x!pphlo.pub> + %3 = "pphlo.floor"(%2) : (tensor<9x!pphlo.pub>) -> tensor<9x!pphlo.pub> + %9 = "pphlo.concatenate"(%3) {dimension = 0 : i64} : (tensor<9x!pphlo.pub>) -> tensor<9x!pphlo.pub> + // expected-error @+1 {{broadcast_dimensions contains invalid value 13 for result with rank 1}} + %10 = "pphlo.broadcast"(%9) {broadcast_dimensions = dense<13> : tensor<1xi64>} : (tensor<9x!pphlo.pub>) -> tensor<9x!pphlo.pub> + %51 = "pphlo.constant"() {value = dense<5> : tensor} : () -> tensor> + "pphlo.return"(%51) : (tensor>) -> () +} + +// ----- diff --git a/spu/compiler/tests/maxpool_optimize.mlir b/spu/compiler/tests/optimize_maxpool.mlir similarity index 85% rename from spu/compiler/tests/maxpool_optimize.mlir rename to spu/compiler/tests/optimize_maxpool.mlir index 01bc7c16..597642af 100644 --- a/spu/compiler/tests/maxpool_optimize.mlir +++ b/spu/compiler/tests/optimize_maxpool.mlir @@ -5,10 +5,9 @@ func.func @main(%arg0: tensor<129x24x24x16x!pphlo.sec>, %arg1: tensor<129x2 %1 = "pphlo.constant"() {value = dense<0.000000e+00> : tensor} : () -> tensor> %2 = "pphlo.convert"(%0) : (tensor>) -> tensor> %3 = "pphlo.convert"(%1) : (tensor>) -> tensor> + //CHECK: "pphlo.argmax"(%arg0) {base_dilations = dense<1> : tensor<4xi64>, onehot_index = true, padding = dense<0> : tensor<4x2xi64>, window_dilations = dense<1> : tensor<4xi64>, window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, window_strides = dense<1> : tensor<4xi64>} : (tensor<129x24x24x16x!pphlo.sec>) -> (tensor<129x23x23x16x!pphlo.sec>, tensor<129x23x23x16x4x!pphlo.sec>) %4 = "pphlo.reduce_window"(%arg0, %2) ({ ^bb0(%arg2: tensor>, %arg3: tensor>): - //CHECK: %[[GE:.+]] = "pphlo.greater_equal"(%arg2, %arg4) : (tensor>, tensor>) -> tensor> - //CHECK-NEXT: "pphlo.select"(%[[GE:.+]], %arg2, %arg3, %arg4, %arg5) : (tensor>, tensor>, tensor>, tensor>, tensor>) -> (tensor>, tensor>) %6 = "pphlo.maximum"(%arg2, %arg3) : (tensor>, tensor>) -> tensor> "pphlo.return"(%6) : (tensor>) -> () }) {base_dilations = dense<1> : tensor<4xi64>, padding = dense<0> : tensor<4x2xi64>, window_dilations = dense<1> : tensor<4xi64>, window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, window_strides = dense<1> : tensor<4xi64>} : (tensor<129x24x24x16x!pphlo.sec>, tensor>) -> tensor<129x23x23x16x!pphlo.sec> diff --git a/spu/compiler/tests/optimize_select.mlir b/spu/compiler/tests/optimize_select.mlir new file mode 100644 index 00000000..d06139bc --- /dev/null +++ b/spu/compiler/tests/optimize_select.mlir @@ -0,0 +1,30 @@ +// RUN: mlir-pphlo-opt --optimize-select --split-input-file %s | FileCheck %s + +func.func @single_select() -> (tensor>) { + %0 = "pphlo.constant"() {value = dense<0xFF800000> : tensor} : () -> tensor> + %1 = "pphlo.constant"() {value = dense<0.000000e+00> : tensor} : () -> tensor> + %2 = "pphlo.less"(%0, %1): (tensor>, tensor>) -> tensor> + //CHECK-NOT: pphlo.prefer_a + %3 = "pphlo.select"(%2, %0, %1): (tensor>, tensor>, tensor>) -> tensor> + return %3: tensor> +} + +func.func @multi_selects() -> (tensor>, tensor>) { + %0 = "pphlo.constant"() {value = dense<0xFF800000> : tensor} : () -> tensor> + %1 = "pphlo.constant"() {value = dense<0.000000e+00> : tensor} : () -> tensor> + %2 = "pphlo.less"(%0, %1): (tensor>, tensor>) -> tensor> + //CHECK: pphlo.prefer_a + %3 = "pphlo.select"(%2, %0, %1): (tensor>, tensor>, tensor>) -> tensor> + %4 = "pphlo.select"(%2, %1, %0): (tensor>, tensor>, tensor>) -> tensor> + return %3, %4: tensor>, tensor> +} + +func.func @multi_uses_single_select() -> (tensor>, tensor>) { + %0 = "pphlo.constant"() {value = dense<0xFF800000> : tensor} : () -> tensor> + %1 = "pphlo.constant"() {value = dense<0.000000e+00> : tensor} : () -> tensor> + %2 = "pphlo.less"(%0, %1): (tensor>, tensor>) -> tensor> + //CHECK-NOT: pphlo.prefer_a + %3 = "pphlo.select"(%2, %0, %1): (tensor>, tensor>, tensor>) -> tensor> + %4 = "pphlo.not"(%2): (tensor>) -> tensor> + return %3, %4: tensor>, tensor> +} diff --git a/spu/core/trace.cc b/spu/core/trace.cc index 1f7fd792..e7147ed0 100644 --- a/spu/core/trace.cc +++ b/spu/core/trace.cc @@ -148,7 +148,7 @@ std::shared_ptr getDefaultLogger() { } // namespace -void Tracer::logActionBegin(int64_t id, int64_t flag, std::string_view name, +void Tracer::logActionBegin(int64_t id, int64_t flag, const std::string& name, const std::string& detail) { if ((flag & mask_ & TR_MODALL) == 0 || (mask_ & TR_LOGB) == 0) { // module is disabled or logging is disabled, ignore. @@ -162,7 +162,7 @@ void Tracer::logActionBegin(int64_t id, int64_t flag, std::string_view name, } } -void Tracer::logActionEnd(int64_t id, int64_t flag, std::string_view name, +void Tracer::logActionEnd(int64_t id, int64_t flag, const std::string& name, const std::string& detail) { if ((flag & mask_ & TR_MODALL) == 0 || (mask_ & TR_LOGE) == 0) { // module is disabled or logging is disabled, ignore. diff --git a/spu/core/trace.h b/spu/core/trace.h index 860bb62c..0b4ae682 100644 --- a/spu/core/trace.h +++ b/spu/core/trace.h @@ -117,7 +117,7 @@ struct ActionRecord { // the uuid of this action. int64_t id; // name of the action, the name should be static allocated. - std::string_view name; + std::string name; // detail of the action std::string detail; // the flag of this action. @@ -155,9 +155,9 @@ class Tracer final { // @flag, various attributes of the action. // @name, name of the action. // @detail, detail of the action. - void logActionBegin(int64_t id, int64_t flag, std::string_view name, + void logActionBegin(int64_t id, int64_t flag, const std::string& name, const std::string& detail = ""); - void logActionEnd(int64_t id, int64_t flag, std::string_view name, + void logActionEnd(int64_t id, int64_t flag, const std::string& name, const std::string& detail = ""); void addRecord(ActionRecord&& rec) { @@ -178,7 +178,7 @@ class TraceAction final { int64_t id_; // name of the action. - std::string_view name_; + std::string name_; // detail of the action. std::string detail_; @@ -190,7 +190,7 @@ class TraceAction final { int64_t saved_tracer_mask_; template - void begin(std::string_view name, Args&&... args) { + void begin(const std::string& name, Args&&... args) { name_ = name; start_ = std::chrono::high_resolution_clock::now(); @@ -238,7 +238,7 @@ class TraceAction final { // mask = ~TR_MOD2, means disable further TR_MOD2 tracing. template explicit TraceAction(std::shared_ptr tracer, int64_t flag, - int64_t mask, std::string_view name, Args&&... args) + int64_t mask, const std::string& name, Args&&... args) : tracer_(std::move(tracer)), flag_(flag), mask_(mask) { id_ = internal::genActionUuid(); begin(name, std::forward(args)...); diff --git a/spu/device/BUILD.bazel b/spu/device/BUILD.bazel index 7b219946..ce6fa32d 100644 --- a/spu/device/BUILD.bazel +++ b/spu/device/BUILD.bazel @@ -21,7 +21,7 @@ package(default_visibility = ["//visibility:public"]) spu_cc_library( name = "device", deps = [ - ":frame", + ":api", ":io", ], ) @@ -65,29 +65,29 @@ spu_cc_test( ) spu_cc_library( - name = "frame", - srcs = ["frame.cc"], - hdrs = ["frame.h"], + name = "executor", + srcs = ["executor.cc"], + hdrs = ["executor.h"], deps = [ - ":type_checker", + ":symbol_table", + "//spu:spu_cc_proto", "//spu/dialect:pphlo_dialect", + "//spu/kernel:context", "//spu/kernel:value", "@llvm-project//mlir:IR", - "@yacl//yacl/base:exception", ], ) spu_cc_library( - name = "executor", - srcs = ["executor.cc"], - hdrs = ["executor.h"], + name = "api", + srcs = ["api.cc"], + hdrs = ["api.h"], deps = [ - ":profiler", - ":symbol_table", - "//spu:spu_cc_proto", - "//spu/kernel:context", - "//spu/kernel:value", - "//spu/kernel/hal:polymorphic", + ":executor", + "//spu/device/pphlo:pphlo_executor", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Parser", ], ) @@ -114,18 +114,3 @@ spu_cc_library( "//spu/kernel:value", ], ) - -spu_cc_library( - name = "type_checker", - srcs = ["type_checker.cc"], - hdrs = ["type_checker.h"], - deps = [ - "//spu/kernel:value", - "@llvm-project//mlir:IR", - ], -) - -spu_cc_library( - name = "profiler", - hdrs = ["profiler.h"], -) diff --git a/spu/device/api.cc b/spu/device/api.cc new file mode 100644 index 00000000..994f41f0 --- /dev/null +++ b/spu/device/api.cc @@ -0,0 +1,319 @@ +// Copyright 2021 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "spu/device/api.h" + +#include +#include +#include +#include + +#include "llvm/Support/ErrorHandling.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Parser/Parser.h" +#include "spdlog/spdlog.h" + +#include "spu/device/pphlo/pphlo_executor.h" +#include "spu/dialect/pphlo_dialect.h" + +namespace spu::device { +namespace { + +class TimeitGuard { + TimePoint start_; + Duration &duration_; + +public: + explicit TimeitGuard(Duration &dur) : duration_(dur) { + start_ = std::chrono::high_resolution_clock::now(); + } + + ~TimeitGuard() { + duration_ = std::chrono::high_resolution_clock::now() - start_; + } +}; + +double getSeconds(const Duration &dur) { + return std::chrono::duration_cast>(dur).count(); +} + +[[maybe_unused]] double getSeconds(const TimePoint &start, + const TimePoint &end) { + return std::chrono::duration_cast>(end - start) + .count(); +} + +struct ExecutionStats { + Duration total_time() const { + return infeed_time + execution_time + outfeed_time; + } + Duration infeed_time; + Duration execution_time; + Duration outfeed_time; +}; + +struct CommunicationStats { + size_t send_bytes = 0; + size_t send_actions = 0; + + void reset(const std::shared_ptr &lctx) { + if (!lctx) { + return; + } + send_actions = lctx->GetStats()->sent_actions; + send_bytes = lctx->GetStats()->sent_bytes; + } + + void diff(const std::shared_ptr &lctx) { + if (!lctx) { + return; + } + send_bytes = lctx->GetStats()->sent_bytes - send_bytes; + send_actions = lctx->GetStats()->sent_actions - send_actions; + } +}; + +struct ActionKey { + std::string_view name; + int64_t flag; + bool operator<(const ActionKey &other) const { + return std::tie(name, flag) < std::tie(other.name, other.flag); + } +}; + +struct ActionStats { + // number of actions executed. + size_t count = 0; + // total duration time. + Duration total_time = {}; + + inline double getTotalTimeInSecond() const { + return std::chrono::duration_cast>(total_time) + .count(); + } +}; + +void dumpExecutableToFolder(const ExecutableProto &executable, size_t rank, + absl::Span inputs, + const std::string &dump_dir) { + // Naming convention for dumped files must align with debug runner. + std::filesystem::path dump_folder(dump_dir); + dump_folder /= executable.name(); + + std::filesystem::create_directories(dump_folder); + + // dump executable. + if (rank == 0) { + auto fname = dump_folder / std::string("executable.txt"); + SPDLOG_INFO("Dump executable to {}", fname); + std::ofstream ir_file(fname, std::ios::binary | std::ios::out); + ir_file << executable.SerializeAsString(); + } + + // dump all inputs. + { + size_t var_counter = 0; + for (const auto &val : inputs) { + auto fname = + dump_folder / fmt::format("data_{}_{}.txt", rank, var_counter++); + SPDLOG_INFO("Dump data to {}", fname); + std::ofstream inputs_file(fname, std::ios::binary | std::ios::out); + inputs_file << val.toProto().SerializeAsString(); + } + } +} + +void printProfilingData(const std::string &name, + const ExecutionStats &exec_stats, + const CommunicationStats &comm_stats) { + // print overall information + SPDLOG_INFO( + "[Profiling] SPU execution {} completed, input processing took {}s, " + "execution took {}s, output processing took {}s, total time {}s.", + name, getSeconds(exec_stats.infeed_time), + getSeconds(exec_stats.execution_time), + getSeconds(exec_stats.outfeed_time), getSeconds(exec_stats.total_time())); + + // print action trace information + { + std::map stats; + + const auto &tracer = getTracer(GET_CTX_NAME(hctx_)); + const auto &records = tracer->getRecords(); + + for (const auto &rec : records) { + auto &stat = stats[{rec.name, rec.flag}]; + stat.count++; + stat.total_time += + std::chrono::duration_cast(rec.end - rec.start); + } + + static std::map kModules = { + {TR_HLO, "HLO"}, {TR_HAL, "HAL"}, {TR_MPC, "MPC"}}; + + for (const auto &[mod_flag, mod_name] : kModules) { + double total_time = 0.0; + for (const auto &[key, stat] : stats) { + if ((key.flag & mod_flag) != 0) { + total_time += stat.getTotalTimeInSecond(); + } + } + SPDLOG_INFO("{} profiling: total time {}", mod_name, total_time); + for (const auto &[key, stat] : stats) { + if ((key.flag & mod_flag) != 0) { + SPDLOG_INFO("- {}, executed {} times, duration {}s", key.name, + stat.count, stat.getTotalTimeInSecond()); + } + } + } + } + + // print link statistics + SPDLOG_INFO("Link details: total send bytes {}, send actions {}", + comm_stats.send_bytes, comm_stats.send_actions); +} + +void setupTrace(const spu::RuntimeConfig &rt_config) { + int64_t tr_mask = 0; + if (rt_config.enable_action_trace()) { + tr_mask |= TR_LOG; + } + + if (rt_config.enable_pphlo_profile()) { + tr_mask |= TR_HLO; + tr_mask |= TR_REC; + } + + if (rt_config.enable_hal_profile()) { + tr_mask |= TR_HAL | TR_MPC; + tr_mask |= TR_REC; + } + + getTracer(GET_CTX_NAME(ctx))->setMask(tr_mask); + getTracer(GET_CTX_NAME(ctx))->clearRecords(); +} + +void SPUErrorHandler(void *use_data, const char *reason, bool gen_crash_diag) { + (void)use_data; + (void)gen_crash_diag; + YACL_THROW(reason); +} + +std::mutex ErrorHandlerMutex; +void installLLVMErrorHandler() { + std::lock_guard guard(ErrorHandlerMutex); + llvm::remove_fatal_error_handler(); + llvm::install_fatal_error_handler(SPUErrorHandler); +} + +[[maybe_unused]] void removeLLVMErrorHandler() { + std::lock_guard guard(ErrorHandlerMutex); + llvm::remove_fatal_error_handler(); +} + +} // namespace + +void executeImpl(OpExecutor *executor, spu::HalContext *hctx, + const ExecutableProto &executable, SymbolTable *env) { + setupTrace(hctx->rt_config()); + installLLVMErrorHandler(); + + CommunicationStats comm_stats; + comm_stats.reset(hctx->lctx()); + ExecutionStats exec_stats; + + // prepare inputs from environment. + std::vector inputs; + { + TimeitGuard timeit(exec_stats.infeed_time); + inputs.reserve(executable.input_names_size()); + for (int32_t idx = 0; idx < executable.input_names_size(); idx++) { + inputs.emplace_back(env->getVar(executable.input_names(idx))); + } + } + + // TODO: rename this flag, enable_executable_dump? + const RuntimeConfig rt_config = hctx->rt_config(); + if (rt_config.enable_processor_dump()) { + const bool isRefHal = hctx->lctx() == nullptr; + const size_t rank = isRefHal ? 0 : hctx->lctx()->Rank(); + dumpExecutableToFolder(executable, rank, inputs, + rt_config.processor_dump_dir()); + } + + // execution + std::vector outputs; + { + TimeitGuard timeit(exec_stats.execution_time); + + mlir::MLIRContext mlir_ctx; + mlir_ctx.loadDialect(); + + auto &engine = mlir_ctx.getDiagEngine(); + engine.registerHandler( + [&](mlir::Diagnostic &diag) { SPDLOG_ERROR(diag.str()); }); + + mlir_ctx.loadDialect(); + auto moduleOpRef = + mlir::parseSourceString(executable.code(), &mlir_ctx); + + YACL_ENFORCE(moduleOpRef, "MLIR parser failure"); + + SPDLOG_INFO("Executing module {}", + moduleOpRef->getName().value_or("Unnamed")); + + auto entry_function = moduleOpRef->lookupSymbol("main"); + YACL_ENFORCE(entry_function, "main module not found"); + + ExecutionOptions opts; + opts.do_type_check = rt_config.enable_type_checker(); + opts.do_log_execution = rt_config.enable_pphlo_trace(); + outputs = runRegion(executor, hctx, nullptr, entry_function.getBody(), + inputs, opts); + } + + // sync output to environment. + { + TimeitGuard timeit(exec_stats.outfeed_time); + for (int32_t idx = 0; idx < executable.output_names_size(); idx++) { + env->setVar(executable.output_names(idx), outputs[idx]); + } + } + + comm_stats.diff(hctx->lctx()); + if ((getTracer(GET_CTX_NAME(hctx))->getMask() & TR_REC) != 0) { + printProfilingData(executable.name(), exec_stats, comm_stats); + } +} + +void execute(OpExecutor *executor, spu::HalContext *hctx, + const spu::ExecutableProto &executable, SymbolTable *env) { + return executeImpl(executor, hctx, executable, env); +} + +void execute(OpExecutor *executor, spu::HalContext *hctx, + const std::string &text, + const std::vector &input_names, + const std::vector &output_names, SymbolTable *env) { + ExecutableProto executable; + executable.set_name("unnamed"); + *executable.mutable_input_names() = {input_names.begin(), input_names.end()}; + *executable.mutable_output_names() = {output_names.begin(), + output_names.end()}; + executable.set_code(text); + + return executeImpl(executor, hctx, executable, env); +} + +} // namespace spu::device diff --git a/spu/device/pphlo/type_checker.h b/spu/device/api.h similarity index 54% rename from spu/device/pphlo/type_checker.h rename to spu/device/api.h index 89b5435e..e952cd63 100644 --- a/spu/device/pphlo/type_checker.h +++ b/spu/device/api.h @@ -14,14 +14,25 @@ #pragma once -#include "spu/device/type_checker.h" +#include +#include + +#include "spu/device/executor.h" +#include "spu/device/symbol_table.h" +#include "spu/kernel/context.h" +#include "spu/kernel/value.h" + +#include "spu/spu.pb.h" namespace spu::device { -class PPHloTypeChecker : public TypeChecker { -public: - ~PPHloTypeChecker() override = default; - void check(::mlir::Type type, const spu::Value &v) const override; -}; +void execute(OpExecutor *executor, HalContext *hctx, + const ExecutableProto &executable, SymbolTable *env); + +/// +void execute(OpExecutor *executor, spu::HalContext *hctx, + const std::string &text, + const std::vector &input_names, + const std::vector &output_names, SymbolTable *env); } // namespace spu::device diff --git a/spu/device/executor.cc b/spu/device/executor.cc index 2ef08534..e958d3fe 100644 --- a/spu/device/executor.cc +++ b/spu/device/executor.cc @@ -14,170 +14,91 @@ #include "spu/device/executor.h" -#include -#include -#include -#include - -#include "spdlog/spdlog.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Value.h" +#include "yacl/base/exception.h" #include "spu/kernel/context.h" +#include "spu/kernel/value.h" namespace spu::device { -void Executor::runWithEnv(const ExecutableProto &exec, SymbolTable *env) { - // setup global states. - const RuntimeConfig rt_config = hctx_->rt_config(); - - // - const bool isRefHal = hctx_->lctx() == nullptr; - const size_t rank = isRefHal ? 0 : hctx_->lctx()->Rank(); - - module_name_ = exec.name(); +const spu::Value &SymbolScope::lookupValue(mlir::Value key) const { + auto itr = symbols_.find(key); - Timer timer; - Timer stage_timer; - - // prepare inputs from environment. - std::vector inputs; - inputs.reserve(exec.input_names_size()); - for (int32_t idx = 0; idx < exec.input_names_size(); idx++) { - const std::string &sym_name = exec.input_names(idx); - inputs.emplace_back(env->getVar(sym_name)); + if (itr != symbols_.end()) { + return itr->second; } - const auto input_time = stage_timer.count(); + if (parent_ != nullptr) { + return parent_->lookupValue(key); + } - // TODO: rename this flag, enable_executable_dump? - if (rt_config.enable_processor_dump()) { - // Naming convention for dumped files must align with debug runner. - std::filesystem::path dump_folder(rt_config.processor_dump_dir()); - dump_folder /= exec.name(); + // Somehow cannot find this value on stack, print a reasonable error + YACL_THROW("TODO: add more details"); + // YACL_THROW("Try to get a non-exist value, defined at {} ", + // mlirObjectToString(*v.getDefiningOp())); +} - std::filesystem::create_directories(dump_folder); +void SymbolScope::addValue(mlir::Value key, const spu::Value &val) { + symbols_[key] = val; +} - // dump executable. - if (rank == 0) { - auto fname = dump_folder / std::string("exec.txt"); - SPDLOG_INFO("Dump exec to {}", fname); - std::ofstream ir_file(fname, std::ios::binary | std::ios::out); - ir_file << exec.SerializeAsString(); - } +void SymbolScope::addValue(mlir::Value key, spu::Value &&val) { + symbols_[key] = std::move(val); +} - // dump all inputs. - { - size_t var_counter = 0; - for (const auto &val : inputs) { - auto fname = - dump_folder / fmt::format("data_{}_{}.txt", rank, var_counter++); - SPDLOG_INFO("Dump data to {}", fname); - std::ofstream inputs_file(fname, std::ios::binary | std::ios::out); - inputs_file << val.toProto().SerializeAsString(); - } - } +std::vector runRegion(OpExecutor *executor, // + HalContext *hctx, // + SymbolScope *parent_scope, // + mlir::Region ®ion, // + absl::Span params, // + const ExecutionOptions &opts) { + YACL_ENFORCE(region.getNumArguments() == params.size(), + "region requires {} arguments while got number of params {}", + region.getRegionNumber(), params.size()); + + // create a new scope for this region. + SymbolScope sscope(parent_scope); + + // inject the parameters to region's symbol table. + for (const auto &blkarg : region.getArguments()) { + sscope.addValue(blkarg, params[blkarg.getArgNumber()]); } - // Profile: before execution stamp - stage_timer.reset(); - auto outputs = run(exec.code(), inputs); - const auto exec_time = stage_timer.count(); + YACL_ENFORCE(region.hasOneBlock()); + return runBlock(executor, hctx, &sscope, region.front(), params, opts); +} - // sync output to environment. - stage_timer.reset(); - for (int32_t idx = 0; idx < exec.output_names_size(); idx++) { - const std::string &sym_name = exec.output_names(idx); - env->setVar(sym_name, outputs[idx]); - } - const auto output_time = stage_timer.count(); - - // Collect time profile data - auto total_time = timer.count(); - - // Only one party prints for multi-threading simulation - if (hctx_->rt_config().enable_pphlo_profile()) { - SPDLOG_INFO( - "[Profiling] SPU execution {} completed, input processing took {}s, " - "execution took {}s, output processing took {}s, total time {}s.", - module_name_, input_time.count(), exec_time.count(), - output_time.count(), total_time.count()); - const auto &records = getProfileRecords(); - double total_time = .0; - for (const auto &[name, record] : records) { - total_time += record.time.count(); - } - SPDLOG_INFO("HLO profiling: total time: {}", total_time); - for (const auto &[name, record] : records) { - SPDLOG_INFO("- {}, executed {} times, duration {}s", name, record.count, - record.time.count()); - } +std::vector runBlock(OpExecutor *executor, HalContext *hctx, + SymbolScope *symbols, mlir::Block &block, + absl::Span params, + const ExecutionOptions &opts) { + for (auto &op : block.without_terminator()) { + executor->runKernel(hctx, symbols, op); } - struct ActionKey { - std::string_view name; - int64_t flag; - bool operator<(const ActionKey &other) const { - return std::tie(name, flag) < std::tie(other.name, other.flag); - } - }; - - // helper utilities - struct ActionStatistic { - // number of actions executed. - size_t count = 0; - // total duration time. - Duration total_time = {}; - - inline double getTotalTimeInSecond() const { - return std::chrono::duration_cast>( - total_time) - .count(); - } - }; - - std::map stats; - if (hctx_->rt_config().enable_hal_profile()) { - const auto &tracer = getTracer(GET_CTX_NAME(hctx_)); - const auto &records = tracer->getRecords(); - - for (const auto &rec : records) { - auto &stat = stats[{rec.name, rec.flag}]; - stat.count++; - stat.total_time += - std::chrono::duration_cast(rec.end - rec.start); - } - - static std::map kModules = { - {TR_HLO, "HLO"}, {TR_HAL, "HAL"}, {TR_MPC, "MPC"}}; - - for (const auto &[mod_flag, mod_name] : kModules) { - double total_time = 0.0; - for (const auto &[key, stat] : stats) { - if ((key.flag & mod_flag) != 0) { - total_time += stat.getTotalTimeInSecond(); - } - } - SPDLOG_INFO("{} profiling: total time {}", mod_name, total_time); - for (const auto &[key, stat] : stats) { - if ((key.flag & mod_flag) != 0) { - SPDLOG_INFO("- {}, executed {} times, duration {}s", key.name, - stat.count, stat.getTotalTimeInSecond()); - } - } + if (auto *termOp = block.getTerminator()) { + // TODO: enforce ReturnLike + std::vector results; + results.reserve(termOp->getNumOperands()); + for (const auto operand : termOp->getOperands()) { + results.emplace_back(symbols->lookupValue(operand)); } + return results; } + + // No terminator + YACL_THROW("Should not be here"); } -void Executor::runWithEnv(const std::string &text, - const std::vector &input_names, - const std::vector &output_names, - SymbolTable *env) { - ExecutableProto exec; - exec.set_name("unnamed"); - *exec.mutable_input_names() = {input_names.begin(), input_names.end()}; - *exec.mutable_output_names() = {output_names.begin(), output_names.end()}; - exec.set_code(text); - - runWithEnv(exec, env); +std::vector runBlockParallel(OpExecutor *executor, HalContext *hctx, + SymbolScope *symbols, + mlir::Block &block, + absl::Span params, + const ExecutionOptions &opts) { + YACL_THROW("TODO"); } } // namespace spu::device diff --git a/spu/device/executor.h b/spu/device/executor.h index f39c39e5..752847ab 100644 --- a/spu/device/executor.h +++ b/spu/device/executor.h @@ -14,59 +14,80 @@ #pragma once -#include +#include -#include "spu/device/profiler.h" -#include "spu/device/symbol_table.h" +#include "llvm/ADT/DenseMap.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/MLIRContext.h" + +#include "spu/kernel/context.h" #include "spu/kernel/value.h" -#include "spu/spu.pb.h" +namespace spu::device { -namespace spu { -class HalContext; -} +// +class SymbolScope final { + // The parent region, null if this region is isolated from above. + SymbolScope *parent_; -namespace spu::device { + // Local symbols inside this value. + // TODO: thread safety for parallel execution. + llvm::DenseMap symbols_; -// The executor interface, an executor evaluates a texted code with given -// inputs, and produce expected outputs. -class Executor { -protected: - HalContext *hctx_ = nullptr; +public: + explicit SymbolScope(SymbolScope *parent = nullptr) : parent_(parent) {} - // Profiling thingy - std::shared_ptr op_profiler_; + // return true if this is the root scope. + bool isRoot() const { return parent_ == nullptr; } - std::string module_name_ = "unnamed"; + // + const spu::Value &lookupValue(mlir::Value key) const; + void addValue(::mlir::Value key, const spu::Value &val); + void addValue(::mlir::Value key, spu::Value &&val); +}; +// This class encapsulate execution states used during the evaluation. +struct ExecutionOptions { + bool do_type_check = false; + bool do_log_execution = false; +}; + +class OpExecutor { public: - explicit Executor(HalContext *hctx) - : hctx_(hctx), op_profiler_(std::make_shared()){}; + virtual ~OpExecutor() = default; - virtual ~Executor() = default; + // + virtual void checkType(mlir::Type mlir_type, const spu::Value &v) const = 0; - // Return the HAL context. - HalContext *getContext() const { return hctx_; } + // return true if the operation has a corresponding kernel. + virtual bool hasKernel(mlir::Operation &op) const = 0; - /// Run a code snippet with given inputs. - // return a list of output values. - virtual std::vector - run(const std::string &code, const std::vector &inputs) = 0; + // run a kernel in a given region. + virtual void runKernelImpl(HalContext *hctx, SymbolScope *sscope, + mlir::Operation &op, + const ExecutionOptions &opts) = 0; - /// Return the op profiling records. - const Profiler::ExecutionRecordsT &getProfileRecords() const { - // op_profiler_ cannot be nullptr - return op_profiler_->getRecords(); + void runKernel(HalContext *hctx, SymbolScope *sscope, mlir::Operation &op, + const ExecutionOptions &opts = {}) { + return runKernelImpl(hctx, sscope, op, opts); } - - /// Evaluate an spu executable with given environment. - void runWithEnv(const ExecutableProto &exec, SymbolTable *env); - - /// - void runWithEnv(const std::string &text, - const std::vector &input_names, - const std::vector &output_names, - SymbolTable *env); }; +std::vector runRegion(OpExecutor *executor, HalContext *hctx, + SymbolScope *parent_scope, + mlir::Region ®ion, + absl::Span params, + const ExecutionOptions &opts = {}); + +std::vector runBlock(OpExecutor *executor, HalContext *hctx, + SymbolScope *symbols, mlir::Block &block, + absl::Span params, + const ExecutionOptions &opts); + +std::vector runBlockParallel(OpExecutor *executor, HalContext *hctx, + SymbolScope *symbols, + mlir::Block &block, + absl::Span params, + const ExecutionOptions &opts); + } // namespace spu::device diff --git a/spu/device/frame.cc b/spu/device/frame.cc deleted file mode 100644 index 853fd171..00000000 --- a/spu/device/frame.cc +++ /dev/null @@ -1,60 +0,0 @@ -// Copyright 2021 Ant Group Co., Ltd. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "spu/device/frame.h" - -#include "mlir/IR/BuiltinTypes.h" -#include "yacl/base/exception.h" - -#include "spu/dialect/pphlo_types.h" - -namespace spu::device { - -void Frame::releaseValue(::mlir::Value operand) { - YACL_ENFORCE(!segments_.empty(), - "Need at least one activate segment running"); - segments_.back().values_.erase(operand); -} - -void Frame::addValue(::mlir::Value operand, spu::Value &&val) { - YACL_ENFORCE(!segments_.empty(), - "Need at least one activate segment running"); - segments_.back().values_[operand] = std::move(val); -} - -void Frame::addValue(::mlir::Value operand, const spu::Value &val) { - YACL_ENFORCE(!segments_.empty(), - "Need at least one activate segment running"); - segments_.back().values_[operand] = val; -} - -const spu::Value *Frame::getValue(::mlir::Value operand) const { - const spu::Value *val = nullptr; - YACL_ENFORCE(!segments_.empty()); - for (auto siter = segments_.rbegin(); siter != segments_.rend(); ++siter) { - auto iter = siter->values_.find(operand); - if (iter != siter->values_.end()) { - val = &iter->second; - break; - } - } - // If type checker is enabled, do it at getter time - if ((val != nullptr) && type_checker_) { - type_checker_->check(operand.getType(), *val); - } - - return val; -} - -} // namespace spu::device diff --git a/spu/device/frame.h b/spu/device/frame.h deleted file mode 100644 index 3b8018ed..00000000 --- a/spu/device/frame.h +++ /dev/null @@ -1,60 +0,0 @@ -// Copyright 2021 Ant Group Co., Ltd. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include - -#include "llvm/ADT/DenseMap.h" -#include "mlir/IR/Value.h" - -#include "spu/device/type_checker.h" -#include "spu/kernel/value.h" - -namespace spu::device { - -class ModuleRunner; - -// This class represents a call frame. -class Frame final { - struct RegionDataSegment { - llvm::DenseMap values_; - }; - - std::shared_ptr type_checker_; - std::deque segments_; - -public: - Frame() = default; - Frame(const Frame &) = delete; - Frame &operator=(const Frame &) = delete; - - void enterRegion() { segments_.emplace_back(); } - void leaveRegion() { segments_.pop_back(); } - - void setTypeCheker(std::shared_ptr checker) { - type_checker_ = std::move(checker); - } - - bool hasValue(::mlir::Value operand) const; - - void addValue(::mlir::Value operand, const spu::Value &val); - void addValue(::mlir::Value operand, spu::Value &&val); - - void releaseValue(::mlir::Value operand); - const spu::Value *getValue(::mlir::Value operand) const; -}; - -} // namespace spu::device diff --git a/spu/device/pphlo/BUILD.bazel b/spu/device/pphlo/BUILD.bazel index d3dab16f..b92c307c 100644 --- a/spu/device/pphlo/BUILD.bazel +++ b/spu/device/pphlo/BUILD.bazel @@ -17,72 +17,29 @@ load("//bazel:spu.bzl", "spu_cc_binary", "spu_cc_library", "spu_cc_test") package(default_visibility = ["//visibility:public"]) spu_cc_library( - name = "executor", - srcs = ["executor.cc"], - hdrs = ["executor.h"], + name = "pphlo_executor", + srcs = ["pphlo_executor.cc"], + hdrs = ["pphlo_executor.h"], deps = [ - ":region_executor", - ":type_checker", - "//spu/device:executor", - "//spu/device:frame", - "//spu/dialect:pphlo_dialect", - "@llvm-project//mlir:Parser", - ], -) - -spu_cc_library( - name = "region_executor", - srcs = [ - "region_executor.cc", - ], - hdrs = [ - "region_executor.h", - ], - deps = [ - ":type_checker", ":xla_verifier", - "//spu/device:frame", - "//spu/device:profiler", + "//spu/device:executor", "//spu/dialect:pphlo_dialect", "//spu/kernel/hlo", ], ) spu_cc_test( - name = "executor_test", - srcs = ["executor_test.cc"], + name = "pphlo_executor_test", + srcs = ["pphlo_executor_test.cc"], deps = [ - ":executor", + ":pphlo_executor", "//spu/compiler:compile", + "//spu/device:api", "//spu/device:io", "//spu/device:test_utils", ], ) -spu_cc_library( - name = "type_checker", - srcs = ["type_checker.cc"], - hdrs = ["type_checker.h"], - deps = [ - "//spu/device:type_checker", - "//spu/dialect:pphlo_dialect", - "//spu/kernel:value", - "@llvm-project//mlir:IR", - "@yacl//yacl/base:exception", - ], -) - -spu_cc_binary( - name = "executor_debug_runner", - testonly = True, - srcs = ["executor_debug_runner.cc"], - deps = [ - ":executor", - "//spu/device:test_utils", - "@llvm-project//llvm:Support", - ], -) - spu_cc_library( name = "xla_verifier", srcs = ["xla_verifier.cc"], diff --git a/spu/device/pphlo/executor.cc b/spu/device/pphlo/executor.cc deleted file mode 100644 index 62b319cd..00000000 --- a/spu/device/pphlo/executor.cc +++ /dev/null @@ -1,106 +0,0 @@ -// Copyright 2021 Ant Group Co., Ltd. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "spu/device/pphlo/executor.h" - -#include "llvm/Support/ErrorHandling.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/Value.h" -#include "mlir/Parser/Parser.h" -#include "yacl/base/exception.h" - -#include "spu/device/frame.h" -#include "spu/device/pphlo/region_executor.h" -#include "spu/dialect/pphlo_dialect.h" -#include "spu/kernel/context.h" -#include "spu/kernel/value.h" - -namespace spu::device::pphlo { - -namespace { - -std::mutex ErrorHandlerMutex; - -void SPUErrorHandler(void * /*use_data*/, const char *reason, - bool /*gen_crash_diag*/) { - YACL_THROW(reason); -} - -} // namespace - -std::vector -PPHloExecutor::executeFunc(mlir::func::FuncOp &fcn, - llvm::ArrayRef inputs) { - Frame callFrame; - RegionExecutor executor(getContext(), &callFrame, op_profiler_); - return executor.executeRegion(fcn.getBody(), inputs); -} - -PPHloExecutor::PPHloExecutor(HalContext *ctx) : Executor(ctx) { - // Set an error handler - { - std::lock_guard guard(ErrorHandlerMutex); - llvm::remove_fatal_error_handler(); - llvm::install_fatal_error_handler(SPUErrorHandler); - } - - mlir::DialectRegistry registry; - registry.insert(); - mlir_context_ = std::make_unique(registry); - - int64_t tr_mask = 0; - if (ctx->rt_config().enable_action_trace()) { - tr_mask |= TR_LOG; - } - - if (ctx->rt_config().enable_hal_profile()) { - tr_mask |= TR_HLO | TR_HAL | TR_MPC; - tr_mask |= TR_REC; - } - - getTracer(GET_CTX_NAME(ctx))->setMask(tr_mask); - getTracer(GET_CTX_NAME(ctx))->clearRecords(); -} - -PPHloExecutor::~PPHloExecutor() { - std::lock_guard guard(ErrorHandlerMutex); - llvm::remove_fatal_error_handler(); -} - -std::vector -PPHloExecutor::run(const std::string &code, - const std::vector &inputs) { - auto moduleOpRef = - mlir::parseSourceString(code, mlir_context_.get()); - - if (hctx_->rt_config().enable_pphlo_trace()) { - SPDLOG_INFO("Executing module {}", - moduleOpRef->getName().value_or("Unnamed")); - } - - auto entry_function = moduleOpRef->lookupSymbol("main"); - YACL_ENFORCE(entry_function); - - return executeFunc(entry_function, inputs); -} - -mlir::OwningOpRef -PPHloExecutor::parseSourceString(const std::string &code) { - auto moduleOp = - mlir::parseSourceString(code, mlir_context_.get()); - return moduleOp; -} - -} // namespace spu::device::pphlo diff --git a/spu/device/pphlo/executor.h b/spu/device/pphlo/executor.h deleted file mode 100644 index ab7f95b0..00000000 --- a/spu/device/pphlo/executor.h +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright 2021 Ant Group Co., Ltd. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/MLIRContext.h" - -#include "spu/device/executor.h" -#include "spu/kernel/value.h" - -namespace spu::device::pphlo { - -class PPHloExecutor : public Executor { -public: - explicit PPHloExecutor(HalContext *ctx); - - ~PPHloExecutor() override; - - std::vector run(const std::string &code, - const std::vector &inputs) override; - - mlir::OwningOpRef parseSourceString(const std::string &code); - -private: - std::vector executeFunc(mlir::func::FuncOp &fcn, - llvm::ArrayRef inputs); - - std::unique_ptr mlir_context_; -}; - -} // namespace spu::device::pphlo diff --git a/spu/device/pphlo/executor_debug_runner.cc b/spu/device/pphlo/executor_debug_runner.cc index 523b7010..52c77184 100644 --- a/spu/device/pphlo/executor_debug_runner.cc +++ b/spu/device/pphlo/executor_debug_runner.cc @@ -128,5 +128,5 @@ int main(int argc, char **argv) { hctx->lctx()->Rank(), data_file.c_str(), v); table.setVar(exec.input_names(var_counter), v); } - executor.runWithEnv(exec, &table); + spu::device::execute(&executor, exec, &table); } diff --git a/spu/device/pphlo/pphlo_executor.cc b/spu/device/pphlo/pphlo_executor.cc new file mode 100644 index 00000000..855576ff --- /dev/null +++ b/spu/device/pphlo/pphlo_executor.cc @@ -0,0 +1,1063 @@ +// Copyright 2022 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "spu/device/pphlo/pphlo_executor.h" + +#include "llvm/Support/raw_os_ostream.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Location.h" + +#include "spu/dialect/pphlo_base_enums.h" +#include "spu/dialect/pphlo_ops.h" +#include "spu/kernel/hlo/basic_binary.h" +#include "spu/kernel/hlo/basic_ternary.h" +#include "spu/kernel/hlo/basic_unary.h" +#include "spu/kernel/hlo/casting.h" +#include "spu/kernel/hlo/const.h" +#include "spu/kernel/hlo/control_flow.h" +#include "spu/kernel/hlo/convolution.h" +#include "spu/kernel/hlo/dynamic_slice.h" +#include "spu/kernel/hlo/geometrical.h" +#include "spu/kernel/hlo/indexing.h" +#include "spu/kernel/hlo/rand.h" +#include "spu/kernel/hlo/reduce.h" +#include "spu/kernel/hlo/select_and_scatter.h" +#include "spu/kernel/hlo/shift.h" +#include "spu/kernel/hlo/sort.h" + +namespace { + +std::vector +convertDenseIntElementAttr(const mlir::DenseIntElementsAttr &attr) { + std::vector ret; + + for (const auto &v : attr.getValues()) { + ret.emplace_back(v); + } + + return ret; +} + +template +std::string mlirObjectToString(T &&mlir_obj) { + std::string buf; + llvm::raw_string_ostream rss(buf); + mlir_obj.print(rss); + rss.flush(); + return buf; +} + +spu::PtType getPtTypeFromMlirType(mlir::Type mlir_ty) { + mlir::pphlo::TypeTools tool; + auto express_type = tool.getExpressedType(mlir_ty); + + if (auto ft = express_type.dyn_cast()) { + switch (ft.getWidth()) { + case 32: + return spu::PT_F32; + case 64: + return spu::PT_F64; + } + } else if (auto it = express_type.dyn_cast()) { + if (it.getWidth() == 1) { + return spu::PT_BOOL; + } + // In mlir, isSigned is for si[1-9][0-9]* type, isUnsigned is for + // ui[1-9][0-9]*, i[1-9][0-9]* is signless IntegerType... So here, we only + // check for isUnsigned, signless we treat it as signed. + // See https://reviews.llvm.org/D72533 + switch (it.getWidth()) { + case 8: + return it.isUnsigned() ? spu::PT_U8 : spu::PT_I8; + case 16: + return it.isUnsigned() ? spu::PT_U16 : spu::PT_I16; + case 32: + return it.isUnsigned() ? spu::PT_U32 : spu::PT_I32; + case 64: + return it.isUnsigned() ? spu::PT_U64 : spu::PT_I64; + } + } + YACL_THROW("invalid type {}", mlirObjectToString(mlir_ty)); +} + +spu::DataType getDtypeFromMlirType(mlir::Type mlir_ty) { + mlir::pphlo::TypeTools tool; + auto express_type = tool.getExpressedType(mlir_ty); + if (auto int_ty = express_type.dyn_cast()) { + switch (int_ty.getWidth()) { + case 1: + return spu::DT_I1; + case 8: + return int_ty.isUnsigned() ? spu::DT_U8 : spu::DT_I8; + case 16: + return int_ty.isUnsigned() ? spu::DT_U16 : spu::DT_I16; + case 32: + return int_ty.isUnsigned() ? spu::DT_U32 : spu::DT_I32; + case 64: + return int_ty.isUnsigned() ? spu::DT_U64 : spu::DT_I64; + default: + YACL_THROW("unsupported int type {}", mlirObjectToString(mlir_ty)); + } + } else if (auto flp_ty = express_type.dyn_cast()) { + return spu::DT_FXP; + } + YACL_THROW("invalid type {}", mlirObjectToString(mlir_ty)); +} + +// Convert mlir visibility to spu visibility +spu::Visibility convertVisibility(mlir::pphlo::Visibility vis) { + switch (vis) { + case mlir::pphlo::Visibility::VIS_PUBLIC: + return spu::Visibility::VIS_PUBLIC; + case mlir::pphlo::Visibility::VIS_SECRET: + return spu::Visibility::VIS_SECRET; + } + YACL_THROW("Should not hit"); +} + +} // namespace + +namespace spu::device::pphlo { +namespace { + +const spu::Value &lookupValue(SymbolScope *scope, mlir::Value key, + const ExecutionOptions &opts) { + const auto &val = scope->lookupValue(key); + + if (opts.do_type_check) { + const auto mlir_type = key.getType(); + { + const auto &mlir_shape = + mlir_type.dyn_cast().getShape(); + const auto &spu_shape = val.shape(); + + YACL_ENFORCE(mlir_shape.size() == spu_shape.size(), + "Runtime shape mismatch, expected={}, got={}", + fmt::join(mlir_shape, "x"), fmt::join(spu_shape, "x")); + + for (size_t idx = 0; idx < mlir_shape.size(); ++idx) { + YACL_ENFORCE(mlir_shape[idx] == spu_shape[idx], + "Runtime shape mismatch at dim {}, expected={}, got={}", + idx, fmt::join(mlir_shape, "x"), + fmt::join(spu_shape, "x")); + } + } + + // Check dtype + mlir::pphlo::TypeTools tool; + auto expectedType = getDtypeFromMlirType(mlir_type); + YACL_ENFORCE(expectedType == val.dtype(), "Expected mlir_type {}, got {}", + expectedType, val.dtype()); + + // Check vtype + if (tool.isMPCType(mlir_type)) { + YACL_ENFORCE(val.isPublic()); + } else if (tool.isMPCType(mlir_type)) { + YACL_ENFORCE(val.isSecret()); + } else { + YACL_ENFORCE("Unknown vtype"); + } + } + return val; +} + +// +#define STANDARD_UNARY_OP_EXEC_IMPL(OpName, KernelName) \ + void execute(OpExecutor *executor, HalContext *hctx, SymbolScope *sscope, \ + mlir::pphlo::OpName &op, const ExecutionOptions &opts) { \ + const auto in = lookupValue(sscope, op.getOperand(), opts); \ + auto ret = kernel::hlo::KernelName(hctx, in); \ + sscope->addValue(op.getResult(), std::move(ret)); \ + } + +STANDARD_UNARY_OP_EXEC_IMPL(ReciprocalOp, Reciprocal) +STANDARD_UNARY_OP_EXEC_IMPL(NegOp, Neg) +STANDARD_UNARY_OP_EXEC_IMPL(ExpOp, Exp) +STANDARD_UNARY_OP_EXEC_IMPL(Expm1Op, Expm1) +STANDARD_UNARY_OP_EXEC_IMPL(LogOp, Log) +STANDARD_UNARY_OP_EXEC_IMPL(Log1pOp, Log1p) +STANDARD_UNARY_OP_EXEC_IMPL(FloorOp, Floor) +STANDARD_UNARY_OP_EXEC_IMPL(CeilOp, Ceil) +STANDARD_UNARY_OP_EXEC_IMPL(AbsOp, Abs) +STANDARD_UNARY_OP_EXEC_IMPL(LogisticOp, Logistic) +STANDARD_UNARY_OP_EXEC_IMPL(TanhOp, Tanh) +STANDARD_UNARY_OP_EXEC_IMPL(NotOp, Not) +STANDARD_UNARY_OP_EXEC_IMPL(RsqrtOp, Rsqrt) +STANDARD_UNARY_OP_EXEC_IMPL(SqrtOp, Sqrt) +STANDARD_UNARY_OP_EXEC_IMPL(RoundOp, Round_AFZ) + +#undef STANDARD_UNARY_OP_EXEC_IMPL + +#define STANDARD_BINARY_OP_EXEC_IMPL(OpName, KernelName) \ + void execute(OpExecutor *executor, HalContext *hctx, SymbolScope *sscope, \ + mlir::pphlo::OpName &op, const ExecutionOptions &opts) { \ + sscope->addValue( \ + op.getResult(), \ + kernel::hlo::KernelName(hctx, lookupValue(sscope, op.lhs(), opts), \ + lookupValue(sscope, op.rhs(), opts))); \ + } + +STANDARD_BINARY_OP_EXEC_IMPL(AddOp, Add) +STANDARD_BINARY_OP_EXEC_IMPL(EqualOp, Equal) +STANDARD_BINARY_OP_EXEC_IMPL(NotEqualOp, NotEqual) +STANDARD_BINARY_OP_EXEC_IMPL(LessEqualOp, LessEqual) +STANDARD_BINARY_OP_EXEC_IMPL(GreaterEqualOp, GreaterEqual) +STANDARD_BINARY_OP_EXEC_IMPL(SubtractOp, Sub) +STANDARD_BINARY_OP_EXEC_IMPL(LessOp, Less) +STANDARD_BINARY_OP_EXEC_IMPL(GreaterOp, Greater) +STANDARD_BINARY_OP_EXEC_IMPL(MulOp, Mul) +STANDARD_BINARY_OP_EXEC_IMPL(PowOp, Power) +STANDARD_BINARY_OP_EXEC_IMPL(MaxOp, Max) +STANDARD_BINARY_OP_EXEC_IMPL(MinOp, Min) +STANDARD_BINARY_OP_EXEC_IMPL(AndOp, And) +STANDARD_BINARY_OP_EXEC_IMPL(OrOp, Or) +STANDARD_BINARY_OP_EXEC_IMPL(XorOp, Xor) +STANDARD_BINARY_OP_EXEC_IMPL(DivOp, Div) +STANDARD_BINARY_OP_EXEC_IMPL(ShiftLeftOp, Lshift) +STANDARD_BINARY_OP_EXEC_IMPL(ShiftRightArithmeticOp, ARshift) +STANDARD_BINARY_OP_EXEC_IMPL(ShiftRightLogicalOp, Rshift) + +#undef STANDARD_BINARY_OP_EXEC_IMPL + +void execute(OpExecutor *executor, HalContext *hctx, SymbolScope *sscope, + mlir::pphlo::DotOp &op, const ExecutionOptions &opts) { + auto ret = kernel::hlo::Dot(hctx, lookupValue(sscope, op.lhs(), opts), + lookupValue(sscope, op.rhs(), opts)); + + const auto ret_shape = + op.getResult().getType().dyn_cast().getShape(); + + sscope->addValue(op.getResult(), kernel::hlo::Reshape(hctx, ret, ret_shape)); +} + +void execute(OpExecutor *executor, HalContext *hctx, SymbolScope *sscope, + mlir::pphlo::DotGeneralOp &op, const ExecutionOptions &opts) { + auto dnum = op.dot_dimension_numbers(); + // Should in order + YACL_ENFORCE(dnum.getLhsBatchingDimensions().size() == 1 && + dnum.getLhsContractingDimensions().size() == 1 && + dnum.getLhsBatchingDimensions()[0] == 0 && + dnum.getLhsContractingDimensions()[0] == 2, + "LHS dims is not in order"); + YACL_ENFORCE(dnum.getRhsBatchingDimensions().size() == 1 && + dnum.getRhsContractingDimensions().size() == 1 && + dnum.getRhsBatchingDimensions()[0] == 0 && + dnum.getRhsContractingDimensions()[0] == 1, + "RHS dims is not in order"); + + auto lhs = lookupValue(sscope, op.lhs(), opts); + auto rhs = lookupValue(sscope, op.rhs(), opts); + YACL_ENFORCE(lhs.shape()[0] == rhs.shape()[0], "Batch dim should equal"); + int64_t num_batch = lhs.shape()[0]; + + std::vector results(num_batch); + std::vector lhs_slice_begin(3, 0); + std::vector lhs_slice_end = lhs.shape(); + std::vector rhs_slice_begin(3, 0); + std::vector rhs_slice_end = rhs.shape(); + std::vector strides(lhs.shape().size(), 1); + + std::vector lhs_slice_shape{lhs.shape()[1], lhs.shape()[2]}; + std::vector rhs_slice_shape{rhs.shape()[1], rhs.shape()[2]}; + std::vector ret_slice_shape{1, lhs.shape()[1], rhs.shape()[2]}; + + for (int64_t batch_idx = 0; batch_idx < num_batch; ++batch_idx) { + lhs_slice_begin[0] = batch_idx; + lhs_slice_end[0] = batch_idx + 1; + rhs_slice_begin[0] = batch_idx; + rhs_slice_end[0] = batch_idx + 1; + auto lhs_slice = kernel::hlo::Reshape( + hctx, + kernel::hlo::Slice(hctx, lhs, lhs_slice_begin, lhs_slice_end, strides), + lhs_slice_shape); + auto rhs_slice = kernel::hlo::Reshape( + hctx, + kernel::hlo::Slice(hctx, rhs, rhs_slice_begin, rhs_slice_end, strides), + rhs_slice_shape); + results[batch_idx] = kernel::hlo::Reshape( + hctx, kernel::hlo::Dot(hctx, lhs_slice, rhs_slice), ret_slice_shape); + } + + auto ret_type = op.getResult().getType().dyn_cast(); + auto ret = kernel::hlo::Reshape( + hctx, kernel::hlo::Concatenate(hctx, results, 0), ret_type.getShape()); + + sscope->addValue(op.getResult(), std::move(ret)); +} + +void execute(OpExecutor *executor, HalContext *hctx, SymbolScope *sscope, + mlir::pphlo::ConvolutionOp &op, const ExecutionOptions &opts) { + const auto &dnums = op.dimension_numbers(); + const size_t num_spatial_dims = dnums.getOutputSpatialDimensions().size(); + YACL_ENFORCE(num_spatial_dims == dnums.getInputSpatialDimensions().size()); + YACL_ENFORCE(num_spatial_dims == dnums.getKernelSpatialDimensions().size()); + + const auto ret_shape = + op.getResult().getType().dyn_cast().getShape(); + + auto lhs = lookupValue(sscope, op.lhs(), opts); + auto rhs = lookupValue(sscope, op.rhs(), opts); + + std::vector window_strides(dnums.getInputSpatialDimensions().size(), + 1); + if (op.window_strides().has_value()) { + for (const auto &iter : + llvm::enumerate(op.window_strides()->getValues())) { + window_strides[iter.index()] = iter.value(); + } + } + + kernel::hlo::ConvolutionConfig config; + config.featureGroupCount = op.feature_group_count(); + config.batchGroupCount = op.batch_group_count(); + config.window_strides = window_strides; + config.inputBatchDimension = dnums.getInputBatchDimension(); + config.inputFeatureDimension = dnums.getInputFeatureDimension(); + config.inputSpatialDimensions = dnums.getInputSpatialDimensions(); + config.kernelInputFeatureDimension = dnums.getKernelInputFeatureDimension(); + config.kernelOutputFeatureDimension = dnums.getKernelOutputFeatureDimension(); + config.kernelSpatialDimensions = dnums.getKernelSpatialDimensions(); + config.outputBatchDimension = dnums.getOutputBatchDimension(); + config.outputFeatureDimension = dnums.getOutputFeatureDimension(); + config.outputSpatialDimensions = dnums.getOutputSpatialDimensions(); + + spu::Value result; + + if (dnums.getInputSpatialDimensions().size() == 2) { + result = kernel::hlo::Convolution2D(hctx, lhs, rhs, config, ret_shape); + } else { + result = kernel::hlo::Convolution(hctx, lhs, rhs, config, ret_shape); + } + + sscope->addValue(op.getResult(), std::move(result)); +} + +void execute(OpExecutor *executor, HalContext *hctx, SymbolScope *sscope, + mlir::pphlo::DynamicUpdateSliceOp &op, + const ExecutionOptions &opts) { + // Basic idea here, get a ref slice and update the whole slice.. + // Start indicies + std::vector start_indicies(op.start_indices().size()); + const auto &operand = lookupValue(sscope, op.operand(), opts); + const auto &update = lookupValue(sscope, op.update(), opts); + + for (const auto &idx : llvm::enumerate(op.start_indices())) { + start_indicies[idx.index()] = lookupValue(sscope, idx.value(), opts); + } + + sscope->addValue(op.getResult(), kernel::hlo::DynamicUpdateSlice( + hctx, operand, update, start_indicies)); +} + +void execute(OpExecutor *executor, HalContext *hctx, SymbolScope *sscope, + mlir::pphlo::DynamicSliceOp &op, const ExecutionOptions &opts) { + // Start indicies + auto iter = op.slice_sizes().getValues(); + std::vector slice_size{iter.begin(), iter.end()}; + const auto &operand = lookupValue(sscope, op.operand(), opts); + std::vector start_indicies(op.start_indices().size()); + + for (const auto &idx : llvm::enumerate(op.start_indices())) { + start_indicies[idx.index()] = lookupValue(sscope, idx.value(), opts); + } + + sscope->addValue( + op.getResult(), + kernel::hlo::DynamicSlice(hctx, operand, slice_size, start_indicies)); +} + +void execute(OpExecutor *executor, HalContext *hctx, SymbolScope *sscope, + mlir::pphlo::GatherOp &op, const ExecutionOptions &opts) { + // If input is empty, short circuit + auto operand = lookupValue(sscope, op.operand(), opts); + auto start_indicies = lookupValue(sscope, op.start_indices(), opts); + if (operand.numel() == 0) { + sscope->addValue(op.getResult(), operand); + return; + } + + const auto &output_shape = + op.getResult().getType().dyn_cast().getShape(); + + const auto &dim_numbers = op.dimension_numbers(); + + kernel::hlo::GatherConfig config; + auto ss = convertDenseIntElementAttr(op.slice_sizes()); + config.sliceSizes = ss; + config.indexVectorDim = dim_numbers.getIndexVectorDim(); + config.offsetDims = dim_numbers.getOffsetDims(); + config.collapsedSliceDims = dim_numbers.getCollapsedSliceDims(); + config.startIndexMap = dim_numbers.getStartIndexMap(); + + sscope->addValue( + op.getResult(), + kernel::hlo::Gather(hctx, operand, start_indicies, config, output_shape)); +} + +void execute(OpExecutor *executor, HalContext *hctx, SymbolScope *sscope, + mlir::pphlo::SortOp &op, const ExecutionOptions &opts) { + auto sort_dim = op.dimension(); + auto is_stable = op.is_stable(); + std::vector inputs(op->getNumOperands()); + for (size_t idx = 0; idx < inputs.size(); ++idx) { + inputs[idx] = lookupValue(sscope, op->getOperand(idx), opts); + } + + auto body_return = + llvm::dyn_cast(op.comparator().back().back()); + YACL_ENFORCE(body_return, "Cannot find body return"); + YACL_ENFORCE(body_return->getNumOperands() == 1, + "Comparator should have exactly one return"); + + mlir::pphlo::TypeTools type_tools; + auto return_vis = + type_tools.getTypeVisibility(body_return->getOperandTypes().front()); + + auto ret = kernel::hlo::Sort( + hctx, inputs, sort_dim, is_stable, + [&](absl::Span inputs) { + auto ret = runRegion(executor, hctx, sscope, op.comparator(), inputs); + return ret[0]; + }, + convertVisibility(return_vis)); + + for (int64_t idx = 0; idx < op->getNumResults(); ++idx) { + sscope->addValue(op->getResult(idx), std::move(ret[idx])); + } +} + +void execute(OpExecutor *executor, HalContext *hctx, SymbolScope *sscope, + mlir::pphlo::SelectAndScatterOp &op, + const ExecutionOptions &opts) { + auto operand = lookupValue(sscope, op.operand(), opts); + auto source = lookupValue(sscope, op.source(), opts); + auto init_val = lookupValue(sscope, op.init_value(), opts); + + auto window_shape = convertDenseIntElementAttr(op.window_dimensions()); + + // build strides + std::vector window_strides(window_shape.size(), 1); + if (op.window_strides().has_value()) { + window_strides = convertDenseIntElementAttr(*op.window_strides()); + } + + // window padding + std::vector> window_padding(window_shape.size(), + {0, 0}); + if (op.padding().has_value()) { + const auto v = *op.padding(); + + YACL_ENFORCE(window_padding.size() * 2 == (size_t)v.size()); + + for (size_t idx = 0; idx < window_padding.size(); ++idx) { + window_padding[idx] = {*(v.getValues().begin() + 2 * idx), + *(v.getValues().begin() + 2 * idx + 1)}; + } + } + + // auto ret = kernel::hlo::SelectAndScatterNaive( + auto ret = kernel::hlo::SelectAndScatterExpanded( + hctx, operand, source, init_val, window_shape, window_strides, + window_padding, + [&](const spu::Value &selected, const spu::Value ¤t) { + auto ret = + runRegion(executor, hctx, sscope, op.select(), {selected, current}); + return ret[0]; + }, + [&](const spu::Value &in, const spu::Value &scatter) { + auto ret = + runRegion(executor, hctx, sscope, op.scatter(), {in, scatter}); + return ret[0]; + }); + + sscope->addValue(op.getResult(), std::move(ret)); +} + +void execute(OpExecutor *executor, HalContext *hctx, SymbolScope *sscope, + mlir::pphlo::MaxPoolScatterOp &op, const ExecutionOptions &opts) { + auto scatter_indices = lookupValue(sscope, op.scatter_indices(), opts); + auto update = lookupValue(sscope, op.update(), opts); + + auto window_shape = + convertDenseIntElementAttr(op.window_dimensions().value()); + + // build strides + std::vector window_strides(window_shape.size(), 1); + if (op.window_strides().has_value()) { + window_strides = convertDenseIntElementAttr(*op.window_strides()); + } + + // window padding + std::vector> window_padding(window_shape.size(), + {0, 0}); + if (op.padding().has_value()) { + const auto v = *op.padding(); + + YACL_ENFORCE(window_padding.size() * 2 == (size_t)v.size()); + + for (size_t idx = 0; idx < window_padding.size(); ++idx) { + window_padding[idx] = {*(v.getValues().begin() + 2 * idx), + *(v.getValues().begin() + 2 * idx + 1)}; + } + } + + auto base_shape = + op.getResult().getType().dyn_cast().getShape(); + + auto ret = + kernel::hlo::MaxPoolScatter(hctx, scatter_indices, update, window_shape, + base_shape, window_strides, window_padding); + + sscope->addValue(op.getResult(), std::move(ret)); +} + +void execute(OpExecutor *executor, HalContext *hctx, SymbolScope *sscope, + mlir::pphlo::IfOp &op, const ExecutionOptions &opts) { + auto conditional = lookupValue(sscope, op.condition(), opts); + + auto results = kernel::hlo::IfElse( + hctx, conditional, // + [&]() { return runRegion(executor, hctx, sscope, op.true_branch(), {}); }, + [&]() { + return runRegion(executor, hctx, sscope, op.false_branch(), {}); + }); + + // Copy output + for (const auto &ret : llvm::enumerate(op->getResults())) { + sscope->addValue(ret.value(), results[ret.index()]); + } +} + +void execute(OpExecutor *executor, HalContext *hctx, SymbolScope *sscope, + mlir::pphlo::WhileOp &op, const ExecutionOptions &opts) { + // First inputs vectors + std::vector inputs; + inputs.reserve(op->getNumOperands()); + + // Prepare inputs + for (const auto operand : op->getOperands()) { + inputs.emplace_back(lookupValue(sscope, operand, opts)); + } + + auto ret = kernel::hlo::While( + hctx, inputs, // + [&](absl::Span inputs) { + return runRegion(executor, hctx, sscope, op.cond(), inputs)[0]; + }, + [&](absl::Span inputs) { + return runRegion(executor, hctx, sscope, op.body(), inputs); + }); + + for (size_t idx = 0; idx < op->getNumResults(); ++idx) { + sscope->addValue(op->getResult(idx), std::move(ret[idx])); + } +} + +#define DISPATCH_ALL_NONE_BOOL_PT_TYPES(PT_TYPE, NAME, ...) \ + [&] { \ + switch (PT_TYPE) { \ + __CASE_PT_TYPE(spu::PT_I8, NAME, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_U8, NAME, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_I16, NAME, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_U16, NAME, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_I32, NAME, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_U32, NAME, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_I64, NAME, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_U64, NAME, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_F32, NAME, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_F64, NAME, __VA_ARGS__) \ + default: \ + YACL_THROW("{} not implemented for pt_type={}", #NAME, PT_TYPE); \ + } \ + }() + +void execute(OpExecutor *executor, HalContext *hctx, SymbolScope *sscope, + mlir::pphlo::IotaOp &op, const ExecutionOptions &opts) { + const auto &ret_type = + op.output().getType().dyn_cast(); + const size_t numel = ret_type.getShape()[op.iota_dimension()]; + + mlir::pphlo::TypeTools type_tools; + auto ret_el_type = type_tools.getExpressedType(ret_type); + auto pt_type = getPtTypeFromMlirType(ret_el_type); + + spu::Value iota_ret; + DISPATCH_ALL_NONE_BOOL_PT_TYPES(pt_type, "_", [&] { + iota_ret = kernel::hlo::Iota(hctx, numel, VIS_PUBLIC); + }); + + if (ret_type.getShape().size() > 1) { + // Need a broadcast + iota_ret = kernel::hlo::Broadcast(hctx, iota_ret, ret_type.getShape(), {}); + } + + sscope->addValue(op.output(), std::move(iota_ret)); +} + +void execute(OpExecutor *executor, HalContext *hctx, SymbolScope *sscope, + mlir::pphlo::RemOp &op, const ExecutionOptions &opts) { + // FIXME: When hal has a remainder, use that + auto lhs = lookupValue(sscope, op.lhs(), opts); + auto rhs = lookupValue(sscope, op.rhs(), opts); + + auto ret = kernel::hlo::Remainder(hctx, lhs, rhs); + sscope->addValue(op.getResult(), std::move(ret)); +} + +void execute(OpExecutor *executor, HalContext *hctx, SymbolScope *sscope, + mlir::pphlo::TransposeOp &op, const ExecutionOptions &opts) { + sscope->addValue( + op.getResult(), + kernel::hlo::Transpose(hctx, lookupValue(sscope, op.getOperand(), opts), + convertDenseIntElementAttr(op.permutation()))); +} + +void execute(OpExecutor *executor, HalContext *hctx, SymbolScope *sscope, + mlir::pphlo::BroadcastOp &op, const ExecutionOptions &opts) { + auto to_shape = op.getType().dyn_cast().getShape(); + sscope->addValue(op.getResult(), + kernel::hlo::Broadcast( + hctx, lookupValue(sscope, op.getOperand(), opts), + to_shape, + convertDenseIntElementAttr(op.broadcast_dimensions()))); +} + +void execute(OpExecutor *executor, HalContext *hctx, SymbolScope *sscope, + mlir::pphlo::ReshapeOp &op, const ExecutionOptions &opts) { + auto to_shape = op.getType().dyn_cast().getShape(); + sscope->addValue( + op.getResult(), + kernel::hlo::Reshape(hctx, lookupValue(sscope, op.getOperand(), opts), + to_shape)); +} + +void execute(OpExecutor *executor, HalContext *hctx, SymbolScope *sscope, + mlir::pphlo::ConcatenateOp &op, const ExecutionOptions &opts) { + std::vector values(op->getNumOperands()); + + for (size_t idx = 0; idx < op->getNumOperands(); ++idx) { + values[idx] = lookupValue(sscope, op->getOperand(idx), opts); + } + + // set result + sscope->addValue(op.getResult(), + kernel::hlo::Concatenate(hctx, values, op.dimension())); +} + +void execute(OpExecutor *executor, HalContext *hctx, SymbolScope *sscope, + mlir::pphlo::SliceOp &op, const ExecutionOptions &opts) { + sscope->addValue( + op.getResult(), + kernel::hlo::Slice(hctx, lookupValue(sscope, op.getOperand(), opts), + convertDenseIntElementAttr(op.start_indices()), + convertDenseIntElementAttr(op.limit_indices()), + convertDenseIntElementAttr(op.strides()))); +} + +void execute(OpExecutor *executor, HalContext *hctx, SymbolScope *sscope, + mlir::pphlo::PadOp &op, const ExecutionOptions &opts) { + const auto &operand = lookupValue(sscope, op.operand(), opts); + const size_t operand_rank = operand.shape().size(); + const auto &padding_value = lookupValue(sscope, op.padding_value(), opts); + YACL_ENFORCE(padding_value.shape().empty()); + + auto edge_padding_low = convertDenseIntElementAttr(op.edge_padding_low()); + YACL_ENFORCE(edge_padding_low.size() == operand_rank); + auto edge_padding_high = convertDenseIntElementAttr(op.edge_padding_high()); + YACL_ENFORCE(edge_padding_high.size() == operand_rank); + auto interior_padding = convertDenseIntElementAttr(op.interior_padding()); + YACL_ENFORCE(interior_padding.size() == operand_rank); + YACL_ENFORCE(std::all_of(interior_padding.begin(), interior_padding.end(), + [](int64_t i) { return i >= 0; })); + + sscope->addValue(op.getResult(), + kernel::hlo::Pad(hctx, operand, padding_value, + edge_padding_low, edge_padding_high, + interior_padding)); +} + +void execute(OpExecutor *executor, HalContext *hctx, SymbolScope *sscope, + mlir::pphlo::ReverseOp &op, const ExecutionOptions &opts) { + sscope->addValue( + op.getResult(), + kernel::hlo::Reverse(hctx, lookupValue(sscope, op.getOperand(), opts), + convertDenseIntElementAttr(op.dimensions()))); +} + +void execute(OpExecutor *executor, HalContext *hctx, SymbolScope *sscope, + mlir::pphlo::ReduceOp &op, const ExecutionOptions &opts) { + int64_t num_args = op->getNumOperands() / 2; + std::vector dimensions_to_reduce = + convertDenseIntElementAttr(op.dimensions()); + + std::vector input_args(num_args); + std::vector init_values(num_args); + for (int64_t i = 0; i < num_args; ++i) { + input_args[i] = lookupValue(sscope, op.inputs()[i], opts); + init_values[i] = lookupValue(sscope, op.init_values()[i], opts); + } + + std::vector ret = kernel::hlo::Reduce( + hctx, input_args, init_values, dimensions_to_reduce, + [&](absl::Span lhs, absl::Span rhs) { + std::vector operands; + operands.reserve(lhs.size() + rhs.size()); + operands.insert(operands.end(), lhs.begin(), lhs.end()); + operands.insert(operands.end(), rhs.begin(), rhs.end()); + return runRegion(executor, hctx, sscope, op.body(), operands); + }); + + const auto &output_shape = + op->getResultTypes()[0].dyn_cast().getShape(); + for (size_t idx = 0; idx < op->getNumResults(); ++idx) { + sscope->addValue(op->getResult(idx), + kernel::hlo::Reshape(hctx, ret[idx], output_shape)); + } +} + +void execute(OpExecutor *executor, HalContext *hctx, SymbolScope *sscope, + mlir::pphlo::ReduceWindowOp &op, const ExecutionOptions &opts) { + int64_t num_args = op->getNumOperands() / 2; + + std::vector input_args(num_args); + std::vector init_values(num_args); + + for (int64_t i = 0; i < num_args; ++i) { + input_args[i] = lookupValue(sscope, op.inputs()[i], opts); + init_values[i] = lookupValue(sscope, op.init_values()[i], opts); + } + + auto ret_shape = op->getResults()[0] + .getType() + .dyn_cast() + .getShape(); + auto window_shape = convertDenseIntElementAttr(op.window_dimensions()); + + // build strides + std::vector window_strides(window_shape.size(), 1); + if (op.window_strides().has_value()) { + window_strides = convertDenseIntElementAttr(*op.window_strides()); + } + + // window dilation + std::vector window_dilations(window_shape.size(), 1); + if (op.window_dilations().has_value()) { + window_dilations = convertDenseIntElementAttr(*op.window_dilations()); + } + + // window padding + std::vector> window_padding(window_shape.size(), + {0, 0}); + if (op.padding().has_value()) { + const auto v = *op.padding(); + + YACL_ENFORCE(window_padding.size() * 2 == (size_t)v.size()); + + for (size_t idx = 0; idx < window_padding.size(); ++idx) { + window_padding[idx] = {*(v.getValues().begin() + 2 * idx), + *(v.getValues().begin() + 2 * idx + 1)}; + } + } + + // base dilation + std::vector base_dilation(window_shape.size(), 1); + if (op.base_dilations().has_value()) { + base_dilation = convertDenseIntElementAttr(*op.base_dilations()); + } + + kernel::hlo::ReduceWindowConfig config; + config.window_shape = window_shape; + config.window_strides = window_strides; + config.window_dilations = window_dilations; + config.window_padding = window_padding; + config.base_dilations = base_dilation; + + auto rets = kernel::hlo::ReduceWindow( + hctx, input_args, init_values, ret_shape, config, + [&](absl::Span lhs, absl::Span rhs) { + std::vector operands; + operands.reserve(lhs.size() + rhs.size()); + operands.insert(operands.end(), lhs.begin(), lhs.end()); + operands.insert(operands.end(), rhs.begin(), rhs.end()); + return runRegion(executor, hctx, sscope, op.body(), operands); + }); + + for (int64_t idx = 0; idx < op->getNumResults(); ++idx) { + sscope->addValue(op->getResults()[idx], std::move(rets[idx])); + } +} + +void execute(OpExecutor *executor, HalContext *hctx, SymbolScope *sscope, + mlir::pphlo::ArgMaxOp &op, const ExecutionOptions &opts) { + auto window_shape = convertDenseIntElementAttr(op.window_dimensions()); + + // build strides + std::vector window_strides(window_shape.size(), 1); + if (op.window_strides().has_value()) { + window_strides = convertDenseIntElementAttr(*op.window_strides()); + } + + // window dilation + std::vector window_dilations(window_shape.size(), 1); + if (op.window_dilations().has_value()) { + window_dilations = convertDenseIntElementAttr(*op.window_dilations()); + } + + // window padding + std::vector> window_padding(window_shape.size(), + {0, 0}); + if (op.padding().has_value()) { + const auto v = *op.padding(); + + YACL_ENFORCE(window_padding.size() * 2 == (size_t)v.size()); + + for (size_t idx = 0; idx < window_padding.size(); ++idx) { + window_padding[idx] = {*(v.getValues().begin() + 2 * idx), + *(v.getValues().begin() + 2 * idx + 1)}; + } + } + + // base dilation + std::vector base_dilation(window_shape.size(), 1); + if (op.base_dilations().has_value()) { + base_dilation = convertDenseIntElementAttr(*op.base_dilations()); + } + + auto ret_shape = op->getResults()[0] + .getType() + .dyn_cast() + .getShape(); + + kernel::hlo::ReduceWindowConfig config; + config.window_shape = window_shape; + config.window_strides = window_strides; + config.window_dilations = window_dilations; + config.window_padding = window_padding; + config.base_dilations = base_dilation; + + auto ret = kernel::hlo::ArgMax(hctx, lookupValue(sscope, op.input(), opts), + ret_shape, config); + + sscope->addValue(op.getResult(0), ret.first); + + sscope->addValue(op.getResult(1), ret.second); +} + +void execute(OpExecutor *executor, HalContext *hctx, SymbolScope *sscope, + mlir::pphlo::SelectOp &op, const ExecutionOptions &opts) { + auto pred = lookupValue(sscope, op.pred(), opts); + + auto on_true = lookupValue(sscope, op.on_true(), opts); + auto on_false = lookupValue(sscope, op.on_false(), opts); + + sscope->addValue(op.getResult(), + kernel::hlo::Select(hctx, pred, on_true, on_false)); +} + +void execute(OpExecutor *executor, HalContext *hctx, SymbolScope *sscope, + mlir::pphlo::RngOp &op, const ExecutionOptions &opts) { + auto to_shape = op.getType().dyn_cast().getShape(); + sscope->addValue( + op.getResult(), + kernel::hlo::Uniform_rand(hctx, lookupValue(sscope, op.a(), opts), + lookupValue(sscope, op.b(), opts), to_shape)); +} + +void execute(OpExecutor *executor, HalContext *hctx, SymbolScope *sscope, + mlir::pphlo::ConvertOp &op, const ExecutionOptions &opts) { + mlir::pphlo::TypeTools tool; + auto dst_dtype = getDtypeFromMlirType(op.getType()); + auto dst_vtype = tool.isMPCType(op.getType()) + ? VIS_PUBLIC + : VIS_SECRET; + auto in = lookupValue(sscope, op.getOperand(), opts); + + sscope->addValue(op.getResult(), + kernel::hlo::Cast(hctx, in, dst_vtype, dst_dtype)); +} + +void execute(OpExecutor *executor, HalContext *hctx, SymbolScope *sscope, + mlir::pphlo::PreferAOp &op, const ExecutionOptions &opts) { + auto in = lookupValue(sscope, op.operand(), opts); + auto k0 = kernel::hlo::Cast(hctx, kernel::hlo::Constant(hctx, 0, in.shape()), + VIS_PUBLIC, in.dtype()); + sscope->addValue(op.getResult(), kernel::hlo::Add(hctx, in, k0)); +} + +void execute(OpExecutor *executor, HalContext *hctx, SymbolScope *sscope, + mlir::pphlo::SignOp &op, const ExecutionOptions &opts) { + auto in = lookupValue(sscope, op.operand(), opts); + sscope->addValue(op.getResult(), kernel::hlo::Sign(hctx, in)); +} + +void execute(OpExecutor *executor, HalContext *hctx, SymbolScope *sscope, + mlir::pphlo::BitcastConvertOp &op, const ExecutionOptions &opts) { + const auto &in_type = + op.getOperand().getType().dyn_cast(); + const auto &out_type = + op.getResult().getType().dyn_cast(); + + // bitcast should not change total #bytes, so if sizeof(in_t) != + // sizeof(out_t) will result to a shape change, thus it's enough to just + // ensure in_shape == out_shape + YACL_ENFORCE(in_type.getShape() == out_type.getShape(), + "bitcast with different size is not supported yet"); + + sscope->addValue( + op.getResult(), + kernel::hlo::Bitcast(hctx, lookupValue(sscope, op.getOperand(), opts), + getDtypeFromMlirType(out_type))); +} + +void execute(OpExecutor *executor, HalContext *hctx, SymbolScope *sscope, + mlir::pphlo::ConstantOp &op, const ExecutionOptions &opts) { + const auto &val = op.value(); + const auto &dea = val.dyn_cast(); + const auto &type = val.getType().dyn_cast(); + const auto &dst_shape = type.getShape(); + const auto &pt_type = getPtTypeFromMlirType(type.getElementType()); + + PtBufferView view(dea.getRawData().data(), pt_type, + dea.isSplat() ? llvm::ArrayRef() : dst_shape, + dea.isSplat() ? std::vector() + : makeCompactStrides(dst_shape)); + + sscope->addValue(op.getResult(), + kernel::hlo::Constant(hctx, view, dst_shape)); +} + +void execute(OpExecutor *executor, HalContext *hctx, SymbolScope *sscope, + mlir::pphlo::ClampOp &op, const ExecutionOptions &opts) { + sscope->addValue(op.getResult(), + kernel::hlo::Clamp(hctx, + lookupValue(sscope, op.operand(), opts), + lookupValue(sscope, op.min(), opts), + lookupValue(sscope, op.max(), opts))); +} + +void execute(OpExecutor *executor, HalContext *hctx, SymbolScope *sscope, + mlir::pphlo::DbgPrintOp &op, const ExecutionOptions &opts) { + kernel::hal::dbg_print(hctx, lookupValue(sscope, op.operand(), opts)); +} + +#define DEFINE_UNIMPLEMENTED_OP(OpName) \ + void execute(OpExecutor *executor, HalContext *hctx, SymbolScope *sscope, \ + mlir::pphlo::OpName &, const ExecutionOptions &opts) { \ + YACL_THROW("Lowered op should not occur at backend"); \ + } + +DEFINE_UNIMPLEMENTED_OP(ReturnOp) + +#undef DEFINE_UNIMPLEMENTED_OP + +} // namespace + +template +static bool hasKernelImpl(mlir::Operation &op) { + if (auto casted = llvm::dyn_cast(op)) { + return true; + } else { + if constexpr (!sizeof...(MoreOpT)) { + return false; + } else { + return hasKernelImpl(op); + } + } +} + +bool PPHloExecutor::hasKernel(mlir::Operation &op) const { + return hasKernelImpl< +#define GET_OP_LIST +#include "spu/dialect/pphlo_ops.cc.inc" + >(op); +} + +template +static void dispatchOp(OpExecutor *executor, HalContext *hctx, + SymbolScope *sscope, mlir::Operation &op, + const ExecutionOptions &opts) { + if (auto casted = llvm::dyn_cast(op)) { + // Execute op + { + const auto fn_name = op.getName().getStringRef().str(); + SPU_TRACE_ACTION(GET_CTX_NAME(hctx_), (TR_HLO | TR_LAR), ~TR_HLO, + fn_name); + execute(executor, hctx, sscope, casted, opts); + } + + // currently we only support config verifier statically. + constexpr bool kEnableXlaVerifier = false; + if (kEnableXlaVerifier) { + XlaVerifier verifier(hctx); + // handle mixed (int, fxp) multiplication + if constexpr (std::is_same_v or + std::is_same_v or + std::is_same_v) { + spu::Value lhs = sscope->lookupValue(casted.lhs()); + spu::Value rhs = sscope->lookupValue(casted.rhs()); + spu::Value ret = sscope->lookupValue(casted.getResult()); + mlir::pphlo::TypeTools type_tool; + auto lhs_type = type_tool.getExpressedType(casted.lhs().getType()); + auto rhs_type = type_tool.getExpressedType(casted.rhs().getType()); + auto ret_type = + type_tool.getExpressedType(casted.getResult().getType()); + + if (lhs_type != ret_type) { + lhs = kernel::hlo::Cast(hctx, lhs, lhs.vtype(), ret.dtype()); + } + if (rhs_type != ret_type) { + rhs = kernel::hlo::Cast(hctx, rhs, rhs.vtype(), ret.dtype()); + } + + verifier.verify(casted, {lhs, rhs}, {ret}); + } else { + // Collect inputs + std::vector ins; + for (auto operand : op.getOperands()) { + ins.emplace_back(sscope->lookupValue(operand)); + } + std::vector outs; + for (auto operand : op.getResults()) { + outs.emplace_back(sscope->lookupValue(operand)); + } + + verifier.verify(casted, ins, outs); + } + } + } else { + if constexpr (!sizeof...(MoreOpT)) { + YACL_THROW("Unhandled mlir op {} at {}", mlirObjectToString(op), + mlirObjectToString(op.getLoc())); + } else { + dispatchOp(executor, hctx, sscope, op, opts); + } + } +} + +void PPHloExecutor::runKernelImpl(HalContext *hctx, SymbolScope *sscope, + mlir::Operation &op, + const ExecutionOptions &opts) { + if (opts.do_log_execution) { + SPDLOG_INFO("PPHLO {}", mlirObjectToString(op)); + } + dispatchOp< +#define GET_OP_LIST +#include "spu/dialect/pphlo_ops.cc.inc" + >(this, hctx, sscope, op, opts); +} + +void PPHloExecutor::checkType(mlir::Type mlir_type, const spu::Value &v) const { +} + +} // namespace spu::device::pphlo diff --git a/spu/device/pphlo/pphlo_executor.h b/spu/device/pphlo/pphlo_executor.h new file mode 100644 index 00000000..49a60cc2 --- /dev/null +++ b/spu/device/pphlo/pphlo_executor.h @@ -0,0 +1,40 @@ +// Copyright 2022 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "spu/device/executor.h" +#include "spu/device/pphlo/xla_verifier.h" +#include "spu/dialect/pphlo_ops.h" +#include "spu/dialect/pphlo_types.h" +#include "spu/kernel/context.h" +#include "spu/kernel/hlo/casting.h" + +namespace spu::device::pphlo { + +class PPHloExecutor : public OpExecutor { +public: + void checkType(mlir::Type mlir_type, const spu::Value &v) const override; + + // return true if the operation has a corresponding kernel. + bool hasKernel(mlir::Operation &op) const override; + + // run a kernel in a given region. + void runKernelImpl(HalContext *hcts, SymbolScope *sscope, mlir::Operation &op, + const ExecutionOptions &opts) override; +}; + +} // namespace spu::device::pphlo diff --git a/spu/device/pphlo/executor_test.cc b/spu/device/pphlo/pphlo_executor_test.cc similarity index 93% rename from spu/device/pphlo/executor_test.cc rename to spu/device/pphlo/pphlo_executor_test.cc index 69778897..d6aa2096 100644 --- a/spu/device/pphlo/executor_test.cc +++ b/spu/device/pphlo/pphlo_executor_test.cc @@ -13,6 +13,7 @@ // limitations under the License. #include +#include #include #include #include @@ -23,7 +24,8 @@ #include "spu/compiler/common/compilation_context.h" #include "spu/compiler/compile.h" -#include "spu/device/pphlo/executor.h" +#include "spu/device/api.h" +#include "spu/device/pphlo/pphlo_executor.h" #include "spu/device/symbol_table.h" #include "spu/device/test_utils.h" #include "spu/mpc/ref2k/ref2k.h" @@ -33,7 +35,6 @@ namespace spu::device { namespace { class Runner { - public: Runner(size_t world_size, FieldType field, ProtocolKind protocol) : world_size_(world_size) { @@ -49,7 +50,7 @@ class Runner { void addInput(const T &input, Visibility vis = Visibility::VIS_PUBLIC) { const std::string name = fmt::format("input{}", input_idx_++); io_->InFeed(name, input, vis); - exec_.add_input_names(name); + executable_.add_input_names(name); } std::string compileMHlo(const std::string &mhlo) { @@ -59,9 +60,9 @@ class Runner { void run(const std::string &mlir, size_t num_output = 1) { for (size_t idx = 0; idx < num_output; ++idx) { - exec_.add_output_names(fmt::format("output{}", idx)); + executable_.add_output_names(fmt::format("output{}", idx)); } - exec_.set_code(mlir); + executable_.set_code(mlir); ::spu::mpc::util::simulate( world_size_, [&](const std::shared_ptr &lctx) { RuntimeConfig conf; @@ -70,9 +71,9 @@ class Runner { // conf.set_enable_action_trace(true); } HalContext hctx(conf, lctx); - pphlo::PPHloExecutor executor(&hctx); auto *env = io_->GetSymbolTable(lctx->Rank()); - executor.runWithEnv(exec_, env); + pphlo::PPHloExecutor executor; + execute(&executor, &hctx, executable_, env); }); } @@ -105,7 +106,7 @@ class Runner { RuntimeConfig config_; std::unique_ptr io_; size_t input_idx_{0}; - ExecutableProto exec_; + ExecutableProto executable_; }; } // namespace @@ -128,6 +129,22 @@ func.func @main(%arg0: tensor>, %arg1: tensor>) r.verifyScalarOutput(3); } +TEST_P(ExecutorTest, InvalidIR) { + Runner r(std::get<0>(GetParam()), std::get<1>(GetParam()), + std::get<2>(GetParam())); + + ASSERT_THROW(r.run(R"( +func.func @main() -> tensor> { + %2 = "pphlo.constant"() {value = dense<[0x41DA6E5887800000, 0x41C94E3940000000, 0x41C4BD2007000000, 0x41DC95133AC00000, 0x41D1650CEC000000, 0x41C9DF42E7800000, 0x41D46C43B6800000, 0x41C467EE0E800000, 0x41DC705F14400000]> : tensor<9xf64>} : () -> tensor<9x!pphlo.pub> + %3 = "pphlo.floor"(%2) : (tensor<9x!pphlo.pub>) -> tensor<9x!pphlo.pub> + %9 = "pphlo.concatenate"(%3) {dimension = 0 : i64} : (tensor<9x!pphlo.pub>) -> tensor<9x!pphlo.pub> + %10 = "pphlo.broadcast"(%9) {broadcast_dimensions = dense<13> : tensor<1xi64>} : (tensor<9x!pphlo.pub>) -> tensor<9x!pphlo.pub> + %51 = "pphlo.constant"() {value = dense<5> : tensor} : () -> tensor> + "pphlo.return"(%51) : (tensor>) -> () +})"), + std::exception); +} + TEST_P(ExecutorTest, WithConst) { Runner r(std::get<0>(GetParam()), std::get<1>(GetParam()), std::get<2>(GetParam())); @@ -1789,18 +1806,9 @@ TEST_P(ExecutorTest, MaxPoolReduce1) { }); r.run(R"( -func.func @main(%arg0: tensor<4x6x!pphlo.pub>) -> (tensor<2x2x!pphlo.pub>, tensor<2x2x6x!pphlo.pub>) { - %0 = "pphlo.constant"() {value = dense<[[1, 0, 0, 0, 0, 0], [0, 1, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0], [0, 0, 0, 1, 0, 0], [0, 0, 0, 0, 1, 0], [0, 0, 0, 0, 0, 1]]> : tensor<6x6xi8>} : () -> tensor<6x6x!pphlo.pub> - %1 = "pphlo.constant"() {value = dense<-1> : tensor} : () -> tensor> - %2 = "pphlo.constant"() {value = dense<-1> : tensor} : () -> tensor> - %4:2 = "pphlo.reduce_window"(%arg0, %0, %2, %1) ({ - ^bb0(%arg2: tensor>, %arg3: tensor>, %arg4: tensor>, %arg5: tensor>): - %6 = "pphlo.greater_equal"(%arg2, %arg4) : (tensor>, tensor>) -> tensor> - %7 = "pphlo.select"(%6, %arg2, %arg4) : (tensor>, tensor>, tensor>) -> tensor> - %8 = "pphlo.select"(%6, %arg3, %arg5) : (tensor>, tensor>, tensor>) -> tensor> - "pphlo.return"(%7, %8) : (tensor>, tensor>) -> () - }) {base_dilations = dense<1> : tensor<4xi64>, ignore_init_value = true, last_operand_is_window_mask = true, padding = dense<0> : tensor<2x2xi64>, window_dilations = dense<1> : tensor<2xi64>, window_dimensions = dense<[2, 3]> : tensor<2xi64>, window_strides = dense<[2, 3]> : tensor<2xi64>} : (tensor<4x6x!pphlo.pub>, tensor<6x6x!pphlo.pub>, tensor>, tensor>) -> (tensor<2x2x!pphlo.pub>, tensor<2x2x6x!pphlo.pub>) - return %4#0, %4#1: tensor<2x2x!pphlo.pub>, tensor<2x2x6x!pphlo.pub> +func.func @main(%arg0: tensor<4x6x!pphlo.pub>) -> (tensor<2x2x!pphlo.pub>, tensor<2x2x6x!pphlo.pub>) { + %4:2 = "pphlo.argmax"(%arg0) {base_dilations = dense<1> : tensor<4xi64>, padding = dense<0> : tensor<2x2xi64>, window_dilations = dense<1> : tensor<2xi64>, window_dimensions = dense<[2, 3]> : tensor<2xi64>, window_strides = dense<[2, 3]> : tensor<2xi64>} : (tensor<4x6x!pphlo.pub>) -> (tensor<2x2x!pphlo.pub>, tensor<2x2x6x!pphlo.pub>) + return %4#0, %4#1: tensor<2x2x!pphlo.pub>, tensor<2x2x6x!pphlo.pub> })", 2); @@ -1808,7 +1816,7 @@ func.func @main(%arg0: tensor<4x6x!pphlo.pub>) -> (tensor<2x2x!pphlo.pub mask = { + xt::xarray mask = { {{0, 0, 0, 0, 0, 1}, {0, 1, 0, 0, 0, 0}}, // {{0, 0, 1, 0, 0, 0}, {0, 0, 0, 0, 0, 1}}, // }; @@ -1828,18 +1836,9 @@ TEST_P(ExecutorTest, MaxPoolReduce2) { }); r.run(R"( -func.func @main(%arg0: tensor<4x5x!pphlo.pub>) -> (tensor<2x2x!pphlo.pub>, tensor<2x2x6x!pphlo.pub>) { - %0 = "pphlo.constant"() {value = dense<[[1, 0, 0, 0, 0, 0], [0, 1, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0], [0, 0, 0, 1, 0, 0], [0, 0, 0, 0, 1, 0], [0, 0, 0, 0, 0, 1]]> : tensor<6x6xi8>} : () -> tensor<6x6x!pphlo.pub> - %1 = "pphlo.constant"() {value = dense<-1> : tensor} : () -> tensor> - %2 = "pphlo.constant"() {value = dense<-1> : tensor} : () -> tensor> - %4:2 = "pphlo.reduce_window"(%arg0, %0, %2, %1) ({ - ^bb0(%arg2: tensor>, %arg3: tensor>, %arg4: tensor>, %arg5: tensor>): - %6 = "pphlo.greater_equal"(%arg2, %arg4) : (tensor>, tensor>) -> tensor> - %7 = "pphlo.select"(%6, %arg2, %arg4) : (tensor>, tensor>, tensor>) -> tensor> - %8 = "pphlo.select"(%6, %arg3, %arg5) : (tensor>, tensor>, tensor>) -> tensor> - "pphlo.return"(%7, %8) : (tensor>, tensor>) -> () - }) {base_dilations = dense<1> : tensor<4xi64>, ignore_init_value = true, last_operand_is_window_mask = true, padding = dense<0> : tensor<2x2xi64>, window_dilations = dense<1> : tensor<2xi64>, window_dimensions = dense<[2, 3]> : tensor<2xi64>, window_strides = dense<[2, 2]> : tensor<2xi64>} : (tensor<4x5x!pphlo.pub>, tensor<6x6x!pphlo.pub>, tensor>, tensor>) -> (tensor<2x2x!pphlo.pub>, tensor<2x2x6x!pphlo.pub>) - return %4#0, %4#1: tensor<2x2x!pphlo.pub>, tensor<2x2x6x!pphlo.pub> +func.func @main(%arg0: tensor<4x5x!pphlo.pub>) -> (tensor<2x2x!pphlo.pub>, tensor<2x2x6x!pphlo.pub>) { + %4:2 = "pphlo.argmax"(%arg0) {base_dilations = dense<1> : tensor<4xi64>, padding = dense<0> : tensor<2x2xi64>, window_dilations = dense<1> : tensor<2xi64>, window_dimensions = dense<[2, 3]> : tensor<2xi64>, window_strides = dense<[2, 2]> : tensor<2xi64>} : (tensor<4x5x!pphlo.pub>) -> (tensor<2x2x!pphlo.pub>, tensor<2x2x6x!pphlo.pub>) + return %4#0, %4#1: tensor<2x2x!pphlo.pub>, tensor<2x2x6x!pphlo.pub> })", 2); @@ -1911,41 +1910,47 @@ TEST_P(ExecutorTest, MaxPoolReduce3) { Runner r(std::get<0>(GetParam()), std::get<1>(GetParam()), std::get<2>(GetParam())); - xt::xarray in = xt::reshape_view( - xt::xarray{ - {7, 2, 5, 3}, // - {3, 8, 9, 3}, // - {1, 5, 7, 5}, // - {0, 6, 2, 7} // - }, - {1, 4, 4, 1}); + { + xt::xarray in = xt::reshape_view( + xt::xarray{ + {7, 2, 5, 3}, // + {3, 8, 9, 3}, // + {1, 5, 7, 5}, // + {0, 6, 2, 7} // + }, + {1, 4, 4, 1}); + + r.addInput(in); + } - r.addInput(in); + { + xt::xarray in = xt::reshape_view( + xt::xarray{ + {10, 11, 12}, // + {13, 14, 15}, // + {16, 17, 18}, // + }, + {1, 3, 3, 1}); + + r.addInput(in); + } r.run(R"( -func.func @main(%arg0: tensor<1x4x4x1x!pphlo.pub>) -> (tensor<1x3x3x1x!pphlo.pub>, tensor<1x3x3x1x4x!pphlo.pub>) { - %0 = "pphlo.constant"() {value = dense<[[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]> : tensor<4x4xi8>} : () -> tensor<4x4x!pphlo.pub> - %1 = "pphlo.constant"() {value = dense<-1> : tensor} : () -> tensor> - %2 = "pphlo.constant"() {value = dense<-1> : tensor} : () -> tensor> - %4:2 = "pphlo.reduce_window"(%arg0, %0, %2, %1) ({ - ^bb0(%arg2: tensor>, %arg3: tensor>, %arg4: tensor>, %arg5: tensor>): - %6 = "pphlo.greater_equal"(%arg2, %arg4) : (tensor>, tensor>) -> tensor> - %7 = "pphlo.select"(%6, %arg2, %arg4) : (tensor>, tensor>, tensor>) -> tensor> - %8 = "pphlo.select"(%6, %arg3, %arg5) : (tensor>, tensor>, tensor>) -> tensor> - "pphlo.return"(%7, %8) : (tensor>, tensor>) -> () - }) {base_dilations = dense<1> : tensor<4xi64>, ignore_init_value = true, last_operand_is_window_mask = true, padding = dense<0> : tensor<4x2xi64>, window_dilations = dense<1> : tensor<4xi64>, window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, window_strides = dense<1> : tensor<4xi64>} : (tensor<1x4x4x1x!pphlo.pub>, tensor<4x4x!pphlo.pub>, tensor>, tensor>) -> (tensor<1x3x3x1x!pphlo.pub>, tensor<1x3x3x1x4x!pphlo.pub>) - return %4#0, %4#1: tensor<1x3x3x1x!pphlo.pub>, tensor<1x3x3x1x4x!pphlo.pub> +func.func @main(%arg0: tensor<1x4x4x1x!pphlo.pub>, %arg1: tensor<1x3x3x1x!pphlo.pub>) -> (tensor<1x3x3x1x!pphlo.pub>, tensor<1x3x3x1x4x!pphlo.pub>, tensor<1x4x4x1x!pphlo.pub>) { + %0:2 = "pphlo.argmax"(%arg0) {base_dilations = dense<1> : tensor<4xi64>, padding = dense<0> : tensor<4x2xi64>, window_dilations = dense<1> : tensor<4xi64>, window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, window_strides = dense<1> : tensor<4xi64>} : (tensor<1x4x4x1x!pphlo.pub>) -> (tensor<1x3x3x1x!pphlo.pub>, tensor<1x3x3x1x4x!pphlo.pub>) + %1 = "pphlo.maxpool_scatter"(%0#1, %arg1) {padding = dense<0> : tensor<4x2xi64>, window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, window_strides = dense<1> : tensor<4xi64>} : (tensor<1x3x3x1x4x!pphlo.pub>, tensor<1x3x3x1x!pphlo.pub>) -> tensor<1x4x4x1x!pphlo.pub> + return %0#0, %0#1, %1: tensor<1x3x3x1x!pphlo.pub>, tensor<1x3x3x1x4x!pphlo.pub>, tensor<1x4x4x1x!pphlo.pub> })", - 2); + 3); xt::xarray reduce_ret = {{8, 9, 9}, // {8, 9, 9}, {6, 7, 7}}; r.verifyOutput(reduce_ret.data(), 0); - xt::xarray mask = {{{0, 0, 0, 1}, {0, 0, 0, 1}, {0, 0, 1, 0}}, // - {{0, 1, 0, 0}, {0, 1, 0, 0}, {1, 0, 0, 0}}, // - {{0, 0, 0, 1}, {0, 1, 0, 0}, {0, 0, 0, 1}}}; + xt::xarray mask = {{{0, 0, 0, 1}, {0, 0, 0, 1}, {0, 0, 1, 0}}, // + {{0, 1, 0, 0}, {0, 1, 0, 0}, {1, 0, 0, 0}}, // + {{0, 0, 0, 1}, {0, 1, 0, 0}, {0, 0, 0, 1}}}; r.verifyOutput(mask.data(), 1); } diff --git a/spu/device/pphlo/region_executor.cc b/spu/device/pphlo/region_executor.cc deleted file mode 100644 index 68a9aba1..00000000 --- a/spu/device/pphlo/region_executor.cc +++ /dev/null @@ -1,931 +0,0 @@ -// Copyright 2022 Ant Group Co., Ltd. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "spu/device/pphlo/region_executor.h" - -#include "llvm/Support/raw_os_ostream.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/Location.h" - -#include "spu/device/frame.h" -#include "spu/dialect/pphlo_ops.h" -#include "spu/kernel/hlo/basic_binary.h" -#include "spu/kernel/hlo/basic_ternary.h" -#include "spu/kernel/hlo/basic_unary.h" -#include "spu/kernel/hlo/casting.h" -#include "spu/kernel/hlo/const.h" -#include "spu/kernel/hlo/control_flow.h" -#include "spu/kernel/hlo/convolution.h" -#include "spu/kernel/hlo/dynamic_slice.h" -#include "spu/kernel/hlo/geometrical.h" -#include "spu/kernel/hlo/indexing.h" -#include "spu/kernel/hlo/rand.h" -#include "spu/kernel/hlo/reduce.h" -#include "spu/kernel/hlo/select_and_scatter.h" -#include "spu/kernel/hlo/shift.h" -#include "spu/kernel/hlo/sort.h" - -namespace { - -std::vector -convertDenseIntElementAttr(const mlir::DenseIntElementsAttr &attr) { - std::vector ret; - - for (const auto &v : attr.getValues()) { - ret.emplace_back(v); - } - - return ret; -} - -std::string printLocation(const mlir::Location &loc) { - std::string pstr; - llvm::raw_string_ostream ss(pstr); - loc->print(ss); - ss.flush(); - return pstr; -} - -spu::PtType getPtType(const mlir::Type &type) { - if (auto ft = type.dyn_cast()) { - switch (ft.getWidth()) { - case 32: - return spu::PT_F32; - case 64: - return spu::PT_F64; - } - } - if (auto it = type.dyn_cast()) { - if (it.getWidth() == 1) { - return spu::PT_BOOL; - } - // In mlir, isSigned is for si[1-9][0-9]* type, isUnsigned is for - // ui[1-9][0-9]*, i[1-9][0-9]* is signless IntegerType... So here, we only - // check for isUnsigned, signless we treat it as signed. - // See https://reviews.llvm.org/D72533 - switch (it.getWidth()) { - case 8: - return it.isUnsigned() ? spu::PT_U8 : spu::PT_I8; - case 16: - return it.isUnsigned() ? spu::PT_U16 : spu::PT_I16; - case 32: - return it.isUnsigned() ? spu::PT_U32 : spu::PT_I32; - case 64: - return it.isUnsigned() ? spu::PT_U64 : spu::PT_I64; - } - } - YACL_THROW("Hit unknown pt_type"); -} - -spu::DataType getDtypeFromMlirType(::mlir::Type mlir_ty) { - mlir::pphlo::TypeTools tool; - if (auto int_ty = - tool.getExpressedType(mlir_ty).dyn_cast<::mlir::IntegerType>()) { - switch (int_ty.getWidth()) { - case 1: - return spu::DT_I1; - case 8: - return int_ty.isUnsigned() ? spu::DT_U8 : spu::DT_I8; - case 16: - return int_ty.isUnsigned() ? spu::DT_U16 : spu::DT_I16; - case 32: - return int_ty.isUnsigned() ? spu::DT_U32 : spu::DT_I32; - case 64: - return int_ty.isUnsigned() ? spu::DT_U64 : spu::DT_I64; - default: - YACL_THROW("unsupported int type {}"); - } - } - auto flp_ty = tool.getExpressedType(mlir_ty).dyn_cast<::mlir::FloatType>(); - YACL_ENFORCE(flp_ty, "invalid type"); - return spu::DT_FXP; -} - -} // namespace - -namespace spu::device::pphlo { - -const spu::Value &RegionExecutor::lookupValue(::mlir::Value v) const { - const auto *val = frame_->getValue(v); - if (val == nullptr) { - // Somehow cannot find this value on stack, print a reasonable error - // message. - std::string str; - llvm::raw_string_ostream debug_s(str); - v.getDefiningOp()->print(debug_s); - YACL_ENFORCE(false, "Try to get a non-exist value, defined at {}", - debug_s.str()); - } - return *val; -} - -#define LOWERED_OP_IMPL(OpName) \ - void RegionExecutor::execute(mlir::pphlo::OpName &) { \ - YACL_THROW("Lowered op should not occur at backend"); \ - } - -LOWERED_OP_IMPL(ReturnOp) - -#undef LOWERED_OP_IMPL - -#define UNIMPL_OP(OpName) \ - void RegionExecutor::execute(mlir::pphlo::OpName &op) { \ - YACL_THROW("Missing Runtime Impl Op {}", op->getName().getStringRef()); \ - } - -#undef UNIMPL_OP - -std::vector -RegionExecutor::executeRegion(mlir::Region ®ion, - absl::Span inputs) { - getFrame()->enterRegion(); - if (suppress_type_check_) { - getFrame()->setTypeCheker(nullptr); - } - - YACL_ENFORCE(region.getNumArguments() == inputs.size(), - "Entrypoint function requires {} arguments, which is more than " - "actual number of inputs {}", - region.getRegionNumber(), inputs.size()); - - for (const auto &blkarg : region.getArguments()) { - getFrame()->addValue(blkarg, inputs[blkarg.getArgNumber()]); - } - - auto ret = executeBlock(region.front()); - getFrame()->leaveRegion(); - if (getContext()->rt_config().enable_type_checker()) { - getFrame()->setTypeCheker(type_checker_); - } - return ret; -} - -std::vector RegionExecutor::executeBlock(mlir::Block &block) { - for (auto &op : block.without_terminator()) { - dispatchOp< -#define GET_OP_LIST -#include "spu/dialect/pphlo_ops.cc.inc" - >(op); - } - - if (auto *termOp = block.getTerminator()) { - if (!suppress_pphlo_trace_ && hctx_->rt_config().enable_pphlo_trace()) { - debug_print(*termOp); - } - return executeTerminator(*termOp); - } - - // No terminator - return {}; -} - -void RegionExecutor::debug_print(mlir::Operation &op) { - if (hctx_->lctx() && hctx_->lctx()->Rank() == 0) { - std::string buf; - llvm::raw_string_ostream debug_stream(buf); - op.print(debug_stream); - SPDLOG_INFO("PPHLO {}", debug_stream.str()); - } -} - -std::vector RegionExecutor::executeTerminator(mlir::Operation &op) { - if (llvm::isa(op) || - llvm::isa(op)) { - std::vector results; - results.reserve(op.getNumOperands()); - for (const auto operand : op.getOperands()) { - results.emplace_back(lookupValue(operand)); - } - return results; - } - llvm_unreachable("Unknown block terminator"); -} - -#define STANDARD_UNARY_OP_EXEC_IMPL(OpName, KernelName) \ - void RegionExecutor::execute(mlir::pphlo::OpName &op) { \ - const auto in = lookupValue(op.getOperand()); \ - auto ret = kernel::hlo::KernelName(hctx_, in); \ - getFrame()->addValue(op.getResult(), std::move(ret)); \ - } - -STANDARD_UNARY_OP_EXEC_IMPL(ReciprocalOp, Reciprocal) -STANDARD_UNARY_OP_EXEC_IMPL(NegOp, Neg) -STANDARD_UNARY_OP_EXEC_IMPL(ExpOp, Exp) -STANDARD_UNARY_OP_EXEC_IMPL(Expm1Op, Expm1) -STANDARD_UNARY_OP_EXEC_IMPL(LogOp, Log) -STANDARD_UNARY_OP_EXEC_IMPL(Log1pOp, Log1p) -STANDARD_UNARY_OP_EXEC_IMPL(FloorOp, Floor) -STANDARD_UNARY_OP_EXEC_IMPL(CeilOp, Ceil) -STANDARD_UNARY_OP_EXEC_IMPL(AbsOp, Abs) -STANDARD_UNARY_OP_EXEC_IMPL(LogisticOp, Logistic) -STANDARD_UNARY_OP_EXEC_IMPL(TanhOp, Tanh) -STANDARD_UNARY_OP_EXEC_IMPL(NotOp, Not) -STANDARD_UNARY_OP_EXEC_IMPL(RsqrtOp, Rsqrt) -STANDARD_UNARY_OP_EXEC_IMPL(SqrtOp, Sqrt) -STANDARD_UNARY_OP_EXEC_IMPL(RoundOp, Round_AFZ) - -#undef STANDARD_UNARY_OP_EXEC_IMPL - -#define STANDARD_BINARY_OP_EXEC_IMPL(OpName, KernelName) \ - void RegionExecutor::execute(mlir::pphlo::OpName &op) { \ - getFrame()->addValue(op.getResult(), \ - kernel::hlo::KernelName(hctx_, lookupValue(op.lhs()), \ - lookupValue(op.rhs()))); \ - } - -STANDARD_BINARY_OP_EXEC_IMPL(AddOp, Add) -STANDARD_BINARY_OP_EXEC_IMPL(EqualOp, Equal) -STANDARD_BINARY_OP_EXEC_IMPL(NotEqualOp, NotEqual) -STANDARD_BINARY_OP_EXEC_IMPL(LessEqualOp, LessEqual) -STANDARD_BINARY_OP_EXEC_IMPL(GreaterEqualOp, GreaterEqual) -STANDARD_BINARY_OP_EXEC_IMPL(SubtractOp, Sub) -STANDARD_BINARY_OP_EXEC_IMPL(LessOp, Less) -STANDARD_BINARY_OP_EXEC_IMPL(GreaterOp, Greater) -STANDARD_BINARY_OP_EXEC_IMPL(MulOp, Mul) -STANDARD_BINARY_OP_EXEC_IMPL(PowOp, Power) -STANDARD_BINARY_OP_EXEC_IMPL(MaxOp, Max) -STANDARD_BINARY_OP_EXEC_IMPL(MinOp, Min) -STANDARD_BINARY_OP_EXEC_IMPL(AndOp, And) -STANDARD_BINARY_OP_EXEC_IMPL(OrOp, Or) -STANDARD_BINARY_OP_EXEC_IMPL(XorOp, Xor) -STANDARD_BINARY_OP_EXEC_IMPL(DivOp, Div) -STANDARD_BINARY_OP_EXEC_IMPL(ShiftLeftOp, Lshift) -STANDARD_BINARY_OP_EXEC_IMPL(ShiftRightArithmeticOp, ARshift) -STANDARD_BINARY_OP_EXEC_IMPL(ShiftRightLogicalOp, Rshift) - -#undef STANDARD_BINARY_OP_EXEC_IMPL - -void RegionExecutor::execute(mlir::pphlo::DotOp &op) { - auto ret = - kernel::hlo::Dot(hctx_, lookupValue(op.lhs()), lookupValue(op.rhs())); - - const auto ret_shape = - op.getResult().getType().dyn_cast().getShape(); - - getFrame()->addValue(op.getResult(), - kernel::hlo::Reshape(hctx_, ret, ret_shape)); -} - -void RegionExecutor::execute(mlir::pphlo::DotGeneralOp &op) { - auto dnum = op.dot_dimension_numbers(); - // Should in order - YACL_ENFORCE(dnum.getLhsBatchingDimensions().size() == 1 && - dnum.getLhsContractingDimensions().size() == 1 && - dnum.getLhsBatchingDimensions()[0] == 0 && - dnum.getLhsContractingDimensions()[0] == 2, - "LHS dims is not in order"); - YACL_ENFORCE(dnum.getRhsBatchingDimensions().size() == 1 && - dnum.getRhsContractingDimensions().size() == 1 && - dnum.getRhsBatchingDimensions()[0] == 0 && - dnum.getRhsContractingDimensions()[0] == 1, - "RHS dims is not in order"); - - auto lhs = lookupValue(op.lhs()); - auto rhs = lookupValue(op.rhs()); - YACL_ENFORCE(lhs.shape()[0] == rhs.shape()[0], "Batch dim should equal"); - int64_t num_batch = lhs.shape()[0]; - - std::vector results(num_batch); - std::vector lhs_slice_begin(3, 0); - std::vector lhs_slice_end = lhs.shape(); - std::vector rhs_slice_begin(3, 0); - std::vector rhs_slice_end = rhs.shape(); - std::vector strides(lhs.shape().size(), 1); - - std::vector lhs_slice_shape{lhs.shape()[1], lhs.shape()[2]}; - std::vector rhs_slice_shape{rhs.shape()[1], rhs.shape()[2]}; - std::vector ret_slice_shape{1, lhs.shape()[1], rhs.shape()[2]}; - - for (int64_t batch_idx = 0; batch_idx < num_batch; ++batch_idx) { - lhs_slice_begin[0] = batch_idx; - lhs_slice_end[0] = batch_idx + 1; - rhs_slice_begin[0] = batch_idx; - rhs_slice_end[0] = batch_idx + 1; - auto lhs_slice = kernel::hlo::Reshape( - hctx_, - kernel::hlo::Slice(hctx_, lhs, lhs_slice_begin, lhs_slice_end, strides), - lhs_slice_shape); - auto rhs_slice = kernel::hlo::Reshape( - hctx_, - kernel::hlo::Slice(hctx_, rhs, rhs_slice_begin, rhs_slice_end, strides), - rhs_slice_shape); - results[batch_idx] = kernel::hlo::Reshape( - hctx_, kernel::hlo::Dot(hctx_, lhs_slice, rhs_slice), ret_slice_shape); - } - - auto ret_type = op.getResult().getType().dyn_cast(); - auto ret = kernel::hlo::Reshape( - hctx_, kernel::hlo::Concatenate(hctx_, results, 0), ret_type.getShape()); - - getFrame()->addValue(op.getResult(), std::move(ret)); -} - -void RegionExecutor::execute(mlir::pphlo::ConvolutionOp &op) { - const auto &dnums = op.dimension_numbers(); - const size_t num_spatial_dims = dnums.getOutputSpatialDimensions().size(); - YACL_ENFORCE(num_spatial_dims == dnums.getInputSpatialDimensions().size()); - YACL_ENFORCE(num_spatial_dims == dnums.getKernelSpatialDimensions().size()); - - const auto ret_shape = - op.getResult().getType().dyn_cast().getShape(); - - auto lhs = lookupValue(op.lhs()); - auto rhs = lookupValue(op.rhs()); - - std::vector window_strides(dnums.getInputSpatialDimensions().size(), - 1); - if (op.window_strides().has_value()) { - for (const auto &iter : - llvm::enumerate(op.window_strides()->getValues())) { - window_strides[iter.index()] = iter.value(); - } - } - - kernel::hlo::ConvolutionConfig config; - config.featureGroupCount = op.feature_group_count(); - config.batchGroupCount = op.batch_group_count(); - config.window_strides = window_strides; - config.inputBatchDimension = dnums.getInputBatchDimension(); - config.inputFeatureDimension = dnums.getInputFeatureDimension(); - config.inputSpatialDimensions = dnums.getInputSpatialDimensions(); - config.kernelInputFeatureDimension = dnums.getKernelInputFeatureDimension(); - config.kernelOutputFeatureDimension = dnums.getKernelOutputFeatureDimension(); - config.kernelSpatialDimensions = dnums.getKernelSpatialDimensions(); - config.outputBatchDimension = dnums.getOutputBatchDimension(); - config.outputFeatureDimension = dnums.getOutputFeatureDimension(); - config.outputSpatialDimensions = dnums.getOutputSpatialDimensions(); - - spu::Value result; - - if (dnums.getInputSpatialDimensions().size() == 2) { - result = kernel::hlo::Convolution2D(hctx_, lhs, rhs, config, ret_shape); - } else { - result = kernel::hlo::Convolution(hctx_, lhs, rhs, config, ret_shape); - } - - getFrame()->addValue(op.getResult(), std::move(result)); -} - -void RegionExecutor::execute(mlir::pphlo::DynamicUpdateSliceOp &op) { - // Basic idea here, get a ref slice and update the whole slice.. - // Start indicies - std::vector start_indicies(op.start_indices().size()); - const auto &operand = lookupValue(op.operand()); - const auto &update = lookupValue(op.update()); - - for (const auto &idx : llvm::enumerate(op.start_indices())) { - start_indicies[idx.index()] = lookupValue(idx.value()); - } - - getFrame()->addValue( - op.getResult(), - kernel::hlo::DynamicUpdateSlice(hctx_, operand, update, start_indicies)); -} - -void RegionExecutor::execute(mlir::pphlo::DynamicSliceOp &op) { - // Start indicies - auto iter = op.slice_sizes().getValues(); - std::vector slice_size{iter.begin(), iter.end()}; - const auto &operand = lookupValue(op.operand()); - std::vector start_indicies(op.start_indices().size()); - - for (const auto &idx : llvm::enumerate(op.start_indices())) { - start_indicies[idx.index()] = lookupValue(idx.value()); - } - - getFrame()->addValue( - op.getResult(), - kernel::hlo::DynamicSlice(hctx_, operand, slice_size, start_indicies)); -} - -void RegionExecutor::execute(mlir::pphlo::GatherOp &op) { - // If input is empty, short circuit - auto operand = lookupValue(op.operand()); - auto start_indicies = lookupValue(op.start_indices()); - if (operand.numel() == 0) { - getFrame()->addValue(op.getResult(), operand); - return; - } - - const auto &output_shape = - op.getResult().getType().dyn_cast().getShape(); - - const auto &dim_numbers = op.dimension_numbers(); - - kernel::hlo::GatherConfig config; - auto ss = convertDenseIntElementAttr(op.slice_sizes()); - config.sliceSizes = ss; - config.indexVectorDim = dim_numbers.getIndexVectorDim(); - config.offsetDims = dim_numbers.getOffsetDims(); - config.collapsedSliceDims = dim_numbers.getCollapsedSliceDims(); - config.startIndexMap = dim_numbers.getStartIndexMap(); - - getFrame()->addValue(op.getResult(), - kernel::hlo::Gather(hctx_, operand, start_indicies, - config, output_shape)); -} - -void RegionExecutor::execute(mlir::pphlo::SortOp &op) { - auto sort_dim = op.dimension(); - auto is_stable = op.is_stable(); - std::vector inputs(op->getNumOperands()); - for (size_t idx = 0; idx < inputs.size(); ++idx) { - inputs[idx] = lookupValue(op->getOperand(idx)); - } - - // TODO(junfeng): provide comparator_ret_vis. - suppress_type_check_ = true; - auto ret = kernel::hlo::Sort(hctx_, inputs, sort_dim, is_stable, - [&](absl::Span inputs) { - auto ret = - executeRegion(op.comparator(), inputs); - return ret[0]; - }); - suppress_type_check_ = false; - - for (int64_t idx = 0; idx < op->getNumResults(); ++idx) { - getFrame()->addValue(op->getResult(idx), std::move(ret[idx])); - } -} - -void RegionExecutor::execute(mlir::pphlo::SelectAndScatterOp &op) { - auto operand = lookupValue(op.operand()); - auto source = lookupValue(op.source()); - auto init_val = lookupValue(op.init_value()); - - auto window_shape = convertDenseIntElementAttr(op.window_dimensions()); - - // build strides - std::vector window_strides(window_shape.size(), 1); - if (op.window_strides().has_value()) { - window_strides = convertDenseIntElementAttr(*op.window_strides()); - } - - // window padding - std::vector> window_padding(window_shape.size(), - {0, 0}); - if (op.padding().has_value()) { - const auto v = *op.padding(); - - YACL_ENFORCE(window_padding.size() * 2 == (size_t)v.size()); - - for (size_t idx = 0; idx < window_padding.size(); ++idx) { - window_padding[idx] = {*(v.getValues().begin() + 2 * idx), - *(v.getValues().begin() + 2 * idx + 1)}; - } - } - - suppress_pphlo_trace_ = true; - suppress_type_check_ = true; - - // auto ret = kernel::hlo::SelectAndScatterNaive( - auto ret = kernel::hlo::SelectAndScatterExpanded( - hctx_, operand, source, init_val, window_shape, window_strides, - window_padding, - [&](const spu::Value &selected, const spu::Value ¤t) { - auto ret = executeRegion(op.select(), {selected, current}); - return ret[0]; - }, - [&](const spu::Value &in, const spu::Value &scatter) { - auto ret = executeRegion(op.scatter(), {in, scatter}); - return ret[0]; - }); - - suppress_pphlo_trace_ = false; - suppress_type_check_ = false; - - getFrame()->addValue(op.getResult(), std::move(ret)); -} - -void RegionExecutor::execute(mlir::pphlo::MaxPoolScatterOp &op) { - auto scatter_indices = lookupValue(op.scatter_indices()); - auto update = lookupValue(op.update()); - - auto window_shape = - convertDenseIntElementAttr(op.window_dimensions().value()); - - // build strides - std::vector window_strides(window_shape.size(), 1); - if (op.window_strides().has_value()) { - window_strides = convertDenseIntElementAttr(*op.window_strides()); - } - - // window padding - std::vector> window_padding(window_shape.size(), - {0, 0}); - if (op.padding().has_value()) { - const auto v = *op.padding(); - - YACL_ENFORCE(window_padding.size() * 2 == (size_t)v.size()); - - for (size_t idx = 0; idx < window_padding.size(); ++idx) { - window_padding[idx] = {*(v.getValues().begin() + 2 * idx), - *(v.getValues().begin() + 2 * idx + 1)}; - } - } - - auto base_shape = - op.getResult().getType().dyn_cast().getShape(); - - auto ret = - kernel::hlo::MaxPoolScatter(hctx_, scatter_indices, update, window_shape, - base_shape, window_strides, window_padding); - - getFrame()->addValue(op.getResult(), std::move(ret)); -} - -void RegionExecutor::execute(mlir::pphlo::IfOp &op) { - auto conditional = lookupValue(op.condition()); - - auto results = kernel::hlo::IfElse( - hctx_, conditional, // - [&]() { return executeRegion(op.true_branch(), {}); }, - [&]() { return executeRegion(op.false_branch(), {}); }); - - // Copy output - for (const auto &ret : llvm::enumerate(op->getResults())) { - getFrame()->addValue(ret.value(), results[ret.index()]); - } -} - -void RegionExecutor::execute(mlir::pphlo::WhileOp &op) { - // First inputs vectors - std::vector inputs; - inputs.reserve(op->getNumOperands()); - - // Prepare inputs - for (const auto operand : op->getOperands()) { - inputs.emplace_back(lookupValue(operand)); - } - - auto ret = kernel::hlo::While( - hctx_, inputs, // - [&](absl::Span inputs) { - return executeRegion(op.cond(), inputs)[0]; - }, - [&](absl::Span inputs) { - return executeRegion(op.body(), inputs); - }); - - for (size_t idx = 0; idx < op->getNumResults(); ++idx) { - getFrame()->addValue(op->getResult(idx), std::move(ret[idx])); - } -} - -#define DISPATCH_ALL_NONE_BOOL_PT_TYPES(PT_TYPE, NAME, ...) \ - [&] { \ - switch (PT_TYPE) { \ - __CASE_PT_TYPE(spu::PT_I8, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_U8, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_I16, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_U16, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_I32, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_U32, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_I64, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_U64, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_F32, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_F64, NAME, __VA_ARGS__) \ - default: \ - YACL_THROW("{} not implemented for pt_type={}", #NAME, PT_TYPE); \ - } \ - }() - -void RegionExecutor::execute(mlir::pphlo::IotaOp &op) { - const auto &ret_type = - op.output().getType().dyn_cast(); - const size_t numel = ret_type.getShape()[op.iota_dimension()]; - - auto ret_el_type = type_tools_.getExpressedType(ret_type); - auto pt_type = getPtType(ret_el_type); - - spu::Value iota_ret; - DISPATCH_ALL_NONE_BOOL_PT_TYPES(pt_type, "_", [&] { - iota_ret = kernel::hlo::Iota(hctx_, numel, VIS_PUBLIC); - }); - - if (ret_type.getShape().size() > 1) { - // Need a broadcast - iota_ret = kernel::hlo::Broadcast(hctx_, iota_ret, ret_type.getShape(), {}); - } - - getFrame()->addValue(op.output(), std::move(iota_ret)); -} - -void RegionExecutor::execute(mlir::pphlo::RemOp &op) { - // FIXME: When hal has a remainder, use that - auto lhs = lookupValue(op.lhs()); - auto rhs = lookupValue(op.rhs()); - - auto ret = kernel::hlo::Remainder(hctx_, lhs, rhs); - getFrame()->addValue(op.getResult(), std::move(ret)); -} - -void RegionExecutor::execute(mlir::pphlo::TransposeOp &op) { - getFrame()->addValue( - op.getResult(), - kernel::hlo::Transpose(hctx_, lookupValue(op.getOperand()), - convertDenseIntElementAttr(op.permutation()))); -} - -void RegionExecutor::execute(mlir::pphlo::BroadcastOp &op) { - auto to_shape = op.getType().dyn_cast().getShape(); - getFrame()->addValue( - op.getResult(), - kernel::hlo::Broadcast( - hctx_, lookupValue(op.getOperand()), to_shape, - convertDenseIntElementAttr(op.broadcast_dimensions()))); -} - -void RegionExecutor::execute(mlir::pphlo::ReshapeOp &op) { - auto to_shape = op.getType().dyn_cast().getShape(); - getFrame()->addValue( - op.getResult(), - kernel::hlo::Reshape(hctx_, lookupValue(op.getOperand()), to_shape)); -} - -void RegionExecutor::execute(mlir::pphlo::ConcatenateOp &op) { - std::vector values(op->getNumOperands()); - - for (size_t idx = 0; idx < op->getNumOperands(); ++idx) { - values[idx] = lookupValue(op->getOperand(idx)); - } - - // set result - getFrame()->addValue(op.getResult(), - kernel::hlo::Concatenate(hctx_, values, op.dimension())); -} - -void RegionExecutor::execute(mlir::pphlo::SliceOp &op) { - getFrame()->addValue( - op.getResult(), - kernel::hlo::Slice(hctx_, lookupValue(op.getOperand()), - convertDenseIntElementAttr(op.start_indices()), - convertDenseIntElementAttr(op.limit_indices()), - convertDenseIntElementAttr(op.strides()))); -} - -void RegionExecutor::execute(mlir::pphlo::PadOp &op) { - const auto &operand = lookupValue(op.operand()); - const size_t operand_rank = operand.shape().size(); - const auto &padding_value = lookupValue(op.padding_value()); - YACL_ENFORCE(padding_value.shape().empty()); - - auto edge_padding_low = convertDenseIntElementAttr(op.edge_padding_low()); - YACL_ENFORCE(edge_padding_low.size() == operand_rank); - auto edge_padding_high = convertDenseIntElementAttr(op.edge_padding_high()); - YACL_ENFORCE(edge_padding_high.size() == operand_rank); - auto interior_padding = convertDenseIntElementAttr(op.interior_padding()); - YACL_ENFORCE(interior_padding.size() == operand_rank); - YACL_ENFORCE(std::all_of(interior_padding.begin(), interior_padding.end(), - [](int64_t i) { return i >= 0; })); - - getFrame()->addValue(op.getResult(), - kernel::hlo::Pad(hctx_, operand, padding_value, - edge_padding_low, edge_padding_high, - interior_padding)); -} - -void RegionExecutor::execute(mlir::pphlo::ReverseOp &op) { - getFrame()->addValue( - op.getResult(), - kernel::hlo::Reverse(hctx_, lookupValue(op.getOperand()), - convertDenseIntElementAttr(op.dimensions()))); -} - -void RegionExecutor::errorUnknownOp(mlir::Operation &op) { - // These lines of code in theory should not hit. - // If hit, make a proper error message. - std::string err_str; - llvm::raw_string_ostream err(err_str); - op.print(err); - YACL_THROW("Unhandled mlir op {} at {}", err.str(), - printLocation(op.getLoc())); -} - -void RegionExecutor::execute(mlir::pphlo::ReduceOp &op) { - int64_t num_args = op->getNumOperands() / 2; - std::vector dimensions_to_reduce = - convertDenseIntElementAttr(op.dimensions()); - - std::vector input_args(num_args); - std::vector init_values(num_args); - for (int64_t i = 0; i < num_args; ++i) { - input_args[i] = lookupValue(op.inputs()[i]); - init_values[i] = lookupValue(op.init_values()[i]); - } - - suppress_type_check_ = true; - suppress_pphlo_trace_ = true; - - std::vector ret = kernel::hlo::Reduce( - hctx_, input_args, init_values, dimensions_to_reduce, - [&](absl::Span lhs, absl::Span rhs) { - std::vector operands; - operands.reserve(lhs.size() + rhs.size()); - operands.insert(operands.end(), lhs.begin(), lhs.end()); - operands.insert(operands.end(), rhs.begin(), rhs.end()); - return executeRegion(op.body(), operands); - }); - - suppress_type_check_ = false; - suppress_pphlo_trace_ = false; - - const auto &output_shape = - op->getResultTypes()[0].dyn_cast().getShape(); - for (size_t idx = 0; idx < op->getNumResults(); ++idx) { - getFrame()->addValue(op->getResult(idx), - kernel::hlo::Reshape(hctx_, ret[idx], output_shape)); - } -} - -void RegionExecutor::execute(mlir::pphlo::ReduceWindowOp &op) { - int64_t num_args = op->getNumOperands() / 2; - - std::vector input_args(num_args); - std::vector init_values(num_args); - - for (int64_t i = 0; i < num_args; ++i) { - input_args[i] = lookupValue(op.inputs()[i]); - init_values[i] = lookupValue(op.init_values()[i]); - } - - auto ret_shape = op->getResults()[0] - .getType() - .dyn_cast() - .getShape(); - auto window_shape = convertDenseIntElementAttr(op.window_dimensions()); - - // build strides - std::vector window_strides(window_shape.size(), 1); - if (op.window_strides().has_value()) { - window_strides = convertDenseIntElementAttr(*op.window_strides()); - } - - // window dilation - std::vector window_dilations(window_shape.size(), 1); - if (op.window_dilations().has_value()) { - window_dilations = convertDenseIntElementAttr(*op.window_dilations()); - } - - // window padding - std::vector> window_padding(window_shape.size(), - {0, 0}); - if (op.padding().has_value()) { - const auto v = *op.padding(); - - YACL_ENFORCE(window_padding.size() * 2 == (size_t)v.size()); - - for (size_t idx = 0; idx < window_padding.size(); ++idx) { - window_padding[idx] = {*(v.getValues().begin() + 2 * idx), - *(v.getValues().begin() + 2 * idx + 1)}; - } - } - - // base dilation - std::vector base_dilation(window_shape.size(), 1); - if (op.base_dilations().has_value()) { - base_dilation = convertDenseIntElementAttr(*op.base_dilations()); - } - - kernel::hlo::ReduceWindowConfig config; - config.window_shape = window_shape; - config.window_strides = window_strides; - config.window_dilations = window_dilations; - config.window_padding = window_padding; - config.base_dilations = base_dilation; - config.last_operand_is_window_mask = op.last_operand_is_window_mask(); - config.ignore_init_value = op.ignore_init_value(); - - suppress_type_check_ = true; - suppress_pphlo_trace_ = true; - auto rets = kernel::hlo::ReduceWindow( - hctx_, input_args, init_values, ret_shape, config, - [&](absl::Span lhs, absl::Span rhs) { - std::vector operands; - operands.reserve(lhs.size() + rhs.size()); - operands.insert(operands.end(), lhs.begin(), lhs.end()); - operands.insert(operands.end(), rhs.begin(), rhs.end()); - return executeRegion(op.body(), operands); - }); - suppress_type_check_ = false; - suppress_pphlo_trace_ = false; - - for (int64_t idx = 0; idx < op->getNumResults(); ++idx) { - getFrame()->addValue(op->getResults()[idx], std::move(rets[idx])); - } -} - -void RegionExecutor::execute(mlir::pphlo::SelectOp &op) { - auto pred = lookupValue(op.pred()); - - for (size_t idx = 0; idx < op.on_true().size(); ++idx) { - auto on_true = lookupValue(op.on_true()[idx]); - auto on_false = lookupValue(op.on_false()[idx]); - - auto pred_ = pred; - if (suppress_type_check_) { - // FIXME: This can happen during argmax reduce window, which the mask can - // have an extra trailing dimens - if (pred.shape().size() + 1 == on_true.shape().size()) { - // Do a broadcast - auto new_shape = on_true.shape(); - new_shape.back() = 1; - pred_ = kernel::hlo::Reshape(hctx_, pred_, new_shape); - pred_ = kernel::hlo::Broadcast(hctx_, pred_, on_true.shape(), {}); - } - } - - getFrame()->addValue(op.getResults()[idx], - kernel::hlo::Select(hctx_, pred_, on_true, on_false)); - } -} - -void RegionExecutor::execute(mlir::pphlo::RngOp &op) { - auto to_shape = op.getType().dyn_cast().getShape(); - getFrame()->addValue( - op.getResult(), kernel::hlo::Uniform_rand(hctx_, lookupValue(op.a()), - lookupValue(op.b()), to_shape)); -} - -void RegionExecutor::execute(mlir::pphlo::ConvertOp &op) { - mlir::pphlo::TypeTools tool; - auto dst_dtype = getDtypeFromMlirType(op.getType()); - auto dst_vtype = tool.isMPCType(op.getType()) - ? VIS_PUBLIC - : VIS_SECRET; - auto in = lookupValue(op.getOperand()); - - getFrame()->addValue(op.getResult(), - kernel::hlo::Cast(hctx_, in, dst_vtype, dst_dtype)); -} - -void RegionExecutor::execute(mlir::pphlo::PreferAOp &op) { - auto in = lookupValue(op.operand()); - auto k0 = - kernel::hlo::Cast(hctx_, kernel::hlo::Constant(hctx_, 0, in.shape()), - VIS_PUBLIC, in.dtype()); - getFrame()->addValue(op.getResult(), kernel::hlo::Add(hctx_, in, k0)); -} - -void RegionExecutor::execute(mlir::pphlo::SignOp &op) { - auto in = lookupValue(op.operand()); - getFrame()->addValue(op.getResult(), kernel::hlo::Sign(hctx_, in)); -} - -void RegionExecutor::execute(mlir::pphlo::BitcastConvertOp &op) { - const auto &in_type = - op.getOperand().getType().dyn_cast(); - const auto &out_type = - op.getResult().getType().dyn_cast(); - - // bitcast should not change total #bytes, so if sizeof(in_t) != - // sizeof(out_t) will result to a shape change, thus it's enough to just - // ensure in_shape == out_shape - YACL_ENFORCE(in_type.getShape() == out_type.getShape(), - "bitcast with different size is not supported yet"); - - getFrame()->addValue(op.getResult(), - kernel::hlo::Bitcast(hctx_, lookupValue(op.getOperand()), - getDtypeFromMlirType(out_type))); -} - -void RegionExecutor::execute(mlir::pphlo::ConstantOp &op) { - const auto &val = op.value(); - const auto &dea = val.dyn_cast(); - const auto &type = val.getType().dyn_cast(); - const auto &dst_shape = type.getShape(); - const auto &pt_type = getPtType(type.getElementType()); - - PtBufferView view(dea.getRawData().data(), pt_type, - dea.isSplat() ? llvm::ArrayRef() : dst_shape, - dea.isSplat() ? std::vector() - : makeCompactStrides(dst_shape)); - - getFrame()->addValue(op.getResult(), - kernel::hlo::Constant(hctx_, view, dst_shape)); -} - -void RegionExecutor::execute(mlir::pphlo::ClampOp &op) { - getFrame()->addValue(op.getResult(), - kernel::hlo::Clamp(hctx_, lookupValue(op.operand()), - lookupValue(op.min()), - lookupValue(op.max()))); -} - -void RegionExecutor::execute(mlir::pphlo::DbgPrintOp &op) { - kernel::hal::dbg_print(hctx_, lookupValue(op.operand())); -} - -} // namespace spu::device::pphlo diff --git a/spu/device/pphlo/region_executor.h b/spu/device/pphlo/region_executor.h deleted file mode 100644 index 3370ed2e..00000000 --- a/spu/device/pphlo/region_executor.h +++ /dev/null @@ -1,232 +0,0 @@ -// Copyright 2022 Ant Group Co., Ltd. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include -#include - -#include "spu/device/frame.h" -#include "spu/device/pphlo/type_checker.h" -#include "spu/device/pphlo/xla_verifier.h" -#include "spu/device/profiler.h" -#include "spu/dialect/pphlo_ops.h" -#include "spu/dialect/pphlo_types.h" -#include "spu/kernel/context.h" -#include "spu/kernel/hlo/casting.h" - -namespace spu::device::pphlo { - -class RegionExecutor { -public: - explicit RegionExecutor(HalContext *ctx, Frame *frame, - std::shared_ptr profiler) - : hctx_(ctx), frame_(frame), profiler_(std::move(profiler)), - type_checker_(std::make_shared()) { - frame->enterRegion(); - - if (ctx->feature_control().enable_xla_verifier) { - verifier_ = std::make_unique(ctx); - } - } - - ~RegionExecutor() { frame_->leaveRegion(); } - - std::vector executeRegion(mlir::Region ®ion, - absl::Span inputs); - - HalContext *getContext() const { return hctx_; } - -private: - std::vector executeBlock(mlir::Block &block); - std::vector executeTerminator(mlir::Operation &op); - - void debug_print(mlir::Operation &op); - - template - void dispatchOp(mlir::Operation &op) { - if (auto casted = llvm::dyn_cast(op)) { - if (!suppress_pphlo_trace_ && - (hctx_->rt_config().enable_pphlo_trace() || - hctx_->feature_control().enable_xla_verifier)) { - debug_print(op); - } - - std::optional tp; - if (hctx_->rt_config().enable_pphlo_profile()) { - tp = profiler_->start(); - } - - // Execute op - execute(casted); - - if (tp.has_value()) { - profiler_->end(op.getName().getStringRef(), *tp); - } - - if (verifier_) { - // handle mixed (int, fxp) multiplication - if constexpr (std::is_same_v or - std::is_same_v or - std::is_same_v) { - spu::Value lhs = lookupValue(casted.lhs()); - spu::Value rhs = lookupValue(casted.rhs()); - spu::Value ret = lookupValue(casted.getResult()); - mlir::pphlo::TypeTools type_tool; - auto lhs_type = type_tool.getExpressedType(casted.lhs().getType()); - auto rhs_type = type_tool.getExpressedType(casted.rhs().getType()); - auto ret_type = - type_tool.getExpressedType(casted.getResult().getType()); - - if (lhs_type != ret_type) { - lhs = kernel::hlo::Cast(hctx_, lhs, lhs.vtype(), ret.dtype()); - } - if (rhs_type != ret_type) { - rhs = kernel::hlo::Cast(hctx_, rhs, rhs.vtype(), ret.dtype()); - } - - verifier_->verify(casted, {lhs, rhs}, {ret}); - } else { - // Collect inputs - std::vector ins; - for (auto operand : op.getOperands()) { - ins.emplace_back(lookupValue(operand)); - } - std::vector outs; - for (auto operand : op.getResults()) { - outs.emplace_back(lookupValue(operand)); - } - - verifier_->verify(casted, ins, outs); - } - } - } else { - if constexpr (!sizeof...(MoreOpT)) { - // If there is no more op types to dispatch, and the previous cast - // fails..print error message - errorUnknownOp(op); - } else { - dispatchOp(op); - } - } - } - - /// Unary ops - void execute(mlir::pphlo::ReciprocalOp &op); - void execute(mlir::pphlo::NegOp &op); - void execute(mlir::pphlo::ExpOp &op); - void execute(mlir::pphlo::Expm1Op &op); - void execute(mlir::pphlo::LogOp &op); - void execute(mlir::pphlo::Log1pOp &op); - void execute(mlir::pphlo::CeilOp &op); - void execute(mlir::pphlo::FloorOp &op); - void execute(mlir::pphlo::AbsOp &op); - void execute(mlir::pphlo::TransposeOp &op); - void execute(mlir::pphlo::LogisticOp &op); - void execute(mlir::pphlo::NotOp &op); - void execute(mlir::pphlo::TanhOp &op); - void execute(mlir::pphlo::RsqrtOp &op); - void execute(mlir::pphlo::RoundOp &op); - void execute(mlir::pphlo::SqrtOp &op); - void execute(mlir::pphlo::SignOp &op); - - /// Binary ops - void execute(mlir::pphlo::EqualOp &op); - void execute(mlir::pphlo::LessOp &op); - void execute(mlir::pphlo::GreaterOp &op); - - void execute(mlir::pphlo::AddOp &op); - void execute(mlir::pphlo::SubtractOp &op); - void execute(mlir::pphlo::MulOp &op); - void execute(mlir::pphlo::PowOp &op); - void execute(mlir::pphlo::RemOp &op); - void execute(mlir::pphlo::MaxOp &op); - void execute(mlir::pphlo::MinOp &op); - void execute(mlir::pphlo::DotOp &op); - void execute(mlir::pphlo::DotGeneralOp &op); - void execute(mlir::pphlo::ShiftLeftOp &op); - void execute(mlir::pphlo::ShiftRightArithmeticOp &op); - void execute(mlir::pphlo::ShiftRightLogicalOp &op); - - /// Ternary ops - void execute(mlir::pphlo::ClampOp &op); - - /// Logical ops - void execute(mlir::pphlo::AndOp &op); - void execute(mlir::pphlo::OrOp &op); - void execute(mlir::pphlo::XorOp &op); - - /// Shape ops - void execute(mlir::pphlo::BroadcastOp &op); - void execute(mlir::pphlo::ReshapeOp &op); - void execute(mlir::pphlo::ConcatenateOp &op); - void execute(mlir::pphlo::SliceOp &op); - void execute(mlir::pphlo::GatherOp &op); - void execute(mlir::pphlo::PadOp &op); - void execute(mlir::pphlo::ReverseOp &op); - - /// Data generator ops - void execute(mlir::pphlo::ConstantOp &op); - void execute(mlir::pphlo::IotaOp &op); - - /// Other ops - void execute(mlir::pphlo::RngOp &op); - void execute(mlir::pphlo::ConvertOp &op); - void execute(mlir::pphlo::BitcastConvertOp &op); - void execute(mlir::pphlo::ConvolutionOp &op); - void execute(mlir::pphlo::SortOp &op); - void execute(mlir::pphlo::DynamicUpdateSliceOp &op); - void execute(mlir::pphlo::DynamicSliceOp &op); - void execute(mlir::pphlo::PreferAOp &op); - - /// Reduce ops - void execute(mlir::pphlo::ReduceOp &op); - void execute(mlir::pphlo::ReduceWindowOp &op); - - /// Control flow ops - void execute(mlir::pphlo::WhileOp &op); - void execute(mlir::pphlo::IfOp &op); - - /// Debug ops - void execute(mlir::pphlo::DbgPrintOp &op); - - /// Lowered ops (All these ops will throw at run time) - void execute(mlir::pphlo::SelectOp &op); - void execute(mlir::pphlo::SelectAndScatterOp &op); - void execute(mlir::pphlo::MaxPoolScatterOp &op); - void execute(mlir::pphlo::ReturnOp &op); - void execute(mlir::pphlo::NotEqualOp &op); - void execute(mlir::pphlo::LessEqualOp &op); - void execute(mlir::pphlo::GreaterEqualOp &op); - void execute(mlir::pphlo::DivOp &op); - void errorUnknownOp(mlir::Operation &op); - - Frame *getFrame() { return frame_; } - - const spu::Value &lookupValue(::mlir::Value v) const; - - HalContext *hctx_{nullptr}; - Frame *frame_{nullptr}; - std::shared_ptr profiler_; - mlir::pphlo::TypeTools type_tools_; - std::shared_ptr type_checker_; - std::unique_ptr verifier_; - - // - bool suppress_type_check_ = false; - bool suppress_pphlo_trace_ = false; -}; - -} // namespace spu::device::pphlo diff --git a/spu/device/pphlo/type_checker.cc b/spu/device/pphlo/type_checker.cc deleted file mode 100644 index 2e35e6ca..00000000 --- a/spu/device/pphlo/type_checker.cc +++ /dev/null @@ -1,87 +0,0 @@ -// Copyright 2021 Ant Group Co., Ltd. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "spu/device/pphlo/type_checker.h" - -#include "spu/dialect/pphlo_types.h" - -namespace spu::device { - -namespace { - -DataType getDType(const mlir::Type &type) { - if (auto ft = type.dyn_cast()) { - return DT_FXP; - } - if (auto it = type.dyn_cast()) { - if (it.getWidth() == 1) { - return DT_I1; - } - switch (it.getWidth()) { - case 8: - return it.isUnsigned() ? DT_U8 : DT_I8; - case 16: - return it.isUnsigned() ? DT_U16 : DT_I16; - case 32: - return it.isUnsigned() ? DT_U32 : DT_I32; - case 64: - return it.isUnsigned() ? DT_U64 : DT_I64; - } - } - YACL_THROW("Hit unknown mlir type"); -} - -} // namespace - -std::string toString(const ::mlir::Type &type) { - std::string str; - llvm::raw_string_ostream os(str); - type.print(os); - return str; -} - -void checkShape(llvm::ArrayRef mlir_shape, - const absl::Span rt_shape) { - YACL_ENFORCE(mlir_shape.size() == rt_shape.size(), - "Runtime shape mismatch, expected={}, got={}", - fmt::join(mlir_shape, "x"), fmt::join(rt_shape, "x")); - - for (size_t idx = 0; idx < mlir_shape.size(); ++idx) { - YACL_ENFORCE(mlir_shape[idx] == rt_shape[idx], - "Runtime shape mismatch at dim {}, expected={}, got={}", idx, - fmt::join(mlir_shape, "x"), fmt::join(rt_shape, "x")); - } -} - -void PPHloTypeChecker::check(::mlir::Type type, const spu::Value &v) const { - // Check shape - checkShape(type.dyn_cast<::mlir::RankedTensorType>().getShape(), v.shape()); - - // dType checker - mlir::pphlo::TypeTools tool; - auto expectedType = getDType(tool.getExpressedType(type)); - YACL_ENFORCE(expectedType == v.dtype(), "Expected Type {}, got {}", - expectedType, v.dtype()); - - // vType checker - if (tool.isMPCType<::mlir::pphlo::PublicType>(type)) { - YACL_ENFORCE(v.isPublic()); - } else if (tool.isMPCType<::mlir::pphlo::SecretType>(type)) { - YACL_ENFORCE(v.isSecret()); - } else { - YACL_ENFORCE("Unknown vtype"); - } -} - -} // namespace spu::device diff --git a/spu/device/pphlo/xla_verifier.h b/spu/device/pphlo/xla_verifier.h index 6ab63e40..2812dd05 100644 --- a/spu/device/pphlo/xla_verifier.h +++ b/spu/device/pphlo/xla_verifier.h @@ -25,11 +25,10 @@ namespace spu::device::pphlo { class XlaVerifier { private: HalContext *ctx_{nullptr}; - std::function mismatch_handler_; + std::function mismatch_handler_{[](bool) {}}; public: - explicit XlaVerifier(HalContext *ctx) - : ctx_(ctx), mismatch_handler_(ctx->feature_control().verifier_handler) {} + explicit XlaVerifier(HalContext *ctx) : ctx_(ctx) {} void setMismatchHandler(std::function f) { mismatch_handler_ = std::move(f); @@ -139,6 +138,7 @@ class XlaVerifier { NO_VERIFY_DEFN(ConstantOp) NO_VERIFY_DEFN(MaxPoolScatterOp) NO_VERIFY_DEFN(PreferAOp) + NO_VERIFY_DEFN(ArgMaxOp) #undef NO_VERIFY_DEFN }; diff --git a/spu/device/profiler.h b/spu/device/profiler.h deleted file mode 100644 index 4c2c3179..00000000 --- a/spu/device/profiler.h +++ /dev/null @@ -1,65 +0,0 @@ -// Copyright 2022 Ant Group Co., Ltd. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include -#include -#include - -namespace spu::device { - -class Timer { - using TimePoint = decltype(std::chrono::high_resolution_clock::now()); - TimePoint start_; - -public: - Timer() { reset(); } - - void reset() { start_ = std::chrono::high_resolution_clock::now(); } - - std::chrono::duration count() const { - auto end = std::chrono::high_resolution_clock::now(); - return std::chrono::duration_cast>(end - - start_); - } -}; - -class Profiler { -public: - struct ExecutionRecord { - // total number of executation. - size_t count = 0; - // total elapsed time. - std::chrono::duration time = {}; - }; - using ExecutionRecordsT = std::unordered_map; - - Timer start() const { return {}; } - - void end(std::string_view id, const Timer &time) { - auto t = time.count(); - auto &record = records_[std::string(id)]; - record.count++; - record.time += t; - } - - const ExecutionRecordsT &getRecords() const { return records_; } - -private: - ExecutionRecordsT records_; -}; - -} // namespace spu::device \ No newline at end of file diff --git a/spu/device/type_checker.cc b/spu/device/type_checker.cc deleted file mode 100644 index 06345157..00000000 --- a/spu/device/type_checker.cc +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright 2022 Ant Group Co., Ltd. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "spu/device/type_checker.h" - -namespace spu::device { - -TypeChecker::~TypeChecker() = default; - -} diff --git a/spu/device/type_checker.h b/spu/device/type_checker.h deleted file mode 100644 index 79945138..00000000 --- a/spu/device/type_checker.h +++ /dev/null @@ -1,31 +0,0 @@ -// Copyright 2022 Ant Group Co., Ltd. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "mlir/IR/BuiltinTypes.h" - -#include "spu/kernel/value.h" - -namespace spu::device { - -class TypeChecker { -public: - TypeChecker() = default; - virtual ~TypeChecker(); - - virtual void check(mlir::Type, const spu::Value &value) const = 0; -}; - -} // namespace spu::device diff --git a/spu/dialect/pphlo_ops.cc b/spu/dialect/pphlo_ops.cc index c358d26b..4870617f 100644 --- a/spu/dialect/pphlo_ops.cc +++ b/spu/dialect/pphlo_ops.cc @@ -37,9 +37,17 @@ namespace mlir::pphlo { #include "spu/dialect/pphlo_patterns.cc.inc" namespace { + Type convertPtTypeToPPhloType(Type ptType) { return pphlo::PublicType::get(ptType.getContext(), ptType); } + +// Checks if the vector `nums` has duplicates. +bool hasDuplicates(const ArrayRef nums) { + llvm::SmallDenseSet set(nums.begin(), nums.end()); + return set.size() != nums.size(); +} + } // namespace template @@ -500,8 +508,7 @@ OpFoldResult ReciprocalOp::fold(ArrayRef operands) { values.push_back(one / it); } - return DenseFPElementsAttr::get(val.getType().dyn_cast(), - values); + return DenseFPElementsAttr::get(val.getType().dyn_cast(), values); } OpFoldResult ReshapeOp::fold(ArrayRef operands) { @@ -592,6 +599,121 @@ LogicalResult PadOp::inferReturnTypeComponents( return success(); } +LogicalResult ConcatenateOp::verify() { + RankedTensorType firstRankedType; + int firstRankedIndex; + int numOperands = getNumOperands(); + auto concatDimension = static_cast(dimension()); + if (concatDimension < 0) { + return emitOpError( + llvm::formatv("dimension {0} is negative", concatDimension)); + } + for (int i = 0; i < numOperands; i++) { + auto secondType = getOperand(i).getType().dyn_cast(); + if (!secondType.hasRank()) { + continue; + } + + if (!firstRankedType) { + firstRankedType = secondType.cast(); + firstRankedIndex = i; + if (firstRankedType.getRank() == 0) { + return emitOpError( + llvm::formatv("rank-0 values cannot be concatenated")); + } + if (concatDimension >= firstRankedType.getRank()) { + return emitOpError( + llvm::formatv("dimension {0} is out-of-bounds for input rank {1}", + concatDimension, firstRankedType.getRank())); + } + continue; + } + + if (firstRankedType.getRank() != secondType.getRank()) { + return emitOpError(llvm::formatv( + "operands ({0}) and ({1}) do not match rank", firstRankedIndex, i)); + } + + auto firstShape = firstRankedType.getShape(); + auto secondShape = secondType.getShape(); + for (int d = 0; d < firstRankedType.getRank(); ++d) { + if (!ShapedType::isDynamic(firstShape[d]) && + !ShapedType::isDynamic(secondShape[d]) && + firstShape[d] != secondShape[d] && d != concatDimension) { + return emitOpError(llvm::formatv( + "shapes of operand ({0}) and ({1}) do not match at non-concat " + "index: ({2}) != ({3}) at non-concat index {4}", + firstRankedIndex, i, + llvm::make_range(firstShape.begin(), firstShape.end()), + llvm::make_range(secondShape.begin(), secondShape.end()), d)); + } + } + } + return success(); +} + +LogicalResult BroadcastOp::verify() { + auto operandType = getOperand().getType().dyn_cast(); + + auto operandRank = operandType.getRank(); + + if (!broadcast_dimensions()) { + if (operandRank == 0) { + return success(); + } + return emitOpError( + llvm::formatv("broadcast_dimensions is absent, but required because " + "operand has non-zero rank ({0})", + operandRank)); + } + + auto dimensionsType = broadcast_dimensions().getType(); + auto dimensionsRank = dimensionsType.getRank(); + if (dimensionsRank != 1) { + return emitOpError(llvm::formatv( + "broadcast_dimensions has rank {0} instead of rank 1", dimensionsRank)); + } + + auto dimensionsSize = dimensionsType.getNumElements(); + if (dimensionsSize != operandRank) { + return emitOpError(llvm::formatv( + "broadcast_dimensions size ({0}) does not match operand rank ({1})", + dimensionsSize, operandRank)); + } + + auto dimensions = + llvm::to_vector(broadcast_dimensions().getValues()); + if (hasDuplicates(dimensions)) { + return emitOpError("broadcast_dimensions should not have duplicates"); + } + + auto resultType = getResult().getType().cast(); + auto resultRank = resultType.getRank(); + + for (int i = 0; i != dimensionsSize; ++i) { + auto dimIndex = dimensions[i]; + if (dimIndex >= resultRank) { + return emitOpError( + llvm::formatv("broadcast_dimensions contains invalid value {0} for " + "result with rank {1}", + dimIndex, resultRank)); + } + + if (!operandType.isDynamicDim(i)) { + auto dimSize = operandType.getDimSize(i); + auto resultDimSize = resultType.getDimSize(dimIndex); + if (dimSize != 1 && dimSize != resultDimSize) { + return emitOpError( + llvm::formatv("size of operand dimension {0} ({1}) is not equal to " + "1 or size of result dimension {2} ({3})", + i, dimSize, dimIndex, resultDimSize)); + } + } + } + + return success(); +} + template static void printField(AsmPrinter& printer, StringRef name, T field, StringRef& separator) { diff --git a/spu/dialect/pphlo_ops.td b/spu/dialect/pphlo_ops.td index ca1b4ed4..2acb2c4b 100644 --- a/spu/dialect/pphlo_ops.td +++ b/spu/dialect/pphlo_ops.td @@ -579,11 +579,12 @@ def PPHLO_BroadcastOp See https://www.tensorflow.org/xla/broadcasting. }]; - let arguments = (ins PPHLO_Tensor - : $operand, I64ElementsAttr - : $broadcast_dimensions); + let arguments = (ins PPHLO_Tensor: $operand, + I64ElementsAttr: $broadcast_dimensions); let results = (outs PPHLO_Tensor); + + let hasVerifier = 1; } def PPHLO_ReshapeOp @@ -613,6 +614,8 @@ def PPHLO_ConcatenateOp let arguments = (ins Variadic : $val, I64Attr : $dimension); let results = (outs PPHLO_Tensor); + + let hasVerifier = 1; } def PPHLO_DotOp : PPHLO_Op<"dot", [Pure]> { @@ -646,7 +649,7 @@ def PPHLO_DotGeneralOp: PPHLO_Op<"dot_general", [Pure]> { } def PPHLO_SelectOp - : PPHLO_Op<"select", [Pure, SameVariadicOperandSize]> { + : PPHLO_Op<"select", [Pure]> { let summary = "Select operator"; let description = [{ Constructs an output tensor from the elements of `on_true` and `on_false` @@ -658,21 +661,10 @@ def PPHLO_SelectOp }]; let arguments = (ins PPHLO_IntTensor: $pred, - Variadic: $on_true, - Variadic: $on_false); + PPHLO_Tensor: $on_true, + PPHLO_Tensor: $on_false); - let results = (outs Variadic); - - // Builder for non-variadic version of the operation. - let builders = [ - OpBuilder<(ins "Type":$result_type, "Value":$pred, - "Value":$on_true, - "Value":$on_false), - [{ - build($_builder, $_state, TypeRange(result_type), pred, - ValueRange(on_true), ValueRange(on_false)); - }]> - ]; + let results = (outs PPHLO_Tensor); let hasCanonicalizer = 1; } @@ -772,11 +764,7 @@ def PPHLO_ReduceWindowOp : PPHLO_Op<"reduce_window", [ OptionalAttr:$window_strides, OptionalAttr:$base_dilations, OptionalAttr:$window_dilations, - OptionalAttr:$padding, - // These are two special attributes for argmax reduce window - // Do not set unless you know what these means - DefaultValuedAttr:$last_operand_is_window_mask, - DefaultValuedAttr:$ignore_init_value + OptionalAttr:$padding ); let results = (outs Variadic); @@ -784,6 +772,29 @@ def PPHLO_ReduceWindowOp : PPHLO_Op<"reduce_window", [ let regions = (region SizedRegion<1> : $body); } +def PPHLO_ArgMaxOp: PPHLO_Op<"argmax", [Pure]> { + let summary = "ArgMax operator"; + + let description = [{ + Returns the max value and index in a window. + }]; + + let arguments = (ins + PPHLO_Tensor:$input, + I64ElementsAttr:$window_dimensions, + // If strides or dilations attributes are missing then the default value is + // one for each of the input dimensions. Similarly, padding values are zero + // for both low and high in each of the dimensions, if not specified. + OptionalAttr:$window_strides, + OptionalAttr:$base_dilations, + OptionalAttr:$window_dilations, + OptionalAttr:$padding, + DefaultValuedAttr:$onehot_index + ); + + let results = (outs PPHLO_Tensor, PPHLO_IntTensor); +} + def PPHLO_ReturnOp : PPHLO_Op<"return", [Pure, Terminator]> { let summary = [{ The `pphlo.return` operation terminates a region and returns values. diff --git a/spu/kernel/context.h b/spu/kernel/context.h index d62c10f3..293706a6 100644 --- a/spu/kernel/context.h +++ b/spu/kernel/context.h @@ -26,12 +26,6 @@ namespace spu { -// Toggles put here are internal features -struct FeatureControl { - bool enable_xla_verifier = false; - std::function verifier_handler{[](bool) {}}; -}; - // The hal evaluation context for all spu operators. class HalContext final { const RuntimeConfig rt_config_; @@ -42,8 +36,6 @@ class HalContext final { std::default_random_engine rand_engine_; - FeatureControl fc_; - public: explicit HalContext(RuntimeConfig config, std::shared_ptr lctx); @@ -67,8 +59,6 @@ class HalContext final { // Return current working runtime config. const RuntimeConfig& rt_config() const { return rt_config_; } - FeatureControl& feature_control() { return fc_; } - // std::default_random_engine& rand_engine() { return rand_engine_; } }; diff --git a/spu/kernel/hal/concat.cc b/spu/kernel/hal/concat.cc index e5be9a5a..86c57534 100644 --- a/spu/kernel/hal/concat.cc +++ b/spu/kernel/hal/concat.cc @@ -31,6 +31,11 @@ Value concatenate(HalContext* ctx, absl::Span values, SPU_TRACE_HAL_DISP(ctx, axis); YACL_ENFORCE(!values.empty(), "got={}", values.size()); + if (values.size() == 1) { + // Nothing to concate + return values.front(); + } + bool all_same_dtype = std::all_of( values.begin() + 1, values.end(), [&](const Value& v) { return v.dtype() == values.begin()->dtype(); }); diff --git a/spu/kernel/hal/shape_ops.cc b/spu/kernel/hal/shape_ops.cc index 32000bf0..31e55a8d 100644 --- a/spu/kernel/hal/shape_ops.cc +++ b/spu/kernel/hal/shape_ops.cc @@ -225,6 +225,11 @@ Value broadcast_to(HalContext* ctx, const Value& in, absl::Span to_shape, absl::Span in_dims) { SPU_TRACE_HAL_DISP(ctx, in, to_shape); + for (auto d : in_dims) { + YACL_ENFORCE(d < (int64_t)to_shape.size() && d >= 0, + "Broadcast dim {} out of valid range [0, {})", d, + to_shape.size()); + } std::vector new_strides(to_shape.size(), 0); diff --git a/spu/kernel/hlo/reduce.cc b/spu/kernel/hlo/reduce.cc index 4f65dc4c..4484962c 100644 --- a/spu/kernel/hlo/reduce.cc +++ b/spu/kernel/hlo/reduce.cc @@ -16,14 +16,18 @@ #include #include +#include #include #include #include #include "spu/core/parallel_utils.h" #include "spu/core/shape_util.h" +#include "spu/kernel/hal/concat.h" #include "spu/kernel/hal/constants.h" +#include "spu/kernel/hal/polymorphic.h" #include "spu/kernel/hal/shape_ops.h" +#include "spu/kernel/hal/type_cast.h" #include "spu/kernel/hlo/geometrical.h" #include "spu/kernel/hlo/utils.h" @@ -190,58 +194,6 @@ spu::Value ConvertToTiledLayout(HalContext *ctx, const spu::Value &in, return hal::transpose(ctx, out, perm); } -// So idea here.. -// When windows size is 2x2, tile and run parallel on window element level has -// way to much overhead (both memory and computation). -// Just do a window level parallel is good enough -// And without dilation and padding, this can be achieved through just slicing -// FIXME: This is a super special case...consider generalize it a little bit -std::vector -ReduceWindow1x2x2x1NoPaddingOneStrideWithoutDilationWithWindowMask( - HalContext *ctx, absl::Span inputs, - absl::Span init_values, - absl::Span ret_shape, const BatchedValueBinaryFn &reducer) { - std::vector start_indices = {0, 0, 0, 0}; - auto input_shape = inputs[0].shape(); - - std::vector input_slices(4); - std::vector mask_slices(4); - input_slices[0] = hal::slice( - ctx, inputs[0], {0, 0, 0, 0}, - {input_shape[0], input_shape[1] - 1, input_shape[2] - 1, input_shape[3]}, - {1, 1, 1, 1}); - input_slices[1] = hal::slice( - ctx, inputs[0], {0, 0, 1, 0}, - {input_shape[0], input_shape[1] - 1, input_shape[2], input_shape[3]}, - {1, 1, 1, 1}); - input_slices[2] = hal::slice( - ctx, inputs[0], {0, 1, 0, 0}, - {input_shape[0], input_shape[1], input_shape[2] - 1, input_shape[3]}, - {1, 1, 1, 1}); - input_slices[3] = hal::slice( - ctx, inputs[0], {0, 1, 1, 0}, - {input_shape[0], input_shape[1], input_shape[2], input_shape[3]}, - {1, 1, 1, 1}); - - std::vector mask_shape = {inputs[0].shape()[0], ret_shape[1], - ret_shape[2], inputs[0].shape().back(), 4}; - for (int64_t mask_idx = 0; mask_idx < 4; ++mask_idx) { - mask_slices[mask_idx] = hal::slice(ctx, inputs.back(), {mask_idx, 0}, - {mask_idx + 1, 4}, {1, 1}); - mask_slices[mask_idx] = - hal::reshape(ctx, mask_slices[mask_idx], {1, 1, 1, 1, 4}); - mask_slices[mask_idx] = - hal::broadcast_to(ctx, mask_slices[mask_idx], mask_shape); - } - - std::vector ret = {input_slices[0], mask_slices[0]}; - for (size_t i = 1; i < input_slices.size(); ++i) { - ret = reducer({input_slices[i], mask_slices[i]}, ret); - } - - return ret; -} - std::vector ReduceWindowWithoutDilation( HalContext *ctx, absl::Span inputs, absl::Span init_values, @@ -250,19 +202,6 @@ std::vector ReduceWindowWithoutDilation( absl::Span> window_padding, bool last_operand_is_window_mask, bool ignore_init_value, absl::Span ret_shape, const BatchedValueBinaryFn &reducer) { - // Add a fast 1x2x2x1, no padding fast reduce - auto no_padding = std::all_of(window_padding.begin(), window_padding.end(), - [](const std::pair &p) { - return p.first == 0 && p.second == 0; - }); - auto one_stride = std::all_of(window_strides.begin(), window_strides.end(), - [](auto v) { return v == 1; }); - if (window_shape == absl::Span{1, 2, 2, 1} && - inputs.size() == 2 && last_operand_is_window_mask && no_padding && - one_stride) { - return ReduceWindow1x2x2x1NoPaddingOneStrideWithoutDilationWithWindowMask( - ctx, inputs, init_values, ret_shape, reducer); - } const size_t nargs = last_operand_is_window_mask ? inputs.size() - 1 : inputs.size(); @@ -338,12 +277,12 @@ std::vector ReduceWindowWithoutDilation( return outputs; } -std::vector ReduceWindow(HalContext *ctx, - absl::Span inputs, - absl::Span init_values, - absl::Span ret_shape, - const ReduceWindowConfig &config, - const BatchedValueBinaryFn &reducer) { +std::vector ReduceWindowImpl( + HalContext *ctx, absl::Span inputs, + absl::Span init_values, + absl::Span ret_shape, const ReduceWindowConfig &config, + bool last_operand_is_window_mask, bool ignore_init_value, + const BatchedValueBinaryFn &reducer) { if (std::all_of(config.window_dilations.begin(), config.window_dilations.end(), [](const int64_t x) { return x == 1; }) && @@ -351,11 +290,11 @@ std::vector ReduceWindow(HalContext *ctx, [](const int64_t x) { return x == 1; })) { return ReduceWindowWithoutDilation( ctx, inputs, init_values, config.window_shape, config.window_strides, - config.window_padding, config.last_operand_is_window_mask, - config.ignore_init_value, ret_shape, reducer); + config.window_padding, last_operand_is_window_mask, ignore_init_value, + ret_shape, reducer); } - YACL_ENFORCE(config.last_operand_is_window_mask == false); + YACL_ENFORCE(!last_operand_is_window_mask); const int64_t ndims = inputs[0].shape().size(); std::vector window_index(ndims, 0); @@ -411,6 +350,16 @@ std::vector ReduceWindow(HalContext *ctx, return rets; } +std::vector ReduceWindow(HalContext *ctx, + absl::Span inputs, + absl::Span init_values, + absl::Span ret_shape, + const ReduceWindowConfig &config, + const BatchedValueBinaryFn &reducer) { + return ReduceWindowImpl(ctx, inputs, init_values, ret_shape, config, false, + false, reducer); +} + std::vector Reduce(HalContext *ctx, absl::Span inputs, absl::Span init_values, @@ -496,4 +445,156 @@ std::vector Reduce(HalContext *ctx, return reducer(results, broadcasted_init_values); } +// So idea here.. +// When windows size is 2x2, tile and run parallel on window element level has +// way to much overhead (both memory and computation). +// Just do a window level parallel is good enough +// And without dilation and padding, this can be achieved through just slicing +// FIXME: This is a super special case...consider generalize it a little bit +std::pair +ArgMax1x2x2x1NoPaddingOneStrideWithoutDilation(HalContext *ctx, + const spu::Value &input) { + auto input_shape = input.shape(); + + spu::Value h_max; + spu::Value h_idx_max; + { + // Get to horizontal slices + auto lhs = hal::slice( + ctx, input, {0, 0, 0, 0}, + {input_shape[0], input_shape[1], input_shape[2] - 1, input_shape[3]}, + {1, 1, 1, 1}); + auto rhs = hal::slice( + ctx, input, {0, 0, 1, 0}, + {input_shape[0], input_shape[1], input_shape[2], input_shape[3]}, + {1, 1, 1, 1}); + // Do a less comp + auto h_comp = hal::less(ctx, rhs, lhs); + // make comp an ashare + h_comp = + hal::add(ctx, h_comp, + hal::zeros(ctx, VIS_PUBLIC, h_comp.dtype(), h_comp.shape())); + + auto h_i_comp = hal::reshape(ctx, h_comp, + {h_comp.shape()[0], h_comp.shape()[1], + h_comp.shape()[2], h_comp.shape()[3], 1}); + + // Now do two selections + auto mask_shape = h_comp.shape(); + mask_shape.emplace_back(2); + + // Now compute horizontal max... + h_max = hal::select(ctx, h_comp, lhs, rhs); + + // Mask index + h_idx_max = + hal::concatenate(ctx, {h_i_comp, hal::logical_not(ctx, h_i_comp)}, 4); + } + + // Now do vertical compare... + auto upper_value = hal::slice(ctx, h_max, {0, 0, 0, 0}, + {h_max.shape()[0], h_max.shape()[1] - 1, + h_max.shape()[2], h_max.shape()[3]}, + {1, 1, 1, 1}); + auto bottom_value = hal::slice( + ctx, h_max, {0, 1, 0, 0}, + {h_max.shape()[0], h_max.shape()[1], h_max.shape()[2], h_max.shape()[3]}, + {1, 1, 1, 1}); + + auto v_comp = hal::less(ctx, bottom_value, upper_value); + // make comp an ashare + v_comp = hal::add( + ctx, v_comp, hal::zeros(ctx, VIS_PUBLIC, v_comp.dtype(), v_comp.shape())); + + // Compute max value + auto max_ret = hal::select(ctx, v_comp, upper_value, bottom_value); + + // Compute max indicies + auto v_comp_not = hal::logical_not(ctx, v_comp); + + auto v_i_comp = hal::reshape(ctx, v_comp, + {v_comp.shape()[0], v_comp.shape()[1], + v_comp.shape()[2], v_comp.shape()[3], 1}); + v_i_comp = hal::broadcast_to(ctx, v_i_comp, + {v_i_comp.shape()[0], v_i_comp.shape()[1], + v_i_comp.shape()[2], v_i_comp.shape()[3], 2}); + + auto v_i_comp_not = + hal::reshape(ctx, v_comp_not, + {v_comp_not.shape()[0], v_comp_not.shape()[1], + v_comp_not.shape()[2], v_comp_not.shape()[3], 1}); + v_i_comp_not = + hal::broadcast_to(ctx, v_i_comp_not, + {v_i_comp_not.shape()[0], v_i_comp_not.shape()[1], + v_i_comp_not.shape()[2], v_i_comp_not.shape()[3], 2}); + + auto upper_slice = hal::slice( + ctx, h_idx_max, {0, 0, 0, 0, 0}, + {h_idx_max.shape()[0], h_idx_max.shape()[1] - 1, h_idx_max.shape()[2], + h_idx_max.shape()[3], h_idx_max.shape()[4]}, + {1, 1, 1, 1, 1}); + + auto bottom_slice = hal::slice( + ctx, h_idx_max, {0, 1, 0, 0, 0}, + {h_idx_max.shape()[0], h_idx_max.shape()[1], h_idx_max.shape()[2], + h_idx_max.shape()[3], h_idx_max.shape()[4]}, + {1, 1, 1, 1, 1}); + + upper_slice = hal::mul(ctx, v_i_comp, upper_slice); + bottom_slice = hal::mul(ctx, v_i_comp_not, bottom_slice); + + auto max_indicies = hal::concatenate(ctx, {upper_slice, bottom_slice}, 4); + + return {max_ret, max_indicies}; +} + +std::pair ArgMax(HalContext *ctx, + const spu::Value &input, + absl::Span ret_shape, + const ReduceWindowConfig &config) { + // Add a fast 1x2x2x1, no padding fast reduce + auto no_padding = + std::all_of(config.window_padding.begin(), config.window_padding.end(), + [](const std::pair &p) { + return p.first == 0 && p.second == 0; + }); + auto one_stride = + std::all_of(config.window_strides.begin(), config.window_strides.end(), + [](auto v) { return v == 1; }); + if (config.window_shape == absl::Span{1, 2, 2, 1} && + no_padding && one_stride) { + return ArgMax1x2x2x1NoPaddingOneStrideWithoutDilation(ctx, input); + } + + // Create eye + size_t window_size = + std::accumulate(config.window_shape.begin(), config.window_shape.end(), 1, + std::multiplies()); + xt::xarray e = xt::eye({window_size, window_size}, 0); + + auto mask = hal::constant(ctx, e); + + auto result = ReduceWindowImpl( + ctx, {input, mask}, {}, ret_shape, config, true, true, + [&](absl::Span lhs, + absl::Span rhs) -> std::vector { + YACL_ENFORCE(lhs.size() == 2); + auto c = hal::less(ctx, rhs[0], lhs[0]); + // make a share + c = hal::add(ctx, c, hal::zeros(ctx, VIS_PUBLIC, c.dtype(), c.shape())); + // Select value + auto v = hal::select(ctx, c, lhs[0], rhs[0]); + // Select index + std::vector c_i_shape = c.shape(); + c_i_shape.emplace_back(1); + auto c_i = hal::reshape(ctx, c, c_i_shape); + c_i_shape.back() = window_size; + c_i = hal::broadcast_to(ctx, c_i, c_i_shape); + auto i = hal::select(ctx, c_i, lhs[1], rhs[1]); + + return {v, i}; + }); + return {result[0], result[1]}; +} + } // namespace spu::kernel::hlo diff --git a/spu/kernel/hlo/reduce.h b/spu/kernel/hlo/reduce.h index a199dc39..19c120d8 100644 --- a/spu/kernel/hlo/reduce.h +++ b/spu/kernel/hlo/reduce.h @@ -14,6 +14,7 @@ #pragma once +#include #include #include "absl/types/span.h" @@ -46,11 +47,6 @@ struct ReduceWindowConfig { absl::Span window_dilations; absl::Span> window_padding; absl::Span base_dilations; - // This is a special attribute for - // argmax-like reduce, DO NOT set to true unless you know what - // this means - bool last_operand_is_window_mask{false}; - bool ignore_init_value; }; std::vector ReduceWindow(HalContext *ctx, @@ -66,4 +62,9 @@ std::vector Reduce(HalContext *ctx, absl::Span dimensions_to_reduce, const BatchedValueBinaryFn &reducer); +std::pair ArgMax(HalContext *ctx, + const spu::Value &input, + absl::Span ret_shape, + const ReduceWindowConfig &config); + } // namespace spu::kernel::hlo diff --git a/spu/kernel/hlo/select_and_scatter.cc b/spu/kernel/hlo/select_and_scatter.cc index 3965408b..7a9bda5a 100644 --- a/spu/kernel/hlo/select_and_scatter.cc +++ b/spu/kernel/hlo/select_and_scatter.cc @@ -16,7 +16,9 @@ #include "spu/kernel/hlo/select_and_scatter.h" #include +#include #include +#include #include "yacl/utils/parallel.h" @@ -26,26 +28,83 @@ #include "spu/kernel/hal/debug.h" #include "spu/kernel/hal/polymorphic.h" // for select #include "spu/kernel/hal/shape_ops.h" +#include "spu/kernel/hal/type_cast.h" #include "spu/kernel/hlo/const.h" #include "spu/kernel/hlo/reduce.h" #include "spu/kernel/hlo/utils.h" +#include "spu/kernel/value.h" namespace spu::kernel::hlo { +spu::Value MaxPoolScatter1x2x2x1NoPaddingNoStrides( + HalContext *ctx, const spu::Value &scatter_indices, + const spu::Value &source) { + std::vector slices(4); + for (int64_t idx = 0; idx < 4; ++idx) { + slices[idx] = hal::slice( + ctx, scatter_indices, {0, 0, 0, 0, idx}, + {scatter_indices.shape()[0], scatter_indices.shape()[1], + scatter_indices.shape()[2], scatter_indices.shape()[3], idx + 1}, + {1, 1, 1, 1, 1}); + slices[idx] = + hal::mul(ctx, hal::reshape(ctx, slices[idx], source.shape()), source); + + // FIXME(jint), handle int type promotion + slices[idx] = hal::dtype_cast(ctx, slices[idx], source.dtype()); + } + + auto z = hal::zeros(ctx, slices[0].vtype(), slices[0].dtype()); + + std::vector> f_slices(4); + f_slices[0] = std::async(std::launch::async, hal::pad, ctx, slices[0], z, + std::vector{0, 0, 0, 0}, + std::vector{0, 1, 1, 0}, + std::vector{0, 0, 0, 0}); + f_slices[1] = std::async(std::launch::async, hal::pad, ctx, slices[1], z, + std::vector{0, 0, 1, 0}, + std::vector{0, 1, 0, 0}, + std::vector{0, 0, 0, 0}); + f_slices[2] = std::async(std::launch::async, hal::pad, ctx, slices[2], z, + std::vector{0, 1, 0, 0}, + std::vector{0, 0, 1, 0}, + std::vector{0, 0, 0, 0}); + f_slices[3] = std::async(std::launch::async, hal::pad, ctx, slices[3], z, + std::vector{0, 1, 1, 0}, + std::vector{0, 0, 0, 0}, + std::vector{0, 0, 0, 0}); + + spu::Value ret = f_slices[0].get(); + for (size_t idx = 1; idx < 4; ++idx) { + ret = hal::add(ctx, ret, f_slices[idx].get()); + } + + return ret; +}; + spu::Value MaxPoolScatter( HalContext *ctx, const spu::Value &scatter_indices, const spu::Value &source, absl::Span window_shape, absl::Span base_shape, absl::Span window_strides, absl::Span> window_padding) { - // source_shape * window_numel + // Add a fast 1x2x2x1, no padding fast reduce + auto no_padding = std::all_of(window_padding.begin(), window_padding.end(), + [](const std::pair &p) { + return p.first == 0 && p.second == 0; + }); + auto one_stride = std::all_of(window_strides.begin(), window_strides.end(), + [](auto v) { return v == 1; }); + if (window_shape == absl::Span{1, 2, 2, 1} && no_padding && + one_stride) { + return MaxPoolScatter1x2x2x1NoPaddingNoStrides(ctx, scatter_indices, + source); + } + // source_shape * window_numel std::vector tiled_1d_shape = source.shape(); const int64_t window_numel = std::accumulate( window_shape.begin(), window_shape.end(), 1, std::multiplies()); tiled_1d_shape.push_back(window_numel); - auto tiled_1d_select = hal::reshape(ctx, scatter_indices, tiled_1d_shape); - std::vector broadcast_dims(source.shape().size(), 0); std::iota(broadcast_dims.begin(), broadcast_dims.end(), 0); @@ -53,7 +112,7 @@ spu::Value MaxPoolScatter( hal::broadcast_to(ctx, source, tiled_1d_shape, broadcast_dims); // selected_pos is the one hot encoding for each window. - auto selected = hal::mul(ctx, tiled_1d_source, tiled_1d_select); + auto selected = hal::mul(ctx, tiled_1d_source, scatter_indices); std::vector tiled_shape(source.shape().begin(), source.shape().end()); diff --git a/spu/kernel/hlo/sort.cc b/spu/kernel/hlo/sort.cc index 2827eeb1..22bae99a 100644 --- a/spu/kernel/hlo/sort.cc +++ b/spu/kernel/hlo/sort.cc @@ -14,6 +14,7 @@ #include "sort.h" +#include "absl/numeric/bits.h" #include "emp-tool/circuits/number.h" #include "spu/kernel/hal/permute_util.h" @@ -98,36 +99,130 @@ void cmpSwap(HalContext *ctx, const CompFn &comparator_body, } } -void bitonicMerge(HalContext *ctx, const CompFn &comparator_body, - std::vector *values_to_sort, size_t lo, size_t n, - bool acc) { +void sequentialBitonicMerge(HalContext *ctx, const CompFn &comparator_body, + std::vector *values_to_sort, size_t lo, + size_t n, bool acc) { if (n > 1) { size_t m = emp::greatestPowerOfTwoLessThan(n); cmpSwap(ctx, comparator_body, values_to_sort, lo, lo + m, n - m, acc); - bitonicMerge(ctx, comparator_body, values_to_sort, lo, m, acc); - bitonicMerge(ctx, comparator_body, values_to_sort, lo + m, n - m, acc); + sequentialBitonicMerge(ctx, comparator_body, values_to_sort, lo, m, acc); + sequentialBitonicMerge(ctx, comparator_body, values_to_sort, lo + m, n - m, + acc); } } -void bitonicSort(HalContext *ctx, const CompFn &comparator_body, - std::vector *values_to_sort, size_t lo, size_t n, - bool acc) { +void sequentialBitonicSort(HalContext *ctx, const CompFn &comparator_body, + std::vector *values_to_sort, size_t lo, + size_t n, bool acc) { if (n > 1) { size_t m = (n >> 1); - bitonicSort(ctx, comparator_body, values_to_sort, lo, m, !acc); - bitonicSort(ctx, comparator_body, values_to_sort, lo + m, n - m, acc); - bitonicMerge(ctx, comparator_body, values_to_sort, lo, n, acc); + sequentialBitonicSort(ctx, comparator_body, values_to_sort, lo, m, !acc); + sequentialBitonicSort(ctx, comparator_body, values_to_sort, lo + m, n - m, + acc); + sequentialBitonicMerge(ctx, comparator_body, values_to_sort, lo, n, acc); } } +void generateBitonicSortIndex(size_t n, + std::vector> *indices) { + YACL_ENFORCE(absl::has_single_bit(n)); + size_t stage = absl::bit_width(n) - 1; + + for (int i = static_cast(stage); i > 0; i--) { + std::vector fst; + std::vector sec; + + for (size_t j = 0; j < n; j++) { + if (((j >> (i - 1)) & 1) == 0) { + fst.emplace_back(j); + } else { + sec.emplace_back(j); + } + } + + fst.insert(fst.end(), sec.begin(), sec.end()); + indices->emplace_back(fst); + } +} + +void generateBitonicMergeIndex(size_t n, + std::vector> *indices) { + YACL_ENFORCE(absl::has_single_bit(n)); + size_t stage = absl::bit_width(n) - 1; + + for (int stage_idx = 0; stage_idx < static_cast(stage - 1); + stage_idx++) { + for (int substage_idx = static_cast(stage_idx); substage_idx > -1; + substage_idx--) { + std::vector fst; + std::vector sec; + for (size_t i = 0; i < n; i++) { + bool asc_flag = ((i >> (stage_idx + 1)) & 1) == 0; + bool fst_flag = ((i >> substage_idx) & 1) == 0; + + if (asc_flag ^ fst_flag) { + sec.emplace_back(i); + } else { + fst.emplace_back(i); + } + } + + fst.insert(fst.end(), sec.begin(), sec.end()); + indices->emplace_back(fst); + } + } +} + +std::vector parallelBitonicSort( + HalContext *ctx, const CompFn &comparator_body, + const std::vector &values_to_sort, size_t n) { + YACL_ENFORCE(absl::has_single_bit(n)); + + std::vector> indices; + generateBitonicMergeIndex(n, &indices); + generateBitonicSortIndex(n, &indices); + + std::vector target = values_to_sort; + + for (const auto &index : indices) { + // permute + std::vector permuted_values; + + for (auto v : target) { + permuted_values.emplace_back(hal::permute(ctx, v, 0, xt::adapt(index))); + } + + // cmp and swap + cmpSwap(ctx, comparator_body, &permuted_values, 0, + static_cast(n / 2), static_cast(n / 2), true); + + // inverse permute + std::vector inverse_permutation(index.size()); + std::iota(inverse_permutation.begin(), inverse_permutation.end(), 0); + std::sort(inverse_permutation.begin(), inverse_permutation.end(), + [&index](int left, int right) -> bool { + return index[left] < index[right]; + }); + + target.clear(); + + for (auto v : permuted_values) { + target.emplace_back( + hal::permute(ctx, v, 0, xt::adapt(inverse_permutation))); + } + } + + return target; +} + } // namespace std::vector Sort(HalContext *ctx, absl::Span inputs, int64_t sort_dim, bool is_stable, const CompFn &comparator_body, - const Visibility &comparator_ret_vis) { + Visibility comparator_ret_vis) { int64_t num_operands = inputs.size(); auto key_shape = inputs[0].shape(); auto rank = key_shape.size(); @@ -148,68 +243,64 @@ std::vector Sort(HalContext *ctx, sort_dim, increment.size()); increment[sort_dim] = sort_dim_elements; - bool use_secret_sort = comparator_ret_vis == VIS_SECRET && - !ctx->rt_config().reveal_secret_condition(); + if (comparator_ret_vis == VIS_PUBLIC) { + // Iterate through each dimension except 'sort_dim'. + forEachIndex(key_shape, zero_base, key_shape, increment, + [&](const std::vector &indices) { + // Extract a slice from each operand literal that corresponds + // to exactly the row in dimension 'sort_dim'. + std::vector values_to_sort = + getValuesToSort(ctx, inputs, indices, sort_dim, + sort_dim_elements, num_operands); + + std::vector indices_to_sort(sort_dim_elements); + std::iota(indices_to_sort.begin(), indices_to_sort.end(), 0); + auto comparator = [&comparator_body, &num_operands, &ctx, + &values_to_sort](int64_t a, int64_t b) { + std::vector values; + values.reserve(2 * num_operands); + for (int64_t i = 0; i < num_operands; ++i) { + values.push_back(values_to_sort[i].getElementAt(a)); + values.push_back(values_to_sort[i].getElementAt(b)); + } + spu::Value ret = comparator_body(values); + return getConditionValue(ctx, ret); + }; - if (!use_secret_sort) { - bool warned = false; + if (is_stable) { + std::stable_sort(indices_to_sort.begin(), + indices_to_sort.end(), comparator); + } else { + std::sort(indices_to_sort.begin(), indices_to_sort.end(), + comparator); + } + std::vector start_indices(rank, 0); + for (int64_t i = 0; i < num_operands; ++i) { + auto sorted_value = hal::permute( + ctx, values_to_sort[i], 0, xt::adapt(indices_to_sort)); + sliceCopy(results[i], sorted_value, indices, sort_dim); + } + }); + } else { // Iterate through each dimension except 'sort_dim'. forEachIndex( key_shape, zero_base, key_shape, increment, [&](const std::vector &indices) { - // Extract a slice from each operand literal that corresponds to - // exactly the row in dimension 'sort_dim'. std::vector values_to_sort = getValuesToSort( ctx, inputs, indices, sort_dim, sort_dim_elements, num_operands); - std::vector indices_to_sort(sort_dim_elements); - std::iota(indices_to_sort.begin(), indices_to_sort.end(), 0); - auto comparator = [&comparator_body, &num_operands, &ctx, - &values_to_sort, &warned](int64_t a, int64_t b) { - std::vector values; - values.reserve(2 * num_operands); - for (int64_t i = 0; i < num_operands; ++i) { - values.push_back(values_to_sort[i].getElementAt(a)); - values.push_back(values_to_sort[i].getElementAt(b)); - } - spu::Value ret = comparator_body(values); - if (ret.isSecret()) { - ret = Reveal(ctx, ret); - if (!warned) { - SPDLOG_WARN("Reveal condition region result of SortOp"); - warned = true; - } - } - - return getConditionValue(ctx, ret); - }; - - if (is_stable) { - std::stable_sort(indices_to_sort.begin(), indices_to_sort.end(), - comparator); + int64_t n = values_to_sort[0].numel(); + if (absl::has_single_bit(static_cast(n))) { + // TODO(junfeng): add paddings when n x is not an integral power of + // two. + values_to_sort = + parallelBitonicSort(ctx, comparator_body, values_to_sort, n); } else { - std::sort(indices_to_sort.begin(), indices_to_sort.end(), - comparator); + sequentialBitonicSort(ctx, comparator_body, &values_to_sort, 0, n, + true); } - std::vector start_indices(rank, 0); - for (int64_t i = 0; i < num_operands; ++i) { - auto sorted_value = hal::permute(ctx, values_to_sort[i], 0, - xt::adapt(indices_to_sort)); - sliceCopy(results[i], sorted_value, indices, sort_dim); - } - }); - } else { - // Iterate through each dimension except 'sort_dim'. - forEachIndex( - key_shape, zero_base, key_shape, increment, - [&](const std::vector &indices) { - std::vector values_to_sort = getValuesToSort( - ctx, inputs, indices, sort_dim, sort_dim_elements, num_operands); - bitonicSort(ctx, comparator_body, &values_to_sort, 0, - values_to_sort[0].numel(), true); - for (int64_t i = 0; i < num_operands; ++i) { sliceCopy(results[i], values_to_sort[i], indices, sort_dim); } diff --git a/spu/kernel/hlo/sort.h b/spu/kernel/hlo/sort.h index 4f30e108..db1c6635 100644 --- a/spu/kernel/hlo/sort.h +++ b/spu/kernel/hlo/sort.h @@ -25,9 +25,10 @@ namespace spu::kernel::hlo { using CompFn = std::function)>; -std::vector Sort( - HalContext *ctx, absl::Span inputs, int64_t sort_dim, - bool is_stable, const CompFn &comparator_body, - const Visibility &comparator_ret_vis = Visibility::VIS_SECRET); +std::vector Sort(HalContext *ctx, + absl::Span inputs, + int64_t sort_dim, bool is_stable, + const CompFn &comparator_body, + Visibility comparator_ret_vis); } // namespace spu::kernel::hlo diff --git a/spu/mpc/aby3/arithmetic.cc b/spu/mpc/aby3/arithmetic.cc index 37b60523..59a2231c 100644 --- a/spu/mpc/aby3/arithmetic.cc +++ b/spu/mpc/aby3/arithmetic.cc @@ -34,44 +34,84 @@ namespace spu::mpc::aby3 { namespace { +// [zx]: Addapt this to new semantics of boolean sharing std::vector a1b_offline(size_t sender, const ArrayRef& a, FieldType field, size_t self_rank, PrgState* prg_state, size_t numel, - const ArrayRef& b1, const ArrayRef& b2) { - if (self_rank == sender) { - auto c1 = prg_state->genPrssPair(field, numel, false, true).first; - auto c2 = prg_state->genPrssPair(field, numel, true).second; - - auto m0 = ring_zeros(field, numel); - { - ring_xor_(m0, b1); - ring_xor_(m0, b2); - ring_mul_(m0, a); - ring_sub_(m0, c1); - ring_sub_(m0, c2); - } + const ArrayRef& b) { + YACL_ENFORCE(a.eltype().isa()); + YACL_ENFORCE(b.eltype().isa()); - auto m1 = ring_ones(field, numel); - { - ring_xor_(m1, b1); - ring_xor_(m1, b2); - ring_mul_(m1, a); - ring_sub_(m1, c1); - ring_sub_(m1, c2); - } + return DISPATCH_ALL_FIELDS(field, "_", [&]() -> std::vector { + using AShrT = ring2k_t; - return {c1, c2, m0, m1}; - } else if (self_rank == (sender + 1) % 3) { - prg_state->genPrssPair(field, numel, true, true); - auto c1 = prg_state->genPrssPair(field, numel, false, true).first; + ArrayRef m0(makeType(field), numel), + m1(makeType(field), numel); + linalg::setConstantValue(numel, &m0.at(0), m0.stride(), AShrT(0)); + linalg::setConstantValue(numel, &m1.at(0), m1.stride(), AShrT(1)); + + auto _m0 = ArrayView>(m0); + auto _m1 = ArrayView>(m1); + auto _a = ArrayView>(a); + + return DISPATCH_UINT_PT_TYPES( + b.eltype().as()->getBacktype(), "_", + [&]() -> std::vector { + using BSharT = ScalarT; + if (self_rank == sender) { + auto c1 = prg_state->genPrssPair(field, numel, false, true).first; + auto c2 = prg_state->genPrssPair(field, numel, true).second; + + auto _b = ArrayView>(b); + + auto _c1 = ArrayView>(c1); + auto _c2 = ArrayView>(c2); + + // (i \xor b1 \xor b2) * a - c1 - c2 + pforeach(0, numel, [&](int64_t idx) { + _m0[idx][0] = + (_m0[idx][0] ^ (_b[idx][0] & 0x1) ^ (_b[idx][1] & 0x1)) * + _a[idx][0] - + _c1[idx][0] - _c2[idx][0]; + }); + + pforeach(0, numel, [&](int64_t idx) { + _m1[idx][0] = + (_m1[idx][0] ^ (_b[idx][0] & 0x1) ^ (_b[idx][1] & 0x1)) * + _a[idx][0] - + _c1[idx][0] - _c2[idx][0]; + }); + + return {c1, c2, m0, m1}; + } else if (self_rank == (sender + 1) % 3) { + prg_state->genPrssPair(field, numel, true, true); + auto c1 = prg_state->genPrssPair(field, numel, false, true).first; + + return {c1}; + } else { + auto c2 = prg_state->genPrssPair(field, numel, true, false).second; + prg_state->genPrssPair(field, numel, true, true); + + return {c2}; + } + }); + }); +} - return {c1}; - } else { - auto c2 = prg_state->genPrssPair(field, numel, true, false).second; - prg_state->genPrssPair(field, numel, true, true); +std::vector ring_cast_boolean_(const ArrayRef& x) { + YACL_ENFORCE(x.eltype().isa(), "expect PtTy type, got={}", x.eltype()); + const size_t numel = x.numel(); + std::vector res(numel); - return {c2}; - } + DISPATCH_UINT_PT_TYPES(x.eltype().as()->pt_type(), "_", [&]() { + using BShrT = ScalarT; + auto _x = ArrayView>(x); + pforeach(0, numel, [&](int64_t idx) { + res[idx] = static_cast(_x[idx][0] & 0x1); + }); + }); + + return res; } } // namespace @@ -325,12 +365,21 @@ ArrayRef MulA1B::proc(KernelEvalContext* ctx, const ArrayRef& lhs, SPU_TRACE_MPC_LEAF(ctx, lhs, rhs); YACL_ENFORCE(lhs.numel() == rhs.numel()); - YACL_ENFORCE(lhs.eltype().isa()); - YACL_ENFORCE(rhs.eltype().isa() && - rhs.eltype().as()->nbits() == 1); + YACL_ENFORCE(lhs.eltype().isa()); + YACL_ENFORCE(rhs.eltype().isa() && + rhs.eltype().as()->nbits() == 1); + + const auto field = lhs.eltype().as()->field(); + const size_t in_nbits = rhs.eltype().as()->nbits(); + + YACL_ENFORCE(in_nbits <= SizeOf(field) * 8, "invalid nbits={}", in_nbits); + + const auto numel = rhs.numel(); + + const auto b_ty = *rhs.eltype().as(); + + ArrayRef out(makeType(field), numel); - const auto field = lhs.eltype().as()->field(); - const auto numel = lhs.numel(); auto* comm = ctx->caller()->getState(); auto* prg_state = ctx->caller()->getState(); @@ -344,9 +393,10 @@ ArrayRef MulA1B::proc(KernelEvalContext* ctx, const ArrayRef& lhs, // leave only lsb, in case the boolean value is randomized in a larger // domain. - const auto kOne = ring_ones(field, numel); - b1 = ring_and(b1, kOne); - b2 = ring_and(b2, kOne); + // NOTE: This is useless since we have n_bits to indicate valid bits + // const auto kOne = ring_ones(back_type, rhs.numel()); + // b1 = ring_and(b1, kOne); + // b2 = ring_and(b2, kOne); auto self_rank = comm->getRank(); const auto kComm = a1.elsize() * a1.numel(); @@ -364,7 +414,7 @@ ArrayRef MulA1B::proc(KernelEvalContext* ctx, const ArrayRef& lhs, // TODO: optimization for large input. // online part: tasks two rounds latency. do 3-parties OT. auto offline = [&](size_t sender, const ArrayRef& a) { - return a1b_offline(sender, a, field, self_rank, prg_state, numel, b1, b2); + return a1b_offline(sender, a, field, self_rank, prg_state, numel, rhs); }; // parallel online: parallel two 3-parties OT. @@ -380,6 +430,7 @@ ArrayRef MulA1B::proc(KernelEvalContext* ctx, const ArrayRef& lhs, // asymmetric cost. comm->addCommStatsManually(2, kComm * 8); + // c1, c3, m0, m1 if (self_rank == sender1) { ot1.send(data1[2], data1[3]); // 2k r1 = {data1[0], data1[1]}; @@ -389,28 +440,31 @@ ArrayRef MulA1B::proc(KernelEvalContext* ctx, const ArrayRef& lhs, r2 = {data2[0], data2[1]}; } + // helper send wc to receiver if (self_rank == (sender1 + 1) % 3) { - ot1.help(ring_cast_boolean(b2)); // 1k + ot1.help(ring_cast_boolean_(b2)); // 1k r1.first = data1[0]; } if (self_rank == (sender2 + 1) % 3) { - ot2.help(ring_cast_boolean(b2)); // 1k + ot2.help(ring_cast_boolean_(b2)); // 1k r2.first = data2[0]; } + // receiver recv c2 and send c2 to helper if (self_rank == (sender1 + 2) % 3) { // 1 latency - auto c1 = ot1.recv(ring_cast_boolean(b1)); + auto c1 = ot1.recv(ring_cast_boolean_(b1)); comm->sendAsync((sender1 + 1) % 3, c1, "ABY3-MUL-R1C1"); // 1k r1 = {c1, data1[0]}; } if (self_rank == (sender2 + 2) % 3) { // 1 latency overlapping with "ABY3-MUL-R1C1" - auto c1 = ot2.recv(ring_cast_boolean(b1)); + auto c1 = ot2.recv(ring_cast_boolean_(b1)); comm->sendAsync((sender2 + 1) % 3, c1, "ABY3-MUL-R2C1"); // 1k r2 = {c1, data2[0]}; } + // // NOTE: here two sequential rounds are required if (self_rank == (sender1 + 1) % 3) { // 1 latency r1.second = comm->recv((sender1 + 2) % 3, a1.eltype(), "ABY3-MUL-R1C1"); @@ -420,9 +474,21 @@ ArrayRef MulA1B::proc(KernelEvalContext* ctx, const ArrayRef& lhs, r2.second = comm->recv((sender2 + 2) % 3, a1.eltype(), "ABY3-MUL-R2C1"); } - ring_add_(r1.first, r2.first); - ring_add_(r1.second, r2.second); + DISPATCH_ALL_FIELDS(field, "_", [&]() { + using AShrT = ring2k_t; + auto _r1_first = ArrayView>(r1.first); + auto _r1_second = ArrayView>(r1.second); + auto _r2_first = ArrayView>(r2.first); + auto _r2_second = ArrayView>(r2.second); + + // r1.first = r1.first + r2.first + // r1.second = r1.second + r2.second + pforeach(0, r1.first.numel(), [&](int64_t idx) { + _r1_first[idx][0] = _r1_first[idx][0] + _r2_first[idx][0]; + _r1_second[idx][0] = _r1_second[idx][0] + _r2_second[idx][0]; + }); + }); return r1; }; @@ -430,8 +496,10 @@ ArrayRef MulA1B::proc(KernelEvalContext* ctx, const ArrayRef& lhs, auto data1 = offline(0, a2); // only sender access a1 + a2, avoid useless add for other two parties. auto data2 = offline(2, self_rank == 2 ? ring_add(a1, a2) : a2); + auto ret = parallel_online(0, data1, 2, data2); - return makeAShare(ret.first, ret.second, field); + + return makeAShare(ret.first, ret.second, field, self_rank); } //////////////////////////////////////////////////////////////////// diff --git a/spu/mpc/aby3/ot.cc b/spu/mpc/aby3/ot.cc index cfd4fc22..d41f89d7 100644 --- a/spu/mpc/aby3/ot.cc +++ b/spu/mpc/aby3/ot.cc @@ -111,6 +111,7 @@ ArrayRef Ot3::recv(const std::vector& choices) { // get masked messages from sender. auto m0 = comm_->recv(roles_.sender, ty, "m0"); auto m1 = comm_->recv(roles_.sender, ty, "m1"); + auto mc = ring_select(choices, m0, m1); // get chosen masks auto wc = comm_->recv(roles_.helper, ty, "wc"); diff --git a/spu/mpc/aby3/protocol.cc b/spu/mpc/aby3/protocol.cc index db118cbd..c7e481ca 100644 --- a/spu/mpc/aby3/protocol.cc +++ b/spu/mpc/aby3/protocol.cc @@ -56,7 +56,7 @@ std::unique_ptr makeAby3Protocol( obj->regKernel(); obj->regKernel(); // FIXME: temp disable MulA1B method. - // obj->regKernel(); + obj->regKernel(); obj->regKernel(); obj->regKernel(); obj->regKernel();