Skip to content

Commit

Permalink
Repo sync (#826)
Browse files Browse the repository at this point in the history
  • Loading branch information
anakinxc authored Aug 21, 2024
1 parent 28fef7d commit d51f4e7
Show file tree
Hide file tree
Showing 30 changed files with 98 additions and 134 deletions.
4 changes: 2 additions & 2 deletions .circleci/asan-config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ version: 2.1
parameters:
run-asan:
type: boolean
default: false
default: true

# Define a job to be invoked later in a workflow.
# See: https://circleci.com/docs/2.0/configuration-reference/#jobs
Expand Down Expand Up @@ -55,7 +55,7 @@ jobs:
command: |
set +e
declare -i test_status
bazel test //libspu/... --features=asan --ui_event_filters=-info,-debug,-warning --test_output=errors | tee test_result.log; test_status=${PIPESTATUS[0]}
bazel test //libspu/... --features=asan --test_timeout=500 --ui_event_filters=-info,-debug,-warning --test_output=errors | tee test_result.log; test_status=${PIPESTATUS[0]}
sh ../devtools/rename-junit-xml.sh
find bazel-testlogs/ -type f -name "test.log" -print0 | xargs -0 tar -cvzf test_logs.tar.gz
Expand Down
4 changes: 4 additions & 0 deletions libspu/kernel/hal/prot_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,10 @@ Value _s2v(SPUContext* ctx, const Value& in, int owner) {
MAP_UNARY_OP(not_p)
MAP_UNARY_OP(not_s)
MAP_UNARY_OP(not_v)
// Negate family
MAP_UNARY_OP(negate_p)
MAP_UNARY_OP(negate_s)
MAP_UNARY_OP(negate_v)
// Msb family
MAP_UNARY_OP(msb_p)
MAP_UNARY_OP(msb_s)
Expand Down
4 changes: 4 additions & 0 deletions libspu/kernel/hal/prot_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ Value _not_p(SPUContext* ctx, const Value& in);
Value _not_s(SPUContext* ctx, const Value& in);
Value _not_v(SPUContext* ctx, const Value& in);

Value _negate_p(SPUContext* ctx, const Value& in);
Value _negate_s(SPUContext* ctx, const Value& in);
Value _negate_v(SPUContext* ctx, const Value& in);

Value _msb_p(SPUContext* ctx, const Value& in);
Value _msb_s(SPUContext* ctx, const Value& in);
Value _msb_v(SPUContext* ctx, const Value& in);
Expand Down
9 changes: 1 addition & 8 deletions libspu/kernel/hal/ring.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ Value _cast_type(SPUContext* ctx, const Value& x, const Type& to) {
SPU_THROW("unsupport unary op={} for {}", #Name, in); \
} \
}

IMPL_UNARY_OP(_not)
IMPL_UNARY_OP(_negate)
IMPL_UNARY_OP(_msb)
IMPL_UNARY_OP(_square)

Expand Down Expand Up @@ -438,13 +438,6 @@ Value _equal(SPUContext* ctx, const Value& x, const Value& y) {
_xor(ctx, _less(ctx, y, x), _k1));
}

Value _negate(SPUContext* ctx, const Value& x) {
SPU_TRACE_HAL_LEAF(ctx, x);

// negate(x) = not(x) + 1
return _add(ctx, _not(ctx, x), _constant(ctx, 1, x.shape()));
}

Value _sign(SPUContext* ctx, const Value& x) {
SPU_TRACE_HAL_LEAF(ctx, x);

Expand Down
2 changes: 1 addition & 1 deletion libspu/mpc/ab_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ Value rand_b(SPUContext* ctx, const Shape& shape) {
FORCE_DISPATCH(ctx, shape);
}

Value not_a(SPUContext* ctx, const Value& x) { FORCE_DISPATCH(ctx, x); }
Value negate_a(SPUContext* ctx, const Value& x) { FORCE_DISPATCH(ctx, x); }

Value add_ap(SPUContext* ctx, const Value& x, const Value& y) {
FORCE_DISPATCH(ctx, x, y);
Expand Down
2 changes: 1 addition & 1 deletion libspu/mpc/ab_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ Value msb_a2b(SPUContext* ctx, const Value& x);
Value rand_a(SPUContext* ctx, const Shape& shape);
Value rand_b(SPUContext* ctx, const Shape& shape);

Value not_a(SPUContext* ctx, const Value& x);
Value negate_a(SPUContext* ctx, const Value& x);

Value equal_ap(SPUContext* ctx, const Value& x, const Value& y);
Value equal_aa(SPUContext* ctx, const Value& x, const Value& y);
Expand Down
8 changes: 4 additions & 4 deletions libspu/mpc/ab_api_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ TEST_P(ArithmeticTest, MatMulAV) {
});
}

TEST_P(ArithmeticTest, NotA) {
TEST_P(ArithmeticTest, NegateA) {
const auto factory = std::get<0>(GetParam());
const RuntimeConfig& conf = std::get<1>(GetParam());
const size_t npc = std::get<2>(GetParam());
Expand All @@ -450,15 +450,15 @@ TEST_P(ArithmeticTest, NotA) {

/* WHEN */
auto prev = obj->prot()->getState<Communicator>()->getStats();
auto r_a = not_a(obj.get(), a0);
auto r_a = negate_a(obj.get(), a0);
auto cost = obj->prot()->getState<Communicator>()->getStats() - prev;

auto r_p = a2p(obj.get(), r_a);
auto r_pp = a2p(obj.get(), not_a(obj.get(), a0));
auto r_pp = a2p(obj.get(), negate_a(obj.get(), a0));

/* THEN */
EXPECT_VALUE_EQ(r_p, r_pp);
EXPECT_TRUE(verifyCost(obj->prot()->getKernel("not_a"), "not_a",
EXPECT_TRUE(verifyCost(obj->prot()->getKernel("negate_a"), "negate_a",
conf.field(), kShape, npc, cost));
});
}
Expand Down
12 changes: 1 addition & 11 deletions libspu/mpc/aby3/arithmetic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -300,13 +300,10 @@ NdArrayRef V2A::proc(KernelEvalContext* ctx, const NdArrayRef& in) const {
});
}

NdArrayRef NotA::proc(KernelEvalContext* ctx, const NdArrayRef& in) const {
auto* comm = ctx->getState<Communicator>();
NdArrayRef NegateA::proc(KernelEvalContext* ctx, const NdArrayRef& in) const {
const auto* in_ty = in.eltype().as<AShrTy>();
const auto field = in_ty->field();

auto rank = comm->getRank();

return DISPATCH_ALL_FIELDS(field, [&]() {
using el_t = std::make_unsigned_t<ring2k_t>;
using shr_t = std::array<el_t, 2>;
Expand All @@ -315,16 +312,9 @@ NdArrayRef NotA::proc(KernelEvalContext* ctx, const NdArrayRef& in) const {
NdArrayView<shr_t> _out(out);
NdArrayView<shr_t> _in(in);

// neg(x) = not(x) + 1
// not(x) = neg(x) - 1
pforeach(0, in.numel(), [&](int64_t idx) {
_out[idx][0] = -_in[idx][0];
_out[idx][1] = -_in[idx][1];
if (rank == 0) {
_out[idx][1] -= 1;
} else if (rank == 1) {
_out[idx][0] -= 1;
}
});

return out;
Expand Down
4 changes: 2 additions & 2 deletions libspu/mpc/aby3/arithmetic.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,9 @@ class RandA : public RandKernel {
NdArrayRef proc(KernelEvalContext* ctx, const Shape& shape) const override;
};

class NotA : public UnaryKernel {
class NegateA : public UnaryKernel {
public:
static constexpr char kBindName[] = "not_a";
static constexpr char kBindName[] = "negate_a";

ce::CExpr latency() const override { return ce::Const(0); }

Expand Down
2 changes: 1 addition & 1 deletion libspu/mpc/aby3/protocol.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ void regAby3Protocol(SPUContext* ctx,
aby3::B2P, aby3::P2B, aby3::A2B, // Conversion2
aby3::B2ASelector, /*aby3::B2AByOT, aby3::B2AByPPA*/ // B2A
aby3::CastTypeB, // Cast
aby3::NotA, // Not
aby3::NegateA, // Negate
aby3::AddAP, aby3::AddAA, // Add
aby3::MulAP, aby3::MulAA, aby3::MulA1B, // Mul
aby3::MatMulAP, aby3::MatMulAA, // MatMul
Expand Down
40 changes: 30 additions & 10 deletions libspu/mpc/api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -273,22 +273,42 @@ Value rand_s(SPUContext* ctx, const Shape& shape) {
return rand_a(ctx, shape);
}

// only works for Z2k.
// Neg(x) = Not(x) + 1
// Not(x) = Neg(x) - 1
Value not_v(SPUContext* ctx, const Value& x) {
SPU_TRACE_MPC_DISP(ctx, x);
auto k1 = make_p(ctx, 1, x.shape());
return add_vp(ctx, negate_v(ctx, x), negate_p(ctx, k1));
}

Value not_p(SPUContext* ctx, const Value& x) {
SPU_TRACE_MPC_DISP(ctx, x);
auto k1 = make_p(ctx, 1, x.shape());
return add_pp(ctx, negate_p(ctx, x), negate_p(ctx, k1));
}

Value not_s(SPUContext* ctx, const Value& x) {
SPU_TRACE_MPC_DISP(ctx, x);
if (x.storage_type().isa<BShare>()) {
auto ones = make_p(ctx, -1, x.shape());
return xor_bp(ctx, x, ones);
} else {
SPU_ENFORCE(x.storage_type().isa<Secret>());
auto k1 = make_p(ctx, 1, x.shape());
return add_sp(ctx, negate_s(ctx, x), negate_p(ctx, k1));
}
}

Value negate_s(SPUContext* ctx, const Value& x) {
SPU_TRACE_MPC_DISP(ctx, x);
TRY_DISPATCH(ctx, x);
// TODO: Both A&B could handle not(invert).
// if (x.eltype().isa<BShare>()) {
// return not_b(ctx, x);
//} else {
// SPU_ENFORCE(x.eltype().isa<AShare>());
// return not_a(ctx, x);
//}
return not_a(ctx, _2a(ctx, x));
return negate_a(ctx, _2a(ctx, x));
}

Value not_v(SPUContext* ctx, const Value& x) { FORCE_DISPATCH(ctx, x); }
Value negate_v(SPUContext* ctx, const Value& x) { FORCE_DISPATCH(ctx, x); }

Value not_p(SPUContext* ctx, const Value& x) { FORCE_DISPATCH(ctx, x); }
Value negate_p(SPUContext* ctx, const Value& x) { FORCE_DISPATCH(ctx, x); }

//////////////////////////////////////////////////////////////////////////////

Expand Down
7 changes: 6 additions & 1 deletion libspu/mpc/api.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,11 +91,16 @@ Value make_p(SPUContext* ctx, uint128_t init, const Shape& shape);
Value rand_p(SPUContext* ctx, const Shape& shape);
Value rand_s(SPUContext* ctx, const Shape& shape);

// Compute bitwise_not(invert) of a value in ring 2k space.
// Compute bitwise not of a value.
Value not_p(SPUContext* ctx, const Value& x);
Value not_s(SPUContext* ctx, const Value& x);
Value not_v(SPUContext* ctx, const Value& x);

// Compute negate of a value.
Value negate_p(SPUContext* ctx, const Value& x);
Value negate_s(SPUContext* ctx, const Value& x);
Value negate_v(SPUContext* ctx, const Value& x);

Value msb_p(SPUContext* ctx, const Value& x);
Value msb_s(SPUContext* ctx, const Value& x);
Value msb_v(SPUContext* ctx, const Value& x);
Expand Down
1 change: 1 addition & 0 deletions libspu/mpc/api_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ TEST_BINARY_OP(xor)
TEST_UNARY_OP_V(OP) \
TEST_UNARY_OP_P(OP)

TEST_UNARY_OP(negate)
TEST_UNARY_OP(not )
TEST_UNARY_OP_V(msb)
TEST_UNARY_OP_P(msb)
Expand Down
4 changes: 2 additions & 2 deletions libspu/mpc/cheetah/arithmetic.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,9 @@ class V2A : public UnaryKernel {
NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in) const override;
};

class NotA : public UnaryKernel {
class NegateA : public UnaryKernel {
public:
static constexpr char kBindName[] = "not_a";
static constexpr char kBindName[] = "negate_a";

ce::CExpr latency() const override { return ce::Const(0); }

Expand Down
7 changes: 1 addition & 6 deletions libspu/mpc/cheetah/arithmetic_semi2k.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,8 @@ NdArrayRef V2A::proc(KernelEvalContext* ctx, const NdArrayRef& in) const {
return x.as(makeType<AShrTy>(field));
}

NdArrayRef NotA::proc(KernelEvalContext* ctx, const NdArrayRef& in) const {
auto* comm = ctx->getState<Communicator>();
NdArrayRef NegateA::proc(KernelEvalContext* ctx, const NdArrayRef& in) const {
auto res = ring_neg(in);
if (comm->getRank() == 0) {
const auto field = in.eltype().as<Ring2k>()->field();
ring_add_(res, ring_not(ring_zeros(field, in.shape())));
}

return res.as(in.eltype());
}
Expand Down
2 changes: 1 addition & 1 deletion libspu/mpc/cheetah/protocol.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ void regCheetahProtocol(SPUContext* ctx,
ctx->prot()
->regKernel<cheetah::P2A, cheetah::A2P, cheetah::V2A, cheetah::A2V, //
cheetah::B2P, cheetah::P2B, cheetah::A2B, cheetah::B2A, //
cheetah::NotA, //
cheetah::NegateA, //
cheetah::AddAP, cheetah::AddAA, //
cheetah::MulAP, cheetah::MulAA, cheetah::MulAV, //
cheetah::SquareA, //
Expand Down
14 changes: 7 additions & 7 deletions libspu/mpc/common/pv2k.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,31 +136,31 @@ class RandP : public RandKernel {
}
};

class NotP : public UnaryKernel {
class NegateP : public UnaryKernel {
public:
static constexpr char kBindName[] = "not_p";
static constexpr char kBindName[] = "negate_p";

ce::CExpr latency() const override { return ce::Const(0); }

ce::CExpr comm() const override { return ce::Const(0); }

NdArrayRef proc(KernelEvalContext*, const NdArrayRef& in) const override {
const auto field = in.eltype().as<Ring2k>()->field();
return ring_not(in).as(makeType<Pub2kTy>(field));
return ring_neg(in).as(makeType<Pub2kTy>(field));
}
};

class NotV : public UnaryKernel {
class NegateV : public UnaryKernel {
public:
static constexpr char kBindName[] = "not_v";
static constexpr char kBindName[] = "negate_v";

ce::CExpr latency() const override { return ce::Const(0); }

ce::CExpr comm() const override { return ce::Const(0); }

NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in) const override {
if (isOwner(ctx, in.eltype())) {
return ring_not(in).as(in.eltype());
return ring_neg(in).as(in.eltype());
} else {
return in;
}
Expand Down Expand Up @@ -954,7 +954,7 @@ void regPV2kTypes() {
void regPV2kKernels(Object* obj) {
obj->regKernel<V2P, P2V, //
MakeP, RandP, //
NotV, NotP, //
NegateV, NegateP, //
EqualVVV, EqualVP, EqualPP, //
AddVVV, AddVP, AddPP, //
MulVVV, MulVP, MulPP, //
Expand Down
8 changes: 4 additions & 4 deletions libspu/mpc/ref2k/ref2k.cc
Original file line number Diff line number Diff line change
Expand Up @@ -200,17 +200,17 @@ class Ref2kRandS : public RandKernel {
}
};

class Ref2kNotS : public UnaryKernel {
class Ref2kNegateS : public UnaryKernel {
public:
static constexpr char kBindName[] = "not_s";
static constexpr char kBindName[] = "negate_s";

ce::CExpr latency() const override { return ce::Const(0); }

ce::CExpr comm() const override { return ce::Const(0); }

NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in) const override {
const auto field = in.eltype().as<Ring2k>()->field();
return ring_not(in).as(makeType<Ref2kSecrTy>(field));
return ring_neg(in).as(makeType<Ref2kSecrTy>(field));
}
};

Expand Down Expand Up @@ -488,7 +488,7 @@ void regRef2kProtocol(SPUContext* ctx,
ctx->prot()
->regKernel<Ref2kCommonTypeS, Ref2kCommonTypeV, Ref2kCastTypeS, //
Ref2kP2S, Ref2kS2P, Ref2kV2S, Ref2kS2V, //
Ref2kNotS, //
Ref2kNegateS, //
Ref2kAddSS, Ref2kAddSP, //
Ref2kMulSS, Ref2kMulSP, //
Ref2kMatMulSS, Ref2kMatMulSP, //
Expand Down
24 changes: 1 addition & 23 deletions libspu/mpc/securenn/arithmetic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -132,30 +132,8 @@ NdArrayRef A2P::proc(KernelEvalContext* ctx, const NdArrayRef& in) const {
return out.as(makeType<Pub2kTy>(field));
}

NdArrayRef NotA::proc(KernelEvalContext* ctx, const NdArrayRef& in) const {
auto* comm = ctx->getState<Communicator>();

// First, let's show negate could be locally processed.
// let X = sum(Xi) % M
// let Yi = neg(Xi) = M-Xi
//
// we get
// Y = sum(Yi) % M
// = n*M - sum(Xi) % M
// = -sum(Xi) % M
// = -X % M
//
// 'not' could be processed accordingly.
// not(X)
// = M-1-X # by definition, not is the complement of 2^k
// = neg(X) + M-1
//
NdArrayRef NegateA::proc(KernelEvalContext* ctx, const NdArrayRef& in) const {
auto res = ring_neg(in);
if (comm->getRank() == 0) {
const auto field = in.eltype().as<Ring2k>()->field();
ring_add_(res, ring_not(ring_zeros(field, in.shape())));
}

return res.as(in.eltype());
}

Expand Down
4 changes: 2 additions & 2 deletions libspu/mpc/securenn/arithmetic.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,9 @@ class A2P : public UnaryKernel {
NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in) const override;
};

class NotA : public UnaryKernel {
class NegateA : public UnaryKernel {
public:
static constexpr char kBindName[] = "not_a";
static constexpr char kBindName[] = "negate_a";

ce::CExpr latency() const override { return ce::Const(0); }

Expand Down
Loading

0 comments on commit d51f4e7

Please sign in to comment.