Skip to content

Commit

Permalink
feat: implement MultiPoint struct (#11)
Browse files Browse the repository at this point in the history
  • Loading branch information
twoeths authored Dec 25, 2024
1 parent 6348fe7 commit 75448cf
Show file tree
Hide file tree
Showing 4 changed files with 309 additions and 0 deletions.
142 changes: 142 additions & 0 deletions src/multi_point.zig
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
const std = @import("std");

/// Create MultiPoint struct, this follows non-std implementation of Rust binding
/// equivalent to blst/bindings/rust/src/pippenger-no_std.rs
/// IT: input type, for example PublicKey
/// OT: output type, for example AggregatePublicKey
pub fn createMultiPoint(comptime IT: type, comptime OT: type, it_default_fn: anytype, ot_default_fn: anytype, out_eql_fn: anytype, add_fn: anytype, multi_scalar_mult_fn: anytype, scratch_sizeof_fn: anytype, mult_fn: anytype, generator_fn: anytype, to_affines_fn: anytype, add_or_double_fn: anytype) type {
const MultiPoint = struct {
/// Skip from([]OT) api
/// Rust accepts []OT here which make it convenient for test_add
/// bringing that here makes us deal with allocator
/// instead of that, it accepts []IT, the conversion of []IT to []OT is done at consumer side
pub fn add(points: []*const IT) !OT {
if (points.len == 0) {
return error.ZeroPoints;
}

var result = ot_default_fn();
// consumer usually need to convert []IT to []*IT which is not required in Rust
add_fn(&result, &points[0], points.len);
return result;
}

/// scratch parameter is designed to be reused here
pub fn mult(points: []*const IT, scalars: []*const u8, n_bits: usize, scratch: []u64) !OT {
if (points.len == 0) {
return error.ZeroPoints;
}

const n_points = points.len;
// this is different from Rust but it helps the test passed
if (scalars.len < n_points) {
return error.ScalarLenMismatch;
}

if (scratch.len < (scratch_sizeof_fn(n_points) / 8)) {
return error.ScratchLenMismatch;
}

var result = ot_default_fn();
multi_scalar_mult_fn(&result, &points[0], points.len, &scalars[0], n_bits, &scratch[0]);

return result;
}
};

return struct {
MultiPoint: MultiPoint,

pub fn testAdd() !void {
const n_points = 2000;
const n_bits = 32;
const n_bytes = (n_bits + 7) / 8;

var scalars = [_]u8{0} ** (n_points * n_bytes);

var rng = std.rand.DefaultPrng.init(12345);
rng.random().bytes(scalars[0..]);

var points: [n_points]OT = undefined;
var naive: OT = ot_default_fn();

for (0..n_points) |i| {
mult_fn(&points[i], generator_fn(), &scalars[i * n_bytes], 32);
add_or_double_fn(&naive, &naive, &points[i]);
}

var points_refs: [n_points]*OT = undefined;
for (points[0..], 0..) |*point, i| {
points_refs[i] = point;
}

// convert []OT to []IT
var aff_points_refs: [n_points]*IT = undefined;
var aff_points: [n_points]IT = [_]IT{it_default_fn()} ** n_points;
for (aff_points[0..], 0..) |*point, i| {
aff_points_refs[i] = point;
}

to_affines_fn(aff_points_refs[0], &points_refs, n_points);

const add_res = MultiPoint.add(aff_points_refs[0..]) catch return error.TestAddFailed;
try std.testing.expect(out_eql_fn(&naive, &add_res));
}

pub fn testMult() !void {
const n_points = 2000;
const n_bits = 160;
const n_bytes = (n_bits + 7) / 8;

var scalars = [_]u8{0} ** (n_points * n_bytes);
var rng = std.rand.DefaultPrng.init(12345);
rng.random().bytes(scalars[0..]);

var scalars_refs: [n_points]*const u8 = undefined;
for (0..n_points) |i| {
scalars_refs[i] = &scalars[i * n_bytes];
}

// std.debug.print("scratch_sizeof_fn(n_points) / 8: {}\n", .{scratch_sizeof_fn(n_points) / 8});
var allocator = std.testing.allocator;
const scratch = try allocator.alloc(u64, scratch_sizeof_fn(n_points) / 8);
defer allocator.free(scratch);

var points: [n_points]OT = [_]OT{ot_default_fn()} ** n_points;

// convert []OT to []IT
var aff_points_refs: [n_points]*IT = undefined;
var aff_points: [n_points]IT = [_]IT{it_default_fn()} ** n_points;
for (aff_points[0..], 0..) |*point, i| {
aff_points_refs[i] = point;
}

var naive = ot_default_fn();
var points_refs: [n_points]*OT = undefined;

for (0..n_points) |i| {
mult_fn(&points[i], generator_fn(), &scalars[i * n_bytes], @min(32, n_bits));
points_refs[i] = &points[i];
var t = ot_default_fn();
mult_fn(&t, &points[i], &scalars[i * n_bytes], n_bits);
add_or_double_fn(&naive, &naive, &t);

// TODO: this is not efficient as it contains duplicate works
to_affines_fn(aff_points_refs[0], &points_refs, (i + 1));
if (i < 27) {
const mult_res = MultiPoint.mult(aff_points_refs[0..(i + 1)], scalars_refs[0..], n_bits, scratch) catch return error.TestMultFailed;
try std.testing.expect(out_eql_fn(&naive, &mult_res));
}
}

for (points[0..], 0..) |*point, i| {
points_refs[i] = point;
}

to_affines_fn(aff_points_refs[0], &points_refs, n_points);

const mult_res = MultiPoint.mult(aff_points_refs[0..], scalars_refs[0..], n_bits, scratch) catch return error.TestMultFailed;
try std.testing.expect(out_eql_fn(&naive, &mult_res));
}
};
}
93 changes: 93 additions & 0 deletions src/sig_variant.zig
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ const toBlstError = util.toBlstError;
/// generic implementation for both min_pk and min_sig
/// this is equivalent to Rust binding in blst/bindings/rust/src/lib.rs
pub fn createSigVariant(
// Zig specific default functions
default_pubkey_fn: anytype,
default_agg_pubkey_fn: anytype,
default_sig_fn: anytype,
Expand All @@ -29,6 +30,9 @@ pub fn createSigVariant(
sign_fn: anytype,
pk_eq_fn: anytype,
sig_eq_fn: anytype,
// 2 new zig specific eq functions
agg_pk_eq_fn: anytype,
agg_sig_eq_fn: anytype,
verify_fn: anytype,
pk_in_group_fn: anytype,
pk_to_aff_fn: anytype,
Expand All @@ -55,6 +59,19 @@ pub fn createSigVariant(
pk_is_inf_fn: anytype,
sig_is_inf_fn: anytype,
sig_aggr_in_group_fn: anytype,
// Zig specific multi_points
pk_add_fn: anytype,
pk_multi_scalar_mult_fn: anytype,
pk_scratch_size_of_fn: anytype,
pk_mult_fn: anytype,
pk_generator_fn: anytype,
pk_to_affines_fn: anytype,
sig_add_fn: anytype,
sig_multi_scalar_mult_fn: anytype,
sig_scratch_size_of_fn: anytype,
sig_mult_fn: anytype,
sig_generator_fn: anytype,
sig_to_affines_fn: anytype,
) type {
// TODO: implement MultiPoint
const Pairing = struct {
Expand Down Expand Up @@ -254,6 +271,10 @@ pub fn createSigVariant(

pk_add_or_dbl_aff_fn(&self.point, &self.point, &pk.point);
}

pub fn isEqual(self: *const @This(), other: *const @This()) bool {
return agg_pk_eq_fn(&self.point, &other.point);
}
};

const Signature = struct {
Expand Down Expand Up @@ -512,6 +533,10 @@ pub fn createSigVariant(
pub fn subgroupCheck(self: *const @This()) bool {
return sig_aggr_in_group_fn(&self.point);
}

pub fn isEqual(self: *const @This(), other: *const @This()) bool {
return agg_sig_eq_fn(&self.point, &other.point);
}
};

const SecretKey = struct {
Expand Down Expand Up @@ -646,6 +671,39 @@ pub fn createSigVariant(
}
};

// for PublicKey and AggregatePublicKey
const pk_multi_point = @import("./multi_point.zig").createMultiPoint(
pk_aff_type,
pk_type,
default_pubkey_fn,
default_agg_pubkey_fn,
agg_pk_eq_fn,
pk_add_fn,
pk_multi_scalar_mult_fn,
pk_scratch_size_of_fn,
pk_mult_fn,
pk_generator_fn,
pk_to_affines_fn,
pk_add_or_dbl_fn,
);

const sig_multi_point = @import("./multi_point.zig").createMultiPoint(
sig_aff_type,
sig_type,
default_sig_fn,
default_agg_sig_fn,
agg_sig_eq_fn,
sig_add_fn,
sig_multi_scalar_mult_fn,
sig_scratch_size_of_fn,
sig_mult_fn,
sig_generator_fn,
sig_to_affines_fn,
sig_add_or_dbl_fn,
);

// TODO: consume the above struct to work with public data structures

return struct {
pub fn createSecretKey() type {
return SecretKey;
Expand Down Expand Up @@ -982,6 +1040,41 @@ pub fn createSigVariant(
try std.testing.expect(sig.isEqual(&sig2));
}

/// additional tests in Zig to make sure our wrapped types point to the same memory as the original types
/// for example, given a slice of PublicKey, we can pass pointer to the first element to the C function which expect *const pk_aff_type
pub fn testTypeAlignment() !void {
// alignOf
try std.testing.expect(@alignOf(SecretKey) == @alignOf(c.blst_scalar));
try std.testing.expect(@alignOf(PublicKey) == @alignOf(pk_aff_type));
try std.testing.expect(@alignOf(AggregatePublicKey) == @alignOf(pk_type));
try std.testing.expect(@alignOf(Signature) == @alignOf(sig_aff_type));
try std.testing.expect(@alignOf(AggregateSignature) == @alignOf(sig_type));

// sizeOf
try std.testing.expect(@sizeOf(SecretKey) == @sizeOf(c.blst_scalar));
try std.testing.expect(@sizeOf(PublicKey) == @sizeOf(pk_aff_type));
try std.testing.expect(@sizeOf(AggregatePublicKey) == @sizeOf(pk_type));
try std.testing.expect(@sizeOf(Signature) == @sizeOf(sig_aff_type));
try std.testing.expect(@sizeOf(AggregateSignature) == @sizeOf(sig_type));
}

/// multi point
pub fn testAddPubkey() !void {
try pk_multi_point.testAdd();
}

pub fn testMultPubkey() !void {
try pk_multi_point.testMult();
}

pub fn testAddSig() !void {
try sig_multi_point.testAdd();
}

pub fn testMultSig() !void {
try sig_multi_point.testMult();
}

fn getRandomKey(rng: *Xoshiro256) SecretKey {
var value: [32]u8 = [_]u8{0} ** 32;
rng.random().bytes(value[0..]);
Expand Down
37 changes: 37 additions & 0 deletions src/sig_variant_min_pk.zig
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ const SigVariant = createSigVariant(
c.blst_sign_pk2_in_g1,
c.blst_p1_affine_is_equal,
c.blst_p2_affine_is_equal,
// 2 new zig specific eq functions
c.blst_p1_is_equal,
c.blst_p2_is_equal,
c.blst_core_verify_pk_in_g1,
c.blst_p1_affine_in_g1,
c.blst_p1_to_affine,
Expand All @@ -52,6 +55,19 @@ const SigVariant = createSigVariant(
c.blst_p1_affine_is_inf,
c.blst_p2_affine_is_inf,
c.blst_p2_in_g2,
// multi_point
c.blst_p1s_add,
c.blst_p1s_mult_pippenger,
c.blst_p1s_mult_pippenger_scratch_sizeof,
c.blst_p1_mult,
c.blst_p1_generator,
c.blst_p1s_to_affine,
c.blst_p2s_add,
c.blst_p2s_mult_pippenger,
c.blst_p2s_mult_pippenger_scratch_sizeof,
c.blst_p2_mult,
c.blst_p2_generator,
c.blst_p2s_to_affine,
);

pub const min_pk = struct {
Expand Down Expand Up @@ -82,4 +98,25 @@ test "test_serde" {
try SigVariant.testSerde();
}

// prerequisite for test_multi_point
test "multi_point_test_type_alignment" {
try SigVariant.testTypeAlignment();
}

test "multi_point_test_add_pubkey" {
try SigVariant.testAddPubkey();
}

test "multi_point_test_mult_pubkey" {
try SigVariant.testMultPubkey();
}

test "multi_point_test_add_signature" {
try SigVariant.testAddSig();
}

test "multi_point_test_mult_signature" {
try SigVariant.testMultSig();
}

// TODO test_multi_point
37 changes: 37 additions & 0 deletions src/sig_variant_min_sig.zig
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ const SigVariant = createSigVariant(
c.blst_sign_pk2_in_g2,
c.blst_p2_affine_is_equal,
c.blst_p1_affine_is_equal,
// 2 new zig specific eq functions
c.blst_p2_is_equal,
c.blst_p1_is_equal,
c.blst_core_verify_pk_in_g2,
c.blst_p2_affine_in_g2,
c.blst_p2_to_affine,
Expand All @@ -52,6 +55,19 @@ const SigVariant = createSigVariant(
c.blst_p2_affine_is_inf,
c.blst_p1_affine_is_inf,
c.blst_p1_in_g1,
// multi_point
c.blst_p2s_add,
c.blst_p2s_mult_pippenger,
c.blst_p2s_mult_pippenger_scratch_sizeof,
c.blst_p2_mult,
c.blst_p2_generator,
c.blst_p2s_to_affine,
c.blst_p1s_add,
c.blst_p1s_mult_pippenger,
c.blst_p1s_mult_pippenger_scratch_sizeof,
c.blst_p1_mult,
c.blst_p1_generator,
c.blst_p1s_to_affine,
);

pub const min_sig = struct {
Expand Down Expand Up @@ -82,4 +98,25 @@ test "test_serde" {
try SigVariant.testSerde();
}

// prerequisite for test_multi_point
test "test_type_alignment" {
try SigVariant.testTypeAlignment();
}

test "multi_point_test_add_pubkey" {
try SigVariant.testAddPubkey();
}

test "multi_point_test_mult_pubkey" {
try SigVariant.testMultPubkey();
}

test "multi_point_test_add_signature" {
try SigVariant.testAddSig();
}

test "multi_point_test_mult_signature" {
try SigVariant.testMultSig();
}

// TODO test_multi_point

0 comments on commit 75448cf

Please sign in to comment.