diff --git a/libspu/kernel/hal/fxp_cleartext.cc b/libspu/kernel/hal/fxp_cleartext.cc index a7636e81..91b48b0d 100644 --- a/libspu/kernel/hal/fxp_cleartext.cc +++ b/libspu/kernel/hal/fxp_cleartext.cc @@ -145,4 +145,10 @@ Value f_erf_p(SPUContext* ctx, const Value& in) { return applyFloatingPointFn(ctx, in, [](float x) { return std::erf(x); }); } +Value f_pow_p(SPUContext* ctx, const Value& x, const Value& y) { + SPU_TRACE_HAL_DISP(ctx, x, y); + return applyFloatingPointFn(ctx, x, y, + [](float a, float b) { return std::pow(a, b); }); +} + } // namespace spu::kernel::hal diff --git a/libspu/kernel/hal/fxp_cleartext.h b/libspu/kernel/hal/fxp_cleartext.h index a38354cf..1b4a264d 100644 --- a/libspu/kernel/hal/fxp_cleartext.h +++ b/libspu/kernel/hal/fxp_cleartext.h @@ -40,4 +40,6 @@ Value f_cosine_p(SPUContext* ctx, const Value& in); Value f_erf_p(SPUContext* ctx, const Value& in); +Value f_pow_p(SPUContext* ctx, const Value& x, const Value& y); + } // namespace spu::kernel::hal diff --git a/libspu/kernel/hal/polymorphic.cc b/libspu/kernel/hal/polymorphic.cc index 2542eb59..f63dacbe 100644 --- a/libspu/kernel/hal/polymorphic.cc +++ b/libspu/kernel/hal/polymorphic.cc @@ -19,6 +19,7 @@ #include "libspu/core/trace.h" #include "libspu/kernel/hal/fxp_approx.h" #include "libspu/kernel/hal/fxp_base.h" +#include "libspu/kernel/hal/fxp_cleartext.h" #include "libspu/kernel/hal/integer.h" #include "libspu/kernel/hal/ring.h" // for fast fxp x int #include "libspu/kernel/hal/type_cast.h" @@ -329,15 +330,36 @@ Value min(SPUContext* ctx, const Value& x, const Value& y) { Value power(SPUContext* ctx, const Value& x, const Value& y) { SPU_TRACE_HAL_DISP(ctx, x, y); - if (x.isInt() && y.isInt()) { + if (x.isInt() || y.isInt()) { auto x_f = dtype_cast(ctx, x, DT_F32); auto y_f = dtype_cast(ctx, y, DT_F32); auto ret = power(ctx, x_f, y_f); - return dtype_cast(ctx, ret, x.dtype()); + return ret; + } + if (x.isPublic() && y.isPublic()) { + return f_pow_p(ctx, x, y); } + auto msb = _msb(ctx, x); + auto msb_a = _prefer_a(ctx, msb); + auto x_abs = _mux(ctx, msb_a, _negate(ctx, x), x).setDtype(x.dtype()); + + // if x=0 is public, then log(x) get -inf, the wrong output will be got after + // multiplying y. So we force x to be secret, then computing log(x) leads to + // a small negative numbers, so exp(y*log(x))=0. + auto x_s = x.isPublic() ? hal::seal(ctx, x_abs) : x_abs; // x^y = e^(y*ln(x)) - return exp(ctx, mul(ctx, y, log(ctx, x))); + // the precision is highly dependent on the precision of exp and log, so we + // choose the most precise methods here. + auto val = detail::exp_pade(ctx, mul(ctx, y, detail::log_minmax(ctx, x_s))); + + // the final sign is decided on both sign of x and the parity of y + // when x<0 and y is odd, e.g. (-2)^3 = -8 + auto odd = _and(ctx, _rshift(ctx, y, ctx->getFxpBits()), + _constant(ctx, 1, y.shape())); + auto sign = _and(ctx, msb, odd); + + return _mux(ctx, sign, _negate(ctx, val), val).setDtype(x.dtype()); } Value idiv(SPUContext* ctx, const Value& x, const Value& y) { diff --git a/libspu/kernel/hal/polymorphic_test.cc b/libspu/kernel/hal/polymorphic_test.cc index c623e013..e4a7014e 100644 --- a/libspu/kernel/hal/polymorphic_test.cc +++ b/libspu/kernel/hal/polymorphic_test.cc @@ -406,26 +406,42 @@ TYPED_TEST(MathTest, Pow) { using LHS_VT = typename std::tuple_element<1, TypeParam>::type; using RHS_DT = typename std::tuple_element<2, TypeParam>::type; using RHS_VT = typename std::tuple_element<3, TypeParam>::type; - using RES_DT = typename std::tuple_element<4, TypeParam>::type; + // using RES_DT = typename std::tuple_element<4, TypeParam>::type; - if constexpr (!std::is_same_v || - !std::is_same_v || std::is_integral_v) { - return; + // GIVEN + xt::xarray x; + xt::xarray y; + { + // random test + x = test::xt_random({5, 6}, 0, 100); + y = test::xt_random({5, 6}, -2, 2); + + // WHAT + auto z = test::evalBinaryOp(LHS_VT(), RHS_VT(), power, x, y); + + // THEN + auto expected = xt::pow(x, y); + EXPECT_TRUE(xt::allclose(expected, z, 0.3, 0.03)) << x << std::endl + << y << std::endl + << expected << std::endl + << z << std::endl; } - // GIVEN - const xt::xarray x = test::xt_random({5, 6}, 0, 100); - const xt::xarray y = test::xt_random({5, 6}, 0, 2); + { + // some fixed corner case + x = {-1, -1, -3, 1, -3, 0, 1, 1, 5, 0}; + y = {1, 0, -3, -3, 3, 0, 0, 2, 5, 2}; - // WHAT - auto z = test::evalBinaryOp(LHS_VT(), RHS_VT(), power, x, y); + // WHAT + auto z = test::evalBinaryOp(LHS_VT(), RHS_VT(), power, x, y); - // THEN - auto expected = xt::pow(x, y); - EXPECT_TRUE(xt::allclose(expected, z, 0.3, 0.03)) << x << std::endl - << y << std::endl - << expected << std::endl - << z << std::endl; + // THEN + auto expected = xt::pow(x, y); + EXPECT_TRUE(xt::allclose(expected, z, 0.3, 0.03)) << x << std::endl + << y << std::endl + << expected << std::endl + << z << std::endl; + } } using MathUnaryTestTypes = ::testing::Types< diff --git a/libspu/mpc/cheetah/state.cc b/libspu/mpc/cheetah/state.cc index 05c8c978..cc41365c 100644 --- a/libspu/mpc/cheetah/state.cc +++ b/libspu/mpc/cheetah/state.cc @@ -28,7 +28,7 @@ namespace spu::mpc::cheetah { namespace { // Return num_workers for the given size of jobs size_t InitOTState(KernelEvalContext* ctx, size_t njobs) { - constexpr size_t kMinWorkSize = 5000; + constexpr size_t kMinWorkSize = 2048; if (njobs == 0) { return 0; } @@ -139,86 +139,44 @@ std::array CheetahMulState::TakeCachedBeaver(FieldType field, NdArrayRef TiledDispatchOTFunc(KernelEvalContext* ctx, const NdArrayRef& x, OTUnaryFunc func) { - Shape shape = x.shape(); + const Shape& shape = x.shape(); + SPU_ENFORCE(shape.numel() > 0); // (lazy) init OT int64_t numel = x.numel(); int64_t nworker = InitOTState(ctx, numel); int64_t workload = nworker == 0 ? 0 : CeilDiv(numel, nworker); - int64_t slicing_dim = -1; - int64_t slice_numel = 1; - for (int64_t dim = shape.size() - 1; dim >= 0; dim--) { - slice_numel *= shape[dim]; - if (slice_numel > workload) { - slice_numel /= shape[dim]; - slicing_dim = dim; - break; - } - } - - // get the slice num in the left outer dimensions - int64_t num_slice = 1; - for (int64_t dim = 0; dim < slicing_dim; dim++) { - num_slice *= shape[dim]; - } - - int64_t slice_stride = (workload + slice_numel - 1) / slice_numel; - if (slice_stride == 1) { - return func(x, ctx->getState()->get(0)); - } - - int64_t num_slice_dim = shape[slicing_dim] / slice_stride + - ((shape[slicing_dim] % slice_stride) != 0 ? 1 : 0); - - // initialize slice indices - Index start_indices(shape.size()); - Index end_indices(shape.begin(), shape.end()); - end_indices[slicing_dim] = slice_stride; - for (int64_t dim = slicing_dim - 1; dim >= 0; dim--) { - end_indices[dim] = 1; + if (shape.ndim() != 1) { + // TiledDispatchOTFunc over flatten input + return TiledDispatchOTFunc(ctx, x.reshape({numel}), func) + .reshape(x.shape()); } - SPU_ENFORCE_LE(num_slice * num_slice_dim, nworker); - nworker = num_slice * num_slice_dim; - std::vector outs(nworker); std::vector> futures; - Index sidx = start_indices; - Index eidx = end_indices; - for (int64_t wi = 0; wi < nworker; ++wi) { - auto slice_input = x.slice(sidx, eidx, {}); + int64_t slice_end = 0; + for (int64_t wi = 0; wi + 1 < nworker; ++wi) { + int64_t slice_bgn = wi * workload; + slice_end = std::min(numel, slice_bgn + workload); + auto slice_input = x.slice({slice_bgn}, {slice_end}, {}); futures.emplace_back(std::async( [&](int64_t idx, const NdArrayRef& input) { auto ot_instance = ctx->getState()->get(idx); outs[idx] = func(input, ot_instance); }, wi, slice_input)); - - // update indices - if (0 == (eidx[slicing_dim] % shape[slicing_dim])) { - // carray out - sidx[slicing_dim] = 0; - eidx[slicing_dim] = slice_stride; - for (int64_t dim = slicing_dim - 1; dim >= 0; dim--) { - sidx[dim] = (sidx[dim] + 1) % shape[dim]; - eidx[dim] = eidx[dim] % shape[dim] + 1; - if (eidx[dim] != 1) { - break; - } - } - } else { - sidx[slicing_dim] += slice_stride; - eidx[slicing_dim] += slice_stride; - eidx[slicing_dim] = std::min(shape[slicing_dim], eidx[slicing_dim]); - } } + auto slice_input = x.slice({slice_end}, {numel}, {1}); + auto ot_instance = ctx->getState()->get(nworker - 1); + outs[nworker - 1] = func(slice_input, ot_instance); + for (auto&& f : futures) { f.get(); } - NdArrayRef out(x.eltype(), x.shape()); + NdArrayRef out(outs[0].eltype(), x.shape()); int64_t offset = 0; for (auto& out_slice : outs) { @@ -232,89 +190,50 @@ NdArrayRef TiledDispatchOTFunc(KernelEvalContext* ctx, const NdArrayRef& x, NdArrayRef TiledDispatchOTFunc(KernelEvalContext* ctx, const NdArrayRef& x, const NdArrayRef& y, OTBinaryFunc func) { - Shape shape = x.shape(); - SPU_ENFORCE_EQ(x.shape(), y.shape()); + const Shape& shape = x.shape(); + SPU_ENFORCE(shape.numel() > 0); + SPU_ENFORCE_EQ(shape, y.shape()); // (lazy) init OT int64_t numel = x.numel(); int64_t nworker = InitOTState(ctx, numel); int64_t workload = nworker == 0 ? 0 : CeilDiv(numel, nworker); - int64_t slicing_dim = -1; - int64_t slice_numel = 1; - for (int64_t dim = shape.size() - 1; dim >= 0; dim--) { - slice_numel *= shape[dim]; - if (slice_numel > workload) { - slice_numel /= shape[dim]; - slicing_dim = dim; - break; - } + if (shape.ndim() != 1) { + // TiledDispatchOTFunc over flatten input + return TiledDispatchOTFunc(ctx, x.reshape({numel}), y.reshape({numel}), + func) + .reshape(x.shape()); } - // get the slice num in the left outer dimensions - int64_t num_slice = 1; - for (int64_t dim = 0; dim < slicing_dim; dim++) { - num_slice *= shape[dim]; - } - - int64_t slice_stride = (workload + slice_numel - 1) / slice_numel; - if (slice_stride == 1) { - return func(x, y, ctx->getState()->get(0)); - } - - int64_t num_slice_dim = shape[slicing_dim] / slice_stride + - ((shape[slicing_dim] % slice_stride) != 0 ? 1 : 0); - - // initialize slice indices - Index start_indices(shape.size()); - Index end_indices(shape.begin(), shape.end()); - end_indices[slicing_dim] = slice_stride; - for (int64_t dim = slicing_dim - 1; dim >= 0; dim--) { - end_indices[dim] = 1; - } - - SPU_ENFORCE_LE(num_slice * num_slice_dim, nworker); - nworker = num_slice * num_slice_dim; - std::vector outs(nworker); std::vector> futures; - Index sidx = start_indices; - Index eidx = end_indices; - for (int64_t wi = 0; wi < nworker; ++wi) { - auto x_slice = x.slice(sidx, eidx, {}); - auto y_slice = y.slice(sidx, eidx, {}); - + int64_t slice_end = 0; + for (int64_t wi = 0; wi + 1 < nworker; ++wi) { + int64_t slice_bgn = wi * workload; + slice_end = std::min(numel, slice_bgn + workload); + auto x_slice = x.slice({slice_bgn}, {slice_end}, {1}); + auto y_slice = y.slice({slice_bgn}, {slice_end}, {1}); futures.emplace_back(std::async( - [&](int64_t idx, const NdArrayRef& input0, const NdArrayRef& input1) { + [&](int64_t idx, const NdArrayRef& inp0, const NdArrayRef& inp1) { auto ot_instance = ctx->getState()->get(idx); - outs[idx] = func(input0, input1, ot_instance); + outs[idx] = func(inp0, inp1, ot_instance); }, wi, x_slice, y_slice)); - - // update indices - if (0 == (eidx[slicing_dim] % shape[slicing_dim])) { - // carray out - sidx[slicing_dim] = 0; - eidx[slicing_dim] = slice_stride; - for (int64_t dim = slicing_dim - 1; dim >= 0; dim--) { - sidx[dim] = (sidx[dim] + 1) % shape[dim]; - eidx[dim] = eidx[dim] % shape[dim] + 1; - if (eidx[dim] != 1) { - break; - } - } - } else { - sidx[slicing_dim] += slice_stride; - eidx[slicing_dim] += slice_stride; - eidx[slicing_dim] = std::min(shape[slicing_dim], eidx[slicing_dim]); - } } + + auto x_slice = x.slice({slice_end}, {numel}, {}); + auto y_slice = y.slice({slice_end}, {numel}, {}); + auto ot_instance = ctx->getState()->get(nworker - 1); + outs[nworker - 1] = func(x_slice, y_slice, ot_instance); + for (auto&& f : futures) { f.get(); } - NdArrayRef out(x.eltype(), x.shape()); + NdArrayRef out(outs[0].eltype(), x.shape()); int64_t offset = 0; + for (auto& out_slice : outs) { std::memcpy(out.data() + offset, out_slice.data(), out_slice.numel() * out.elsize()); diff --git a/libspu/mpc/cheetah/state.h b/libspu/mpc/cheetah/state.h index e891d454..36277138 100644 --- a/libspu/mpc/cheetah/state.h +++ b/libspu/mpc/cheetah/state.h @@ -25,6 +25,8 @@ #include "libspu/mpc/cheetah/ot/basic_ot_prot.h" #include "libspu/mpc/cheetah/rlwe/utils.h" +#include "libspu/spu.pb.h" + namespace spu::mpc::cheetah { using OTUnaryFunc = std::function basic_ot_prot_;