diff --git a/.bazelrc b/.bazelrc index e31a866e..a1604334 100644 --- a/.bazelrc +++ b/.bazelrc @@ -62,7 +62,3 @@ build:macos --host_macos_minimum_os=12.0 build:linux --copt=-fopenmp build:linux --linkopt=-fopenmp - -build:asan --features=asan -build:ubsan --features=ubsan - diff --git a/CHANGELOG.md b/CHANGELOG.md index efc46dea..afd2d802 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,11 +10,12 @@ > > please add your unreleased change here. -- [Feature] Add Odd-Even Merge Sort to replace the bitonic sort. -- [Feature] Add radix sort support for ABY3. -- [Feature] Integrate with secretflow/psi. -- [Feature] Add Linux aarch64 support. -- [Improvement] Optimize sort memory usage. +- [Feature] Add Odd-Even Merge Sort to replace the bitonic sort +- [Feature] Add radix sort support for ABY3 +- [Feature] Integrate with secretflow/psi +- [Feature] Add Linux aarch64 support +- [Feature] Add equal support for SEMI2K and ABY3 +- [Improvement] Optimize sort memory usage - [Deprecated] macOS 11.x is no longer supported ## 20231108 diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index ffbc39ca..12fd62dc 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -105,8 +105,8 @@ bazel build //... -c opt bazel test //... # [optional] build & test with ASAN or UBSAN, for macOS users please use configs with macOS prefix -bazel test //... --config=[macos-]asan -bazel test //... --config=[macos-]ubsan +bazel test //... --features=asan +bazel test //... --features=ubsan ``` ### Bazel build options diff --git a/libspu/mpc/ab_api_test.cc b/libspu/mpc/ab_api_test.cc index 6f58eda8..6bc25ffe 100644 --- a/libspu/mpc/ab_api_test.cc +++ b/libspu/mpc/ab_api_test.cc @@ -687,4 +687,74 @@ TEST_P(ConversionTest, MSB) { }); } +TEST_P(ConversionTest, EqualAA) { + const auto factory = std::get<0>(GetParam()); + const RuntimeConfig& conf = std::get<1>(GetParam()); + const size_t npc = std::get<2>(GetParam()); + + utils::simulate(npc, [&](std::shared_ptr lctx) { + auto obj = factory(conf, lctx); + + if (!obj->hasKernel("equal_aa")) { + return; + } + /* GIVEN */ + auto r0 = rand_p(obj.get(), kShape); + auto r1 = rand_p(obj.get(), kShape); + auto r2 = rand_p(obj.get(), kShape); + std::memcpy(r2.data().data(), r0.data().data(), 16); + std::vector test_values = {r0, r1, r2}; + + for (auto& test_value : test_values) { + auto l_value = p2a(obj.get(), r0); + auto r_value = p2a(obj.get(), test_value); + auto prev = obj->prot()->getState()->getStats(); + auto tmp = dynDispatch(obj.get(), "equal_aa", l_value, r_value); + auto cost = obj->prot()->getState()->getStats() - prev; + auto out_value = b2p(obj.get(), tmp); + auto t_value = equal_pp(obj.get(), r0, test_value); + + /* THEN */ + EXPECT_VALUE_EQ(out_value, t_value); + EXPECT_TRUE(verifyCost(obj->prot()->getKernel("equal_aa"), "equal_aa", + conf.field(), kShape, npc, cost)); + } + }); +} + +TEST_P(ConversionTest, EqualAP) { + const auto factory = std::get<0>(GetParam()); + const RuntimeConfig& conf = std::get<1>(GetParam()); + const size_t npc = std::get<2>(GetParam()); + + utils::simulate(npc, [&](std::shared_ptr lctx) { + auto obj = factory(conf, lctx); + + if (!obj->hasKernel("equal_ap")) { + return; + } + /* GIVEN */ + auto r0 = rand_p(obj.get(), kShape); + auto r1 = rand_p(obj.get(), kShape); + auto r2 = rand_p(obj.get(), kShape); + std::memcpy(r2.data().data(), r0.data().data(), 16); + std::vector test_values = {r0, r1, r2}; + + for (auto& test_value : test_values) { + auto l_value = p2a(obj.get(), r0); + auto r_value = test_value; + auto prev = obj->prot()->getState()->getStats(); + auto tmp = dynDispatch(obj.get(), "equal_ap", l_value, r_value); + auto cost = obj->prot()->getState()->getStats() - prev; + auto out_value = b2p(obj.get(), tmp); + auto t_value = equal_pp(obj.get(), r0, test_value); + + /* THEN */ + EXPECT_VALUE_EQ(out_value, t_value); + EXPECT_TRUE(verifyCost(obj->prot()->getKernel("equal_ap"), "equal_ap", + conf.field(), kShape, npc, cost)); + } + }); +} + } // namespace spu::mpc::test diff --git a/libspu/mpc/aby3/conversion.cc b/libspu/mpc/aby3/conversion.cc index f73d3037..2a95a761 100644 --- a/libspu/mpc/aby3/conversion.cc +++ b/libspu/mpc/aby3/conversion.cc @@ -27,6 +27,7 @@ #include "libspu/mpc/common/communicator.h" #include "libspu/mpc/common/prg_state.h" #include "libspu/mpc/common/pv2k.h" +#include "libspu/mpc/utils/ring_ops.h" namespace spu::mpc::aby3 { @@ -624,6 +625,292 @@ NdArrayRef MsbA2B::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { } } +// Reference: +// New Primitives for Actively-Secure MPC over Rings with Applications to +// Private Machine Learning +// P8 IV.D protocol eqz +// https://eprint.iacr.org/2019/599.pdf +// +// Improved Primitives for MPC over Mixed Arithmetic-Binary Circuits +// https://eprint.iacr.org/2020/338.pdf +// +// P0 as the helper/dealer, samples r, deals [r]a and [r]b. +// P1 and P2 get new share [a] +// P1: [a] = x2 + x3 +// P2: [a] = x1 +// reveal c = [a]+[r]a +// check [a] == 0 <=> c == r +// c == r <=> ~c ^ rb to be bit wise all 1 +// then eqz(a) = bit_wise_and(~c ^ rb) +NdArrayRef eqz(KernelEvalContext* ctx, const NdArrayRef& in) { + auto* prg_state = ctx->getState(); + auto* comm = ctx->getState(); + + const auto field = in.eltype().as()->field(); + const PtType in_bshr_btype = calcBShareBacktype(SizeOf(field) * 8); + const auto numel = in.numel(); + + NdArrayRef out(makeType(calcBShareBacktype(8), 8), in.shape()); + + size_t pivot; + prg_state->fillPubl(absl::MakeSpan(&pivot, 1)); + size_t P0 = pivot % 3; + size_t P1 = (pivot + 1) % 3; + size_t P2 = (pivot + 2) % 3; + + DISPATCH_ALL_FIELDS(field, "_", [&]() { + using ashr_el_t = ring2k_t; + using ashr_t = std::array; + DISPATCH_UINT_PT_TYPES(in_bshr_btype, "_", [&]() { + using bshr_el_t = ScalarT; + std::vector zero_flag_3pc_0(numel); + std::vector zero_flag_3pc_1(numel); + + // algorithm begins + if (comm->getRank() == P0) { + std::vector r(numel); + prg_state->fillPriv(absl::MakeSpan(r)); + + std::vector r_arith_0(numel); + prg_state->fillPrssPair({}, r_arith_0.data(), numel, + PrgState::GenPrssCtrl::Second); + std::vector r_bool_0(numel); + prg_state->fillPrssPair({}, r_bool_0.data(), numel, + PrgState::GenPrssCtrl::Second); + + std::vector r_arith_1(numel); + pforeach(0, numel, [&](int64_t idx) { + r_arith_1[idx] = r[idx] - r_arith_0[idx]; + }); + comm->sendAsync(P2, r_arith_1, "r_arith"); + + std::vector r_bool_1(numel); + pforeach(0, numel, + [&](int64_t idx) { r_bool_1[idx] = r[idx] ^ r_bool_0[idx]; }); + comm->sendAsync(P2, r_bool_1, "r_bool"); + + // back to 3 pc + // P0 zero_flag = (rb1, rz) + pforeach(0, numel, + [&](int64_t idx) { zero_flag_3pc_0[idx] = r_bool_1[idx]; }); + + prg_state->fillPrssPair({}, zero_flag_3pc_1.data(), numel, + PrgState::GenPrssCtrl::Second); + + } else { + std::vector a_s(numel); + NdArrayView _in(in); + std::vector r_arith(numel); + std::vector r_bool(numel); + + if (comm->getRank() == P1) { + pforeach(0, numel, + [&](int64_t idx) { a_s[idx] = _in[idx][0] + _in[idx][1]; }); + + prg_state->fillPrssPair(r_arith.data(), {}, numel, + PrgState::GenPrssCtrl::First); + prg_state->fillPrssPair(r_bool.data(), {}, numel, + PrgState::GenPrssCtrl::First); + } else { + pforeach(0, numel, [&](int64_t idx) { a_s[idx] = _in[idx][1]; }); + prg_state->fillPrssPair({}, {}, numel, + PrgState::GenPrssCtrl::None); + prg_state->fillPrssPair({}, {}, numel, + PrgState::GenPrssCtrl::None); + r_arith = comm->recv(P0, "r_arith"); + r_bool = comm->recv(P0, "r_bool"); + } + + // c in secret share + std::vector c_s(numel); + pforeach(0, numel, + [&](int64_t idx) { c_s[idx] = r_arith[idx] + a_s[idx]; }); + + std::vector zero_flag_2pc(numel); + if (comm->getRank() == P1) { + auto c_p = comm->recv(P2, "c_s"); + + // reveal c + pforeach(0, numel, + [&](int64_t idx) { c_p[idx] = c_p[idx] + c_s[idx]; }); + // P1 zero_flag = (rz, not(c_p xor [r]b0)^ rz) + std::vector r_z(numel); + prg_state->fillPrssPair(r_z.data(), {}, numel, + PrgState::GenPrssCtrl::First); + pforeach(0, numel, [&](int64_t idx) { + zero_flag_2pc[idx] = ~(c_p[idx] ^ r_bool[idx]) ^ r_z[idx]; + }); + + comm->sendAsync(P2, zero_flag_2pc, "flag_split"); + + pforeach(0, numel, [&](int64_t idx) { + zero_flag_3pc_0[idx] = r_z[idx]; + zero_flag_3pc_1[idx] = zero_flag_2pc[idx]; + }); + } else { + comm->sendAsync(P1, c_s, "c_s"); + // P1 zero_flag = (not(c_p xor [r]b0)^ rz, rb1) + pforeach(0, numel, + [&](int64_t idx) { zero_flag_3pc_1[idx] = r_bool[idx]; }); + prg_state->fillPrssPair({}, {}, numel, + PrgState::GenPrssCtrl::None); + + auto flag_split = comm->recv(P1, "flag_split"); + pforeach(0, numel, [&](int64_t idx) { + zero_flag_3pc_0[idx] = flag_split[idx]; + }); + } + } + + // Reference: + // Improved Primitives for Secure Multiparty Integer Computation + // P10 4.1 k-ary + // https://link.springer.com/chapter/10.1007/978-3-642-15317-4_13 + // + // if a == 0, zero_flag supposed to be all 1 + // do log k round bit wise and + // in each round, bit wise split zero_flag in half + // compute and(left_half, right_half) + auto cur_bytes = SizeOf(field) * numel; + auto cur_bits = cur_bytes * 8; + auto cur_numel = (unsigned long)numel; + std::vector round_res_0(cur_bytes); + std::memcpy(round_res_0.data(), zero_flag_3pc_0.data(), cur_bytes); + std::vector round_res_1(cur_bytes); + std::memcpy(round_res_1.data(), zero_flag_3pc_1.data(), cur_bytes); + while (cur_bits != cur_numel) { + // byte num per element + auto byte_num_el = cur_bytes == cur_numel ? 1 : (cur_bytes / numel); + // byte num of left/right_bits + auto half_num_bytes = + cur_bytes == cur_numel ? cur_numel : (cur_bytes / 2); + + // break into left_bits and right_bits + std::vector> left_bits( + 2, std::vector(half_num_bytes)); + std::vector> right_bits( + 2, std::vector(half_num_bytes)); + + // cur_bits <= 8, use rshift to split in half + if (cur_bytes == cur_numel) { + pforeach(0, numel, [&](int64_t idx) { + left_bits[0][idx] = + round_res_0[idx] >> (cur_bits / (cur_numel * 2)); + left_bits[1][idx] = + round_res_1[idx] >> (cur_bits / (cur_numel * 2)); + right_bits[0][idx] = round_res_0[idx]; + right_bits[1][idx] = round_res_1[idx]; + }); + // cur_bits > 8 + } else { + pforeach(0, numel, [&](int64_t idx) { + auto cur_byte_idx = idx * byte_num_el; + for (size_t i = 0; i < (byte_num_el / 2); i++) { + left_bits[0][cur_byte_idx / 2 + i] = + round_res_0[cur_byte_idx + i]; + left_bits[1][cur_byte_idx / 2 + i] = + round_res_1[cur_byte_idx + i]; + } + for (size_t i = 0; i < (byte_num_el / 2); i++) { + right_bits[0][cur_byte_idx / 2 + i] = + round_res_0[cur_byte_idx + byte_num_el / 2 + i]; + right_bits[1][cur_byte_idx / 2 + i] = + round_res_1[cur_byte_idx + byte_num_el / 2 + i]; + } + }); + } + + // compute and(left_half, right_half) + std::vector r0(half_num_bytes); + std::vector r1(half_num_bytes); + prg_state->fillPrssPair(r0.data(), r1.data(), half_num_bytes, + PrgState::GenPrssCtrl::Both); + + // z1 = (x1 & y1) ^ (x1 & y2) ^ (x2 & y1) ^ (r0 ^ r1); + pforeach(0, half_num_bytes, [&](int64_t idx) { + r0[idx] = (left_bits[0][idx] & right_bits[0][idx]) ^ + (left_bits[0][idx] & right_bits[1][idx]) ^ + (left_bits[1][idx] & right_bits[0][idx]) ^ + (r0[idx] ^ r1[idx]); + }); + + auto temp = comm->rotate(r0, "andbb"); + r1.assign(temp.begin(), temp.end()); + + cur_bytes = cur_bytes == cur_numel ? cur_numel : (cur_bytes / 2); + cur_bits /= 2; + round_res_0.assign(r0.begin(), r0.end()); + round_res_1.assign(r1.begin(), r1.end()); + } + + NdArrayView> _out(out); + + pforeach(0, numel, [&](int64_t idx) { + _out[idx][0] = round_res_0[idx]; + _out[idx][1] = round_res_1[idx]; + }); + }); + }); + + return out; +} + +NdArrayRef EqualAA::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, + const NdArrayRef& rhs) const { + const auto* lhs_ty = lhs.eltype().as(); + const auto* rhs_ty = rhs.eltype().as(); + + SPU_ENFORCE(lhs_ty->field() == rhs_ty->field()); + const auto field = lhs_ty->field(); + NdArrayRef out(makeType(field), lhs.shape()); + + DISPATCH_ALL_FIELDS(field, "_", [&]() { + using shr_t = std::array; + NdArrayView _out(out); + NdArrayView _lhs(lhs); + NdArrayView _rhs(rhs); + + pforeach(0, lhs.numel(), [&](int64_t idx) { + _out[idx][0] = _lhs[idx][0] - _rhs[idx][0]; + _out[idx][1] = _lhs[idx][1] - _rhs[idx][1]; + }); + }); + + return eqz(ctx, out); +} + +NdArrayRef EqualAP::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, + const NdArrayRef& rhs) const { + auto* comm = ctx->getState(); + const auto* lhs_ty = lhs.eltype().as(); + const auto* rhs_ty = rhs.eltype().as(); + + SPU_ENFORCE(lhs_ty->field() == rhs_ty->field()); + const auto field = lhs_ty->field(); + NdArrayRef out(makeType(field), lhs.shape()); + + auto rank = comm->getRank(); + + DISPATCH_ALL_FIELDS(field, "_", [&]() { + using el_t = ring2k_t; + using shr_t = std::array; + + NdArrayView _out(out); + NdArrayView _lhs(lhs); + NdArrayView _rhs(rhs); + + pforeach(0, lhs.numel(), [&](int64_t idx) { + _out[idx][0] = _lhs[idx][0]; + _out[idx][1] = _lhs[idx][1]; + if (rank == 0) _out[idx][1] -= _rhs[idx]; + if (rank == 1) _out[idx][0] -= _rhs[idx]; + }); + return out; + }); + + return eqz(ctx, out); +} + void CommonTypeV::evaluate(KernelEvalContext* ctx) const { const Type& lhs = ctx->getParam(0); const Type& rhs = ctx->getParam(1); diff --git a/libspu/mpc/aby3/conversion.h b/libspu/mpc/aby3/conversion.h index 76fe752a..d6a0b747 100644 --- a/libspu/mpc/aby3/conversion.h +++ b/libspu/mpc/aby3/conversion.h @@ -120,6 +120,26 @@ class MsbA2B : public UnaryKernel { NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in) const override; }; +class EqualAA : public BinaryKernel { + public: + static constexpr char kBindName[] = "equal_aa"; + + Kind kind() const override { return Kind::Dynamic; } + + NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& lhs, + const NdArrayRef& rhs) const override; +}; + +class EqualAP : public BinaryKernel { + public: + static constexpr char kBindName[] = "equal_ap"; + + Kind kind() const override { return Kind::Dynamic; } + + NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& lhs, + const NdArrayRef& rhs) const override; +}; + class CommonTypeV : public Kernel { public: static constexpr char kBindName[] = "common_type_v"; diff --git a/libspu/mpc/aby3/protocol.cc b/libspu/mpc/aby3/protocol.cc index 3b888a28..068fe460 100644 --- a/libspu/mpc/aby3/protocol.cc +++ b/libspu/mpc/aby3/protocol.cc @@ -63,6 +63,8 @@ void regAby3Protocol(SPUContext* ctx, #endif ctx->prot()->regKernel(); + ctx->prot()->regKernel(); + ctx->prot()->regKernel(); ctx->prot()->regKernel(); ctx->prot()->regKernel(); diff --git a/libspu/mpc/aby3/protocol_test.cc b/libspu/mpc/aby3/protocol_test.cc index 79a9bb19..ab65344a 100644 --- a/libspu/mpc/aby3/protocol_test.cc +++ b/libspu/mpc/aby3/protocol_test.cc @@ -14,7 +14,9 @@ #include "libspu/mpc/aby3/protocol.h" +#include "libspu/mpc/ab_api.h" #include "libspu/mpc/ab_api_test.h" +#include "libspu/mpc/api.h" #include "libspu/mpc/api_test.h" namespace spu::mpc::test { diff --git a/libspu/mpc/cheetah/BUILD.bazel b/libspu/mpc/cheetah/BUILD.bazel index b30c92d6..ba05873d 100644 --- a/libspu/mpc/cheetah/BUILD.bazel +++ b/libspu/mpc/cheetah/BUILD.bazel @@ -33,11 +33,17 @@ spu_cc_library( deps = [ "//libspu/mpc/cheetah/arith:cheetah_arith", "//libspu/mpc/cheetah/nonlinear:cheetah_nonlinear", + "//libspu/mpc/cheetah/ot", "//libspu/mpc/cheetah/rlwe:cheetah_rlwe", - "//libspu/mpc/cheetah/yacl_ot:yacl_ferret_ot", ], ) +spu_cc_library( + name = "env", + srcs = ["env.cc"], + hdrs = ["env.h"], +) + spu_cc_library( name = "boolean", srcs = [ diff --git a/libspu/mpc/cheetah/env.cc b/libspu/mpc/cheetah/env.cc new file mode 100644 index 00000000..8d59b782 --- /dev/null +++ b/libspu/mpc/cheetah/env.cc @@ -0,0 +1,43 @@ +// Copyright 2021 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "libspu/mpc/cheetah/env.h" + +#include +#include + +namespace spu::mpc::cheetah { +static bool IsEnvOn(const char *name) { + const char *str = std::getenv(name); + if (str == nullptr) { + return false; + } + + std::string s(str); + // to lower case + std::transform(s.begin(), s.end(), s.begin(), + [](auto c) { return std::tolower(c); }); + return s == "1" or s == "on"; +} + +bool TestEnvFlag(EnvFlag g) { + switch (g) { + case EnvFlag::SPU_CTH_ENABLE_EMP_OT: + return IsEnvOn("SPU_CTH_ENABLE_EMP_OT"); + default: + return false; + } +} + +} // namespace spu::mpc::cheetah diff --git a/libspu/mpc/cheetah/env.h b/libspu/mpc/cheetah/env.h new file mode 100644 index 00000000..c225cc4a --- /dev/null +++ b/libspu/mpc/cheetah/env.h @@ -0,0 +1,25 @@ +// Copyright 2021 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +namespace spu::mpc::cheetah { +enum class EnvFlag { + // use emp/ferret instead of the yacl/ferret + SPU_CTH_ENABLE_EMP_OT, +}; + +bool TestEnvFlag(EnvFlag f); + +} // namespace spu::mpc::cheetah \ No newline at end of file diff --git a/libspu/mpc/cheetah/nonlinear/BUILD.bazel b/libspu/mpc/cheetah/nonlinear/BUILD.bazel index 5776134b..8669cf31 100644 --- a/libspu/mpc/cheetah/nonlinear/BUILD.bazel +++ b/libspu/mpc/cheetah/nonlinear/BUILD.bazel @@ -30,7 +30,7 @@ spu_cc_library( srcs = ["compare_prot.cc"], hdrs = ["compare_prot.h"], deps = [ - "//libspu/mpc/cheetah/yacl_ot:yacl_ferret_ot", + "//libspu/mpc/cheetah/ot", "@yacl//yacl/link", ], ) @@ -40,7 +40,7 @@ spu_cc_library( srcs = ["equal_prot.cc"], hdrs = ["equal_prot.h"], deps = [ - "//libspu/mpc/cheetah/yacl_ot:yacl_ferret_ot", + "//libspu/mpc/cheetah/ot", "@yacl//yacl/link", ], ) diff --git a/libspu/mpc/cheetah/nonlinear/README.md b/libspu/mpc/cheetah/nonlinear/README.md new file mode 100644 index 00000000..87f33d7f --- /dev/null +++ b/libspu/mpc/cheetah/nonlinear/README.md @@ -0,0 +1,18 @@ +# Performance Stats + +| | bit width | Send & Recv (bits) | +| ----------------------------- | -------------- | ------------------ | +| Millionare (radix = 4) | 32 | 348.377 | +| Millionare (radix = 4) | 40 | 428.377 | +| Millionare (radix = 4) | 64 | 690.503 | +| TruncatePr (unknown sign bit) | k=32, fxp = 12 | 369.503 | +| TruncatePr (known sign bit) | k=32, fxp = 12 | 32.252 | +| TruncatePr (unknown sign bit) | k=64, fxp = 12 | 723.754 | +| TruncatePr (known sign bit) | k=64, fxp = 12 | 36.252 | + +* Note: The perfermance stats here are average from at least 2^17 input length. + +* _Millionare_ Protocol implements the tree-based protocol from [CrypTFlow2](https://eprint.iacr.org/2020/1002)). + We set the comparison radix as `4` +* _TruncatePr_ Protocol implements the 1-bit approximated truncate protocol from [Cheetah](https://eprint.iacr.org/2022/207.pdf)). + We also set the comparison radix as `4` by default. diff --git a/libspu/mpc/cheetah/nonlinear/compare_prot.cc b/libspu/mpc/cheetah/nonlinear/compare_prot.cc index 1b613c5b..384d6dab 100644 --- a/libspu/mpc/cheetah/nonlinear/compare_prot.cc +++ b/libspu/mpc/cheetah/nonlinear/compare_prot.cc @@ -18,16 +18,15 @@ #include "yacl/link/link.h" #include "libspu/core/type.h" +#include "libspu/mpc/cheetah/ot/basic_ot_prot.h" +#include "libspu/mpc/cheetah/ot/ot_util.h" #include "libspu/mpc/cheetah/type.h" -#include "libspu/mpc/cheetah/yacl_ot/basic_ot_prot.h" -#include "libspu/mpc/cheetah/yacl_ot/util.h" -#include "libspu/mpc/cheetah/yacl_ot/yacl_ferret.h" #include "libspu/mpc/common/communicator.h" #include "libspu/mpc/utils/ring_ops.h" namespace spu::mpc::cheetah { -CompareProtocol::CompareProtocol(std::shared_ptr base, +CompareProtocol::CompareProtocol(const std::shared_ptr& base, size_t compare_radix) : compare_radix_(compare_radix), basic_ot_prot_(base) { SPU_ENFORCE(base != nullptr); @@ -165,8 +164,13 @@ NdArrayRef CompareProtocol::DoCompute(const NdArrayRef& inp, bool greater_than, return _gt.as(boolean_t); } -std::array CompareProtocol::TraversalANDWithEq( +std::array CompareProtocol::TraversalANDWithEqFullBinaryTree( NdArrayRef cmp, NdArrayRef eq, size_t num_input, size_t num_digits) { + SPU_ENFORCE(num_digits > 0 && absl::has_single_bit(num_digits), + "require num_digits be a 2-power"); + if (num_digits == 1) { + return {cmp, eq}; + } SPU_ENFORCE(cmp.shape().size() == 1, "need 1D array"); SPU_ENFORCE_EQ(cmp.shape(), eq.shape()); SPU_ENFORCE_EQ(cmp.numel(), eq.numel()); @@ -198,8 +202,76 @@ std::array CompareProtocol::TraversalANDWithEq( return {cmp, eq}; } -NdArrayRef CompareProtocol::TraversalAND(NdArrayRef cmp, NdArrayRef eq, - size_t num_input, size_t num_digits) { +std::array CompareProtocol::TraversalANDWithEq( + NdArrayRef cmp, NdArrayRef eq, size_t num_input, size_t num_digits) { + if (absl::has_single_bit(num_digits)) { + return TraversalANDWithEqFullBinaryTree(cmp, eq, num_input, num_digits); + } + + // Split the current tree into two subtrees + size_t current_num_digits = absl::bit_floor(num_digits); + + Shape current_shape({static_cast(current_num_digits * num_input)}); + NdArrayRef current_cmp(cmp.eltype(), current_shape); + NdArrayRef current_eq(eq.eltype(), current_shape); + // Copy from the CMP and EQ bits for the current sub-full-tree + pforeach(0, num_input, [&](int64_t i) { + std::memcpy(¤t_cmp.at(i * current_num_digits), + &cmp.at(i * num_digits), current_num_digits * cmp.elsize()); + std::memcpy(¤t_eq.at(i * current_num_digits), &eq.at(i * num_digits), + current_num_digits * eq.elsize()); + }); + + auto [_cmp, _eq] = TraversalANDWithEqFullBinaryTree( + current_cmp, current_eq, num_input, current_num_digits); + // NOTE(lwj): auto unbox is a C++20 feature + NdArrayRef subtree_cmp = _cmp; + NdArrayRef subtree_eq = _eq; + + // NOTE(lwj): +1 due to the AND on the sub-full-tree + size_t remain_num_digits = num_digits - current_num_digits + 1; + while (remain_num_digits > 1) { + current_num_digits = absl::bit_floor(remain_num_digits); + Shape current_shape({static_cast(current_num_digits * num_input)}); + NdArrayRef current_cmp(cmp.eltype(), current_shape); + NdArrayRef current_eq(eq.eltype(), current_shape); + + pforeach(0, num_input, [&](int64_t i) { + // copy subtree result as the 1st digit + std::memcpy(¤t_cmp.at(i * current_num_digits), &subtree_cmp.at(i), + 1 * cmp.elsize()); + std::memcpy(¤t_eq.at(i * current_num_digits), &subtree_eq.at(i), + 1 * eq.elsize()); + + // copy the remaining digits from the input 'cmp' and 'eq' + std::memcpy(¤t_cmp.at(i * current_num_digits + 1), + &cmp.at((i + 1) * num_digits - remain_num_digits + 1), + (current_num_digits - 1) * cmp.elsize()); + std::memcpy(¤t_eq.at(i * current_num_digits + 1), + &eq.at((i + 1) * num_digits - remain_num_digits + 1), + (current_num_digits - 1) * eq.elsize()); + }); + + // NOTE(lwj): current_num_digits is not a 2-power + auto [_cmp, _eq] = TraversalANDWithEq(current_cmp, current_eq, num_input, + current_num_digits); + subtree_cmp = _cmp; + subtree_eq = _eq; + remain_num_digits = remain_num_digits - current_num_digits + 1; + } + + return {subtree_cmp, subtree_eq}; +} + +NdArrayRef CompareProtocol::TraversalANDFullBinaryTree(NdArrayRef cmp, + NdArrayRef eq, + size_t num_input, + size_t num_digits) { + SPU_ENFORCE(num_digits > 0 && absl::has_single_bit(num_digits), + "require num_digits be a 2-power"); + if (num_digits == 1) { + return cmp; + } // Tree-based traversal ANDs // lt0[0], lt0[1], ..., lt0[M], // lt1[0], lt1[1], ..., lt1[M], @@ -304,6 +376,62 @@ NdArrayRef CompareProtocol::TraversalAND(NdArrayRef cmp, NdArrayRef eq, return cmp; } +NdArrayRef CompareProtocol::TraversalAND(NdArrayRef cmp, NdArrayRef eq, + size_t num_input, size_t num_digits) { + if (absl::has_single_bit(num_digits)) { + return TraversalANDFullBinaryTree(cmp, eq, num_input, num_digits); + } + + // Split the current tree into two subtrees + size_t current_num_digits = absl::bit_floor(num_digits); + + Shape current_shape({static_cast(current_num_digits * num_input)}); + NdArrayRef current_cmp(cmp.eltype(), current_shape); + NdArrayRef current_eq(eq.eltype(), current_shape); + // Copy from the CMP and EQ bits for the current sub-full-tree + pforeach(0, num_input, [&](int64_t i) { + std::memcpy(¤t_cmp.at(i * current_num_digits), + &cmp.at(i * num_digits), current_num_digits * cmp.elsize()); + std::memcpy(¤t_eq.at(i * current_num_digits), &eq.at(i * num_digits), + current_num_digits * eq.elsize()); + }); + + NdArrayRef subtree_cmp = TraversalANDFullBinaryTree( + current_cmp, current_eq, num_input, current_num_digits); + + // NOTE(lwj): +1 due to the AND on the sub-full-tree + size_t remain_num_digits = num_digits - current_num_digits + 1; + while (remain_num_digits > 1) { + current_num_digits = absl::bit_floor(remain_num_digits); + Shape current_shape({static_cast(current_num_digits * num_input)}); + NdArrayRef current_cmp(cmp.eltype(), current_shape); + NdArrayRef current_eq(eq.eltype(), current_shape); + + pforeach(0, num_input, [&](int64_t i) { + // copy subtree result as the 1st digit + std::memcpy(¤t_cmp.at(i * current_num_digits), &subtree_cmp.at(i), + 1 * cmp.elsize()); + // copy the remaining digits from the input 'cmp' + std::memcpy(¤t_cmp.at(i * current_num_digits + 1), + &cmp.at((i + 1) * num_digits - remain_num_digits + 1), + (current_num_digits - 1) * cmp.elsize()); + + // copy the remaining digits from the input 'eq' + // we skip the left-most equal which is unnecessary + std::memcpy(¤t_eq.at(i * current_num_digits + 1), + &eq.at((i + 1) * num_digits - remain_num_digits + 1), + (current_num_digits - 1) * eq.elsize()); + }); + + // NOTE(lwj): current_num_digits is not a 2-power + subtree_cmp = + TraversalAND(current_cmp, current_eq, num_input, current_num_digits); + remain_num_digits = remain_num_digits - current_num_digits + 1; + } + + return subtree_cmp; +} + NdArrayRef CompareProtocol::Compute(const NdArrayRef& inp, bool greater_than, int64_t bitwidth) { int64_t bw = SizeOf(inp.eltype().as()->field()) * 8; diff --git a/libspu/mpc/cheetah/nonlinear/compare_prot.h b/libspu/mpc/cheetah/nonlinear/compare_prot.h index 763be10a..90e1a29c 100644 --- a/libspu/mpc/cheetah/nonlinear/compare_prot.h +++ b/libspu/mpc/cheetah/nonlinear/compare_prot.h @@ -40,7 +40,7 @@ class BasicOTProtocols; class CompareProtocol { public: // REQUIRE 1 <= compare_radix <= 4 - explicit CompareProtocol(std::shared_ptr base, + explicit CompareProtocol(const std::shared_ptr& base, size_t compare_radix = 4); ~CompareProtocol(); @@ -61,10 +61,19 @@ class CompareProtocol { NdArrayRef TraversalAND(NdArrayRef cmp, NdArrayRef eq, size_t num_input, size_t num_digits); + // Require num_digits to be two-power value + NdArrayRef TraversalANDFullBinaryTree(NdArrayRef cmp, NdArrayRef eq, + size_t num_input, size_t num_digits); // Require 1D array std::array TraversalANDWithEq(NdArrayRef cmp, NdArrayRef eq, size_t num_input, size_t num_digits); + + // Require num_digits to be two-power value + std::array TraversalANDWithEqFullBinaryTree(NdArrayRef cmp, + NdArrayRef eq, + size_t num_input, + size_t num_digits); size_t compare_radix_; bool is_sender_{false}; std::shared_ptr basic_ot_prot_; diff --git a/libspu/mpc/cheetah/nonlinear/compare_prot_test.cc b/libspu/mpc/cheetah/nonlinear/compare_prot_test.cc index 7db66e8d..06f7dd95 100644 --- a/libspu/mpc/cheetah/nonlinear/compare_prot_test.cc +++ b/libspu/mpc/cheetah/nonlinear/compare_prot_test.cc @@ -18,8 +18,8 @@ #include "gtest/gtest.h" +#include "libspu/mpc/cheetah/ot/basic_ot_prot.h" #include "libspu/mpc/cheetah/type.h" -#include "libspu/mpc/cheetah/yacl_ot/basic_ot_prot.h" #include "libspu/mpc/utils/ring_ops.h" #include "libspu/mpc/utils/simulate.h" @@ -40,7 +40,7 @@ INSTANTIATE_TEST_SUITE_P( (int)std::get<2>(p.param), (int)std::get<1>(p.param)); }); -TEST_P(CompareProtTest, Basic) { +TEST_P(CompareProtTest, Compare) { size_t kWorldSize = 2; Shape shape = {13, 2, 3}; FieldType field = std::get<0>(GetParam()); @@ -87,17 +87,17 @@ TEST_P(CompareProtTest, Basic) { }); } -TEST_P(CompareProtTest, SpecifiedBitwidth) { +TEST_P(CompareProtTest, CompareBitWidth) { size_t kWorldSize = 2; FieldType field = std::get<0>(GetParam()); size_t radix = std::get<2>(GetParam()); bool greater_than = std::get<1>(GetParam()); - int64_t bw = 16; + int64_t bw = std::min(32, SizeOf(field) * 8); NdArrayRef inp[2]; - int64_t n = 1 << 18; - inp[0] = ring_rand(field, {n * 2}); - inp[1] = ring_rand(field, {n * 2}); + int64_t n = 100; + inp[0] = ring_rand(field, {n, 2}); + inp[1] = ring_rand(field, {n, 2}); DISPATCH_ALL_FIELDS(field, "", [&]() { ring2k_t mask = (static_cast(1) << bw) - 1; @@ -212,4 +212,76 @@ TEST_P(CompareProtTest, WithEq) { }); } +TEST_P(CompareProtTest, WithEqBitWidth) { + size_t kWorldSize = 2; + Shape shape = {10, 10, 10}; + FieldType field = std::get<0>(GetParam()); + size_t radix = std::get<2>(GetParam()); + bool greater_than = std::get<1>(GetParam()); + + int64_t bw = std::min(32, SizeOf(field) * 8); + + NdArrayRef inp[2]; + int64_t n = 1 << 10; + inp[0] = ring_rand(field, {n, 2}); + inp[1] = ring_rand(field, {n, 2}); + + DISPATCH_ALL_FIELDS(field, "", [&]() { + ring2k_t mask = (static_cast(1) << bw) - 1; + auto xinp = NdArrayView(inp[0]); + xinp[0] = 1; + xinp[1] = 10; + xinp[2] = 100; + pforeach(0, inp[0].numel(), [&](int64_t i) { xinp[i] &= mask; }); + + xinp = NdArrayView(inp[1]); + xinp[0] = 1; + xinp[1] = 9; + xinp[2] = 1000; + pforeach(0, inp[0].numel(), [&](int64_t i) { xinp[i] &= mask; }); + }); + + NdArrayRef cmp_oup[2]; + NdArrayRef eq_oup[2]; + utils::simulate(kWorldSize, [&](std::shared_ptr ctx) { + auto conn = std::make_shared(ctx); + int rank = ctx->Rank(); + auto base = std::make_shared(conn); + + [[maybe_unused]] auto b0 = ctx->GetStats()->sent_bytes.load(); + [[maybe_unused]] auto s0 = ctx->GetStats()->sent_actions.load(); + + CompareProtocol comp_prot(base, radix); + auto [_c, _e] = comp_prot.ComputeWithEq(inp[rank], greater_than, bw); + + [[maybe_unused]] auto b1 = ctx->GetStats()->sent_bytes.load(); + [[maybe_unused]] auto s1 = ctx->GetStats()->sent_actions.load(); + + SPDLOG_DEBUG( + "CompareWithEq {} bits {} elements sent {} bytes, {} bits each #sent " + "{}", + bw, inp[0].numel(), (b1 - b0), (b1 - b0) * 8. / inp[0].numel(), + (s1 - s0)); + + cmp_oup[rank] = _c; + eq_oup[rank] = _e; + }); + + DISPATCH_ALL_FIELDS(field, "", [&]() { + auto xout0 = NdArrayView(cmp_oup[0]); + auto xout1 = NdArrayView(cmp_oup[1]); + auto xeq0 = NdArrayView(eq_oup[0]); + auto xeq1 = NdArrayView(eq_oup[1]); + auto xinp0 = NdArrayView(inp[0]); + auto xinp1 = NdArrayView(inp[1]); + + for (int64_t i = 0; i < shape.numel(); ++i) { + bool expected = greater_than ? xinp0[i] > xinp1[i] : xinp0[i] < xinp1[i]; + bool got_cmp = xout0[i] ^ xout1[i]; + bool got_eq = xeq0[i] ^ xeq1[i]; + EXPECT_EQ(expected, got_cmp); + EXPECT_EQ((xinp0[i] == xinp1[i]), got_eq); + } + }); +} } // namespace spu::mpc::cheetah diff --git a/libspu/mpc/cheetah/nonlinear/equal_prot.cc b/libspu/mpc/cheetah/nonlinear/equal_prot.cc index 4aa4a73e..2e393e00 100644 --- a/libspu/mpc/cheetah/nonlinear/equal_prot.cc +++ b/libspu/mpc/cheetah/nonlinear/equal_prot.cc @@ -18,16 +18,15 @@ #include "yacl/link/link.h" #include "libspu/core/type.h" +#include "libspu/mpc/cheetah/ot/basic_ot_prot.h" +#include "libspu/mpc/cheetah/ot/ot_util.h" #include "libspu/mpc/cheetah/type.h" -#include "libspu/mpc/cheetah/yacl_ot/basic_ot_prot.h" -#include "libspu/mpc/cheetah/yacl_ot/util.h" -#include "libspu/mpc/cheetah/yacl_ot/yacl_ferret.h" #include "libspu/mpc/common/communicator.h" #include "libspu/mpc/utils/ring_ops.h" namespace spu::mpc::cheetah { -EqualProtocol::EqualProtocol(std::shared_ptr base, +EqualProtocol::EqualProtocol(const std::shared_ptr& base, size_t compare_radix) : compare_radix_(compare_radix), basic_ot_prot_(base) { SPU_ENFORCE(base != nullptr); diff --git a/libspu/mpc/cheetah/nonlinear/equal_prot.h b/libspu/mpc/cheetah/nonlinear/equal_prot.h index 9dd93fd3..78de9559 100644 --- a/libspu/mpc/cheetah/nonlinear/equal_prot.h +++ b/libspu/mpc/cheetah/nonlinear/equal_prot.h @@ -38,7 +38,7 @@ class BasicOTProtocols; class EqualProtocol { public: // REQUIRE 1 <= compare_radix <= 8. - explicit EqualProtocol(std::shared_ptr base, + explicit EqualProtocol(const std::shared_ptr& base, size_t compare_radix = 4); ~EqualProtocol(); diff --git a/libspu/mpc/cheetah/nonlinear/equal_prot_test.cc b/libspu/mpc/cheetah/nonlinear/equal_prot_test.cc index 525613c2..13dcbc6d 100644 --- a/libspu/mpc/cheetah/nonlinear/equal_prot_test.cc +++ b/libspu/mpc/cheetah/nonlinear/equal_prot_test.cc @@ -18,9 +18,8 @@ #include "gtest/gtest.h" -#include "libspu/core/xt_helper.h" +#include "libspu/mpc/cheetah/ot/basic_ot_prot.h" #include "libspu/mpc/cheetah/type.h" -#include "libspu/mpc/cheetah/yacl_ot/basic_ot_prot.h" #include "libspu/mpc/utils/ring_ops.h" #include "libspu/mpc/utils/simulate.h" @@ -47,9 +46,9 @@ TEST_P(EqualProtTest, Basic) { inp[1] = ring_rand(field, shape); DISPATCH_ALL_FIELDS(field, "", [&]() { - auto xinp0 = xt_mutable_adapt(inp[0]); - auto xinp1 = xt_mutable_adapt(inp[1]); - std::copy_n(xinp1.data(), 5, xinp0.data()); + auto xinp0 = NdArrayView(inp[0]); + auto xinp1 = NdArrayView(inp[1]); + std::copy_n(&xinp1[0], 5, &xinp0[0]); }); NdArrayRef eq_oup[2]; diff --git a/libspu/mpc/cheetah/nonlinear/truncate_prot.cc b/libspu/mpc/cheetah/nonlinear/truncate_prot.cc index 6d7d6537..3edb5117 100644 --- a/libspu/mpc/cheetah/nonlinear/truncate_prot.cc +++ b/libspu/mpc/cheetah/nonlinear/truncate_prot.cc @@ -15,14 +15,15 @@ #include "libspu/core/type.h" #include "libspu/mpc/cheetah/nonlinear/compare_prot.h" +#include "libspu/mpc/cheetah/ot/basic_ot_prot.h" +#include "libspu/mpc/cheetah/ot/ot_util.h" #include "libspu/mpc/cheetah/type.h" -#include "libspu/mpc/cheetah/yacl_ot/basic_ot_prot.h" -#include "libspu/mpc/cheetah/yacl_ot/util.h" #include "libspu/mpc/utils/ring_ops.h" namespace spu::mpc::cheetah { -TruncateProtocol::TruncateProtocol(std::shared_ptr base) +TruncateProtocol::TruncateProtocol( + const std::shared_ptr& base) : basic_ot_prot_(base) { SPU_ENFORCE(base != nullptr); } diff --git a/libspu/mpc/cheetah/nonlinear/truncate_prot.h b/libspu/mpc/cheetah/nonlinear/truncate_prot.h index b2487d77..8b56996f 100644 --- a/libspu/mpc/cheetah/nonlinear/truncate_prot.h +++ b/libspu/mpc/cheetah/nonlinear/truncate_prot.h @@ -47,7 +47,7 @@ class TruncateProtocol { size_t shift_bits = 0; }; - explicit TruncateProtocol(std::shared_ptr base); + explicit TruncateProtocol(const std::shared_ptr &base); ~TruncateProtocol(); @@ -62,7 +62,7 @@ class TruncateProtocol { // w = msbA & msbB NdArrayRef MSB1ToWrap(const NdArrayRef &inp, size_t shift_bits); - std::shared_ptr basic_ot_prot_{nullptr}; + std::shared_ptr basic_ot_prot_ = nullptr; }; -} // namespace spu::mpc::cheetah \ No newline at end of file +} // namespace spu::mpc::cheetah diff --git a/libspu/mpc/cheetah/nonlinear/truncate_prot_test.cc b/libspu/mpc/cheetah/nonlinear/truncate_prot_test.cc index b187e978..2a524ce8 100644 --- a/libspu/mpc/cheetah/nonlinear/truncate_prot_test.cc +++ b/libspu/mpc/cheetah/nonlinear/truncate_prot_test.cc @@ -18,8 +18,8 @@ #include "gtest/gtest.h" +#include "libspu/mpc/cheetah/ot/basic_ot_prot.h" #include "libspu/mpc/cheetah/type.h" -#include "libspu/mpc/cheetah/yacl_ot/basic_ot_prot.h" #include "libspu/mpc/utils/ring_ops.h" #include "libspu/mpc/utils/simulate.h" @@ -50,8 +50,8 @@ bool SignBit(T x) { TEST_P(TruncateProtTest, Basic) { size_t kWorldSize = 2; - int64_t n = 1024; - size_t shift = 13; + int64_t n = 100; + size_t shift = 12; FieldType field = std::get<0>(GetParam()); bool signed_arith = std::get<1>(GetParam()); std::string msb = std::get<2>(GetParam()); @@ -93,8 +93,21 @@ TEST_P(TruncateProtTest, Basic) { meta.sign = sign; meta.signed_arith = signed_arith; meta.shift_bits = shift; + meta.use_heuristic = false; + + [[maybe_unused]] auto b0 = ctx->GetStats()->sent_bytes.load(); + [[maybe_unused]] auto s0 = ctx->GetStats()->sent_actions.load(); + oup[rank] = trunc_prot.Compute(inp[rank], meta); + + [[maybe_unused]] auto b1 = ctx->GetStats()->sent_bytes.load(); + [[maybe_unused]] auto s1 = ctx->GetStats()->sent_actions.load(); + + SPDLOG_DEBUG("Truncate {} bits share by {} bits {} bits each #sent {}", + SizeOf(field) * 8, meta.shift_bits, + (b1 - b0) * 8. / inp[0].numel(), (s1 - s0)); }); + EXPECT_EQ(oup[0].shape(), oup[1].shape()); DISPATCH_ALL_FIELDS(field, "", [&]() { @@ -198,9 +211,6 @@ TEST_P(TruncateProtTest, Heuristic) { } } }); - - // printf("0 %d %f, +1 %d %f -1 %d %f\n", count_zero, count_zero * 1. / n, - // count_pos, count_pos * 1. / n, count_neg, count_neg * 1. / n); } } // namespace spu::mpc::cheetah diff --git a/libspu/mpc/cheetah/ot/BUILD.bazel b/libspu/mpc/cheetah/ot/BUILD.bazel index 2f6e94c8..e841eb41 100644 --- a/libspu/mpc/cheetah/ot/BUILD.bazel +++ b/libspu/mpc/cheetah/ot/BUILD.bazel @@ -18,36 +18,36 @@ load("@yacl//bazel:yacl.bzl", "AES_COPT_FLAGS") package(default_visibility = ["//visibility:public"]) spu_cc_library( - name = "cheetah_ot", - srcs = [ - "basic_ot_prot.cc", - "ferret.cc", - "util.cc", - ], - hdrs = [ - "basic_ot_prot.h", - "ferret.h", - "mitccrh_exp.h", - "util.h", - ], - copts = AES_COPT_FLAGS + ["-Wno-ignored-attributes"], + name = "ferret_ot_interface", + hdrs = ["ferret_ot_interface.h"], +) + +spu_cc_library( + name = "ot_util", + srcs = ["ot_util.cc"], + hdrs = ["ot_util.h"], deps = [ - "//libspu/core:xt_helper", - "//libspu/mpc/cheetah:type", + "//libspu/core:ndarray_ref", + "//libspu/core:prelude", "//libspu/mpc/common:communicator", - "@com_github_emptoolkit_emp_ot//:emp-ot", - "@com_github_emptoolkit_emp_tool//:emp-tool", + "@com_google_absl//absl/types:span", "@yacl//yacl/base:int128", - "@yacl//yacl/link", ], ) -spu_cc_test( - name = "ferret_test", - srcs = ["ferret_test.cc"], +spu_cc_library( + name = "ot", + srcs = ["basic_ot_prot.cc"], + hdrs = ["basic_ot_prot.h"], deps = [ - ":cheetah_ot", - "//libspu/mpc/utils:simulate", + ":ot_util", + "//libspu/mpc/cheetah:env", + "//libspu/mpc/cheetah:type", + "//libspu/mpc/cheetah/ot/emp:ferret", + "//libspu/mpc/cheetah/ot/yacl:ferret", + "//libspu/mpc/common:communicator", + "@yacl//yacl/base:int128", + "@yacl//yacl/link", ], ) @@ -59,15 +59,16 @@ spu_cc_test( "exclusive-if-local", ], deps = [ - ":cheetah_ot", + ":ot", "//libspu/mpc/utils:simulate", ], ) spu_cc_test( - name = "util_test", - srcs = ["util_test.cc"], + name = "ot_util_test", + srcs = ["ot_util_test.cc"], deps = [ - ":cheetah_ot", + ":ot_util", + "//libspu/mpc/utils:ring_ops", ], ) diff --git a/libspu/mpc/cheetah/ot/basic_ot_prot.cc b/libspu/mpc/cheetah/ot/basic_ot_prot.cc index 293ebac0..2d9115bf 100644 --- a/libspu/mpc/cheetah/ot/basic_ot_prot.cc +++ b/libspu/mpc/cheetah/ot/basic_ot_prot.cc @@ -14,7 +14,14 @@ #include "libspu/mpc/cheetah/ot/basic_ot_prot.h" -#include "libspu/mpc/cheetah/ot/util.h" +#include + +#include "ot_util.h" + +#include "libspu/mpc/cheetah/env.h" +#include "libspu/mpc/cheetah/ot/emp/ferret.h" +#include "libspu/mpc/cheetah/ot/ot_util.h" +#include "libspu/mpc/cheetah/ot/yacl/ferret.h" #include "libspu/mpc/cheetah/type.h" #include "libspu/mpc/common/communicator.h" #include "libspu/mpc/utils/ring_ops.h" @@ -24,12 +31,24 @@ namespace spu::mpc::cheetah { BasicOTProtocols::BasicOTProtocols(std::shared_ptr conn) : conn_(std::move(conn)) { SPU_ENFORCE(conn_ != nullptr); - if (conn_->getRank() == 0) { - ferret_sender_ = std::make_shared(conn_, true); - ferret_receiver_ = std::make_shared(conn_, false); + if (TestEnvFlag(EnvFlag::SPU_CTH_ENABLE_EMP_OT)) { + using Ot = EmpFerretOt; + if (conn_->getRank() == 0) { + ferret_sender_ = std::make_shared(conn_, true); + ferret_receiver_ = std::make_shared(conn_, false); + } else { + ferret_receiver_ = std::make_shared(conn_, false); + ferret_sender_ = std::make_shared(conn_, true); + } } else { - ferret_receiver_ = std::make_shared(conn_, false); - ferret_sender_ = std::make_shared(conn_, true); + using Ot = YaclFerretOt; + if (conn_->getRank() == 0) { + ferret_sender_ = std::make_shared(conn_, true); + ferret_receiver_ = std::make_shared(conn_, false); + } else { + ferret_receiver_ = std::make_shared(conn_, false); + ferret_sender_ = std::make_shared(conn_, true); + } } } @@ -81,9 +100,7 @@ NdArrayRef BasicOTProtocols::PackedB2A(const NdArrayRef &inp) { auto rand = convert_from_bits_form(rand_bits); // open c = x ^ r - // FIXME(juhou): Actually, we only want to exchange the low-end bits. - auto opened = - conn_->allReduce(ReduceOp::XOR, ring_xor(inp, rand), "B2AFull_open"); + auto opened = OpenShare(ring_xor(inp, rand), ReduceOp::XOR, nbits, conn_); // compute c + (1 - 2*c)* NdArrayRef oup = ring_zeros(field, inp.shape()); @@ -167,7 +184,6 @@ NdArrayRef BasicOTProtocols::SingleB2A(const NdArrayRef &inp, int bit_width) { // Random bit r \in {0, 1} and return as AShr NdArrayRef BasicOTProtocols::RandBits(FieldType filed, const Shape &shape) { - // TODO(juhou): profile ring_randbit performance auto r = ring_randbit(filed, shape).as(makeType(filed, 1)); return SingleB2A(r); } @@ -188,61 +204,17 @@ NdArrayRef BasicOTProtocols::BitwiseAnd(const NdArrayRef &lhs, auto field = lhs.eltype().as()->field(); const auto *shareType = lhs.eltype().as(); - size_t numel = lhs.numel(); auto [a, b, c] = AndTriple(field, lhs.shape(), shareType->nbits()); - NdArrayRef x_a = ring_xor(lhs, a); - NdArrayRef y_b = ring_xor(rhs, b); - size_t pack_load = 8 * SizeOf(field) / shareType->nbits(); - - if (pack_load == 1) { - // Open x^a, y^b - auto res = vmap({x_a, y_b}, [&](const NdArrayRef &s) { - return conn_->allReduce(ReduceOp::XOR, s, "BitwiseAnd"); - }); - x_a = std::move(res[0]); - y_b = std::move(res[1]); - } else { - // Open x^a, y^b - // pack multiple nbits() into single field element before sending through - // network - SPU_ENFORCE(x_a.isCompact() && y_b.isCompact()); - int64_t packed_sze = CeilDiv(numel, pack_load); - - NdArrayRef packed_xa(x_a.eltype(), {packed_sze}); - NdArrayRef packed_yb(y_b.eltype(), {packed_sze}); - - DISPATCH_ALL_FIELDS(field, "_", [&]() { - auto xa_wrap = absl::MakeSpan(&x_a.at(0), numel); - auto yb_wrap = absl::MakeSpan(&y_b.at(0), numel); - auto packed_xa_wrap = - absl::MakeSpan(&packed_xa.at(0), packed_sze); - auto packed_yb_wrap = - absl::MakeSpan(&packed_yb.at(0), packed_sze); - - int64_t used = - ZipArray(xa_wrap, shareType->nbits(), packed_xa_wrap); - (void)ZipArray(yb_wrap, shareType->nbits(), packed_yb_wrap); - SPU_ENFORCE_EQ(used, packed_sze); - - // open x^a, y^b - auto res = vmap({packed_xa, packed_yb}, [&](const NdArrayRef &s) { - return conn_->allReduce(ReduceOp::XOR, s, "BitwiseAnd"); - }); - - packed_xa = std::move(res[0]); - packed_yb = std::move(res[1]); - packed_xa_wrap = absl::MakeSpan(&packed_xa.at(0), packed_sze); - packed_yb_wrap = absl::MakeSpan(&packed_yb.at(0), packed_sze); - UnzipArray(packed_xa_wrap, shareType->nbits(), xa_wrap); - UnzipArray(packed_yb_wrap, shareType->nbits(), yb_wrap); - }); - } + // open x^a, y^b + int nbits = shareType->nbits(); + auto xa = OpenShare(ring_xor(lhs, a), ReduceOp::XOR, nbits, conn_); + auto yb = OpenShare(ring_xor(rhs, b), ReduceOp::XOR, nbits, conn_); // Zi = Ci ^ ((X ^ A) & Bi) ^ ((Y ^ B) & Ai) ^ <(X ^ A) & (Y ^ B)> - auto z = ring_xor(ring_xor(ring_and(x_a, b), ring_and(y_b, a)), c); + auto z = ring_xor(ring_xor(ring_and(xa, b), ring_and(yb, a)), c); if (conn_->getRank() == 0) { - ring_xor_(z, ring_and(x_a, y_b)); + ring_xor_(z, ring_and(xa, yb)); } return z.as(lhs.eltype()); @@ -261,14 +233,10 @@ std::array BasicOTProtocols::CorrelatedBitwiseAnd( auto [a, b0, c0, b1, c1] = CorrelatedAndTriple(field, lhs.shape()); // open x^a, y^b0, y1^b1 - auto res = - vmap({ring_xor(lhs, a), ring_xor(rhs0, b0), ring_xor(rhs1, b1)}, - [&](const NdArrayRef &s) { - return conn_->allReduce(ReduceOp::XOR, s, "CorrelatedBitwiseAnd"); - }); - auto xa = std::move(res[0]); - auto y0b0 = std::move(res[1]); - auto y1b1 = std::move(res[2]); + int nbits = shareType->nbits(); + auto xa = OpenShare(ring_xor(lhs, a), ReduceOp::XOR, nbits, conn_); + auto y0b0 = OpenShare(ring_xor(rhs0, b0), ReduceOp::XOR, nbits, conn_); + auto y1b1 = OpenShare(ring_xor(rhs1, b1), ReduceOp::XOR, nbits, conn_); // Zi = Ci ^ ((X ^ A) & Bi) ^ ((Y ^ B) & Ai) ^ <(X ^ A) & (Y ^ B)> auto z0 = ring_xor(ring_xor(ring_and(xa, b0), ring_and(y0b0, a)), c0); diff --git a/libspu/mpc/cheetah/ot/basic_ot_prot.h b/libspu/mpc/cheetah/ot/basic_ot_prot.h index 7c09b974..3051797a 100644 --- a/libspu/mpc/cheetah/ot/basic_ot_prot.h +++ b/libspu/mpc/cheetah/ot/basic_ot_prot.h @@ -15,7 +15,8 @@ #pragma once #include "libspu/core/ndarray_ref.h" -#include "libspu/mpc/cheetah/ot/ferret.h" +#include "libspu/mpc/cheetah/ot/ferret_ot_interface.h" +#include "libspu/mpc/common/communicator.h" namespace spu::mpc::cheetah { @@ -40,14 +41,10 @@ class BasicOTProtocols { // Create `numel` of AND-triple. Each element contains `k` bits // 1 <= k <= field size - // std::array AndTriple(FieldType field, size_t numel, size_t k); - std::array AndTriple(FieldType field, const Shape &shape, size_t k); // [a, b, b', c, c'] such that c = a*b and c' = a*b' for the same a - // std::array CorrelatedAndTriple(FieldType field, size_t numel); - std::array CorrelatedAndTriple(FieldType field, const Shape &shape); @@ -60,9 +57,11 @@ class BasicOTProtocols { const NdArrayRef &rhs0, const NdArrayRef &rhs1); - std::shared_ptr GetSenderCOT() { return ferret_sender_; } + std::shared_ptr GetSenderCOT() { return ferret_sender_; } - std::shared_ptr GetReceiverCOT() { return ferret_receiver_; } + std::shared_ptr GetReceiverCOT() { + return ferret_receiver_; + } void Flush(); @@ -73,8 +72,8 @@ class BasicOTProtocols { private: std::shared_ptr conn_; - std::shared_ptr ferret_sender_; - std::shared_ptr ferret_receiver_; + std::shared_ptr ferret_sender_; + std::shared_ptr ferret_receiver_; }; } // namespace spu::mpc::cheetah diff --git a/libspu/mpc/cheetah/ot/basic_ot_prot_test.cc b/libspu/mpc/cheetah/ot/basic_ot_prot_test.cc index e112d5d3..a5128af9 100644 --- a/libspu/mpc/cheetah/ot/basic_ot_prot_test.cc +++ b/libspu/mpc/cheetah/ot/basic_ot_prot_test.cc @@ -18,12 +18,24 @@ #include "gtest/gtest.h" -#include "libspu/core/xt_helper.h" #include "libspu/mpc/cheetah/type.h" #include "libspu/mpc/utils/ring_ops.h" #include "libspu/mpc/utils/simulate.h" namespace spu::mpc::cheetah::test { +template +T makeBitsMask(size_t nbits) { + size_t max = sizeof(T) * 8; + if (nbits == 0) { + nbits = max; + } + SPU_ENFORCE(nbits <= max); + T mask = static_cast(-1); + if (nbits < max) { + mask = (static_cast(1) << nbits) - 1; + } + return mask; +} class BasicOTProtTest : public ::testing::TestWithParam { void SetUp() override {} @@ -51,12 +63,12 @@ TEST_P(BasicOTProtTest, SingleB2A) { auto mask = static_cast(-1); if (nbits > 0) { mask = (static_cast(1) << packed_nbits) - 1; - auto xb0 = xt_mutable_adapt(bshr0); - auto xb1 = xt_mutable_adapt(bshr1); - std::transform(xb0.data(), xb0.data() + xb0.size(), xb0.data(), - [&](auto x) { return x & mask; }); - std::transform(xb1.data(), xb1.data() + xb1.size(), xb1.data(), - [&](auto x) { return x & mask; }); + NdArrayView xb0(bshr0); + NdArrayView xb1(bshr1); + pforeach(0, xb0.numel(), [&](int64_t i) { + xb0[i] &= mask; + xb1[i] &= mask; + }); } }); @@ -76,10 +88,10 @@ TEST_P(BasicOTProtTest, SingleB2A) { EXPECT_EQ(shape, ashr0.shape()); DISPATCH_ALL_FIELDS(field, "", [&]() { - auto b0 = xt_adapt(bshr0); - auto b1 = xt_adapt(bshr1); - auto a0 = xt_adapt(ashr0); - auto a1 = xt_adapt(ashr1); + NdArrayView b0(bshr0); + NdArrayView b1(bshr1); + NdArrayView a0(ashr0); + NdArrayView a1(ashr1); auto mask = static_cast(-1); if (nbits > 0) { mask = (static_cast(1) << packed_nbits) - 1; @@ -107,12 +119,12 @@ TEST_P(BasicOTProtTest, PackedB2A) { auto mask = static_cast(-1); if (nbits > 0) { mask = (static_cast(1) << packed_nbits) - 1; - auto xb0 = xt_mutable_adapt(bshr0); - auto xb1 = xt_mutable_adapt(bshr1); - std::transform(xb0.data(), xb0.data() + xb0.size(), xb0.data(), - [&](auto x) { return x & mask; }); - std::transform(xb1.data(), xb1.data() + xb1.size(), xb1.data(), - [&](auto x) { return x & mask; }); + NdArrayView xb0(bshr0); + NdArrayView xb1(bshr1); + pforeach(0, xb0.numel(), [&](int64_t i) { + xb0[i] &= mask; + xb1[i] &= mask; + }); } }); @@ -131,10 +143,10 @@ TEST_P(BasicOTProtTest, PackedB2A) { EXPECT_EQ(ashr0.shape(), shape); DISPATCH_ALL_FIELDS(field, "", [&]() { - auto b0 = xt_adapt(bshr0); - auto b1 = xt_adapt(bshr1); - auto a0 = xt_adapt(ashr0); - auto a1 = xt_adapt(ashr1); + NdArrayView b0(bshr0); + NdArrayView b1(bshr1); + NdArrayView a0(ashr0); + NdArrayView a1(ashr1); auto mask = static_cast(-1); if (nbits > 0) { @@ -165,12 +177,12 @@ TEST_P(BasicOTProtTest, PackedB2AFull) { auto mask = static_cast(-1); if (nbits > 0) { mask = (static_cast(1) << packed_nbits) - 1; - auto xb0 = xt_mutable_adapt(bshr0); - auto xb1 = xt_mutable_adapt(bshr1); - std::transform(xb0.data(), xb0.data() + xb0.size(), xb0.data(), - [&](auto x) { return x & mask; }); - std::transform(xb1.data(), xb1.data() + xb1.size(), xb1.data(), - [&](auto x) { return x & mask; }); + auto xb0 = NdArrayView(bshr0); + auto xb1 = NdArrayView(bshr1); + pforeach(0, xb0.numel(), [&](int64_t i) { + xb0[i] &= mask; + xb1[i] &= mask; + }); } }); @@ -190,10 +202,10 @@ TEST_P(BasicOTProtTest, PackedB2AFull) { EXPECT_EQ(ashr0.shape(), shape); DISPATCH_ALL_FIELDS(field, "", [&]() { - auto b0 = xt_adapt(bshr0); - auto b1 = xt_adapt(bshr1); - auto a0 = xt_adapt(ashr0); - auto a1 = xt_adapt(ashr1); + NdArrayView b0(bshr0); + NdArrayView b1(bshr1); + NdArrayView a0(ashr0); + NdArrayView a1(ashr1); auto mask = static_cast(-1); if (nbits > 0) { mask = (static_cast(1) << packed_nbits) - 1; @@ -228,12 +240,12 @@ TEST_P(BasicOTProtTest, AndTripleSparse) { DISPATCH_ALL_FIELDS(field, "", [&]() { ring2k_t max = static_cast(1) << target_nbits; - auto a0 = xt_adapt(triple[0][0]); - auto b0 = xt_adapt(triple[0][1]); - auto c0 = xt_adapt(triple[0][2]); - auto a1 = xt_adapt(triple[1][0]); - auto b1 = xt_adapt(triple[1][1]); - auto c1 = xt_adapt(triple[1][2]); + NdArrayView a0(triple[0][0]); + NdArrayView b0(triple[0][1]); + NdArrayView c0(triple[0][2]); + NdArrayView a1(triple[1][0]); + NdArrayView b1(triple[1][1]); + NdArrayView c1(triple[1][2]); for (int64_t i = 0; i < shape.numel(); ++i) { EXPECT_TRUE(a0[i] < max && a1[i] < max); @@ -248,6 +260,50 @@ TEST_P(BasicOTProtTest, AndTripleSparse) { } } +TEST_P(BasicOTProtTest, BitwiseAnd) { + size_t kWorldSize = 2; + Shape shape = {55}; + FieldType field = GetParam(); + int bw = SizeOf(field) * 8; + auto boolean_t = makeType(field, bw); + + NdArrayRef lhs[2]; + NdArrayRef rhs[2]; + NdArrayRef out[2]; + + for (int i : {0, 1}) { + lhs[i] = ring_rand(field, shape).as(boolean_t); + rhs[i] = ring_rand(field, shape).as(boolean_t); + DISPATCH_ALL_FIELDS(field, "mask", [&]() { + ring2k_t mask = makeBitsMask(bw); + NdArrayView L(lhs[i]); + NdArrayView R(rhs[i]); + + pforeach(0, shape.numel(), [&](int64_t j) { L[j] &= mask; }); + pforeach(0, shape.numel(), [&](int64_t j) { R[j] &= mask; }); + }); + } + + utils::simulate(kWorldSize, [&](std::shared_ptr ctx) { + auto conn = std::make_shared(ctx); + BasicOTProtocols ot_prot(conn); + int r = ctx->Rank(); + out[r] = ot_prot.BitwiseAnd(lhs[r].clone(), rhs[r].clone()); + }); + + auto expected = ring_and(ring_xor(lhs[0], lhs[1]), ring_xor(rhs[0], rhs[1])); + auto got = ring_xor(out[0], out[1]); + + DISPATCH_ALL_FIELDS(field, "", [&]() { + NdArrayView e(expected); + NdArrayView g(got); + + for (int64_t i = 0; i < shape.numel(); ++i) { + ASSERT_EQ(e[i], g[i]); + } + }); +} + TEST_P(BasicOTProtTest, AndTripleFull) { size_t kWorldSize = 2; Shape shape = {55, 11}; @@ -264,14 +320,14 @@ TEST_P(BasicOTProtTest, AndTripleFull) { }); DISPATCH_ALL_FIELDS(field, "", [&]() { - auto a0 = xt_adapt(packed_triple[0][0]); - auto b0 = xt_adapt(packed_triple[0][1]); - auto c0 = xt_adapt(packed_triple[0][2]); - auto a1 = xt_adapt(packed_triple[1][0]); - auto b1 = xt_adapt(packed_triple[1][1]); - auto c1 = xt_adapt(packed_triple[1][2]); - - size_t nn = a0.size(); + NdArrayView a0(packed_triple[0][0]); + NdArrayView b0(packed_triple[0][1]); + NdArrayView c0(packed_triple[0][2]); + NdArrayView a1(packed_triple[1][0]); + NdArrayView b1(packed_triple[1][1]); + NdArrayView c1(packed_triple[1][2]); + + size_t nn = a0.numel(); EXPECT_TRUE(nn * 8 * SizeOf(field) >= (size_t)shape.numel()); for (size_t i = 0; i < nn; ++i) { @@ -297,12 +353,12 @@ TEST_P(BasicOTProtTest, Multiplexer) { DISPATCH_ALL_FIELDS(field, "", [&]() { auto mask = static_cast(1); - auto xb0 = xt_mutable_adapt(bshr0); - auto xb1 = xt_mutable_adapt(bshr1); - std::transform(xb0.data(), xb0.data() + xb0.size(), xb0.data(), - [&](auto x) { return x & mask; }); - std::transform(xb1.data(), xb1.data() + xb1.size(), xb1.data(), - [&](auto x) { return x & mask; }); + NdArrayView xb0(bshr0); + NdArrayView xb1(bshr1); + pforeach(0, xb0.numel(), [&](int64_t i) { + xb0[i] &= mask; + xb1[i] &= mask; + }); }); NdArrayRef computed[2]; @@ -320,12 +376,12 @@ TEST_P(BasicOTProtTest, Multiplexer) { EXPECT_EQ(computed[0].shape(), shape); DISPATCH_ALL_FIELDS(field, "", [&]() { - auto a0 = xt_adapt(ashr0); - auto a1 = xt_adapt(ashr1); - auto b0 = xt_adapt(bshr0); - auto b1 = xt_adapt(bshr1); - auto c0 = xt_adapt(computed[0]); - auto c1 = xt_adapt(computed[1]); + NdArrayView a0(ashr0); + NdArrayView a1(ashr1); + NdArrayView b0(bshr0); + NdArrayView b1(bshr1); + NdArrayView c0(computed[0]); + NdArrayView c1(computed[1]); for (int64_t i = 0; i < shape.numel(); ++i) { ring2k_t msg = (a0[i] + a1[i]); diff --git a/libspu/mpc/cheetah/ot/emp/BUILD.bazel b/libspu/mpc/cheetah/ot/emp/BUILD.bazel new file mode 100644 index 00000000..11c289d1 --- /dev/null +++ b/libspu/mpc/cheetah/ot/emp/BUILD.bazel @@ -0,0 +1,48 @@ +# Copyright 2022 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("//bazel:spu.bzl", "spu_cc_library", "spu_cc_test") +load("@yacl//bazel:yacl.bzl", "AES_COPT_FLAGS") + +package(default_visibility = ["//visibility:public"]) + +spu_cc_library( + name = "ferret", + srcs = ["ferret.cc"], + hdrs = [ + "emp_util.h", + "ferret.h", + "mitccrh_exp.h", + ], + copts = AES_COPT_FLAGS + ["-Wno-ignored-attributes"], + deps = [ + "//libspu/mpc/cheetah:type", + "//libspu/mpc/cheetah/ot:ferret_ot_interface", + "//libspu/mpc/cheetah/ot:ot_util", + "//libspu/mpc/common:communicator", + "@com_github_emptoolkit_emp_ot//:emp-ot", + "@com_github_emptoolkit_emp_tool//:emp-tool", + "@yacl//yacl/base:int128", + "@yacl//yacl/link", + ], +) + +spu_cc_test( + name = "ferret_test", + srcs = ["ferret_test.cc"], + deps = [ + ":ferret", + "//libspu/mpc/utils:simulate", + ], +) diff --git a/libspu/mpc/cheetah/ot/emp/emp_util.h b/libspu/mpc/cheetah/ot/emp/emp_util.h new file mode 100644 index 00000000..3643c04c --- /dev/null +++ b/libspu/mpc/cheetah/ot/emp/emp_util.h @@ -0,0 +1,47 @@ +// Copyright 2021 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "absl/types/span.h" +#include "emp-tool/utils/block.h" +#include "yacl/base/int128.h" + +#include "libspu/core/prelude.h" + +namespace spu::mpc::cheetah { + +template +inline emp::block ConvToBlock(T x) { + return _mm_set_epi64x(0, static_cast(x)); +} + +template <> +inline emp::block ConvToBlock(uint128_t x) { + return emp::makeBlock(/*hi64*/ static_cast(x >> 64), + /*lo64*/ static_cast(x)); +} + +template +inline T ConvFromBlock(const emp::block& x) { + return static_cast(_mm_extract_epi64(x, 0)); +} + +template <> +inline uint128_t ConvFromBlock(const emp::block& x) { + return yacl::MakeUint128(/*hi64*/ _mm_extract_epi64(x, 1), + /*lo64*/ _mm_extract_epi64(x, 0)); +} + +} // namespace spu::mpc::cheetah \ No newline at end of file diff --git a/libspu/mpc/cheetah/ot/ferret.cc b/libspu/mpc/cheetah/ot/emp/ferret.cc similarity index 88% rename from libspu/mpc/cheetah/ot/ferret.cc rename to libspu/mpc/cheetah/ot/emp/ferret.cc index 6d071529..46bc5a87 100644 --- a/libspu/mpc/cheetah/ot/ferret.cc +++ b/libspu/mpc/cheetah/ot/emp/ferret.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "libspu/mpc/cheetah/ot/ferret.h" +#include "libspu/mpc/cheetah/ot/emp/ferret.h" #include @@ -23,8 +23,9 @@ #include "yacl/base/buffer.h" #include "yacl/link/link.h" -#include "libspu/mpc/cheetah/ot/mitccrh_exp.h" -#include "libspu/mpc/cheetah/ot/util.h" +#include "libspu/mpc/cheetah/ot/emp/emp_util.h" +#include "libspu/mpc/cheetah/ot/emp/mitccrh_exp.h" +#include "libspu/mpc/cheetah/ot/ot_util.h" #define PRE_OT_DATA_REG_SEND_FILE_ALICE "pre_ferret_data_reg_send_alice" #define PRE_OT_DATA_REG_SEND_FILE_BOB "pre_ferret_data_reg_send_bob" @@ -166,7 +167,7 @@ class CheetahIO : public emp::IOChannel { } }; -struct FerretOT::Impl { +struct EmpFerretOt::Impl { private: const bool is_sender_; @@ -308,12 +309,12 @@ struct FerretOT::Impl { std::vector rcm_output(n); RecvRandCorrelatedMsgChosenChoice(choices, absl::MakeSpan(rcm_output)); - size_t pack_load = 8 * sizeof(T) / bit_width; std::array pad; std::vector corr_output(kOTBatchSize); std::vector packed_corr_output; - if (pack_load > 1) { - packed_corr_output.resize(CeilDiv(corr_output.size(), pack_load)); + size_t packed_sze = CeilDiv(corr_output.size() * bit_width, sizeof(T) * 8); + if (packed_sze < corr_output.size()) { + packed_corr_output.resize(packed_sze); } for (size_t i = 0; i < n; i += kOTBatchSize) { @@ -323,8 +324,8 @@ struct FerretOT::Impl { this_batch * sizeof(OtBaseTyp)); ferret_->mitccrh.template hash(pad.data()); - if (pack_load > 1) { - size_t used = CeilDiv(this_batch, pack_load); + if (!packed_corr_output.empty()) { + size_t used = CeilDiv(this_batch * bit_width, sizeof(T) * 8); io_->recv_data(packed_corr_output.data(), sizeof(T) * used); UnzipArray({packed_corr_output.data(), used}, bit_width, {corr_output.data(), this_batch}); @@ -446,12 +447,12 @@ struct FerretOT::Impl { SendRandCorrelatedMsgChosenChoice(rcm_output.get(), n); - size_t pack_load = 8 * sizeof(T) / bit_width; std::array pad; std::vector corr_output(kOTBatchSize); std::vector packed_corr_output; - if (pack_load > 1) { - packed_corr_output.resize(CeilDiv(corr_output.size(), pack_load)); + size_t packed_sze = CeilDiv(corr_output.size() * bit_width, sizeof(T) * 8); + if (packed_sze < corr_output.size()) { + packed_corr_output.resize(packed_sze); } for (size_t i = 0; i < n; i += kOTBatchSize) { @@ -469,10 +470,10 @@ struct FerretOT::Impl { corr_output[j] += corr[i + j] + output[i + j]; } - if (pack_load > 1) { + if (not packed_corr_output.empty()) { size_t used = ZipArray({corr_output.data(), this_batch}, bit_width, absl::MakeSpan(packed_corr_output)); - SPU_ENFORCE(used == CeilDiv(this_batch, pack_load)); + SPU_ENFORCE(used == CeilDiv(this_batch * bit_width, sizeof(T) * 8)); io_->send_data(packed_corr_output.data(), used * sizeof(T)); } else { io_->send_data(corr_output.data(), sizeof(T) * this_batch); @@ -531,19 +532,19 @@ struct FerretOT::Impl { SendRandCorrelatedMsgChosenChoice(rcm_data.get(), n); // async a random seed - emp::block seed; - ferret_->prg.random_block(&seed, 1); - io_->send_block(&seed, 1); - io_->flush(); - ferret_->mitccrh.setS(seed); + // NOTE(lwj): shall we really need to sync seed for each call ? + // emp::block seed; + // ferret_->prg.random_block(&seed, 1); + // io_->send_block(&seed, 1); + // io_->flush(); + // ferret_->mitccrh.setS(seed); - const size_t pack_load = 8 * sizeof(T) / bit_width; std::vector pad(kOTBatchSize * N); std::vector to_send(kOTBatchSize * N); + size_t packed_sze = CeilDiv(to_send.size() * bit_width, sizeof(T) * 8); std::vector packed_to_send; - if (pack_load > 1) { - // NOTE: pack bit chunks into single T element if possible - packed_to_send.resize(CeilDiv(to_send.size(), pack_load)); + if (packed_sze < to_send.size()) { + packed_to_send.resize(packed_sze); } for (size_t i = 0; i < n; i += kOTBatchSize) { @@ -563,13 +564,13 @@ struct FerretOT::Impl { to_send[2 * j + 1] = ConvFromBlock(pad[2 * j + 1]) ^ this_msg[1]; } - if (pack_load == 1) { - io_->send_data(to_send.data(), sizeof(T) * this_batch * N); - } else { + if (not packed_to_send.empty()) { size_t used = ZipArray({to_send.data(), N * this_batch}, bit_width, absl::MakeSpan(packed_to_send)); - SPU_ENFORCE(used == CeilDiv(N * this_batch, pack_load)); + SPU_ENFORCE(used == CeilDiv(N * this_batch * bit_width, sizeof(T) * 8)); io_->send_data(packed_to_send.data(), used * sizeof(T)); + } else { + io_->send_data(to_send.data(), sizeof(T) * this_batch * N); } } } @@ -591,18 +592,18 @@ struct FerretOT::Impl { RecvRandCorrelatedMsgChosenChoice(choices, absl::MakeSpan(rcm_data)); // async a seed from sender - emp::block seed; - io_->recv_block(&seed, 1); - ferret_->mitccrh.setS(seed); + // emp::block seed; + // io_->recv_block(&seed, 1); + // ferret_->mitccrh.setS(seed); const T msg_mask = makeBitsMask(bit_width); - const size_t pack_load = 8 * sizeof(T) / bit_width; std::vector pad(kOTBatchSize); std::vector recv(kOTBatchSize * N); + size_t packed_sze = CeilDiv(recv.size() * bit_width, sizeof(T) * 8); std::vector packed_recv; - if (pack_load > 1) { - packed_recv.resize(CeilDiv(recv.size(), pack_load)); + if (packed_sze < recv.size()) { + packed_recv.resize(packed_sze); } for (size_t i = 0; i < n; i += kOTBatchSize) { @@ -612,10 +613,10 @@ struct FerretOT::Impl { } ferret_->mitccrh.template hash(pad.data()); - if (pack_load == 1) { + if (packed_recv.empty()) { io_->recv_data(recv.data(), N * this_batch * sizeof(T)); } else { - size_t used = CeilDiv(N * this_batch, pack_load); + size_t used = CeilDiv(N * this_batch * bit_width, sizeof(T) * 8); io_->recv_data(packed_recv.data(), used * sizeof(T)); UnzipArray({packed_recv.data(), used}, bit_width, {recv.data(), N * this_batch}); @@ -667,11 +668,10 @@ struct FerretOT::Impl { std::vector pad(kOTBatchSize * N); const T msg_mask = makeBitsMask(bit_width); - bool packable = 8 * sizeof(T) > bit_width; - size_t packed_sze = CeilDiv(bit_width * N * kOTBatchSize, sizeof(T) * 8); std::vector to_send(kOTBatchSize * N); std::vector packed_to_send; - if (packable) { + size_t packed_sze = CeilDiv(to_send.size() * bit_width, sizeof(T) * 8); + if (packed_sze < to_send.size()) { // NOTE: pack bit chunks into single T element if possible packed_to_send.resize(packed_sze); } @@ -711,7 +711,7 @@ struct FerretOT::Impl { } } - if (packable) { + if (not packed_to_send.empty()) { size_t used = ZipArrayBit({to_send.data(), N * this_batch}, bit_width, absl::MakeSpan(packed_to_send)); SPU_ENFORCE(used == CeilDiv(N * this_batch * bit_width, sizeof(T) * 8)); @@ -823,16 +823,16 @@ struct FerretOT::Impl { } }; -FerretOT::FerretOT(std::shared_ptr conn, bool is_sender, - bool malicious) { +EmpFerretOt::EmpFerretOt(std::shared_ptr conn, bool is_sender, + bool malicious) { impl_ = std::make_shared(conn, is_sender, malicious); } -int FerretOT::Rank() const { return impl_->Rank(); } +int EmpFerretOt::Rank() const { return impl_->Rank(); } -void FerretOT::Flush() { impl_->Flush(); } +void EmpFerretOt::Flush() { impl_->Flush(); } -FerretOT::~FerretOT() { impl_->Flush(); } +EmpFerretOt::~EmpFerretOt() { impl_->Flush(); } template size_t CheckBitWidth(size_t bw) { @@ -844,52 +844,52 @@ size_t CheckBitWidth(size_t bw) { return bw; } -#define DEF_SEND_RECV(T) \ - void FerretOT::SendCAMCC(absl::Span corr, absl::Span output, \ - int bw) { \ - impl_->SendCorrelatedMsgChosenChoice(corr, output, bw); \ - } \ - void FerretOT::RecvCAMCC(absl::Span choices, \ - absl::Span output, int bw) { \ - impl_->RecvCorrelatedMsgChosenChoice(choices, output, bw); \ - } \ - void FerretOT::SendRMRC(absl::Span output0, absl::Span output1, \ - size_t bit_width) { \ - bit_width = CheckBitWidth(bit_width); \ - impl_->SendRandMsgRandChoice(output0, output1, bit_width); \ - } \ - void FerretOT::RecvRMRC(absl::Span choices, absl::Span output, \ - size_t bit_width) { \ - bit_width = CheckBitWidth(bit_width); \ - impl_->RecvRandMsgRandChoice(choices, output, bit_width); \ - } \ - void FerretOT::SendCMCC(absl::Span msg_array, size_t N, \ - size_t bit_width) { \ - bit_width = CheckBitWidth(bit_width); \ - if (N == 2) { \ - impl_->SendChosenTwoMsgChosenChoice(msg_array, bit_width); \ - return; \ - } \ - impl_->SendChosenMsgChosenChoice(msg_array, N, bit_width); \ - } \ - void FerretOT::RecvCMCC(absl::Span choices, size_t N, \ - absl::Span output, size_t bit_width) { \ - bit_width = CheckBitWidth(bit_width); \ - if (N == 2) { \ - impl_->RecvChosenTwoMsgChosenChoice(choices, output, bit_width); \ - return; \ - } \ - impl_->RecvChosenMsgChosenChoice(choices, N, output, bit_width); \ - } \ - void FerretOT::SendRMCC(absl::Span output0, absl::Span output1, \ - size_t bit_width) { \ - bit_width = CheckBitWidth(bit_width); \ - impl_->SendRMCC(output0, output1, bit_width); \ - } \ - void FerretOT::RecvRMCC(absl::Span choices, \ - absl::Span output, size_t bit_width) { \ - bit_width = CheckBitWidth(bit_width); \ - impl_->RecvRMCC(choices, output, bit_width); \ +#define DEF_SEND_RECV(T) \ + void EmpFerretOt::SendCAMCC(absl::Span corr, absl::Span output, \ + int bw) { \ + impl_->SendCorrelatedMsgChosenChoice(corr, output, bw); \ + } \ + void EmpFerretOt::RecvCAMCC(absl::Span choices, \ + absl::Span output, int bw) { \ + impl_->RecvCorrelatedMsgChosenChoice(choices, output, bw); \ + } \ + void EmpFerretOt::SendRMRC(absl::Span output0, absl::Span output1, \ + size_t bit_width) { \ + bit_width = CheckBitWidth(bit_width); \ + impl_->SendRandMsgRandChoice(output0, output1, bit_width); \ + } \ + void EmpFerretOt::RecvRMRC(absl::Span choices, \ + absl::Span output, size_t bit_width) { \ + bit_width = CheckBitWidth(bit_width); \ + impl_->RecvRandMsgRandChoice(choices, output, bit_width); \ + } \ + void EmpFerretOt::SendCMCC(absl::Span msg_array, size_t N, \ + size_t bit_width) { \ + bit_width = CheckBitWidth(bit_width); \ + if (N == 2) { \ + impl_->SendChosenTwoMsgChosenChoice(msg_array, bit_width); \ + return; \ + } \ + impl_->SendChosenMsgChosenChoice(msg_array, N, bit_width); \ + } \ + void EmpFerretOt::RecvCMCC(absl::Span choices, size_t N, \ + absl::Span output, size_t bit_width) { \ + bit_width = CheckBitWidth(bit_width); \ + if (N == 2) { \ + impl_->RecvChosenTwoMsgChosenChoice(choices, output, bit_width); \ + return; \ + } \ + impl_->RecvChosenMsgChosenChoice(choices, N, output, bit_width); \ + } \ + void EmpFerretOt::SendRMCC(absl::Span output0, absl::Span output1, \ + size_t bit_width) { \ + bit_width = CheckBitWidth(bit_width); \ + impl_->SendRMCC(output0, output1, bit_width); \ + } \ + void EmpFerretOt::RecvRMCC(absl::Span choices, \ + absl::Span output, size_t bit_width) { \ + bit_width = CheckBitWidth(bit_width); \ + impl_->RecvRMCC(choices, output, bit_width); \ } DEF_SEND_RECV(uint8_t) diff --git a/libspu/mpc/cheetah/ot/ferret.h b/libspu/mpc/cheetah/ot/emp/ferret.h similarity index 74% rename from libspu/mpc/cheetah/ot/ferret.h rename to libspu/mpc/cheetah/ot/emp/ferret.h index 01b61d98..82bbae9e 100644 --- a/libspu/mpc/cheetah/ot/ferret.h +++ b/libspu/mpc/cheetah/ot/emp/ferret.h @@ -20,104 +20,105 @@ #include "absl/types/span.h" #include "yacl/base/int128.h" +#include "libspu/mpc/cheetah/ot/ferret_ot_interface.h" #include "libspu/mpc/common/communicator.h" namespace spu::mpc::cheetah { -class FerretOT { +class EmpFerretOt : public FerretOtInterface { private: struct Impl; std::shared_ptr impl_; public: - FerretOT(std::shared_ptr conn, bool is_sender, - bool malicious = false); + EmpFerretOt(std::shared_ptr conn, bool is_sender, + bool malicious = false); - ~FerretOT(); + ~EmpFerretOt(); - int Rank() const; + int Rank() const override; - void Flush(); + void Flush() override; // One-of-N OT where msg_array is a Nxn array. // choice \in [0, N-1] void SendCMCC(absl::Span msg_array, size_t N, - size_t bit_width = 0); + size_t bit_width = 0) override; void SendCMCC(absl::Span msg_array, size_t N, - size_t bit_width = 0); + size_t bit_width = 0) override; void SendCMCC(absl::Span msg_array, size_t N, - size_t bit_width = 0); + size_t bit_width = 0) override; void SendCMCC(absl::Span msg_array, size_t N, - size_t bit_width = 0); + size_t bit_width = 0) override; void RecvCMCC(absl::Span one_oo_N_choices, size_t N, - absl::Span output, size_t bit_width = 0); + absl::Span output, size_t bit_width = 0) override; void RecvCMCC(absl::Span one_oo_N_choices, size_t N, - absl::Span output, size_t bit_width = 0); + absl::Span output, size_t bit_width = 0) override; void RecvCMCC(absl::Span one_oo_N_choices, size_t N, - absl::Span output, size_t bit_width = 0); + absl::Span output, size_t bit_width = 0) override; void RecvCMCC(absl::Span one_oo_N_choices, size_t N, - absl::Span output, size_t bit_width = 0); + absl::Span output, size_t bit_width = 0) override; // Random Message Random Choice void SendRMRC(absl::Span output0, absl::Span output1, - size_t bit_width = 0); + size_t bit_width = 0) override; void SendRMRC(absl::Span output0, absl::Span output1, - size_t bit_width = 0); + size_t bit_width = 0) override; void SendRMRC(absl::Span output0, absl::Span output1, - size_t bit_width = 0); + size_t bit_width = 0) override; void SendRMRC(absl::Span output0, absl::Span output1, - size_t bit_width = 0); + size_t bit_width = 0) override; void RecvRMRC(absl::Span binary_choices, absl::Span output, - size_t bit_width = 0); + size_t bit_width = 0) override; void RecvRMRC(absl::Span binary_choices, absl::Span output, - size_t bit_width = 0); + size_t bit_width = 0) override; void RecvRMRC(absl::Span binary_choices, absl::Span output, - size_t bit_width = 0); + size_t bit_width = 0) override; void RecvRMRC(absl::Span binary_choices, - absl::Span output, size_t bit_width = 0); + absl::Span output, size_t bit_width = 0) override; // correlated additive message, chosen choice // (x, x + corr * choice) <- (corr, choice) // Can use bit_width to further indicate output ring. `bit_width = 0` means to // use the full range. void SendCAMCC(absl::Span corr, absl::Span output, - int bit_width = 0); + int bit_width = 0) override; void SendCAMCC(absl::Span corr, absl::Span output, - int bit_width = 0); + int bit_width = 0) override; void SendCAMCC(absl::Span corr, absl::Span output, - int bit_width = 0); + int bit_width = 0) override; void SendCAMCC(absl::Span corr, absl::Span output, - int bit_width = 0); + int bit_width = 0) override; void RecvCAMCC(absl::Span binary_choices, - absl::Span output, int bit_width = 0); + absl::Span output, int bit_width = 0) override; void RecvCAMCC(absl::Span binary_choices, - absl::Span output, int bit_width = 0); + absl::Span output, int bit_width = 0) override; void RecvCAMCC(absl::Span binary_choices, - absl::Span output, int bit_width = 0); + absl::Span output, int bit_width = 0) override; void RecvCAMCC(absl::Span binary_choices, - absl::Span output, int bit_width = 0); + absl::Span output, int bit_width = 0) override; // Random Message Chosen Choice void SendRMCC(absl::Span output0, absl::Span output1, - size_t bit_width = 0); + size_t bit_width = 0) override; void SendRMCC(absl::Span output0, absl::Span output1, - size_t bit_width = 0); + size_t bit_width = 0) override; void SendRMCC(absl::Span output0, absl::Span output1, - size_t bit_width = 0); + size_t bit_width = 0) override; void SendRMCC(absl::Span output0, absl::Span output1, - size_t bit_width = 0); + size_t bit_width = 0) override; void RecvRMCC(absl::Span binary_choices, - absl::Span output, size_t bit_width = 0); + absl::Span output, size_t bit_width = 0) override; void RecvRMCC(absl::Span binary_choices, - absl::Span output, size_t bit_width = 0); + absl::Span output, size_t bit_width = 0) override; void RecvRMCC(absl::Span binary_choices, - absl::Span output, size_t bit_width = 0); + absl::Span output, size_t bit_width = 0) override; void RecvRMCC(absl::Span binary_choices, - absl::Span output, size_t bit_width = 0); + absl::Span output, size_t bit_width = 0) override; }; } // namespace spu::mpc::cheetah diff --git a/libspu/mpc/cheetah/ot/emp/ferret_test.cc b/libspu/mpc/cheetah/ot/emp/ferret_test.cc new file mode 100644 index 00000000..af6cb96a --- /dev/null +++ b/libspu/mpc/cheetah/ot/emp/ferret_test.cc @@ -0,0 +1,208 @@ +// Copyright 2022 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "libspu/mpc/cheetah/ot/emp/ferret.h" + +#include + +#include "gtest/gtest.h" + +#include "libspu/mpc/utils/ring_ops.h" +#include "libspu/mpc/utils/simulate.h" + +namespace spu::mpc::cheetah::test { + +class FerretCOTTest : public testing::TestWithParam {}; + +INSTANTIATE_TEST_SUITE_P( + Cheetah, FerretCOTTest, + testing::Values(FieldType::FM32, FieldType::FM64, FieldType::FM128), + [](const testing::TestParamInfo &p) { + return fmt::format("{}", p.param); + }); + +template +absl::Span makeSpan(NdArrayView a) { + return {&a[0], (size_t)a.numel()}; +} + +template +absl::Span makeConstSpan(NdArrayView a) { + return {&a[0], (size_t)a.numel()}; +} + +TEST_P(FerretCOTTest, ChosenCorrelationChosenChoice) { + size_t kWorldSize = 2; + int64_t n = 10; + auto field = GetParam(); + + auto _correlation = ring_rand(field, {n}); + std::vector choices(n); + std::default_random_engine rdv; + std::uniform_int_distribution uniform(0, -1); + std::generate_n(choices.begin(), n, [&]() -> uint8_t { + return static_cast(uniform(rdv) & 1); + }); + + DISPATCH_ALL_FIELDS(field, "", [&]() { + NdArrayView correlation(_correlation); + std::vector computed[2]; + utils::simulate(kWorldSize, [&](std::shared_ptr ctx) { + auto conn = std::make_shared(ctx); + int rank = ctx->Rank(); + computed[rank].resize(n); + EmpFerretOt ferret(conn, rank == 0); + if (rank == 0) { + ferret.SendCAMCC(makeConstSpan(correlation), + absl::MakeSpan(computed[0])); + ferret.Flush(); + } else { + ferret.RecvCAMCC(absl::MakeSpan(choices), absl::MakeSpan(computed[1])); + } + }); + + for (int64_t i = 0; i < n; ++i) { + ring2k_t c = -computed[0][i] + computed[1][i]; + ring2k_t e = choices[i] ? correlation[i] : 0; + EXPECT_EQ(e, c); + } + }); +} + +TEST_P(FerretCOTTest, RndMsgRndChoice) { + size_t kWorldSize = 2; + auto field = GetParam(); + constexpr size_t bw = 2; + + size_t n = 10; + DISPATCH_ALL_FIELDS(field, "", [&]() { + std::vector msg0(n); + std::vector msg1(n); + ring2k_t max = static_cast(1) << bw; + + std::vector choices(n); + std::vector selected(n); + + utils::simulate(kWorldSize, [&](std::shared_ptr ctx) { + auto conn = std::make_shared(ctx); + int rank = ctx->Rank(); + EmpFerretOt ferret(conn, rank == 0); + if (rank == 0) { + ferret.SendRMRC(absl::MakeSpan(msg0), absl::MakeSpan(msg1), bw); + ferret.Flush(); + } else { + ferret.RecvRMRC(absl::MakeSpan(choices), absl::MakeSpan(selected), bw); + } + }); + + for (size_t i = 0; i < n; ++i) { + ring2k_t e = choices[i] ? msg1[i] : msg0[i]; + ring2k_t c = selected[i]; + EXPECT_TRUE(choices[i] < 2); + EXPECT_LT(e, max); + EXPECT_LT(c, max); + EXPECT_EQ(e, c); + } + }); +} + +TEST_P(FerretCOTTest, RndMsgChosenChoice) { + size_t kWorldSize = 2; + auto field = GetParam(); + constexpr size_t bw = 2; + + size_t n = 10; + DISPATCH_ALL_FIELDS(field, "", [&]() { + std::vector msg0(n); + std::vector msg1(n); + ring2k_t max = static_cast(1) << bw; + + std::vector choices(n); + std::default_random_engine rdv; + std::uniform_int_distribution uniform(0, -1); + std::generate_n(choices.begin(), n, [&]() -> uint8_t { + return static_cast(uniform(rdv) & 1); + }); + + std::vector selected(n); + + utils::simulate(kWorldSize, [&](std::shared_ptr ctx) { + auto conn = std::make_shared(ctx); + int rank = ctx->Rank(); + EmpFerretOt ferret(conn, rank == 0); + if (rank == 0) { + ferret.SendRMCC(absl::MakeSpan(msg0), absl::MakeSpan(msg1), bw); + ferret.Flush(); + } else { + ferret.RecvRMCC(absl::MakeSpan(choices), absl::MakeSpan(selected), bw); + } + }); + + for (size_t i = 0; i < n; ++i) { + ring2k_t e = choices[i] ? msg1[i] : msg0[i]; + ring2k_t c = selected[i]; + EXPECT_LT(e, max); + EXPECT_LT(c, max); + EXPECT_EQ(e, c); + } + }); +} + +TEST_P(FerretCOTTest, ChosenMsgChosenChoice) { + size_t kWorldSize = 2; + int64_t n = 100; + auto field = GetParam(); + DISPATCH_ALL_FIELDS(field, "", [&]() { + using scalar_t = ring2k_t; + std::default_random_engine rdv; + std::uniform_int_distribution uniform(0, -1); + for (size_t bw : {2UL, 4UL, sizeof(scalar_t) * 8}) { + scalar_t mask = (static_cast(1) << bw) - 1; + for (int64_t N : {2, 3, 8}) { + auto _msg = ring_rand(field, {N * n}); + NdArrayView msg(_msg); + pforeach(0, msg.numel(), [&](int64_t i) { msg[i] &= mask; }); + + std::vector choices(n); + std::generate_n(choices.begin(), n, [&]() -> uint8_t { + return static_cast(uniform(rdv) % N); + }); + + std::vector selected(n); + + utils::simulate( + kWorldSize, [&](std::shared_ptr ctx) { + auto conn = std::make_shared(ctx); + int rank = ctx->Rank(); + EmpFerretOt ferret(conn, rank == 0); + if (rank == 0) { + ferret.SendCMCC(makeConstSpan(msg), N, bw); + ferret.Flush(); + } else { + ferret.RecvCMCC(absl::MakeSpan(choices), N, + absl::MakeSpan(selected), bw); + } + }); + + for (int64_t i = 0; i < n; ++i) { + scalar_t e = msg[i * N + choices[i]]; + scalar_t c = selected[i]; + EXPECT_EQ(e, c); + } + } + } + }); +} + +} // namespace spu::mpc::cheetah::test diff --git a/libspu/mpc/cheetah/ot/mitccrh_exp.h b/libspu/mpc/cheetah/ot/emp/mitccrh_exp.h similarity index 100% rename from libspu/mpc/cheetah/ot/mitccrh_exp.h rename to libspu/mpc/cheetah/ot/emp/mitccrh_exp.h diff --git a/libspu/mpc/cheetah/ot/util.cc b/libspu/mpc/cheetah/ot/emp/util.cc similarity index 100% rename from libspu/mpc/cheetah/ot/util.cc rename to libspu/mpc/cheetah/ot/emp/util.cc diff --git a/libspu/mpc/cheetah/ot/util.h b/libspu/mpc/cheetah/ot/emp/util.h similarity index 100% rename from libspu/mpc/cheetah/ot/util.h rename to libspu/mpc/cheetah/ot/emp/util.h diff --git a/libspu/mpc/cheetah/ot/ferret_ot_interface.h b/libspu/mpc/cheetah/ot/ferret_ot_interface.h new file mode 100644 index 00000000..6abec8f3 --- /dev/null +++ b/libspu/mpc/cheetah/ot/ferret_ot_interface.h @@ -0,0 +1,107 @@ +// Copyright 2022 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +namespace spu::mpc::cheetah { + +class FerretOtInterface { + public: + virtual ~FerretOtInterface() = default; + + virtual int Rank() const = 0; + virtual void Flush() = 0; + + // One-of-N OT where msg_array is a Nxn array. + // choice \in [0, N-1] + virtual void SendCMCC(absl::Span msg_array, size_t N, + size_t bit_width = 0) = 0; + virtual void SendCMCC(absl::Span msg_array, size_t N, + size_t bit_width = 0) = 0; + virtual void SendCMCC(absl::Span msg_array, size_t N, + size_t bit_width = 0) = 0; + virtual void SendCMCC(absl::Span msg_array, size_t N, + size_t bit_width = 0) = 0; + + virtual void RecvCMCC(absl::Span one_oo_N_choices, size_t N, + absl::Span output, size_t bit_width = 0) = 0; + virtual void RecvCMCC(absl::Span one_oo_N_choices, size_t N, + absl::Span output, size_t bit_width = 0) = 0; + virtual void RecvCMCC(absl::Span one_oo_N_choices, size_t N, + absl::Span output, size_t bit_width = 0) = 0; + virtual void RecvCMCC(absl::Span one_oo_N_choices, size_t N, + absl::Span output, size_t bit_width = 0) = 0; + + // Random Message Random Choice + virtual void SendRMRC(absl::Span output0, + absl::Span output1, size_t bit_width = 0) = 0; + virtual void SendRMRC(absl::Span output0, + absl::Span output1, size_t bit_width = 0) = 0; + virtual void SendRMRC(absl::Span output0, + absl::Span output1, size_t bit_width = 0) = 0; + virtual void SendRMRC(absl::Span output0, + absl::Span output1, + size_t bit_width = 0) = 0; + + virtual void RecvRMRC(absl::Span binary_choices, + absl::Span output, size_t bit_width = 0) = 0; + virtual void RecvRMRC(absl::Span binary_choices, + absl::Span output, size_t bit_width = 0) = 0; + virtual void RecvRMRC(absl::Span binary_choices, + absl::Span output, size_t bit_width = 0) = 0; + virtual void RecvRMRC(absl::Span binary_choices, + absl::Span output, size_t bit_width = 0) = 0; + + // correlated additive message, chosen choice + // (x, x + corr * choice) <- (corr, choice) + // Can use bit_width=0 to further indicate output ring. `bit_width=0` means to + // use the full range. + virtual void SendCAMCC(absl::Span corr, + absl::Span output, int bit_width = 0) = 0; + virtual void SendCAMCC(absl::Span corr, + absl::Span output, int bit_width = 0) = 0; + virtual void SendCAMCC(absl::Span corr, + absl::Span output, int bit_width = 0) = 0; + virtual void SendCAMCC(absl::Span corr, + absl::Span output, int bit_width = 0) = 0; + + virtual void RecvCAMCC(absl::Span binary_choices, + absl::Span output, int bit_width = 0) = 0; + virtual void RecvCAMCC(absl::Span binary_choices, + absl::Span output, int bit_width = 0) = 0; + virtual void RecvCAMCC(absl::Span binary_choices, + absl::Span output, int bit_width = 0) = 0; + virtual void RecvCAMCC(absl::Span binary_choices, + absl::Span output, int bit_width = 0) = 0; + + // Random Message Chosen Choice + virtual void SendRMCC(absl::Span output0, + absl::Span output1, size_t bit_width = 0) = 0; + virtual void SendRMCC(absl::Span output0, + absl::Span output1, size_t bit_width = 0) = 0; + virtual void SendRMCC(absl::Span output0, + absl::Span output1, size_t bit_width = 0) = 0; + virtual void SendRMCC(absl::Span output0, + absl::Span output1, + size_t bit_width = 0) = 0; + + virtual void RecvRMCC(absl::Span binary_choices, + absl::Span output, size_t bit_width = 0) = 0; + virtual void RecvRMCC(absl::Span binary_choices, + absl::Span output, size_t bit_width = 0) = 0; + virtual void RecvRMCC(absl::Span binary_choices, + absl::Span output, size_t bit_width = 0) = 0; + virtual void RecvRMCC(absl::Span binary_choices, + absl::Span output, size_t bit_width = 0) = 0; +}; +} // namespace spu::mpc::cheetah \ No newline at end of file diff --git a/libspu/mpc/cheetah/ot/ot_util.cc b/libspu/mpc/cheetah/ot/ot_util.cc new file mode 100644 index 00000000..7ba19ec2 --- /dev/null +++ b/libspu/mpc/cheetah/ot/ot_util.cc @@ -0,0 +1,82 @@ +// Copyright 2021 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "libspu/mpc/cheetah/ot/ot_util.h" + +#include + +#include "libspu/core/prelude.h" + +namespace spu::mpc::cheetah { + +uint8_t BoolToU8(absl::Span bits) { + size_t len = bits.size(); + SPU_ENFORCE(len >= 1 && len <= 8); + return std::accumulate( + bits.data(), bits.data() + len, + /*init*/ static_cast(0), + [](uint8_t init, uint8_t next) { return (init << 1) | (next & 1); }); +} + +void U8ToBool(absl::Span bits, uint8_t u8) { + size_t len = std::min(8UL, bits.size()); + SPU_ENFORCE(len >= 1); + for (size_t i = 0; i < len; ++i) { + bits[i] = (u8 & 1); + u8 >>= 1; + } +} + +NdArrayRef OpenShare(const NdArrayRef &shr, ReduceOp op, size_t nbits, + std::shared_ptr conn) { + SPU_ENFORCE(conn != nullptr); + SPU_ENFORCE(shr.eltype().isa()); + SPU_ENFORCE(op == ReduceOp::ADD or op == ReduceOp::XOR); + + auto field = shr.eltype().as()->field(); + size_t fwidth = SizeOf(field) * 8; + if (nbits == 0) { + nbits = fwidth; + } + SPU_ENFORCE(nbits <= fwidth, "nbits out-of-bound"); + bool packable = fwidth > nbits; + if (not packable) { + return conn->allReduce(op, shr, "open"); + } + + size_t numel = shr.numel(); + size_t compact_numel = CeilDiv(numel * nbits, fwidth); + + NdArrayRef out(shr.eltype(), {(int64_t)numel}); + DISPATCH_ALL_FIELDS(field, "zip", [&]() { + auto inp = absl::MakeConstSpan(&shr.at(0), numel); + auto oup = absl::MakeSpan(&out.at(0), compact_numel); + + size_t used = ZipArray(inp, nbits, oup); + SPU_ENFORCE_EQ(used, compact_numel); + + std::vector opened; + if (op == ReduceOp::XOR) { + opened = conn->allReduce(oup, "open"); + } else { + opened = conn->allReduce(oup, "open"); + } + + oup = absl::MakeSpan(&out.at(0), numel); + UnzipArray(absl::MakeConstSpan(opened), nbits, oup); + }); + return out.reshape(shr.shape()); +} + +} // namespace spu::mpc::cheetah diff --git a/libspu/mpc/cheetah/ot/ot_util.h b/libspu/mpc/cheetah/ot/ot_util.h new file mode 100644 index 00000000..de1f2378 --- /dev/null +++ b/libspu/mpc/cheetah/ot/ot_util.h @@ -0,0 +1,138 @@ +// Copyright 2021 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "absl/types/span.h" +#include "yacl/base/int128.h" + +#include "libspu/core/ndarray_ref.h" +#include "libspu/core/prelude.h" +#include "libspu/mpc/common/communicator.h" + +namespace spu::mpc::cheetah { + +template +T makeBitsMask(size_t nbits) { + size_t max = sizeof(T) * 8; + if (nbits == 0) { + nbits = max; + } + SPU_ENFORCE(nbits <= max); + T mask = static_cast(-1); + if (nbits < max) { + mask = (static_cast(1) << nbits) - 1; + } + return mask; +} + +template +inline T CeilDiv(T a, T b) { + return (a + b - 1) / b; +} + +template +size_t ZipArray(absl::Span inp, size_t bit_width, absl::Span oup) { + size_t width = sizeof(T) * 8; + SPU_ENFORCE(bit_width > 0 && width >= bit_width); + size_t numel = inp.size(); + size_t packed_sze = CeilDiv(numel * bit_width, width); + + SPU_ENFORCE(oup.size() >= packed_sze); + + const T mask = makeBitsMask(bit_width); + for (size_t i = 0; i < packed_sze; ++i) { + oup[i] = 0; + } + // shift will in [0, 2 * width] + for (size_t i = 0, has_done = 0; i < numel; i += 1, has_done += bit_width) { + T real_data = inp[i] & mask; + size_t packed_index0 = i * bit_width / width; + size_t shft0 = has_done % width; + oup[packed_index0] |= (real_data << shft0); + if (shft0 + bit_width > width) { + size_t shft1 = width - shft0; + size_t packed_index1 = packed_index0 + 1; + oup[packed_index1] |= (real_data >> shft1); + } + } + return packed_sze; +} + +template +size_t UnzipArray(absl::Span inp, size_t bit_width, + absl::Span oup) { + size_t width = sizeof(T) * 8; + SPU_ENFORCE(bit_width > 0 && bit_width <= width); + + size_t packed_sze = inp.size(); + size_t n = oup.size(); + size_t raw_sze = packed_sze * width / bit_width; + SPU_ENFORCE(n > 0 && n <= raw_sze); + + const T mask = makeBitsMask(bit_width); + for (size_t i = 0, has_done = 0; i < n; i += 1, has_done += bit_width) { + size_t packed_index0 = i * bit_width / width; + size_t shft0 = has_done % width; + oup[i] = (inp[packed_index0] >> shft0); + if (shft0 + bit_width > width) { + size_t shft1 = width - shft0; + size_t packed_index1 = packed_index0 + 1; + oup[i] |= (inp[packed_index1] << shft1); + } + oup[i] &= mask; + } + return n; +} + +template +size_t ZipArrayBit(absl::Span inp, size_t bit_width, + absl::Span oup) { + return ZipArray(inp, bit_width, oup); +} + +template +size_t UnzipArrayBit(absl::Span inp, size_t bit_width, + absl::Span oup) { + return UnzipArray(inp, bit_width, oup); +} + +template +size_t PackU8Array(absl::Span u8array, absl::Span packed) { + constexpr size_t elsze = sizeof(T); + const size_t nbytes = u8array.size(); + const size_t numel = CeilDiv(nbytes, elsze); + + SPU_ENFORCE(packed.size() >= numel); + + for (size_t i = 0; i < nbytes; i += elsze) { + size_t this_batch = std::min(nbytes - i, elsze); + T acc{0}; + for (size_t j = 0; j < this_batch; ++j) { + acc = (acc << 8) | u8array[i + j]; + } + packed[i / elsze] = acc; + } + + return numel; +} + +NdArrayRef OpenShare(const NdArrayRef &shr, ReduceOp op, size_t nbits, + std::shared_ptr conn); + +uint8_t BoolToU8(absl::Span bits); + +void U8ToBool(absl::Span bits, uint8_t u8); + +} // namespace spu::mpc::cheetah diff --git a/libspu/mpc/cheetah/ot/util_test.cc b/libspu/mpc/cheetah/ot/ot_util_test.cc similarity index 92% rename from libspu/mpc/cheetah/ot/util_test.cc rename to libspu/mpc/cheetah/ot/ot_util_test.cc index 6a7293a2..ea5e0aac 100644 --- a/libspu/mpc/cheetah/ot/util_test.cc +++ b/libspu/mpc/cheetah/ot/ot_util_test.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "libspu/mpc/cheetah/ot/util.h" +#include "libspu/mpc/cheetah/ot/ot_util.h" #include @@ -22,18 +22,18 @@ namespace spu::mpc::cheetah::test { -class UtilTest : public ::testing::TestWithParam { +class OtUtilTest : public ::testing::TestWithParam { void SetUp() override {} }; INSTANTIATE_TEST_SUITE_P( - Cheetah, UtilTest, + Cheetah, OtUtilTest, testing::Values(FieldType::FM32, FieldType::FM64, FieldType::FM128), - [](const testing::TestParamInfo &p) { + [](const testing::TestParamInfo &p) { return fmt::format("{}", p.param); }); -TEST_P(UtilTest, ZipArray) { +TEST_P(OtUtilTest, ZipArray) { const int64_t n = 200; const auto field = GetParam(); const size_t elsze = SizeOf(field); @@ -63,7 +63,7 @@ TEST_P(UtilTest, ZipArray) { }); } -TEST_P(UtilTest, ZipArrayBit) { +TEST_P(OtUtilTest, ZipArrayBit) { const size_t n = 1000; const auto field = GetParam(); diff --git a/libspu/mpc/cheetah/yacl_ot/BUILD.bazel b/libspu/mpc/cheetah/ot/yacl/BUILD.bazel similarity index 67% rename from libspu/mpc/cheetah/yacl_ot/BUILD.bazel rename to libspu/mpc/cheetah/ot/yacl/BUILD.bazel index 6a4e28df..53957f43 100644 --- a/libspu/mpc/cheetah/yacl_ot/BUILD.bazel +++ b/libspu/mpc/cheetah/ot/yacl/BUILD.bazel @@ -1,4 +1,4 @@ -# Copyright 2023 Ant Group Co., Ltd. +# Copyright 2022 Ant Group Co., Ltd. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,37 +17,26 @@ load("@yacl//bazel:yacl.bzl", "AES_COPT_FLAGS") package(default_visibility = ["//visibility:public"]) -spu_cc_test( - name = "yacl_ferret_test", - srcs = ["yacl_ferret_test.cc"], - deps = [ - ":yacl_ferret_ot", - "//libspu/mpc/utils:simulate", - ], -) - spu_cc_library( - name = "yacl_ferret_ot", + name = "ferret", srcs = [ - "basic_ot_prot.cc", - "util.cc", - "yacl_ferret.cc", + "ferret.cc", "yacl_ote_adapter.cc", + "yacl_util.cc", ], hdrs = [ - "basic_ot_prot.h", + "ferret.h", "mitccrh_exp.h", - "util.h", - "yacl_ferret.h", "yacl_ote_adapter.h", + "yacl_util.h", ], copts = AES_COPT_FLAGS + ["-Wno-ignored-attributes"], deps = [ - "//libspu/core:xt_helper", "//libspu/mpc/cheetah:type", + "//libspu/mpc/cheetah/ot:ferret_ot_interface", + "//libspu/mpc/cheetah/ot:ot_util", "//libspu/mpc/common:communicator", "//libspu/mpc/semi2k:conversion", - "@com_github_emptoolkit_emp_tool//:emp-tool", "@yacl//yacl/base:dynamic_bitset", "@yacl//yacl/base:int128", "@yacl//yacl/crypto/base/aes:aes_opt", @@ -61,22 +50,10 @@ spu_cc_library( ) spu_cc_test( - name = "basic_ot_prot_test", - size = "large", - srcs = ["basic_ot_prot_test.cc"], - tags = [ - "exclusive-if-local", - ], + name = "ferret_test", + srcs = ["ferret_test.cc"], deps = [ - ":yacl_ferret_ot", + ":ferret", "//libspu/mpc/utils:simulate", ], ) - -spu_cc_test( - name = "util_test", - srcs = ["util_test.cc"], - deps = [ - ":yacl_ferret_ot", - ], -) diff --git a/libspu/mpc/cheetah/yacl_ot/yacl_ferret.cc b/libspu/mpc/cheetah/ot/yacl/ferret.cc similarity index 85% rename from libspu/mpc/cheetah/yacl_ot/yacl_ferret.cc rename to libspu/mpc/cheetah/ot/yacl/ferret.cc index 4b3545b4..f8b0f030 100644 --- a/libspu/mpc/cheetah/yacl_ot/yacl_ferret.cc +++ b/libspu/mpc/cheetah/ot/yacl/ferret.cc @@ -12,27 +12,25 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "libspu/mpc/cheetah/yacl_ot/yacl_ferret.h" +#include "libspu/mpc/cheetah/ot/yacl/ferret.h" #include -#include "emp-tool/io/io_channel.h" #include "spdlog/spdlog.h" #include "yacl/base/buffer.h" #include "yacl/crypto/tools/random_permutation.h" #include "yacl/link/link.h" -#include "libspu/mpc/cheetah/yacl_ot/mitccrh_exp.h" -#include "libspu/mpc/cheetah/yacl_ot/util.h" -#include "libspu/mpc/cheetah/yacl_ot/yacl_ote_adapter.h" +#include "libspu/mpc/cheetah/ot/ot_util.h" +#include "libspu/mpc/cheetah/ot/yacl/mitccrh_exp.h" +#include "libspu/mpc/cheetah/ot/yacl/yacl_ote_adapter.h" +#include "libspu/mpc/cheetah/ot/yacl/yacl_util.h" namespace spu::mpc::cheetah { constexpr size_t kOTBatchSize = 8; // emp-ot/cot.h -// A concrete class for emp::IOChannel -// Because emp::FerretCOT needs a uniform API of emp::IOChannel -class CheetahIO : public emp::IOChannel { +class BufferedIO { public: std::shared_ptr conn_; @@ -47,7 +45,7 @@ class CheetahIO : public emp::IOChannel { std::vector recv_buffer_; uint64_t recv_buffer_used_; - explicit CheetahIO(std::shared_ptr conn) + explicit BufferedIO(std::shared_ptr conn) : conn_(std::move(conn)), send_op_(0), recv_op_(0), @@ -56,7 +54,7 @@ class CheetahIO : public emp::IOChannel { send_buffer_.resize(SEND_BUFFER_SIZE); } - ~CheetahIO() { + ~BufferedIO() { try { flush(); } catch (const std::exception& e) { @@ -72,7 +70,7 @@ class CheetahIO : public emp::IOChannel { conn_->sendAsync( conn_->nextRank(), absl::Span{send_buffer_.data(), send_buffer_used_}, - fmt::format("CheetahIO send:{}", send_op_++)); + fmt::format("BufferedIO send:{}", send_op_++)); std::memset(send_buffer_.data(), 0, SEND_BUFFER_SIZE); send_buffer_used_ = 0; @@ -80,11 +78,11 @@ class CheetahIO : public emp::IOChannel { void fill_recv() { recv_buffer_ = conn_->recv( - conn_->nextRank(), fmt::format("CheetahIO recv:{}", recv_op_++)); + conn_->nextRank(), fmt::format("BufferedIO recv:{}", recv_op_++)); recv_buffer_used_ = 0; } - void send_data_internal(const void* data, int len) { + void send_data(const void* data, int len) { size_t send_buffer_left = SEND_BUFFER_SIZE - send_buffer_used_; if (send_buffer_left <= static_cast(len)) { std::memcpy(send_buffer_.data() + send_buffer_used_, data, @@ -92,15 +90,15 @@ class CheetahIO : public emp::IOChannel { send_buffer_used_ += send_buffer_left; flush(); - send_data_internal(static_cast(data) + send_buffer_left, - len - send_buffer_left); + send_data(static_cast(data) + send_buffer_left, + len - send_buffer_left); } else { std::memcpy(send_buffer_.data() + send_buffer_used_, data, len); send_buffer_used_ += len; } } - void recv_data_internal(void* data, int len) { + void recv_data(void* data, int len) { if (send_buffer_used_ > 0) { flush(); } @@ -116,57 +114,17 @@ class CheetahIO : public emp::IOChannel { } fill_recv(); - recv_data_internal(static_cast(data) + recv_buffer_left, - len - recv_buffer_left); - } - } - - template - void send_data_partial(const T* data, int len, int bitlength) { - if (bitlength == sizeof(T) * 8) { - send_data_internal(static_cast(data), len * sizeof(T)); - return; - } - - int compact_len = (bitlength + 7) / 8; - std::vector bytes(len); - for (int i = 0; i < compact_len; i++) { - for (int j = 0; j < len; j++) { - bytes[j] = uint8_t(data[j] >> (i * 8)); - } - send_data_internal(bytes.data(), len); - } - } - - template - void recv_data_partial(T* data, int len, int bitlength) { - if (bitlength == sizeof(T) * 8) { - recv_data_internal(static_cast(data), len * sizeof(T)); - return; - } - std::memset(data, 0, len * sizeof(T)); - - int compact_len = (bitlength + 7) / 8; - std::vector bytes(len); - for (int i = 0; i < compact_len; i++) { - recv_data_internal(bytes.data(), len); - for (int j = 0; j < len; j++) { - data[j] |= T(bytes[j]) << (i * 8); - } - } - T mask = (T(1) << bitlength) - 1; - for (int i = 0; i < len; i++) { - data[i] &= mask; + recv_data(static_cast(data) + recv_buffer_left, + len - recv_buffer_left); } } }; -struct YaclFerretOT::Impl { +struct YaclFerretOt::Impl { private: const bool is_sender_; - std::shared_ptr io_{nullptr}; - std::array io_holder_; + std::shared_ptr io_{nullptr}; std::shared_ptr ferret_{nullptr}; MITCCRHExp<8> mitccrh_exp_{}; @@ -200,8 +158,7 @@ struct YaclFerretOT::Impl { "YACL does NOT support malicious ferret ote"); SPU_ENFORCE(conn != nullptr); - io_ = std::make_shared(conn); - io_holder_[0] = io_.get(); + io_ = std::make_shared(conn); ferret_ = std::make_shared(conn->lctx(), is_sender); ferret_->OneTimeSetup(); } @@ -252,6 +209,60 @@ struct YaclFerretOT::Impl { yc::ParaCrHashInplace_128(output); } + template + void SendCorrelatedMsgChosenChoice(absl::Span corr, + absl::Span output, int bit_width) { + size_t n = corr.size(); + SPU_ENFORCE_EQ(n, output.size()); + if (bit_width == 0) { + bit_width = 8 * sizeof(T); + } + SPU_ENFORCE(bit_width > 0 && bit_width <= (int)(8 * sizeof(T)), + "bit_width={} out-of-range T={} bits", bit_width, + sizeof(T) * 8); + + yacl::AlignedVector rcm_output(n); + + SendRandCorrelatedMsgChosenChoice(rcm_output.data(), n); + + std::array pad; + std::vector corr_output(kOTBatchSize); + + size_t eltsize = 8 * sizeof(T); + bool packable = eltsize > (size_t)bit_width; + size_t packed_size = CeilDiv(kOTBatchSize * bit_width, eltsize); + + std::vector packed_corr_output; + if (packable) { + packed_corr_output.resize(packed_size); + } + + for (size_t i = 0; i < n; i += kOTBatchSize) { + size_t this_batch = std::min(kOTBatchSize, n - i); + for (size_t j = 0; j < this_batch; ++j) { + pad[2 * j] = rcm_output[i + j]; + pad[2 * j + 1] = rcm_output[i + j] ^ ferret_->GetDelta(); + } + + yc::ParaCrHashInplace_128(absl::MakeSpan(pad)); + + for (size_t j = 0; j < this_batch; ++j) { + output[i + j] = (T)(pad[2 * j]); + corr_output[j] = (T)(pad[2 * j + 1]); + corr_output[j] += corr[i + j] + output[i + j]; + } + + if (packable) { + size_t used = ZipArray({corr_output.data(), this_batch}, bit_width, + absl::MakeSpan(packed_corr_output)); + SPU_ENFORCE(used == CeilDiv(this_batch * bit_width, eltsize)); + io_->send_data(packed_corr_output.data(), used * sizeof(T)); + } else { + io_->send_data(corr_output.data(), sizeof(T) * this_batch); + } + } + } + template void RecvCorrelatedMsgChosenChoice(absl::Span choices, absl::Span output, int bit_width) { @@ -267,12 +278,16 @@ struct YaclFerretOT::Impl { yacl::AlignedVector rcm_output(n); RecvRandCorrelatedMsgChosenChoice(choices, absl::MakeSpan(rcm_output)); - size_t pack_load = 8 * sizeof(T) / bit_width; std::array pad; std::vector corr_output(kOTBatchSize); + + size_t eltsize = 8 * sizeof(T); + bool packable = eltsize > (size_t)bit_width; + size_t packed_size = CeilDiv(kOTBatchSize * bit_width, eltsize); + std::vector packed_corr_output; - if (pack_load > 1) { - packed_corr_output.resize(CeilDiv(corr_output.size(), pack_load)); + if (packable) { + packed_corr_output.resize(packed_size); } for (size_t i = 0; i < n; i += kOTBatchSize) { @@ -283,8 +298,8 @@ struct YaclFerretOT::Impl { // Use CrHash yc::ParaCrHashInplace_128(absl::MakeSpan(pad)); - if (pack_load > 1) { - size_t used = CeilDiv(this_batch, pack_load); + if (packable) { + size_t used = CeilDiv(this_batch * bit_width, eltsize); io_->recv_data(packed_corr_output.data(), sizeof(T) * used); UnzipArray({packed_corr_output.data(), used}, bit_width, {corr_output.data(), this_batch}); @@ -375,56 +390,6 @@ struct YaclFerretOT::Impl { } } - template - void SendCorrelatedMsgChosenChoice(absl::Span corr, - absl::Span output, int bit_width) { - size_t n = corr.size(); - SPU_ENFORCE_EQ(n, output.size()); - if (bit_width == 0) { - bit_width = 8 * sizeof(T); - } - SPU_ENFORCE(bit_width > 0 && bit_width <= (int)(8 * sizeof(T)), - "bit_width={} out-of-range T={} bits", bit_width, - sizeof(T) * 8); - - yacl::AlignedVector rcm_output(n); - - SendRandCorrelatedMsgChosenChoice(rcm_output.data(), n); - - size_t pack_load = 8 * sizeof(T) / bit_width; - std::array pad; - std::vector corr_output(kOTBatchSize); - std::vector packed_corr_output; - if (pack_load > 1) { - packed_corr_output.resize(CeilDiv(corr_output.size(), pack_load)); - } - - for (size_t i = 0; i < n; i += kOTBatchSize) { - size_t this_batch = std::min(kOTBatchSize, n - i); - for (size_t j = 0; j < this_batch; ++j) { - pad[2 * j] = rcm_output[i + j]; - pad[2 * j + 1] = rcm_output[i + j] ^ ferret_->GetDelta(); - } - - yc::ParaCrHashInplace_128(absl::MakeSpan(pad)); - - for (size_t j = 0; j < this_batch; ++j) { - output[i + j] = (T)(pad[2 * j]); - corr_output[j] = (T)(pad[2 * j + 1]); - corr_output[j] += corr[i + j] + output[i + j]; - } - - if (pack_load > 1) { - size_t used = ZipArray({corr_output.data(), this_batch}, bit_width, - absl::MakeSpan(packed_corr_output)); - SPU_ENFORCE(used == CeilDiv(this_batch, pack_load)); - io_->send_data(packed_corr_output.data(), used * sizeof(T)); - } else { - io_->send_data(corr_output.data(), sizeof(T) * this_batch); - } - } - } - template void SendRandMsgRandChoice(absl::Span output0, absl::Span output1, size_t bit_width = 0) { @@ -499,12 +464,15 @@ struct YaclFerretOT::Impl { yacl::AlignedVector pad(kOTBatchSize * N); const T msg_mask = makeBitsMask(bit_width); - size_t pack_load = 8 * sizeof(T) / bit_width; + size_t eltsize = 8 * sizeof(T); + bool packable = eltsize > (size_t)bit_width; + std::vector to_send(kOTBatchSize * N); + size_t packed_size = CeilDiv(kOTBatchSize * N * bit_width, eltsize); std::vector packed_to_send; - if (pack_load > 1) { + if (packable) { // NOTE: pack bit chunks into single T element if possible - packed_to_send.resize(CeilDiv(to_send.size(), pack_load)); + packed_to_send.resize(packed_size); } for (size_t i = 0; i < n; i += kOTBatchSize) { @@ -543,10 +511,10 @@ struct YaclFerretOT::Impl { } } - if (pack_load > 1) { + if (packable) { size_t used = ZipArray({to_send.data(), N * this_batch}, bit_width, absl::MakeSpan(packed_to_send)); - SPU_ENFORCE(used == CeilDiv(N * this_batch, pack_load)); + SPU_ENFORCE(used == CeilDiv(N * this_batch * bit_width, eltsize)); io_->send_data(packed_to_send.data(), used * sizeof(T)); } else { io_->send_data(to_send.data(), N * this_batch * sizeof(T)); @@ -588,16 +556,26 @@ struct YaclFerretOT::Impl { yacl::AlignedVector pad(kOTBatchSize); const T msg_mask = makeBitsMask(bit_width); - const size_t pack_load = 8 * sizeof(T) / bit_width; + size_t eltsize = 8 * sizeof(T); + bool packable = eltsize > (size_t)bit_width; + size_t packed_size = CeilDiv(kOTBatchSize * N * bit_width, eltsize); + std::vector recv(kOTBatchSize * N); - std::vector packed_recv(CeilDiv(recv.size(), pack_load)); + std::vector packed_recv; + if (packable) { + packed_recv.resize(packed_size); + } for (size_t i = 0; i < n; i += kOTBatchSize) { size_t this_batch = std::min(kOTBatchSize, n - i); - size_t used = CeilDiv(N * this_batch, pack_load); - io_->recv_data(packed_recv.data(), used * sizeof(T)); - UnzipArray({packed_recv.data(), used}, bit_width, - {recv.data(), N * this_batch}); + size_t used = CeilDiv(N * this_batch * bit_width, eltsize); + if (packable) { + io_->recv_data(packed_recv.data(), used * sizeof(T)); + UnzipArray({packed_recv.data(), used}, bit_width, + {recv.data(), N * this_batch}); + } else { + io_->recv_data(recv.data(), N * this_batch * sizeof(T)); + } std::memset(pad.data(), 0, kOTBatchSize * sizeof(uint128_t)); for (size_t j = 0; j < this_batch; ++j) { @@ -722,19 +700,16 @@ struct YaclFerretOT::Impl { } }; -YaclFerretOT::YaclFerretOT(std::shared_ptr conn, bool is_sender, +YaclFerretOt::YaclFerretOt(std::shared_ptr conn, bool is_sender, bool malicious) { impl_ = std::make_shared(conn, is_sender, malicious); } -int YaclFerretOT::Rank() const { return impl_->Rank(); } +int YaclFerretOt::Rank() const { return impl_->Rank(); } -void YaclFerretOT::Flush() { impl_->Flush(); } +void YaclFerretOt::Flush() { impl_->Flush(); } -YaclFerretOT::~YaclFerretOT() { - impl_->Flush(); - // SPDLOG_INFO(fmt::format("Party {} - Destroying YaclFerretOT", (Rank()))); -} +YaclFerretOt::~YaclFerretOt() { impl_->Flush(); } template size_t CheckBitWidth(size_t bw) { @@ -747,40 +722,40 @@ size_t CheckBitWidth(size_t bw) { } #define DEF_SEND_RECV(T) \ - void YaclFerretOT::SendCAMCC(absl::Span corr, absl::Span output, \ + void YaclFerretOt::SendCAMCC(absl::Span corr, absl::Span output, \ int bw) { \ impl_->SendCorrelatedMsgChosenChoice(corr, output, bw); \ } \ - void YaclFerretOT::RecvCAMCC(absl::Span choices, \ + void YaclFerretOt::RecvCAMCC(absl::Span choices, \ absl::Span output, int bw) { \ impl_->RecvCorrelatedMsgChosenChoice(choices, output, bw); \ } \ - void YaclFerretOT::SendRMRC(absl::Span output0, absl::Span output1, \ + void YaclFerretOt::SendRMRC(absl::Span output0, absl::Span output1, \ size_t bit_width) { \ bit_width = CheckBitWidth(bit_width); \ impl_->SendRandMsgRandChoice(output0, output1, bit_width); \ } \ - void YaclFerretOT::RecvRMRC(absl::Span choices, \ + void YaclFerretOt::RecvRMRC(absl::Span choices, \ absl::Span output, size_t bit_width) { \ bit_width = CheckBitWidth(bit_width); \ impl_->RecvRandMsgRandChoice(choices, output, bit_width); \ } \ - void YaclFerretOT::SendCMCC(absl::Span msg_array, size_t N, \ + void YaclFerretOt::SendCMCC(absl::Span msg_array, size_t N, \ size_t bit_width) { \ bit_width = CheckBitWidth(bit_width); \ impl_->SendChosenMsgChosenChoice(msg_array, N, bit_width); \ } \ - void YaclFerretOT::RecvCMCC(absl::Span choices, size_t N, \ + void YaclFerretOt::RecvCMCC(absl::Span choices, size_t N, \ absl::Span output, size_t bit_width) { \ bit_width = CheckBitWidth(bit_width); \ impl_->RecvChosenMsgChosenChoice(choices, N, output, bit_width); \ } \ - void YaclFerretOT::SendRMCC(absl::Span output0, absl::Span output1, \ + void YaclFerretOt::SendRMCC(absl::Span output0, absl::Span output1, \ size_t bit_width) { \ bit_width = CheckBitWidth(bit_width); \ impl_->SendRMCC(output0, output1, bit_width); \ } \ - void YaclFerretOT::RecvRMCC(absl::Span choices, \ + void YaclFerretOt::RecvRMCC(absl::Span choices, \ absl::Span output, size_t bit_width) { \ bit_width = CheckBitWidth(bit_width); \ impl_->RecvRMCC(choices, output, bit_width); \ diff --git a/libspu/mpc/cheetah/yacl_ot/yacl_ferret.h b/libspu/mpc/cheetah/ot/yacl/ferret.h similarity index 75% rename from libspu/mpc/cheetah/yacl_ot/yacl_ferret.h rename to libspu/mpc/cheetah/ot/yacl/ferret.h index f571cfee..6b0daf1d 100644 --- a/libspu/mpc/cheetah/yacl_ot/yacl_ferret.h +++ b/libspu/mpc/cheetah/ot/yacl/ferret.h @@ -19,104 +19,105 @@ #include "absl/types/span.h" #include "yacl/base/int128.h" +#include "libspu/mpc/cheetah/ot/ferret_ot_interface.h" #include "libspu/mpc/common/communicator.h" namespace spu::mpc::cheetah { -class YaclFerretOT { +class YaclFerretOt : public spu::mpc::cheetah::FerretOtInterface { private: struct Impl; std::shared_ptr impl_; public: - YaclFerretOT(std::shared_ptr conn, bool is_sender, + YaclFerretOt(std::shared_ptr conn, bool is_sender, bool malicious = false); - ~YaclFerretOT(); + ~YaclFerretOt(); - int Rank() const; + int Rank() const override; - void Flush(); + void Flush() override; // One-of-N OT where msg_array is a Nxn array. // choice \in [0, N-1] void SendCMCC(absl::Span msg_array, size_t N, - size_t bit_width = 0); + size_t bit_width = 0) override; void SendCMCC(absl::Span msg_array, size_t N, - size_t bit_width = 0); + size_t bit_width = 0) override; void SendCMCC(absl::Span msg_array, size_t N, - size_t bit_width = 0); + size_t bit_width = 0) override; void SendCMCC(absl::Span msg_array, size_t N, - size_t bit_width = 0); + size_t bit_width = 0) override; void RecvCMCC(absl::Span one_oo_N_choices, size_t N, - absl::Span output, size_t bit_width = 0); + absl::Span output, size_t bit_width = 0) override; void RecvCMCC(absl::Span one_oo_N_choices, size_t N, - absl::Span output, size_t bit_width = 0); + absl::Span output, size_t bit_width = 0) override; void RecvCMCC(absl::Span one_oo_N_choices, size_t N, - absl::Span output, size_t bit_width = 0); + absl::Span output, size_t bit_width = 0) override; void RecvCMCC(absl::Span one_oo_N_choices, size_t N, - absl::Span output, size_t bit_width = 0); + absl::Span output, size_t bit_width = 0) override; // Random Message Random Choice void SendRMRC(absl::Span output0, absl::Span output1, - size_t bit_width = 0); + size_t bit_width = 0) override; void SendRMRC(absl::Span output0, absl::Span output1, - size_t bit_width = 0); + size_t bit_width = 0) override; void SendRMRC(absl::Span output0, absl::Span output1, - size_t bit_width = 0); + size_t bit_width = 0) override; void SendRMRC(absl::Span output0, absl::Span output1, - size_t bit_width = 0); + size_t bit_width = 0) override; void RecvRMRC(absl::Span binary_choices, absl::Span output, - size_t bit_width = 0); + size_t bit_width = 0) override; void RecvRMRC(absl::Span binary_choices, absl::Span output, - size_t bit_width = 0); + size_t bit_width = 0) override; void RecvRMRC(absl::Span binary_choices, absl::Span output, - size_t bit_width = 0); + size_t bit_width = 0) override; void RecvRMRC(absl::Span binary_choices, - absl::Span output, size_t bit_width = 0); + absl::Span output, size_t bit_width = 0) override; // correlated additive message, chosen choice // (x, x + corr * choice) <- (corr, choice) // Can use bit_width to further indicate output ring. `bit_width = 0` means to // use the full range. void SendCAMCC(absl::Span corr, absl::Span output, - int bit_width = 0); + int bit_width = 0) override; void SendCAMCC(absl::Span corr, absl::Span output, - int bit_width = 0); + int bit_width = 0) override; void SendCAMCC(absl::Span corr, absl::Span output, - int bit_width = 0); + int bit_width = 0) override; void SendCAMCC(absl::Span corr, absl::Span output, - int bit_width = 0); + int bit_width = 0) override; void RecvCAMCC(absl::Span binary_choices, - absl::Span output, int bit_width = 0); + absl::Span output, int bit_width = 0) override; void RecvCAMCC(absl::Span binary_choices, - absl::Span output, int bit_width = 0); + absl::Span output, int bit_width = 0) override; void RecvCAMCC(absl::Span binary_choices, - absl::Span output, int bit_width = 0); + absl::Span output, int bit_width = 0) override; void RecvCAMCC(absl::Span binary_choices, - absl::Span output, int bit_width = 0); + absl::Span output, int bit_width = 0) override; // Random Message Chosen Choice void SendRMCC(absl::Span output0, absl::Span output1, - size_t bit_width = 0); + size_t bit_width = 0) override; void SendRMCC(absl::Span output0, absl::Span output1, - size_t bit_width = 0); + size_t bit_width = 0) override; void SendRMCC(absl::Span output0, absl::Span output1, - size_t bit_width = 0); + size_t bit_width = 0) override; void SendRMCC(absl::Span output0, absl::Span output1, - size_t bit_width = 0); + size_t bit_width = 0) override; void RecvRMCC(absl::Span binary_choices, - absl::Span output, size_t bit_width = 0); + absl::Span output, size_t bit_width = 0) override; void RecvRMCC(absl::Span binary_choices, - absl::Span output, size_t bit_width = 0); + absl::Span output, size_t bit_width = 0) override; void RecvRMCC(absl::Span binary_choices, - absl::Span output, size_t bit_width = 0); + absl::Span output, size_t bit_width = 0) override; void RecvRMCC(absl::Span binary_choices, - absl::Span output, size_t bit_width = 0); + absl::Span output, size_t bit_width = 0) override; }; } // namespace spu::mpc::cheetah diff --git a/libspu/mpc/cheetah/ot/ferret_test.cc b/libspu/mpc/cheetah/ot/yacl/ferret_test.cc similarity index 84% rename from libspu/mpc/cheetah/ot/ferret_test.cc rename to libspu/mpc/cheetah/ot/yacl/ferret_test.cc index d5301f25..ac805d9d 100644 --- a/libspu/mpc/cheetah/ot/ferret_test.cc +++ b/libspu/mpc/cheetah/ot/yacl/ferret_test.cc @@ -1,4 +1,4 @@ -// Copyright 2022 Ant Group Co., Ltd. +// Copyright 2023 Ant Group Co., Ltd. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,14 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "libspu/mpc/cheetah/ot/ferret.h" +#include "libspu/mpc/cheetah/ot/yacl/ferret.h" #include #include "gtest/gtest.h" -#include "libspu/core/xt_helper.h" -#include "libspu/mpc/cheetah/ot/basic_ot_prot.h" +#include "libspu/mpc/semi2k/type.h" #include "libspu/mpc/utils/ring_ops.h" #include "libspu/mpc/utils/simulate.h" @@ -34,6 +33,11 @@ INSTANTIATE_TEST_SUITE_P( return fmt::format("{}", p.param); }); +template +absl::Span makeConstSpan(NdArrayView a) { + return {&a[0], (size_t)a.numel()}; +} + TEST_P(FerretCOTTest, ChosenCorrelationChosenChoice) { size_t kWorldSize = 2; int64_t n = 10; @@ -48,15 +52,15 @@ TEST_P(FerretCOTTest, ChosenCorrelationChosenChoice) { }); DISPATCH_ALL_FIELDS(field, "", [&]() { - auto correlation = xt_adapt(_correlation); + NdArrayView correlation(_correlation); std::vector computed[2]; utils::simulate(kWorldSize, [&](std::shared_ptr ctx) { auto conn = std::make_shared(ctx); int rank = ctx->Rank(); computed[rank].resize(n); - FerretOT ferret(conn, rank == 0); + YaclFerretOt ferret(conn, rank == 0); if (rank == 0) { - ferret.SendCAMCC({correlation.data(), correlation.size()}, + ferret.SendCAMCC(makeConstSpan(correlation), absl::MakeSpan(computed[0])); ferret.Flush(); } else { @@ -89,7 +93,7 @@ TEST_P(FerretCOTTest, RndMsgRndChoice) { utils::simulate(kWorldSize, [&](std::shared_ptr ctx) { auto conn = std::make_shared(ctx); int rank = ctx->Rank(); - FerretOT ferret(conn, rank == 0); + YaclFerretOt ferret(conn, rank == 0); if (rank == 0) { ferret.SendRMRC(absl::MakeSpan(msg0), absl::MakeSpan(msg1), bw); ferret.Flush(); @@ -132,7 +136,7 @@ TEST_P(FerretCOTTest, RndMsgChosenChoice) { utils::simulate(kWorldSize, [&](std::shared_ptr ctx) { auto conn = std::make_shared(ctx); int rank = ctx->Rank(); - FerretOT ferret(conn, rank == 0); + YaclFerretOt ferret(conn, rank == 0); if (rank == 0) { ferret.SendRMCC(absl::MakeSpan(msg0), absl::MakeSpan(msg1), bw); ferret.Flush(); @@ -153,20 +157,20 @@ TEST_P(FerretCOTTest, RndMsgChosenChoice) { TEST_P(FerretCOTTest, ChosenMsgChosenChoice) { size_t kWorldSize = 2; - int64_t n = 100; + int64_t n = 1 << 18; auto field = GetParam(); DISPATCH_ALL_FIELDS(field, "", [&]() { using scalar_t = ring2k_t; std::default_random_engine rdv; std::uniform_int_distribution uniform(0, -1); - for (size_t bw : {2UL, 4UL, sizeof(scalar_t) * 8}) { - scalar_t mask = (static_cast(1) << bw) - 1; - for (int64_t N : {2, 3, 8}) { + for (int64_t N : {2, 4, 8}) { + for (size_t bw : {4UL, 8UL, 32UL}) { + scalar_t mask = (static_cast(1) << bw) - 1; auto _msg = ring_rand(field, {N * n}); - auto msg = xt_mutable_adapt(_msg); - msg &= mask; - std::vector choices(n); + NdArrayView msg(_msg); + pforeach(0, msg.numel(), [&](int64_t i) { msg[i] &= mask; }); + std::vector choices(n); std::generate_n(choices.begin(), n, [&]() -> uint8_t { return static_cast(uniform(rdv) % N); }); @@ -177,14 +181,16 @@ TEST_P(FerretCOTTest, ChosenMsgChosenChoice) { [&](std::shared_ptr ctx) { auto conn = std::make_shared(ctx); int rank = ctx->Rank(); - FerretOT ferret(conn, rank == 0); + YaclFerretOt ferret(conn, rank == 0); + size_t sent = ctx->GetStats()->sent_bytes; if (rank == 0) { - ferret.SendCMCC({msg.data(), msg.size()}, N, bw); + ferret.SendCMCC(makeConstSpan(msg), N, bw); ferret.Flush(); } else { ferret.RecvCMCC(absl::MakeSpan(choices), N, absl::MakeSpan(selected), bw); } + sent = ctx->GetStats()->sent_bytes - sent; }); for (int64_t i = 0; i < n; ++i) { diff --git a/libspu/mpc/cheetah/yacl_ot/mitccrh_exp.h b/libspu/mpc/cheetah/ot/yacl/mitccrh_exp.h similarity index 99% rename from libspu/mpc/cheetah/yacl_ot/mitccrh_exp.h rename to libspu/mpc/cheetah/ot/yacl/mitccrh_exp.h index 2c37671f..84bfc28d 100644 --- a/libspu/mpc/cheetah/yacl_ot/mitccrh_exp.h +++ b/libspu/mpc/cheetah/ot/yacl/mitccrh_exp.h @@ -21,7 +21,7 @@ #include "yacl/crypto/base/aes/aes_opt.h" -#include "libspu/mpc/cheetah/yacl_ot/util.h" +#include "libspu/mpc/cheetah/ot/yacl/yacl_util.h" namespace spu::mpc::cheetah { diff --git a/libspu/mpc/cheetah/yacl_ot/yacl_ferret_test.cc b/libspu/mpc/cheetah/ot/yacl/yacl_ferret_test.cc similarity index 91% rename from libspu/mpc/cheetah/yacl_ot/yacl_ferret_test.cc rename to libspu/mpc/cheetah/ot/yacl/yacl_ferret_test.cc index 5f5e5030..a9298a69 100644 --- a/libspu/mpc/cheetah/yacl_ot/yacl_ferret_test.cc +++ b/libspu/mpc/cheetah/ot/yacl/yacl_ferret_test.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "libspu/mpc/cheetah/yacl_ot/yacl_ferret.h" +#include "libspu/mpc/cheetah/ot/yacl/yacl_ferret.h" #include @@ -26,16 +26,16 @@ namespace spu::mpc::cheetah::test { -class FerretCOTTest : public testing::TestWithParam {}; +class YaclFerretTest : public testing::TestWithParam {}; INSTANTIATE_TEST_SUITE_P( - Cheetah, FerretCOTTest, + Cheetah, YaclFerretTest, testing::Values(FieldType::FM32, FieldType::FM64, FieldType::FM128), [](const testing::TestParamInfo &p) { return fmt::format("{}", p.param); }); -TEST_P(FerretCOTTest, ChosenCorrelationChosenChoice) { +TEST_P(YaclFerretTest, ChosenCorrelationChosenChoice) { size_t kWorldSize = 2; int64_t n = 10; auto field = GetParam(); @@ -49,13 +49,13 @@ TEST_P(FerretCOTTest, ChosenCorrelationChosenChoice) { }); DISPATCH_ALL_FIELDS(field, "", [&]() { - auto correlation = xt_adapt(_correlation); + NdArrayView correlation(_correlation); std::vector computed[2]; utils::simulate(kWorldSize, [&](std::shared_ptr ctx) { auto conn = std::make_shared(ctx); int rank = ctx->Rank(); computed[rank].resize(n); - YaclFerretOT ferret(conn, rank == 0); + YaclFerretOt ferret(conn, rank == 0); if (rank == 0) { ferret.SendCAMCC({correlation.data(), correlation.size()}, absl::MakeSpan(computed[0])); @@ -73,7 +73,7 @@ TEST_P(FerretCOTTest, ChosenCorrelationChosenChoice) { }); } -TEST_P(FerretCOTTest, RndMsgRndChoice) { +TEST_P(YaclFerretTest, RndMsgRndChoice) { size_t kWorldSize = 2; auto field = GetParam(); constexpr size_t bw = 2; @@ -90,7 +90,7 @@ TEST_P(FerretCOTTest, RndMsgRndChoice) { utils::simulate(kWorldSize, [&](std::shared_ptr ctx) { auto conn = std::make_shared(ctx); int rank = ctx->Rank(); - YaclFerretOT ferret(conn, rank == 0); + YaclFerretOt ferret(conn, rank == 0); if (rank == 0) { ferret.SendRMRC(absl::MakeSpan(msg0), absl::MakeSpan(msg1), bw); ferret.Flush(); @@ -110,7 +110,7 @@ TEST_P(FerretCOTTest, RndMsgRndChoice) { }); } -TEST_P(FerretCOTTest, RndMsgChosenChoice) { +TEST_P(YaclFerretTest, RndMsgChosenChoice) { size_t kWorldSize = 2; auto field = GetParam(); constexpr size_t bw = 2; @@ -133,7 +133,7 @@ TEST_P(FerretCOTTest, RndMsgChosenChoice) { utils::simulate(kWorldSize, [&](std::shared_ptr ctx) { auto conn = std::make_shared(ctx); int rank = ctx->Rank(); - YaclFerretOT ferret(conn, rank == 0); + YaclFerretOt ferret(conn, rank == 0); if (rank == 0) { ferret.SendRMCC(absl::MakeSpan(msg0), absl::MakeSpan(msg1), bw); ferret.Flush(); @@ -152,7 +152,7 @@ TEST_P(FerretCOTTest, RndMsgChosenChoice) { }); } -TEST_P(FerretCOTTest, ChosenMsgChosenChoice) { +TEST_P(YaclFerretTest, ChosenMsgChosenChoice) { size_t kWorldSize = 2; int64_t n = 106; auto field = GetParam(); @@ -184,7 +184,7 @@ TEST_P(FerretCOTTest, ChosenMsgChosenChoice) { auto conn = std::make_shared(ctx); int rank = ctx->Rank(); { - YaclFerretOT ferret(conn, rank == 0); + YaclFerretOt ferret(conn, rank == 0); if (rank == 0) { ferret.SendCMCC({msg.data(), msg.size()}, N, bw); ferret.Flush(); diff --git a/libspu/mpc/cheetah/yacl_ot/yacl_ote_adapter.cc b/libspu/mpc/cheetah/ot/yacl/yacl_ote_adapter.cc similarity index 98% rename from libspu/mpc/cheetah/yacl_ot/yacl_ote_adapter.cc rename to libspu/mpc/cheetah/ot/yacl/yacl_ote_adapter.cc index cfb19990..19c3f2e2 100644 --- a/libspu/mpc/cheetah/yacl_ot/yacl_ote_adapter.cc +++ b/libspu/mpc/cheetah/ot/yacl/yacl_ote_adapter.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "libspu/mpc/cheetah/yacl_ot/yacl_ote_adapter.h" +#include "libspu/mpc/cheetah/ot/yacl/yacl_ote_adapter.h" namespace spu::mpc::cheetah { @@ -146,8 +146,7 @@ void YaclFerretOTeAdapter::rcot(absl::Span data) { // Require_num is greater then "buff_upper_bound_ - reserve_num_" // which means that an extra "Bootstrap" is needed. if (require_num > (buff_upper_bound_ - reserve_num_)) { - SPDLOG_INFO("[YACL] Worst Case Occured!!! current require_num {}", - require_num); + SPDLOG_WARN("[YACL] Worst Case!!! current require_num {}", require_num); // Bootstrap would reset buff_used_num_ memcpy(data.data() + data_offset, ot_buff_.data() + buff_used_num_, (buff_upper_bound_ - reserve_num_) * sizeof(uint128_t)); diff --git a/libspu/mpc/cheetah/yacl_ot/yacl_ote_adapter.h b/libspu/mpc/cheetah/ot/yacl/yacl_ote_adapter.h similarity index 98% rename from libspu/mpc/cheetah/yacl_ot/yacl_ote_adapter.h rename to libspu/mpc/cheetah/ot/yacl/yacl_ote_adapter.h index cee80cc7..1904a9c1 100644 --- a/libspu/mpc/cheetah/yacl_ot/yacl_ote_adapter.h +++ b/libspu/mpc/cheetah/ot/yacl/yacl_ote_adapter.h @@ -22,7 +22,8 @@ #include "yacl/crypto/utils/rand.h" #include "libspu/core/prelude.h" -#include "libspu/mpc/cheetah/yacl_ot/util.h" +#include "libspu/mpc/cheetah/ot/ot_util.h" +#include "libspu/mpc/cheetah/ot/yacl/yacl_util.h" namespace spu::mpc::cheetah { @@ -59,7 +60,7 @@ class YaclFerretOTeAdapter : public YaclOTeAdapter { } ~YaclFerretOTeAdapter() { - SPDLOG_INFO( + SPDLOG_DEBUG( "[FerretAdapter {}]({}), comsume OT {}, total time {:.3e} ms, " "invoke bootstrap {} ( {:.2e} ms per bootstrap, {:.2e} ms per ot )", id_, (is_sender_ ? fmt::format("Sender") : fmt::format("Receiver")), diff --git a/libspu/mpc/cheetah/yacl_ot/util.cc b/libspu/mpc/cheetah/ot/yacl/yacl_util.cc similarity index 96% rename from libspu/mpc/cheetah/yacl_ot/util.cc rename to libspu/mpc/cheetah/ot/yacl/yacl_util.cc index 820c3841..74f57185 100644 --- a/libspu/mpc/cheetah/yacl_ot/util.cc +++ b/libspu/mpc/cheetah/ot/yacl/yacl_util.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "libspu/mpc/cheetah/yacl_ot/util.h" +#include "libspu/mpc/cheetah/ot/yacl/yacl_util.h" #include diff --git a/libspu/mpc/cheetah/ot/yacl/yacl_util.h b/libspu/mpc/cheetah/ot/yacl/yacl_util.h new file mode 100644 index 00000000..55987aae --- /dev/null +++ b/libspu/mpc/cheetah/ot/yacl/yacl_util.h @@ -0,0 +1,60 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "absl/types/span.h" +#include "yacl/base/dynamic_bitset.h" +#include "yacl/base/int128.h" + +#include "libspu/core/prelude.h" + +namespace spu::mpc::cheetah { + +// Add by @wenfan +inline void VecU8toBitset(absl::Span bits, + yacl::dynamic_bitset& bitset) { + SPU_ENFORCE(bits.size() == bitset.size()); + uint64_t bits_num = bits.size(); + // low efficiency + for (uint64_t i = 0; i < bits_num; ++i) { + bitset[i] = (bool)bits[i]; + } +} + +inline yacl::dynamic_bitset VecU8toBitset( + absl::Span bits) { + yacl::dynamic_bitset bitset(bits.size()); + VecU8toBitset(bits, bitset); + return bitset; +} + +inline void BitsettoVecU8(const yacl::dynamic_bitset& bitset, + absl::Span bits) { + SPU_ENFORCE(bits.size() == bitset.size()); + uint64_t bits_num = bitset.size(); + // low efficiency + for (uint64_t i = 0; i < bits_num; ++i) { + bits[i] = bitset[i]; + } +} + +inline std::vector BitsettoVecU8( + const yacl::dynamic_bitset& bitset) { + std::vector bits(bitset.size()); + BitsettoVecU8(bitset, absl::MakeSpan(bits)); + return bits; +} + +} // namespace spu::mpc::cheetah diff --git a/libspu/mpc/cheetah/state.h b/libspu/mpc/cheetah/state.h index 4d2691e0..d4035e16 100644 --- a/libspu/mpc/cheetah/state.h +++ b/libspu/mpc/cheetah/state.h @@ -20,7 +20,7 @@ #include "libspu/core/object.h" #include "libspu/mpc/cheetah/arith/cheetah_dot.h" #include "libspu/mpc/cheetah/arith/cheetah_mul.h" -#include "libspu/mpc/cheetah/yacl_ot/basic_ot_prot.h" +#include "libspu/mpc/cheetah/ot/basic_ot_prot.h" namespace spu::mpc::cheetah { diff --git a/libspu/mpc/cheetah/yacl_ot/basic_ot_prot.cc b/libspu/mpc/cheetah/yacl_ot/basic_ot_prot.cc deleted file mode 100644 index 4f41a1e4..00000000 --- a/libspu/mpc/cheetah/yacl_ot/basic_ot_prot.cc +++ /dev/null @@ -1,470 +0,0 @@ -// Copyright 2023 Ant Group Co., Ltd. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "libspu/mpc/cheetah/yacl_ot/basic_ot_prot.h" - -#include "libspu/mpc/cheetah/type.h" -#include "libspu/mpc/cheetah/yacl_ot/util.h" -#include "libspu/mpc/common/communicator.h" -#include "libspu/mpc/utils/ring_ops.h" - -namespace spu::mpc::cheetah { - -BasicOTProtocols::BasicOTProtocols(std::shared_ptr conn) - : conn_(std::move(conn)) { - SPU_ENFORCE(conn_ != nullptr); - if (conn_->getRank() == 0) { - ferret_sender_ = std::make_shared(conn_, true); - ferret_receiver_ = std::make_shared(conn_, false); - } else { - ferret_receiver_ = std::make_shared(conn_, false); - ferret_sender_ = std::make_shared(conn_, true); - } -} - -BasicOTProtocols::~BasicOTProtocols() { Flush(); } - -void BasicOTProtocols::Flush() { - if (ferret_sender_) { - ferret_sender_->Flush(); - } -} - -NdArrayRef BasicOTProtocols::B2A(const NdArrayRef &inp) { - const auto *share_t = inp.eltype().as(); - if (share_t->nbits() == 1) { - return SingleB2A(inp); - } - return PackedB2A(inp); -} - -NdArrayRef BasicOTProtocols::PackedB2A(const NdArrayRef &inp) { - const auto *share_t = inp.eltype().as(); - auto field = inp.eltype().as()->field(); - const size_t nbits = share_t->nbits(); - SPU_ENFORCE(nbits > 0 && nbits <= 8 * SizeOf(field)); - - auto convert_from_bits_form = [&](NdArrayRef _bits) { - SPU_ENFORCE(_bits.isCompact(), "need compact input"); - const int64_t n = _bits.numel() / nbits; - // init as all 0s. - auto iform = ring_zeros(field, inp.shape()); - DISPATCH_ALL_FIELDS(field, "conv_to_bits", [&]() { - auto bits = NdArrayView(_bits); - auto digit = NdArrayView(iform); - for (int64_t i = 0; i < n; ++i) { - // LSB is bits[0]; MSB is bits[nbits - 1] - // We iterate the bits in reversed order - const size_t offset = i * nbits; - digit[i] = 0; - for (size_t j = nbits; j > 0; --j) { - digit[i] = (digit[i] << 1) | (bits[offset + j - 1] & 1); - } - } - }); - return iform; - }; - - const int64_t n = inp.numel(); - auto rand_bits = RandBits(field, Shape{n * static_cast(nbits)}); - auto rand = convert_from_bits_form(rand_bits); - - // open c = x ^ r - // FIXME(juhou): Actually, we only want to exchange the low-end bits. - auto opened = - conn_->allReduce(ReduceOp::XOR, ring_xor(inp, rand), "B2AFull_open"); - - // compute c + (1 - 2*c)* - NdArrayRef oup = ring_zeros(field, inp.shape()); - DISPATCH_ALL_FIELDS(field, "packed_b2a", [&]() { - using u2k = std::make_unsigned::type; - int rank = Rank(); - auto xr = NdArrayView(rand_bits); - auto xc = NdArrayView(opened); - auto xo = NdArrayView(oup); - - for (int64_t i = 0; i < n; ++i) { - const size_t offset = i * nbits; - u2k this_elt = xc[i]; - for (size_t j = 0; j < nbits; ++j, this_elt >>= 1) { - u2k c_ij = this_elt & 1; - ring2k_t one_bit = (1 - c_ij * 2) * xr[offset + j]; - if (rank == 0) { - one_bit += c_ij; - } - xo[i] += (one_bit << j); - } - } - }); - return oup; -} - -// Math: -// b0^b1 = b0 + b1 - 2*b0*b1 -// Sender set corr = -2*b0 -// Recv set choice b1 -// Sender gets x0 -// Recv gets x1 = x0 + corr*b1 = x0 - 2*b0*b1 -// -// b0 - x0 + b1 + x1 -// = b0 - x0 + b1 + x0 - 2*b0*b1 -NdArrayRef BasicOTProtocols::SingleB2A(const NdArrayRef &inp, int bit_width) { - const auto *share_t = inp.eltype().as(); - SPU_ENFORCE_EQ(share_t->nbits(), 1UL); - auto field = inp.eltype().as()->field(); - if (bit_width == 0) { - bit_width = SizeOf(field) * 8; - } - const int64_t n = inp.numel(); - - NdArrayRef oup = ring_zeros(field, inp.shape()); - DISPATCH_ALL_FIELDS(field, "single_b2a", [&]() { - using u2k = std::make_unsigned::type; - auto input = NdArrayView(inp); - // NOTE(lwj): oup is compact, so we just use Span - auto output = absl::MakeSpan(&oup.at(0), n); - - SPU_ENFORCE(oup.isCompact()); - - if (Rank() == 0) { - std::vector corr_data(n); - // NOTE(lwj): Masking to make sure there is only single bit. - for (int64_t i = 0; i < n; ++i) { - // corr=-2*xi - corr_data[i] = -((input[i] & 1) << 1); - } - ferret_sender_->SendCAMCC(absl::MakeSpan(corr_data), output, bit_width); - ferret_sender_->Flush(); - - for (int64_t i = 0; i < n; ++i) { - output[i] = (input[i] & 1) - output[i]; - } - } else { - std::vector choices(n); - for (int64_t i = 0; i < n; ++i) { - choices[i] = static_cast(input[i] & 1); - } - ferret_receiver_->RecvCAMCC(absl::MakeSpan(choices), output, bit_width); - - for (int64_t i = 0; i < n; ++i) { - output[i] = (input[i] & 1) + output[i]; - } - } - }); - return oup; -} - -// Random bit r \in {0, 1} and return as AShr -NdArrayRef BasicOTProtocols::RandBits(FieldType filed, const Shape &shape) { - // TODO(juhou): profile ring_randbit performance - auto r = ring_randbit(filed, shape).as(makeType(filed, 1)); - return SingleB2A(r); -} - -NdArrayRef BasicOTProtocols::B2ASingleBitWithSize(const NdArrayRef &inp, - int bit_width) { - const auto *share_t = inp.eltype().as(); - SPU_ENFORCE(share_t->nbits() == 1, "Support for 1bit boolean only"); - auto field = inp.eltype().as()->field(); - SPU_ENFORCE(bit_width > 1 && bit_width < (int)(8 * SizeOf(field)), - "bit_width={} is invalid", bit_width); - return SingleB2A(inp, bit_width); -} - -NdArrayRef BasicOTProtocols::BitwiseAnd(const NdArrayRef &lhs, - const NdArrayRef &rhs) { - SPU_ENFORCE_EQ(lhs.shape(), rhs.shape()); - - auto field = lhs.eltype().as()->field(); - const auto *shareType = lhs.eltype().as(); - size_t numel = lhs.numel(); - auto [a, b, c] = AndTriple(field, lhs.shape(), shareType->nbits()); - - NdArrayRef x_a = ring_xor(lhs, a); - NdArrayRef y_b = ring_xor(rhs, b); - size_t pack_load = 8 * SizeOf(field) / shareType->nbits(); - - if (pack_load == 1) { - // Open x^a, y^b - auto res = vmap({x_a, y_b}, [&](const NdArrayRef &s) { - return conn_->allReduce(ReduceOp::XOR, s, "BitwiseAnd"); - }); - x_a = std::move(res[0]); - y_b = std::move(res[1]); - } else { - // Open x^a, y^b - // pack multiple nbits() into single field element before sending through - // network - SPU_ENFORCE(x_a.isCompact() && y_b.isCompact()); - int64_t packed_sze = CeilDiv(numel, pack_load); - - NdArrayRef packed_xa(x_a.eltype(), {packed_sze}); - NdArrayRef packed_yb(y_b.eltype(), {packed_sze}); - - DISPATCH_ALL_FIELDS(field, "_", [&]() { - auto xa_wrap = absl::MakeSpan(&x_a.at(0), numel); - auto yb_wrap = absl::MakeSpan(&y_b.at(0), numel); - auto packed_xa_wrap = - absl::MakeSpan(&packed_xa.at(0), packed_sze); - auto packed_yb_wrap = - absl::MakeSpan(&packed_yb.at(0), packed_sze); - - int64_t used = - ZipArray(xa_wrap, shareType->nbits(), packed_xa_wrap); - (void)ZipArray(yb_wrap, shareType->nbits(), packed_yb_wrap); - SPU_ENFORCE_EQ(used, packed_sze); - - // open x^a, y^b - auto res = vmap({packed_xa, packed_yb}, [&](const NdArrayRef &s) { - return conn_->allReduce(ReduceOp::XOR, s, "BitwiseAnd"); - }); - - packed_xa = std::move(res[0]); - packed_yb = std::move(res[1]); - packed_xa_wrap = absl::MakeSpan(&packed_xa.at(0), packed_sze); - packed_yb_wrap = absl::MakeSpan(&packed_yb.at(0), packed_sze); - UnzipArray(packed_xa_wrap, shareType->nbits(), xa_wrap); - UnzipArray(packed_yb_wrap, shareType->nbits(), yb_wrap); - }); - } - - // Zi = Ci ^ ((X ^ A) & Bi) ^ ((Y ^ B) & Ai) ^ <(X ^ A) & (Y ^ B)> - auto z = ring_xor(ring_xor(ring_and(x_a, b), ring_and(y_b, a)), c); - if (conn_->getRank() == 0) { - ring_xor_(z, ring_and(x_a, y_b)); - } - - return z.as(lhs.eltype()); -} - -std::array BasicOTProtocols::CorrelatedBitwiseAnd( - const NdArrayRef &lhs, const NdArrayRef &rhs0, const NdArrayRef &rhs1) { - SPU_ENFORCE_EQ(lhs.shape(), rhs0.shape()); - SPU_ENFORCE(lhs.eltype() == rhs0.eltype()); - SPU_ENFORCE_EQ(lhs.shape(), rhs1.shape()); - SPU_ENFORCE(lhs.eltype() == rhs1.eltype()); - - auto field = lhs.eltype().as()->field(); - const auto *shareType = lhs.eltype().as(); - SPU_ENFORCE_EQ(shareType->nbits(), 1UL); - auto [a, b0, c0, b1, c1] = CorrelatedAndTriple(field, lhs.shape()); - - // open x^a, y^b0, y1^b1 - auto res = - vmap({ring_xor(lhs, a), ring_xor(rhs0, b0), ring_xor(rhs1, b1)}, - [&](const NdArrayRef &s) { - return conn_->allReduce(ReduceOp::XOR, s, "CorrelatedBitwiseAnd"); - }); - auto xa = std::move(res[0]); - auto y0b0 = std::move(res[1]); - auto y1b1 = std::move(res[2]); - - // Zi = Ci ^ ((X ^ A) & Bi) ^ ((Y ^ B) & Ai) ^ <(X ^ A) & (Y ^ B)> - auto z0 = ring_xor(ring_xor(ring_and(xa, b0), ring_and(y0b0, a)), c0); - auto z1 = ring_xor(ring_xor(ring_and(xa, b1), ring_and(y1b1, a)), c1); - if (conn_->getRank() == 0) { - ring_xor_(z0, ring_and(xa, y0b0)); - ring_xor_(z1, ring_and(xa, y1b1)); - } - - return {z0.as(lhs.eltype()), z1.as(lhs.eltype())}; -} - -// Ref: https://eprint.iacr.org/2013/552.pdf -// Algorithm 1. AND triple using 1-of-2 ROT. -// Math -// ROT sender obtains x_0, x_1 -// ROT recevier obtains x_a, a for a \in {0, 1} -// -// Sender set (b = x0 ^ x1, v = x0) -// Recevier set (a, u = x_a) -// a & b = a & (x0 ^ x1) -// = a & (x0 ^ x1) ^ (x0 ^ x0) <- zero m0 ^ m0 -// = (a & (x0 ^ x1) ^ x0) ^ x0 -// = (x_a) ^ x0 -// = u ^ v -// -// P0 acts as S to obtain (a0, u0) -// P1 acts as R to obtain (b1, v1) -// such that a0 & b1 = u0 ^ v1 -// -// Flip the role -// P1 obtains (a1, u1) -// P0 obtains (b0, v0) -// such that a1 & b0 = u1 ^ v0 -// -// Pi sets ci = ai & bi ^ ui ^ vi -// such that (a0 ^ a1) & (b0 ^ b1) = (c0 ^ c1) -std::array BasicOTProtocols::AndTriple(FieldType field, - const Shape &shape, - size_t nbits_each) { - int64_t numel = shape.numel(); - SPU_ENFORCE(numel > 0); - SPU_ENFORCE(nbits_each >= 1 && nbits_each <= SizeOf(field) * 8, - "invalid packing load {} for one AND", nbits_each); - - // NOTE(juhou): we use uint8_t to store 1-bit ROT - constexpr size_t ot_msg_width = 1; - std::vector a(numel * nbits_each); - std::vector b(numel * nbits_each); - std::vector v(numel * nbits_each); - std::vector u(numel * nbits_each); - if (0 == Rank()) { - ferret_receiver_->RecvRMRC(absl::MakeSpan(a), absl::MakeSpan(u), - ot_msg_width); - ferret_sender_->SendRMRC(absl::MakeSpan(v), absl::MakeSpan(b), - ot_msg_width); - ferret_sender_->Flush(); - } else { - ferret_sender_->SendRMRC(absl::MakeSpan(v), absl::MakeSpan(b), - ot_msg_width); - ferret_sender_->Flush(); - ferret_receiver_->RecvRMRC(absl::MakeSpan(a), absl::MakeSpan(u), - ot_msg_width); - } - - std::vector c(numel * nbits_each); - pforeach(0, c.size(), [&](int64_t i) { - b[i] = b[i] ^ v[i]; - c[i] = (a[i] & b[i]) ^ u[i] ^ v[i]; - }); - - // init as zero - auto AND_a = ring_zeros(field, shape); - auto AND_b = ring_zeros(field, shape); - auto AND_c = ring_zeros(field, shape); - - DISPATCH_ALL_FIELDS(field, "AndTriple", [&]() { - auto AND_xa = NdArrayView(AND_a); - auto AND_xb = NdArrayView(AND_b); - auto AND_xc = NdArrayView(AND_c); - pforeach(0, numel, [&](int64_t i) { - int64_t bgn = i * nbits_each; - int64_t end = bgn + nbits_each; - for (int64_t j = bgn; j < end; ++j) { - AND_xa[i] = (AND_xa[i] << 1) | (a[j] & 1); - AND_xb[i] = (AND_xb[i] << 1) | (b[j] & 1); - AND_xc[i] = (AND_xc[i] << 1) | (c[j] & 1); - } - }); - }); - - return {AND_a, AND_b, AND_c}; -} - -std::array BasicOTProtocols::CorrelatedAndTriple( - FieldType field, const Shape &shape) { - int64_t numel = shape.numel(); - SPU_ENFORCE(numel > 0); - // NOTE(juhou): we use uint8_t to store 2-bit ROT - constexpr size_t ot_msg_width = 2; - std::vector a(numel); - std::vector b(numel); - std::vector v(numel); - std::vector u(numel); - // random choice a is 1-bit - // random messages b, v and u are 2-bit - if (0 == Rank()) { - ferret_receiver_->RecvRMRC(/*choice*/ absl::MakeSpan(a), absl::MakeSpan(u), - ot_msg_width); - ferret_sender_->SendRMRC(absl::MakeSpan(v), absl::MakeSpan(b), - ot_msg_width); - ferret_sender_->Flush(); - } else { - ferret_sender_->SendRMRC(absl::MakeSpan(v), absl::MakeSpan(b), - ot_msg_width); - ferret_sender_->Flush(); - ferret_receiver_->RecvRMRC(/*choice*/ absl::MakeSpan(a), absl::MakeSpan(u), - ot_msg_width); - } - - std::vector c(numel); - pforeach(0, c.size(), [&](int64_t i) { - b[i] = b[i] ^ v[i]; - // broadcast to 2-bit AND - c[i] = (((a[i] << 1) | a[i]) & b[i]) ^ u[i] ^ v[i]; - }); - - auto AND_a = ring_zeros(field, shape); - auto AND_b0 = ring_zeros(field, shape); - auto AND_c0 = ring_zeros(field, shape); - auto AND_b1 = ring_zeros(field, shape); - auto AND_c1 = ring_zeros(field, shape); - - DISPATCH_ALL_FIELDS(field, "AndTriple", [&]() { - auto AND_xa = NdArrayView(AND_a); - auto AND_xb0 = NdArrayView(AND_b0); - auto AND_xc0 = NdArrayView(AND_c0); - auto AND_xb1 = NdArrayView(AND_b1); - auto AND_xc1 = NdArrayView(AND_c1); - pforeach(0, numel, [&](int64_t i) { - AND_xa[i] = a[i] & 1; - AND_xb0[i] = b[i] & 1; - AND_xc0[i] = c[i] & 1; - AND_xb1[i] = (b[i] >> 1) & 1; - AND_xc1[i] = (c[i] >> 1) & 1; - }); - }); - - return {AND_a, AND_b0, AND_c0, AND_b1, AND_c1}; -} - -int BasicOTProtocols::Rank() const { return ferret_sender_->Rank(); } - -NdArrayRef BasicOTProtocols::Multiplexer(const NdArrayRef &msg, - const NdArrayRef &select) { - SPU_ENFORCE_EQ(msg.shape(), select.shape()); - const auto *shareType = select.eltype().as(); - SPU_ENFORCE_EQ(shareType->nbits(), 1UL); - - const auto field = msg.eltype().as()->field(); - const int64_t size = msg.numel(); - - auto _corr_data = ring_zeros(field, msg.shape()); - auto _sent = ring_zeros(field, msg.shape()); - auto _recv = ring_zeros(field, msg.shape()); - std::vector sel(size); - // Compute (x0 + x1) * (b0 ^ b1) - // Also b0 ^ b1 = 1 - 2*b0*b1 - return DISPATCH_ALL_FIELDS(field, "Multiplexer", [&]() { - NdArrayView _msg(msg); - NdArrayView _sel(select); - auto corr_data = absl::MakeSpan(&_corr_data.at(0), size); - auto sent = absl::MakeSpan(&_sent.at(0), size); - auto recv = absl::MakeSpan(&_recv.at(0), size); - - pforeach(0, size, [&](int64_t i) { - sel[i] = static_cast(_sel[i] & 1); - corr_data[i] = _msg[i] * (1 - 2 * sel[i]); - }); - - if (Rank() == 0) { - ferret_sender_->SendCAMCC(corr_data, sent); - ferret_sender_->Flush(); - ferret_receiver_->RecvCAMCC(absl::MakeSpan(sel), recv); - } else { - ferret_receiver_->RecvCAMCC(absl::MakeSpan(sel), recv); - ferret_sender_->SendCAMCC(corr_data, sent); - ferret_sender_->Flush(); - } - - pforeach(0, size, [&](int64_t i) { - recv[i] = _msg[i] * static_cast(sel[i]) - sent[i] + recv[i]; - }); - - return _recv; - }); -} - -} // namespace spu::mpc::cheetah diff --git a/libspu/mpc/cheetah/yacl_ot/basic_ot_prot.h b/libspu/mpc/cheetah/yacl_ot/basic_ot_prot.h deleted file mode 100644 index 8bb3d7c4..00000000 --- a/libspu/mpc/cheetah/yacl_ot/basic_ot_prot.h +++ /dev/null @@ -1,80 +0,0 @@ -// Copyright 2023 Ant Group Co., Ltd. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "libspu/core/ndarray_ref.h" -#include "libspu/mpc/cheetah/yacl_ot/yacl_ferret.h" - -namespace spu::mpc::cheetah { - -class BasicOTProtocols { - public: - explicit BasicOTProtocols(std::shared_ptr conn); - - ~BasicOTProtocols(); - - int Rank() const; - - NdArrayRef B2A(const NdArrayRef &inp); - - NdArrayRef RandBits(FieldType filed, const Shape &shape); - - // NOTE(lwj): compute the B2A(b) and output to the specified ring - // Require: input is 1-bit boolean and 1 <= bit_width < k. - NdArrayRef B2ASingleBitWithSize(const NdArrayRef &inp, int bit_width); - - // msg * select for select \in {0, 1} - NdArrayRef Multiplexer(const NdArrayRef &msg, const NdArrayRef &select); - - // Create `numel` of AND-triple. Each element contains `k` bits - // 1 <= k <= field size - // std::array AndTriple(FieldType field, size_t numel, size_t k); - - std::array AndTriple(FieldType field, const Shape &shape, - size_t k); - - // [a, b, b', c, c'] such that c = a*b and c' = a*b' for the same a - // std::array CorrelatedAndTriple(FieldType field, size_t numel); - - std::array CorrelatedAndTriple(FieldType field, - const Shape &shape); - - // ArrayRef BitwiseAnd(const ArrayRef &lhs, const ArrayRef &rhs); - - NdArrayRef BitwiseAnd(const NdArrayRef &lhs, const NdArrayRef &rhs); - - // Compute the ANDs `lhs & rhs0` and `lhs & rhs1` - std::array CorrelatedBitwiseAnd(const NdArrayRef &lhs, - const NdArrayRef &rhs0, - const NdArrayRef &rhs1); - - std::shared_ptr GetSenderCOT() { return ferret_sender_; } - - std::shared_ptr GetReceiverCOT() { return ferret_receiver_; } - - void Flush(); - - protected: - NdArrayRef SingleB2A(const NdArrayRef &inp, int bit_width = 0); - - NdArrayRef PackedB2A(const NdArrayRef &inp); - - private: - std::shared_ptr conn_; - std::shared_ptr ferret_sender_; - std::shared_ptr ferret_receiver_; -}; - -} // namespace spu::mpc::cheetah diff --git a/libspu/mpc/cheetah/yacl_ot/basic_ot_prot_test.cc b/libspu/mpc/cheetah/yacl_ot/basic_ot_prot_test.cc deleted file mode 100644 index 70aa47e3..00000000 --- a/libspu/mpc/cheetah/yacl_ot/basic_ot_prot_test.cc +++ /dev/null @@ -1,390 +0,0 @@ -// Copyright 2023 Ant Group Co., Ltd. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "libspu/mpc/cheetah/yacl_ot/basic_ot_prot.h" - -#include - -#include "gtest/gtest.h" - -#include "libspu/core/xt_helper.h" -#include "libspu/mpc/cheetah/type.h" -#include "libspu/mpc/utils/ring_ops.h" -#include "libspu/mpc/utils/simulate.h" - -namespace spu::mpc::cheetah::test { - -class BasicOTProtTest : public ::testing::TestWithParam { - void SetUp() override {} -}; - -INSTANTIATE_TEST_SUITE_P( - Cheetah, BasicOTProtTest, - testing::Values(FieldType::FM32, FieldType::FM64, FieldType::FM128), - [](const testing::TestParamInfo& p) { - return fmt::format("{}", p.param); - }); - -TEST_P(BasicOTProtTest, SingleB2A) { - size_t kWorldSize = 2; - Shape shape = {10, 30}; - FieldType field = GetParam(); - - size_t nbits = 8 * SizeOf(field) - 1; - size_t packed_nbits = 8 * SizeOf(field) - nbits; - auto boolean_t = makeType(field, packed_nbits); - - auto bshr0 = ring_rand(field, shape).as(boolean_t); - auto bshr1 = ring_rand(field, shape).as(boolean_t); - DISPATCH_ALL_FIELDS(field, "", [&]() { - auto mask = static_cast(-1); - if (nbits > 0) { - mask = (static_cast(1) << packed_nbits) - 1; - auto xb0 = xt_mutable_adapt(bshr0); - auto xb1 = xt_mutable_adapt(bshr1); - std::transform(xb0.data(), xb0.data() + xb0.size(), xb0.data(), - [&](auto x) { return x & mask; }); - std::transform(xb1.data(), xb1.data() + xb1.size(), xb1.data(), - [&](auto x) { return x & mask; }); - } - }); - - NdArrayRef ashr0; - NdArrayRef ashr1; - utils::simulate(kWorldSize, [&](std::shared_ptr ctx) { - auto conn = std::make_shared(ctx); - BasicOTProtocols ot_prot(conn); - if (ctx->Rank() == 0) { - ashr0 = ot_prot.B2A(bshr0); - } else { - ashr1 = ot_prot.B2A(bshr1); - } - }); - - EXPECT_EQ(ashr0.shape(), ashr1.shape()); - EXPECT_EQ(shape, ashr0.shape()); - - DISPATCH_ALL_FIELDS(field, "", [&]() { - auto b0 = xt_adapt(bshr0); - auto b1 = xt_adapt(bshr1); - auto a0 = xt_adapt(ashr0); - auto a1 = xt_adapt(ashr1); - auto mask = static_cast(-1); - if (nbits > 0) { - mask = (static_cast(1) << packed_nbits) - 1; - } - for (int64_t i = 0; i < shape.numel(); ++i) { - ring2k_t e = b0[i] ^ b1[i]; - ring2k_t c = (a0[i] + a1[i]) & mask; - EXPECT_EQ(e, c); - } - }); -} - -TEST_P(BasicOTProtTest, PackedB2A) { - size_t kWorldSize = 2; - Shape shape = {11, 12, 13}; - FieldType field = GetParam(); - - for (size_t nbits : {1, 2}) { - size_t packed_nbits = 8 * SizeOf(field) - nbits; - auto boolean_t = makeType(field, packed_nbits); - - auto bshr0 = ring_rand(field, shape).as(boolean_t); - auto bshr1 = ring_rand(field, shape).as(boolean_t); - DISPATCH_ALL_FIELDS(field, "", [&]() { - auto mask = static_cast(-1); - if (nbits > 0) { - mask = (static_cast(1) << packed_nbits) - 1; - auto xb0 = xt_mutable_adapt(bshr0); - auto xb1 = xt_mutable_adapt(bshr1); - std::transform(xb0.data(), xb0.data() + xb0.size(), xb0.data(), - [&](auto x) { return x & mask; }); - std::transform(xb1.data(), xb1.data() + xb1.size(), xb1.data(), - [&](auto x) { return x & mask; }); - } - }); - - NdArrayRef ashr0; - NdArrayRef ashr1; - utils::simulate(kWorldSize, [&](std::shared_ptr ctx) { - auto conn = std::make_shared(ctx); - BasicOTProtocols ot_prot(conn); - if (ctx->Rank() == 0) { - ashr0 = ot_prot.B2A(bshr0); - } else { - ashr1 = ot_prot.B2A(bshr1); - } - }); - EXPECT_EQ(ashr0.shape(), ashr1.shape()); - EXPECT_EQ(ashr0.shape(), shape); - - DISPATCH_ALL_FIELDS(field, "", [&]() { - auto b0 = xt_adapt(bshr0); - auto b1 = xt_adapt(bshr1); - auto a0 = xt_adapt(ashr0); - auto a1 = xt_adapt(ashr1); - auto mask = static_cast(-1); - - if (nbits > 0) { - mask = (static_cast(1) << packed_nbits) - 1; - } - - for (int64_t i = 0; i < shape.numel(); ++i) { - ring2k_t e = b0[i] ^ b1[i]; - ring2k_t c = (a0[i] + a1[i]) & mask; - EXPECT_EQ(e, c); - } - }); - } -} - -TEST_P(BasicOTProtTest, PackedB2AFull) { - size_t kWorldSize = 2; - Shape shape = {1, 2, 3, 4, 5}; - FieldType field = GetParam(); - - for (size_t nbits : {0}) { - size_t packed_nbits = 8 * SizeOf(field) - nbits; - auto boolean_t = makeType(field, packed_nbits); - - auto bshr0 = ring_rand(field, shape).as(boolean_t); - auto bshr1 = ring_rand(field, shape).as(boolean_t); - DISPATCH_ALL_FIELDS(field, "", [&]() { - auto mask = static_cast(-1); - if (nbits > 0) { - mask = (static_cast(1) << packed_nbits) - 1; - auto xb0 = xt_mutable_adapt(bshr0); - auto xb1 = xt_mutable_adapt(bshr1); - std::transform(xb0.data(), xb0.data() + xb0.size(), xb0.data(), - [&](auto x) { return x & mask; }); - std::transform(xb1.data(), xb1.data() + xb1.size(), xb1.data(), - [&](auto x) { return x & mask; }); - } - }); - - NdArrayRef ashr0; - NdArrayRef ashr1; - utils::simulate(kWorldSize, [&](std::shared_ptr ctx) { - auto conn = std::make_shared(ctx); - BasicOTProtocols ot_prot(conn); - if (ctx->Rank() == 0) { - ashr0 = ot_prot.B2A(bshr0); - } else { - ashr1 = ot_prot.B2A(bshr1); - } - }); - - EXPECT_EQ(ashr0.shape(), ashr1.shape()); - EXPECT_EQ(ashr0.shape(), shape); - - DISPATCH_ALL_FIELDS(field, "", [&]() { - auto b0 = xt_adapt(bshr0); - auto b1 = xt_adapt(bshr1); - auto a0 = xt_adapt(ashr0); - auto a1 = xt_adapt(ashr1); - auto mask = static_cast(-1); - if (nbits > 0) { - mask = (static_cast(1) << packed_nbits) - 1; - } - for (int64_t i = 0; i < shape.numel(); ++i) { - ring2k_t e = b0[i] ^ b1[i]; - ring2k_t c = (a0[i] + a1[i]) & mask; - EXPECT_EQ(e, c); - } - }); - } -} - -TEST_P(BasicOTProtTest, AndTripleSparse) { - size_t kWorldSize = 2; - Shape shape = {55, 100}; - FieldType field = GetParam(); - size_t max_bit = 8 * SizeOf(field); - - for (size_t sparse : {1UL, 7UL, max_bit - 1}) { - const size_t target_nbits = max_bit - sparse; - std::vector triple[2]; - - utils::simulate(kWorldSize, [&](std::shared_ptr ctx) { - auto conn = std::make_shared(ctx); - BasicOTProtocols ot_prot(conn); - - for (const auto& t : ot_prot.AndTriple(field, shape, target_nbits)) { - triple[ctx->Rank()].emplace_back(t); - } - }); - - DISPATCH_ALL_FIELDS(field, "", [&]() { - ring2k_t max = static_cast(1) << target_nbits; - auto a0 = xt_adapt(triple[0][0]); - auto b0 = xt_adapt(triple[0][1]); - auto c0 = xt_adapt(triple[0][2]); - auto a1 = xt_adapt(triple[1][0]); - auto b1 = xt_adapt(triple[1][1]); - auto c1 = xt_adapt(triple[1][2]); - - for (int64_t i = 0; i < shape.numel(); ++i) { - EXPECT_TRUE(a0[i] < max && a1[i] < max); - EXPECT_TRUE(b0[i] < max && b1[i] < max); - EXPECT_TRUE(c0[i] < max && c1[i] < max); - - ring2k_t e = (a0[i] ^ a1[i]) & (b0[i] ^ b1[i]); - ring2k_t c = (c0[i] ^ c1[i]); - EXPECT_EQ(e, c); - } - }); - } -} - -TEST_P(BasicOTProtTest, AndTripleFull) { - size_t kWorldSize = 2; - Shape shape = {55, 11}; - FieldType field = GetParam(); - - std::vector packed_triple[2]; - - utils::simulate(kWorldSize, [&](std::shared_ptr ctx) { - auto conn = std::make_shared(ctx); - BasicOTProtocols ot_prot(conn); - for (const auto& t : ot_prot.AndTriple(field, shape, SizeOf(field) * 8)) { - packed_triple[ctx->Rank()].emplace_back(t); - } - }); - - DISPATCH_ALL_FIELDS(field, "", [&]() { - auto a0 = xt_adapt(packed_triple[0][0]); - auto b0 = xt_adapt(packed_triple[0][1]); - auto c0 = xt_adapt(packed_triple[0][2]); - auto a1 = xt_adapt(packed_triple[1][0]); - auto b1 = xt_adapt(packed_triple[1][1]); - auto c1 = xt_adapt(packed_triple[1][2]); - - size_t nn = a0.size(); - EXPECT_TRUE(nn * 8 * SizeOf(field) >= (size_t)shape.numel()); - - for (size_t i = 0; i < nn; ++i) { - ring2k_t e = (a0[i] ^ a1[i]) & (b0[i] ^ b1[i]); - ring2k_t c = (c0[i] ^ c1[i]); - - EXPECT_EQ(e, c); - } - }); -} - -TEST_P(BasicOTProtTest, Multiplexer) { - size_t kWorldSize = 2; - Shape shape = {3, 4, 1, 3}; - FieldType field = GetParam(); - - auto boolean_t = makeType(field, 1); - - auto ashr0 = ring_rand(field, shape); - auto ashr1 = ring_rand(field, shape); - auto bshr0 = ring_rand(field, shape).as(boolean_t); - auto bshr1 = ring_rand(field, shape).as(boolean_t); - - DISPATCH_ALL_FIELDS(field, "", [&]() { - auto mask = static_cast(1); - auto xb0 = xt_mutable_adapt(bshr0); - auto xb1 = xt_mutable_adapt(bshr1); - std::transform(xb0.data(), xb0.data() + xb0.size(), xb0.data(), - [&](auto x) { return x & mask; }); - std::transform(xb1.data(), xb1.data() + xb1.size(), xb1.data(), - [&](auto x) { return x & mask; }); - }); - - NdArrayRef computed[2]; - utils::simulate(kWorldSize, [&](std::shared_ptr ctx) { - auto conn = std::make_shared(ctx); - BasicOTProtocols ot_prot(conn); - if (ctx->Rank() == 0) { - computed[0] = ot_prot.Multiplexer(ashr0, bshr0); - } else { - computed[1] = ot_prot.Multiplexer(ashr1, bshr1); - } - }); - - EXPECT_EQ(computed[0].shape(), computed[1].shape()); - EXPECT_EQ(computed[0].shape(), shape); - - DISPATCH_ALL_FIELDS(field, "", [&]() { - auto a0 = xt_adapt(ashr0); - auto a1 = xt_adapt(ashr1); - auto b0 = xt_adapt(bshr0); - auto b1 = xt_adapt(bshr1); - auto c0 = xt_adapt(computed[0]); - auto c1 = xt_adapt(computed[1]); - - for (int64_t i = 0; i < shape.numel(); ++i) { - ring2k_t msg = (a0[i] + a1[i]); - ring2k_t sel = (b0[i] ^ b1[i]); - ring2k_t exp = msg * sel; - ring2k_t got = (c0[i] + c1[i]); - EXPECT_EQ(exp, got); - } - }); -} - -TEST_P(BasicOTProtTest, CorrelatedAndTriple) { - size_t kWorldSize = 2; - Shape shape = {10 * 8}; - FieldType field = GetParam(); - - std::array corr_triple[2]; - - utils::simulate(kWorldSize, [&](std::shared_ptr ctx) { - auto conn = std::make_shared(ctx); - BasicOTProtocols ot_prot(conn); - corr_triple[ctx->Rank()] = ot_prot.CorrelatedAndTriple(field, shape); - }); - - EXPECT_EQ(corr_triple[0][0].shape(), corr_triple[1][0].shape()); - for (int i = 1; i < 5; ++i) { - EXPECT_EQ(corr_triple[0][0].shape(), corr_triple[0][i].shape()); - EXPECT_EQ(corr_triple[1][0].shape(), corr_triple[1][i].shape()); - } - - DISPATCH_ALL_FIELDS(field, "", [&]() { - auto a0 = NdArrayView(corr_triple[0][0]); - auto b0 = NdArrayView(corr_triple[0][1]); - auto c0 = NdArrayView(corr_triple[0][2]); - auto d0 = NdArrayView(corr_triple[0][3]); - auto e0 = NdArrayView(corr_triple[0][4]); - - auto a1 = NdArrayView(corr_triple[1][0]); - auto b1 = NdArrayView(corr_triple[1][1]); - auto c1 = NdArrayView(corr_triple[1][2]); - auto d1 = NdArrayView(corr_triple[1][3]); - auto e1 = NdArrayView(corr_triple[1][4]); - - for (int64_t i = 0; i < shape.numel(); ++i) { - EXPECT_TRUE(a0[i] < 2 && a1[i] < 2); - EXPECT_TRUE(b0[i] < 2 && b1[i] < 2); - EXPECT_TRUE(c0[i] < 2 && c1[i] < 2); - EXPECT_TRUE(d0[i] < 2 && d1[i] < 2); - EXPECT_TRUE(e0[i] < 2 && e1[i] < 2); - - ring2k_t e = (a0[i] ^ a1[i]) & (b0[i] ^ b1[i]); - ring2k_t c = (c0[i] ^ c1[i]); - EXPECT_EQ(e, c); - - e = (a0[i] ^ a1[i]) & (d0[i] ^ d1[i]); - c = (e0[i] ^ e1[i]); - EXPECT_EQ(e, c); - } - }); -} - -} // namespace spu::mpc::cheetah::test diff --git a/libspu/mpc/cheetah/yacl_ot/util.h b/libspu/mpc/cheetah/yacl_ot/util.h deleted file mode 100644 index 9529b52f..00000000 --- a/libspu/mpc/cheetah/yacl_ot/util.h +++ /dev/null @@ -1,178 +0,0 @@ -// Copyright 2023 Ant Group Co., Ltd. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "absl/types/span.h" -#include "yacl/base/dynamic_bitset.h" -#include "yacl/base/int128.h" - -#include "libspu/core/prelude.h" -#include "libspu/core/xt_helper.h" - -namespace spu::mpc::cheetah { - -template -inline T makeBitsMask(size_t nbits) { - size_t max = sizeof(T) * 8; - if (nbits == 0) { - nbits = max; - } - SPU_ENFORCE(nbits <= max); - T mask = static_cast(-1); - if (nbits < max) { - mask = (static_cast(1) << nbits) - 1; - } - return mask; -} - -template -inline T CeilDiv(T a, T b) { - return (a + b - 1) / b; -} - -template -inline size_t ZipArray(absl::Span inp, size_t bit_width, - absl::Span oup) { - size_t width = sizeof(T) * 8; - SPU_ENFORCE(bit_width > 0 && width >= bit_width); - size_t shft = bit_width; - size_t pack_load = width / shft; - size_t numel = inp.size(); - size_t packed_sze = CeilDiv(numel, pack_load); - SPU_ENFORCE(oup.size() >= packed_sze); - - const T mask = makeBitsMask(bit_width); - for (size_t i = 0; i < numel; i += pack_load) { - size_t this_batch = std::min(pack_load, numel - i); - T acc{0}; - for (size_t j = 0; j < this_batch; ++j) { - acc = (acc << shft) | (inp[i + j] & mask); - } - oup[i / pack_load] = acc; - } - return packed_sze; -} - -template -inline size_t UnzipArray(absl::Span inp, size_t bit_width, - absl::Span oup) { - size_t width = sizeof(T) * 8; - SPU_ENFORCE(bit_width > 0 && bit_width <= width); - - size_t shft = bit_width; - size_t pack_load = width / shft; - size_t packed_sze = inp.size(); - size_t n = oup.size(); - SPU_ENFORCE(n > 0 && n <= pack_load * packed_sze); - - const T mask = makeBitsMask(bit_width); - for (size_t i = 0; i < packed_sze; ++i) { - size_t j0 = std::min(n, i * pack_load); - size_t j1 = std::min(n, j0 + pack_load); - size_t this_batch = j1 - j0; - T package = inp[i]; - // NOTE (reversed order) - for (size_t j = 0; j < this_batch; ++j) { - oup[j1 - 1 - j] = package & mask; - package >>= shft; - } - } - - return n; -} - -template -inline size_t PackU8Array(absl::Span u8array, - absl::Span packed) { - constexpr size_t elsze = sizeof(T); - const size_t nbytes = u8array.size(); - const size_t numel = CeilDiv(nbytes, elsze); - - SPU_ENFORCE(packed.size() >= numel); - - for (size_t i = 0; i < nbytes; i += elsze) { - size_t this_batch = std::min(nbytes - i, elsze); - T acc{0}; - for (size_t j = 0; j < this_batch; ++j) { - acc = (acc << 8) | u8array[i + j]; - } - packed[i / elsze] = acc; - } - - return numel; -} - -template -inline size_t UnpackU8Array(absl::Span input, - absl::Span u8array) { - using UT = typename std::make_unsigned::type; - constexpr size_t elsze = sizeof(T); - const size_t numel = input.size(); - const size_t nbytes = u8array.size(); - SPU_ENFORCE(CeilDiv(nbytes, elsze) >= numel); - - constexpr T mask = (static_cast(1) << 8) - 1; - for (size_t i = 0; i < nbytes; i += elsze) { - size_t this_batch = std::min(nbytes - i, elsze); - UT acc = static_cast(input[i / elsze]); - for (size_t j = 0; j < this_batch; ++j) { - u8array[i + this_batch - 1 - j] = acc & mask; - acc >>= 8; - } - } - - return nbytes; -} - -inline uint8_t BoolToU8(absl::Span bits); - -inline void U8ToBool(absl::Span bits, uint8_t u8); - -// Add by @wenfan -inline void VecU8toBitset(absl::Span bits, - yacl::dynamic_bitset& bitset) { - SPU_ENFORCE(bits.size() == bitset.size()); - uint64_t bits_num = bits.size(); - // low efficiency - for (uint64_t i = 0; i < bits_num; ++i) { - bitset[i] = (bool)bits[i]; - } -} - -inline yacl::dynamic_bitset VecU8toBitset( - absl::Span bits) { - yacl::dynamic_bitset bitset(bits.size()); - VecU8toBitset(bits, bitset); - return bitset; -} - -inline void BitsettoVecU8(const yacl::dynamic_bitset& bitset, - absl::Span bits) { - SPU_ENFORCE(bits.size() == bitset.size()); - uint64_t bits_num = bitset.size(); - // low efficiency - for (uint64_t i = 0; i < bits_num; ++i) { - bits[i] = bitset[i]; - } -} - -inline std::vector BitsettoVecU8( - const yacl::dynamic_bitset& bitset) { - std::vector bits(bitset.size()); - BitsettoVecU8(bitset, absl::MakeSpan(bits)); - return bits; -} - -} // namespace spu::mpc::cheetah diff --git a/libspu/mpc/cheetah/yacl_ot/util_test.cc b/libspu/mpc/cheetah/yacl_ot/util_test.cc deleted file mode 100644 index 880b0c94..00000000 --- a/libspu/mpc/cheetah/yacl_ot/util_test.cc +++ /dev/null @@ -1,90 +0,0 @@ -// Copyright 2023 Ant Group Co., Ltd. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "libspu/mpc/cheetah/yacl_ot/util.h" - -#include - -#include "gtest/gtest.h" - -#include "libspu/mpc/utils/ring_ops.h" - -namespace spu::mpc::cheetah::test { - -class UtilTest : public ::testing::TestWithParam { - void SetUp() override {} -}; - -INSTANTIATE_TEST_SUITE_P( - Cheetah, UtilTest, - testing::Values(FieldType::FM32, FieldType::FM64, FieldType::FM128), - [](const testing::TestParamInfo &p) { - return fmt::format("{}", p.param); - }); - -TEST_P(UtilTest, ZipArray) { - const int64_t n = 20; - const auto field = GetParam(); - const size_t elsze = SizeOf(field); - - auto unzip = ring_zeros(field, {n}); - - DISPATCH_ALL_FIELDS(field, "", [&]() { - for (size_t bw : {1, 2, 4, 7, 15, 16}) { - int64_t pack_load = elsze * 8 / bw; - auto zip = ring_zeros(field, {(n + pack_load - 1) / pack_load}); - auto array = ring_rand(field, {n}); - auto inp = xt_mutable_adapt(array); - auto mask = makeBitsMask(bw); - inp &= mask; - - auto _zip = xt_mutable_adapt(zip); - auto _unzip = xt_mutable_adapt(unzip); - size_t zip_sze = ZipArray({inp.data(), inp.size()}, bw, - {_zip.data(), _zip.size()}); - - UnzipArray({_zip.data(), zip_sze}, bw, - {_unzip.data(), _unzip.size()}); - - for (size_t i = 0; i < n; ++i) { - EXPECT_EQ(inp[i], _unzip[i]); - } - } - }); -} - -TEST_P(UtilTest, PackU8Array) { - const int64_t num_bytes = 223; - const auto field = GetParam(); - const int64_t elsze = SizeOf(field); - - std::uniform_int_distribution uniform(0, -1); - std::default_random_engine rdv; - std::vector u8array(num_bytes); - std::generate_n(u8array.data(), u8array.size(), - [&]() { return uniform(rdv); }); - - auto packed = ring_zeros(field, {(num_bytes + elsze - 1) / elsze}); - - DISPATCH_ALL_FIELDS(field, "", [&]() { - auto xp = xt_mutable_adapt(packed); - PackU8Array(absl::MakeSpan(u8array), {xp.data(), xp.size()}); - std::vector _u8(num_bytes, -1); - UnpackU8Array({xp.data(), xp.size()}, absl::MakeSpan(_u8)); - - EXPECT_TRUE(std::memcmp(_u8.data(), u8array.data(), num_bytes) == 0); - }); -} - -} // namespace spu::mpc::cheetah::test diff --git a/libspu/mpc/semi2k/beaver/beaver_interface.h b/libspu/mpc/semi2k/beaver/beaver_interface.h index 1f5b89e2..4e2fbc22 100644 --- a/libspu/mpc/semi2k/beaver/beaver_interface.h +++ b/libspu/mpc/semi2k/beaver/beaver_interface.h @@ -70,6 +70,10 @@ class Beaver { absl::Span perm_vec) = 0; virtual std::unique_ptr Spawn() = 0; + + // ret[0] (in a share) = ret[1] (in b share) + // ref: https://eprint.iacr.org/2020/338 + virtual Pair Eqz(FieldType field, const Shape& shape) = 0; }; } // namespace spu::mpc::semi2k diff --git a/libspu/mpc/semi2k/beaver/beaver_test.cc b/libspu/mpc/semi2k/beaver/beaver_test.cc index 0ce17d9f..b150dac9 100644 --- a/libspu/mpc/semi2k/beaver/beaver_test.cc +++ b/libspu/mpc/semi2k/beaver/beaver_test.cc @@ -416,6 +416,34 @@ TEST_P(BeaverTest, Randbit) { }); } +TEST_P(BeaverTest, Eqz) { + const auto factory = std::get<0>(GetParam()).first; + const size_t kWorldSize = std::get<1>(GetParam()); + const FieldType kField = std::get<2>(GetParam()); + const int64_t kNumel = 2; + + std::vector pairs; + pairs.resize(kWorldSize); + utils::simulate(kWorldSize, + [&](const std::shared_ptr& lctx) { + auto beaver = factory(lctx, ttp_options_); + pairs[lctx->Rank()] = beaver->Eqz(kField, {kNumel}); + yacl::link::Barrier(lctx, "BeaverUT"); + }); + EXPECT_EQ(pairs.size(), kWorldSize); + auto sum_a = ring_zeros(kField, {kNumel}); + auto sum_b = ring_zeros(kField, {kNumel}); + for (Rank r = 0; r < kWorldSize; r++) { + const auto& [a, b] = pairs[r]; + EXPECT_EQ(a.numel(), kNumel); + EXPECT_EQ(b.numel(), kNumel); + + ring_add_(sum_a, a); + ring_xor_(sum_b, b); + } + EXPECT_TRUE(ring_all_equal(sum_a, sum_b)); +} + TEST_P(BeaverTest, PermPair) { const auto factory = std::get<0>(GetParam()).first; const size_t kWorldSize = std::get<1>(GetParam()); diff --git a/libspu/mpc/semi2k/beaver/beaver_tfp.cc b/libspu/mpc/semi2k/beaver/beaver_tfp.cc index b695d34d..db15af63 100644 --- a/libspu/mpc/semi2k/beaver/beaver_tfp.cc +++ b/libspu/mpc/semi2k/beaver/beaver_tfp.cc @@ -164,4 +164,16 @@ std::unique_ptr BeaverTfpUnsafe::Spawn() { return std::make_unique(lctx_->Spawn()); } +BeaverTfpUnsafe::Pair BeaverTfpUnsafe::Eqz(FieldType field, + const Shape& shape) { + std::vector descs(2); + auto a = prgCreateArray(field, shape, seed_, &counter_, descs.data()); + auto b = prgCreateArray(field, shape, seed_, &counter_, &descs[1]); + if (lctx_->Rank() == 0) { + auto adjust = TrustedParty::adjustEqz(descs, seeds_); + ring_xor_(b, adjust); + } + return {a, b}; +} + } // namespace spu::mpc::semi2k diff --git a/libspu/mpc/semi2k/beaver/beaver_tfp.h b/libspu/mpc/semi2k/beaver/beaver_tfp.h index cfbbd7c5..424bd744 100644 --- a/libspu/mpc/semi2k/beaver/beaver_tfp.h +++ b/libspu/mpc/semi2k/beaver/beaver_tfp.h @@ -59,6 +59,8 @@ class BeaverTfpUnsafe final : public Beaver { absl::Span perm_vec) override; std::unique_ptr Spawn() override; + + Pair Eqz(FieldType field, const Shape& shape) override; }; } // namespace spu::mpc::semi2k diff --git a/libspu/mpc/semi2k/beaver/beaver_ttp.cc b/libspu/mpc/semi2k/beaver/beaver_ttp.cc index 213534c5..dad4f6be 100644 --- a/libspu/mpc/semi2k/beaver/beaver_ttp.cc +++ b/libspu/mpc/semi2k/beaver/beaver_ttp.cc @@ -87,6 +87,9 @@ std::vector RpcCall(brpc::Channel& channel, AdjustRequest req, AdjustRequest, beaver::ttp_server::AdjustRandBitRequest>) { stub.AdjustRandBit(&cntl, &req, &rsp, nullptr); + } else if constexpr (std::is_same_v) { + stub.AdjustEqz(&cntl, &req, &rsp, nullptr); } else if constexpr (std::is_same_v) { stub.AdjustPerm(&cntl, &req, &rsp, nullptr); @@ -337,4 +340,20 @@ std::unique_ptr BeaverTtp::Spawn() { return std::make_unique(lctx_->Spawn(), std::move(new_options)); } +BeaverTtp::Pair BeaverTtp::Eqz(FieldType field, const Shape& shape) { + std::vector descs(2); + + auto a = prgCreateArray(field, shape, seed_, &counter_, descs.data()); + auto b = prgCreateArray(field, shape, seed_, &counter_, &descs[1]); + + if (lctx_->Rank() == options_.adjust_rank) { + auto req = BuildAdjustRequest( + options_.session_id, descs); + auto adjusts = RpcCall(channel_, req, field); + SPU_ENFORCE_EQ(adjusts.size(), 1U); + ring_xor_(b, adjusts[0].reshape(shape)); + } + + return {a, b}; +} } // namespace spu::mpc::semi2k diff --git a/libspu/mpc/semi2k/beaver/beaver_ttp.h b/libspu/mpc/semi2k/beaver/beaver_ttp.h index c5ea9677..ab4d8434 100644 --- a/libspu/mpc/semi2k/beaver/beaver_ttp.h +++ b/libspu/mpc/semi2k/beaver/beaver_ttp.h @@ -76,6 +76,8 @@ class BeaverTtp final : public Beaver { absl::Span perm_vec) override; std::unique_ptr Spawn() override; + + Pair Eqz(FieldType field, const Shape& shape) override; }; } // namespace spu::mpc::semi2k diff --git a/libspu/mpc/semi2k/beaver/trusted_party.cc b/libspu/mpc/semi2k/beaver/trusted_party.cc index 3469a245..8f1b7adc 100644 --- a/libspu/mpc/semi2k/beaver/trusted_party.cc +++ b/libspu/mpc/semi2k/beaver/trusted_party.cc @@ -126,6 +126,15 @@ NdArrayRef TrustedParty::adjustRandBit(Descs descs, Seeds seeds) { return ring_sub(ring_randbit(descs[0].field, descs[0].shape), rs[0]); } +NdArrayRef TrustedParty::adjustEqz(Descs descs, Seeds seeds) { + SPU_ENFORCE_EQ(descs.size(), 2U); + checkDescs(descs); + auto rs_a = reconstruct(RecOp::ADD, seeds, descs.subspan(0, 1)); + auto rs_b = reconstruct(RecOp::XOR, seeds, descs.subspan(1, 2)); + // adjust = rs[0] ^ rs[1]; + return ring_xor(rs_a[0], rs_b[0]); +} + NdArrayRef TrustedParty::adjustPerm(Descs descs, Seeds seeds, absl::Span perm_vec) { SPU_ENFORCE_EQ(descs.size(), 2U); diff --git a/libspu/mpc/semi2k/beaver/trusted_party.h b/libspu/mpc/semi2k/beaver/trusted_party.h index 0c4de181..a60499bd 100644 --- a/libspu/mpc/semi2k/beaver/trusted_party.h +++ b/libspu/mpc/semi2k/beaver/trusted_party.h @@ -43,6 +43,8 @@ class TrustedParty { static NdArrayRef adjustRandBit(Descs descs, Seeds seeds); + static NdArrayRef adjustEqz(Descs descs, Seeds seeds); + static NdArrayRef adjustPerm(Descs descs, Seeds seeds, absl::Span perm_vec); }; diff --git a/libspu/mpc/semi2k/beaver/ttp_server/beaver_server.cc b/libspu/mpc/semi2k/beaver/ttp_server/beaver_server.cc index d550bd81..ddb8e0be 100644 --- a/libspu/mpc/semi2k/beaver/ttp_server/beaver_server.cc +++ b/libspu/mpc/semi2k/beaver/ttp_server/beaver_server.cc @@ -85,6 +85,9 @@ std::vector AdjustImpl(const AdjustRequest& req, } else if constexpr (std::is_same_v) { auto adjust = TrustedParty::adjustRandBit(descs, seeds); ret.push_back(std::move(adjust)); + } else if constexpr (std::is_same_v) { + auto adjust = TrustedParty::adjustEqz(descs, seeds); + ret.push_back(std::move(adjust)); } else if constexpr (std::is_same_v) { PermVector pv; for (auto p : req.perm_vec()) { @@ -317,12 +320,18 @@ class ServiceImpl final : public BeaverService { Adjust(controller, req, rsp, done); } + void AdjustEqz(::google::protobuf::RpcController* controller, + const AdjustEqzRequest* req, AdjustResponse* rsp, + ::google::protobuf::Closure* done) override { + Adjust(controller, req, rsp, done); + } + void AdjustPerm(::google::protobuf::RpcController* controller, const AdjustPermRequest* req, AdjustResponse* rsp, ::google::protobuf::Closure* done) override { Adjust(controller, req, rsp, done); } -}; +}; // namespace spu::mpc::semi2k::beaver::ttp_server std::unique_ptr RunServer(int32_t port) { brpc::FLAGS_max_body_size = std::numeric_limits::max(); diff --git a/libspu/mpc/semi2k/beaver/ttp_server/service.proto b/libspu/mpc/semi2k/beaver/ttp_server/service.proto index 81beda82..c25f3c91 100644 --- a/libspu/mpc/semi2k/beaver/ttp_server/service.proto +++ b/libspu/mpc/semi2k/beaver/ttp_server/service.proto @@ -83,6 +83,8 @@ service BeaverService { rpc AdjustRandBit(AdjustRandBitRequest) returns (AdjustResponse); + rpc AdjustEqz(AdjustEqzRequest) returns (AdjustResponse); + rpc AdjustPerm(AdjustPermRequest) returns (AdjustResponse); } @@ -175,6 +177,19 @@ message AdjustRandBitRequest { // (adjust_a + ra) = random 0/1 array } +message AdjustEqzRequest { + string session_id = 1; + // input two prg buffer + // reconstruct all parties' share get: ra / rb + repeated PrgBufferMeta prg_inputs = 2; + // use which field to interprete buffer. details see: spu.FieldType + int32 field = 3; + // output + // adjust_b = rb + // make + // ra(in a share) = rb(in b share) +} + message AdjustPermRequest { string session_id = 1; // input two prg buffer diff --git a/libspu/mpc/semi2k/conversion.cc b/libspu/mpc/semi2k/conversion.cc index f87e06d2..e202587c 100644 --- a/libspu/mpc/semi2k/conversion.cc +++ b/libspu/mpc/semi2k/conversion.cc @@ -36,6 +36,12 @@ static NdArrayRef wrap_a2b(SPUContext* ctx, const NdArrayRef& x) { return UnwrapValue(a2b(ctx, WrapValue(x))); } +static NdArrayRef wrap_and_bb(SPUContext* ctx, const NdArrayRef& x, + const NdArrayRef& y) { + SPU_ENFORCE(x.shape() == y.shape()); + return UnwrapValue(and_bb(ctx, WrapValue(x), WrapValue(y))); +} + NdArrayRef A2B::proc(KernelEvalContext* ctx, const NdArrayRef& x) const { const auto field = x.eltype().as()->field(); auto* comm = ctx->getState(); @@ -185,6 +191,91 @@ NdArrayRef MsbA2B::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { } } +NdArrayRef eqz(KernelEvalContext* ctx, const NdArrayRef& in) { + auto* prg_state = ctx->getState(); + auto* comm = ctx->getState(); + auto* beaver = ctx->getState()->beaver(); + + const auto field = in.eltype().as()->field(); + const auto numel = in.numel(); + + NdArrayRef out(makeType(field), in.shape()); + + size_t pivot; + prg_state->fillPubl(absl::MakeSpan(&pivot, 1)); + pivot %= comm->getWorldSize(); + // beaver samples r and deals [r]a and [r]b + // receal c = a+r + // check a == 0 <=> c == r + DISPATCH_ALL_FIELDS(field, "_", [&]() { + using el_t = ring2k_t; + auto [ra, rb] = beaver->Eqz(field, in.shape()); + + // c in secret share + ring_add_(ra, in); + // reveal c + NdArrayRef c_p = comm->allReduce(ReduceOp::ADD, ra, "reveal c "); + + if (comm->getRank() == pivot) { + ring_xor_(rb, c_p); + ring_not_(rb); + } + + // if a == 0, ~(a+ra) ^ rb supposed to be all 1 + // do log(k) round bit wise and + // TODO: fix AND triple + // in beaver->AND(field, shape), min FM32, need min 1byte to reduce comm + NdArrayRef round_out = rb.as(makeType(field)); + size_t cur_bits = round_out.eltype().as()->nbits(); + while (cur_bits != 1) { + cur_bits /= 2; + round_out = + wrap_and_bb(ctx->sctx(), round_out, ring_rshift(round_out, cur_bits)); + } + + // 1 bit info in lsb + NdArrayView _out(out); + NdArrayView _round_out(round_out); + pforeach(0, numel, [&](int64_t idx) { _out[idx] = _round_out[idx] & 1; }); + }); + + return out; +} + +NdArrayRef EqualAA::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, + const NdArrayRef& rhs) const { + const auto* lhs_ty = lhs.eltype().as(); + const auto* rhs_ty = rhs.eltype().as(); + + SPU_ENFORCE(lhs_ty->field() == rhs_ty->field()); + const auto field = lhs_ty->field(); + NdArrayRef out(makeType(field), lhs.shape()); + + out = ring_sub(lhs, rhs); + + return eqz(ctx, out); +} + +NdArrayRef EqualAP::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, + const NdArrayRef& rhs) const { + auto* comm = ctx->getState(); + const auto* lhs_ty = lhs.eltype().as(); + const auto* rhs_ty = rhs.eltype().as(); + + SPU_ENFORCE(lhs_ty->field() == rhs_ty->field()); + const auto field = lhs_ty->field(); + NdArrayRef out(makeType(field), lhs.shape()); + + auto rank = comm->getRank(); + if (rank == 0) { + out = ring_sub(lhs, rhs); + } else { + out = lhs; + }; + + return eqz(ctx, out); +} + void CommonTypeV::evaluate(KernelEvalContext* ctx) const { const Type& lhs = ctx->getParam(0); const Type& rhs = ctx->getParam(1); diff --git a/libspu/mpc/semi2k/conversion.h b/libspu/mpc/semi2k/conversion.h index dad01a57..18033c16 100644 --- a/libspu/mpc/semi2k/conversion.h +++ b/libspu/mpc/semi2k/conversion.h @@ -93,6 +93,40 @@ class MsbA2B : public UnaryKernel { NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in) const override; }; +class EqualAA : public BinaryKernel { + public: + static constexpr char kBindName[] = "equal_aa"; + + ce::CExpr latency() const override { + // 1 * edabits + logk * andbb + return Log(ce::K()) + 1; + } + + ce::CExpr comm() const override { + return (2 * Log(ce::K()) + 1) * ce::K() * (ce::N() - 1); + } + + NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& lhs, + const NdArrayRef& rhs) const override; +}; + +class EqualAP : public BinaryKernel { + public: + static constexpr char kBindName[] = "equal_ap"; + + ce::CExpr latency() const override { + // 1 * edabits + logk * andbb + return Log(ce::K()) + 1; + } + + ce::CExpr comm() const override { + return (2 * Log(ce::K()) + 1) * ce::K() * (ce::N() - 1); + } + + NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& lhs, + const NdArrayRef& rhs) const override; +}; + class CommonTypeV : public Kernel { public: static constexpr char kBindName[] = "common_type_v"; diff --git a/libspu/mpc/semi2k/protocol.cc b/libspu/mpc/semi2k/protocol.cc index c6e2496f..748f7ed6 100644 --- a/libspu/mpc/semi2k/protocol.cc +++ b/libspu/mpc/semi2k/protocol.cc @@ -91,6 +91,9 @@ void regSemi2kProtocol(SPUContext* ctx, ctx->prot()->regKernel(); ctx->prot()->regKernel(); ctx->prot()->regKernel(); + + ctx->prot()->regKernel(); + ctx->prot()->regKernel(); } std::unique_ptr makeSemi2kProtocol( diff --git a/libspu/mpc/spdz2k/BUILD.bazel b/libspu/mpc/spdz2k/BUILD.bazel index 8be7c3b3..85c0699d 100644 --- a/libspu/mpc/spdz2k/BUILD.bazel +++ b/libspu/mpc/spdz2k/BUILD.bazel @@ -112,6 +112,7 @@ spu_cc_library( ":state", ":type", "//libspu/core:vectorize", + "//libspu/core:xt_helper", "//libspu/mpc:ab_api", "//libspu/mpc:api", "//libspu/mpc:kernel", diff --git a/libspu/mpc/spdz2k/ot/BUILD.bazel b/libspu/mpc/spdz2k/ot/BUILD.bazel index 26689908..50651d6f 100644 --- a/libspu/mpc/spdz2k/ot/BUILD.bazel +++ b/libspu/mpc/spdz2k/ot/BUILD.bazel @@ -21,7 +21,7 @@ spu_cc_library( name = "ferret", hdrs = ["ferret.h"], deps = [ - "//libspu/mpc/cheetah/ot:cheetah_ot", + "//libspu/mpc/cheetah/ot", "//libspu/mpc/common:communicator", ], ) diff --git a/libspu/mpc/spdz2k/ot/ferret.h b/libspu/mpc/spdz2k/ot/ferret.h index 9cee9684..578039e5 100644 --- a/libspu/mpc/spdz2k/ot/ferret.h +++ b/libspu/mpc/spdz2k/ot/ferret.h @@ -23,12 +23,12 @@ #include "absl/types/span.h" #include "yacl/base/int128.h" -#include "libspu/mpc/cheetah/ot/ferret.h" +#include "libspu/mpc/cheetah/ot/emp/ferret.h" #include "libspu/mpc/common/communicator.h" namespace spu::mpc::spdz2k { -class FerretOT : public cheetah::FerretOT { +class FerretOT : public cheetah::EmpFerretOt { private: struct Impl; std::shared_ptr impl_; @@ -36,9 +36,9 @@ class FerretOT : public cheetah::FerretOT { public: FerretOT(std::shared_ptr conn, bool is_sender, bool malicious = true) - : cheetah::FerretOT::FerretOT(conn, is_sender, malicious) {} + : EmpFerretOt(std::move(conn), is_sender, malicious) {} - ~FerretOT() = default; + ~FerretOT() override = default; // VOLE, only for SPDZ2K // data[i] = data0[i] + a[i] * corr[i] diff --git a/setup.py b/setup.py index 87e1982d..1a18f8fb 100644 --- a/setup.py +++ b/setup.py @@ -192,9 +192,23 @@ def remove_file(target_dir, filename): return 0 +def fix_pb(file, old_pattern, new_pattern): + os.chmod(file, 0o666) + with open(file, 'r+') as f: + content = f.read() + content = content.replace(old_pattern, new_pattern) + + with open(file, 'w+') as f: + f.write(content) + + def pip_run(build_ext): build(True, True) + # Change __module__ in psi_pb2.py and pir_pb2.py + fix_pb('bazel-bin/spu/psi_pb2.py', 'psi.psi.psi_pb2', 'spu.psi_pb2') + fix_pb('bazel-bin/spu/pir_pb2.py', 'psi.pir.pir_pb2', 'spu.pir_pb2') + setup_spec.files_to_include += spu_lib_files # Copy over the autogenerated protobuf Python bindings.