diff --git a/src/multi_point.zig b/src/multi_point.zig new file mode 100644 index 0000000..74282b4 --- /dev/null +++ b/src/multi_point.zig @@ -0,0 +1,142 @@ +const std = @import("std"); + +/// Create MultiPoint struct, this follows non-std implementation of Rust binding +/// equivalent to blst/bindings/rust/src/pippenger-no_std.rs +/// IT: input type, for example PublicKey +/// OT: output type, for example AggregatePublicKey +pub fn createMultiPoint(comptime IT: type, comptime OT: type, it_default_fn: anytype, ot_default_fn: anytype, out_eql_fn: anytype, add_fn: anytype, multi_scalar_mult_fn: anytype, scratch_sizeof_fn: anytype, mult_fn: anytype, generator_fn: anytype, to_affines_fn: anytype, add_or_double_fn: anytype) type { + const MultiPoint = struct { + /// Skip from([]OT) api + /// Rust accepts []OT here which make it convenient for test_add + /// bringing that here makes us deal with allocator + /// instead of that, it accepts []IT, the conversion of []IT to []OT is done at consumer side + pub fn add(points: []*const IT) !OT { + if (points.len == 0) { + return error.ZeroPoints; + } + + var result = ot_default_fn(); + // consumer usually need to convert []IT to []*IT which is not required in Rust + add_fn(&result, &points[0], points.len); + return result; + } + + /// scratch parameter is designed to be reused here + pub fn mult(points: []*const IT, scalars: []*const u8, n_bits: usize, scratch: []u64) !OT { + if (points.len == 0) { + return error.ZeroPoints; + } + + const n_points = points.len; + // this is different from Rust but it helps the test passed + if (scalars.len < n_points) { + return error.ScalarLenMismatch; + } + + if (scratch.len < (scratch_sizeof_fn(n_points) / 8)) { + return error.ScratchLenMismatch; + } + + var result = ot_default_fn(); + multi_scalar_mult_fn(&result, &points[0], points.len, &scalars[0], n_bits, &scratch[0]); + + return result; + } + }; + + return struct { + MultiPoint: MultiPoint, + + pub fn testAdd() !void { + const n_points = 2000; + const n_bits = 32; + const n_bytes = (n_bits + 7) / 8; + + var scalars = [_]u8{0} ** (n_points * n_bytes); + + var rng = std.rand.DefaultPrng.init(12345); + rng.random().bytes(scalars[0..]); + + var points: [n_points]OT = undefined; + var naive: OT = ot_default_fn(); + + for (0..n_points) |i| { + mult_fn(&points[i], generator_fn(), &scalars[i * n_bytes], 32); + add_or_double_fn(&naive, &naive, &points[i]); + } + + var points_refs: [n_points]*OT = undefined; + for (points[0..], 0..) |*point, i| { + points_refs[i] = point; + } + + // convert []OT to []IT + var aff_points_refs: [n_points]*IT = undefined; + var aff_points: [n_points]IT = [_]IT{it_default_fn()} ** n_points; + for (aff_points[0..], 0..) |*point, i| { + aff_points_refs[i] = point; + } + + to_affines_fn(aff_points_refs[0], &points_refs, n_points); + + const add_res = MultiPoint.add(aff_points_refs[0..]) catch return error.TestAddFailed; + try std.testing.expect(out_eql_fn(&naive, &add_res)); + } + + pub fn testMult() !void { + const n_points = 2000; + const n_bits = 160; + const n_bytes = (n_bits + 7) / 8; + + var scalars = [_]u8{0} ** (n_points * n_bytes); + var rng = std.rand.DefaultPrng.init(12345); + rng.random().bytes(scalars[0..]); + + var scalars_refs: [n_points]*const u8 = undefined; + for (0..n_points) |i| { + scalars_refs[i] = &scalars[i * n_bytes]; + } + + // std.debug.print("scratch_sizeof_fn(n_points) / 8: {}\n", .{scratch_sizeof_fn(n_points) / 8}); + var allocator = std.testing.allocator; + const scratch = try allocator.alloc(u64, scratch_sizeof_fn(n_points) / 8); + defer allocator.free(scratch); + + var points: [n_points]OT = [_]OT{ot_default_fn()} ** n_points; + + // convert []OT to []IT + var aff_points_refs: [n_points]*IT = undefined; + var aff_points: [n_points]IT = [_]IT{it_default_fn()} ** n_points; + for (aff_points[0..], 0..) |*point, i| { + aff_points_refs[i] = point; + } + + var naive = ot_default_fn(); + var points_refs: [n_points]*OT = undefined; + + for (0..n_points) |i| { + mult_fn(&points[i], generator_fn(), &scalars[i * n_bytes], @min(32, n_bits)); + points_refs[i] = &points[i]; + var t = ot_default_fn(); + mult_fn(&t, &points[i], &scalars[i * n_bytes], n_bits); + add_or_double_fn(&naive, &naive, &t); + + // TODO: this is not efficient as it contains duplicate works + to_affines_fn(aff_points_refs[0], &points_refs, (i + 1)); + if (i < 27) { + const mult_res = MultiPoint.mult(aff_points_refs[0..(i + 1)], scalars_refs[0..], n_bits, scratch) catch return error.TestMultFailed; + try std.testing.expect(out_eql_fn(&naive, &mult_res)); + } + } + + for (points[0..], 0..) |*point, i| { + points_refs[i] = point; + } + + to_affines_fn(aff_points_refs[0], &points_refs, n_points); + + const mult_res = MultiPoint.mult(aff_points_refs[0..], scalars_refs[0..], n_bits, scratch) catch return error.TestMultFailed; + try std.testing.expect(out_eql_fn(&naive, &mult_res)); + } + }; +} diff --git a/src/sig_variant.zig b/src/sig_variant.zig index e914837..9636457 100644 --- a/src/sig_variant.zig +++ b/src/sig_variant.zig @@ -15,6 +15,7 @@ const toBlstError = util.toBlstError; /// generic implementation for both min_pk and min_sig /// this is equivalent to Rust binding in blst/bindings/rust/src/lib.rs pub fn createSigVariant( + // Zig specific default functions default_pubkey_fn: anytype, default_agg_pubkey_fn: anytype, default_sig_fn: anytype, @@ -29,6 +30,9 @@ pub fn createSigVariant( sign_fn: anytype, pk_eq_fn: anytype, sig_eq_fn: anytype, + // 2 new zig specific eq functions + agg_pk_eq_fn: anytype, + agg_sig_eq_fn: anytype, verify_fn: anytype, pk_in_group_fn: anytype, pk_to_aff_fn: anytype, @@ -55,6 +59,19 @@ pub fn createSigVariant( pk_is_inf_fn: anytype, sig_is_inf_fn: anytype, sig_aggr_in_group_fn: anytype, + // Zig specific multi_points + pk_add_fn: anytype, + pk_multi_scalar_mult_fn: anytype, + pk_scratch_size_of_fn: anytype, + pk_mult_fn: anytype, + pk_generator_fn: anytype, + pk_to_affines_fn: anytype, + sig_add_fn: anytype, + sig_multi_scalar_mult_fn: anytype, + sig_scratch_size_of_fn: anytype, + sig_mult_fn: anytype, + sig_generator_fn: anytype, + sig_to_affines_fn: anytype, ) type { // TODO: implement MultiPoint const Pairing = struct { @@ -254,6 +271,10 @@ pub fn createSigVariant( pk_add_or_dbl_aff_fn(&self.point, &self.point, &pk.point); } + + pub fn isEqual(self: *const @This(), other: *const @This()) bool { + return agg_pk_eq_fn(&self.point, &other.point); + } }; const Signature = struct { @@ -512,6 +533,10 @@ pub fn createSigVariant( pub fn subgroupCheck(self: *const @This()) bool { return sig_aggr_in_group_fn(&self.point); } + + pub fn isEqual(self: *const @This(), other: *const @This()) bool { + return agg_sig_eq_fn(&self.point, &other.point); + } }; const SecretKey = struct { @@ -646,6 +671,39 @@ pub fn createSigVariant( } }; + // for PublicKey and AggregatePublicKey + const pk_multi_point = @import("./multi_point.zig").createMultiPoint( + pk_aff_type, + pk_type, + default_pubkey_fn, + default_agg_pubkey_fn, + agg_pk_eq_fn, + pk_add_fn, + pk_multi_scalar_mult_fn, + pk_scratch_size_of_fn, + pk_mult_fn, + pk_generator_fn, + pk_to_affines_fn, + pk_add_or_dbl_fn, + ); + + const sig_multi_point = @import("./multi_point.zig").createMultiPoint( + sig_aff_type, + sig_type, + default_sig_fn, + default_agg_sig_fn, + agg_sig_eq_fn, + sig_add_fn, + sig_multi_scalar_mult_fn, + sig_scratch_size_of_fn, + sig_mult_fn, + sig_generator_fn, + sig_to_affines_fn, + sig_add_or_dbl_fn, + ); + + // TODO: consume the above struct to work with public data structures + return struct { pub fn createSecretKey() type { return SecretKey; @@ -982,6 +1040,41 @@ pub fn createSigVariant( try std.testing.expect(sig.isEqual(&sig2)); } + /// additional tests in Zig to make sure our wrapped types point to the same memory as the original types + /// for example, given a slice of PublicKey, we can pass pointer to the first element to the C function which expect *const pk_aff_type + pub fn testTypeAlignment() !void { + // alignOf + try std.testing.expect(@alignOf(SecretKey) == @alignOf(c.blst_scalar)); + try std.testing.expect(@alignOf(PublicKey) == @alignOf(pk_aff_type)); + try std.testing.expect(@alignOf(AggregatePublicKey) == @alignOf(pk_type)); + try std.testing.expect(@alignOf(Signature) == @alignOf(sig_aff_type)); + try std.testing.expect(@alignOf(AggregateSignature) == @alignOf(sig_type)); + + // sizeOf + try std.testing.expect(@sizeOf(SecretKey) == @sizeOf(c.blst_scalar)); + try std.testing.expect(@sizeOf(PublicKey) == @sizeOf(pk_aff_type)); + try std.testing.expect(@sizeOf(AggregatePublicKey) == @sizeOf(pk_type)); + try std.testing.expect(@sizeOf(Signature) == @sizeOf(sig_aff_type)); + try std.testing.expect(@sizeOf(AggregateSignature) == @sizeOf(sig_type)); + } + + /// multi point + pub fn testAddPubkey() !void { + try pk_multi_point.testAdd(); + } + + pub fn testMultPubkey() !void { + try pk_multi_point.testMult(); + } + + pub fn testAddSig() !void { + try sig_multi_point.testAdd(); + } + + pub fn testMultSig() !void { + try sig_multi_point.testMult(); + } + fn getRandomKey(rng: *Xoshiro256) SecretKey { var value: [32]u8 = [_]u8{0} ** 32; rng.random().bytes(value[0..]); diff --git a/src/sig_variant_min_pk.zig b/src/sig_variant_min_pk.zig index 7169d7e..b70fb4d 100644 --- a/src/sig_variant_min_pk.zig +++ b/src/sig_variant_min_pk.zig @@ -26,6 +26,9 @@ const SigVariant = createSigVariant( c.blst_sign_pk2_in_g1, c.blst_p1_affine_is_equal, c.blst_p2_affine_is_equal, + // 2 new zig specific eq functions + c.blst_p1_is_equal, + c.blst_p2_is_equal, c.blst_core_verify_pk_in_g1, c.blst_p1_affine_in_g1, c.blst_p1_to_affine, @@ -52,6 +55,19 @@ const SigVariant = createSigVariant( c.blst_p1_affine_is_inf, c.blst_p2_affine_is_inf, c.blst_p2_in_g2, + // multi_point + c.blst_p1s_add, + c.blst_p1s_mult_pippenger, + c.blst_p1s_mult_pippenger_scratch_sizeof, + c.blst_p1_mult, + c.blst_p1_generator, + c.blst_p1s_to_affine, + c.blst_p2s_add, + c.blst_p2s_mult_pippenger, + c.blst_p2s_mult_pippenger_scratch_sizeof, + c.blst_p2_mult, + c.blst_p2_generator, + c.blst_p2s_to_affine, ); pub const min_pk = struct { @@ -82,4 +98,25 @@ test "test_serde" { try SigVariant.testSerde(); } +// prerequisite for test_multi_point +test "multi_point_test_type_alignment" { + try SigVariant.testTypeAlignment(); +} + +test "multi_point_test_add_pubkey" { + try SigVariant.testAddPubkey(); +} + +test "multi_point_test_mult_pubkey" { + try SigVariant.testMultPubkey(); +} + +test "multi_point_test_add_signature" { + try SigVariant.testAddSig(); +} + +test "multi_point_test_mult_signature" { + try SigVariant.testMultSig(); +} + // TODO test_multi_point diff --git a/src/sig_variant_min_sig.zig b/src/sig_variant_min_sig.zig index 3b1c92d..6ce35de 100644 --- a/src/sig_variant_min_sig.zig +++ b/src/sig_variant_min_sig.zig @@ -26,6 +26,9 @@ const SigVariant = createSigVariant( c.blst_sign_pk2_in_g2, c.blst_p2_affine_is_equal, c.blst_p1_affine_is_equal, + // 2 new zig specific eq functions + c.blst_p2_is_equal, + c.blst_p1_is_equal, c.blst_core_verify_pk_in_g2, c.blst_p2_affine_in_g2, c.blst_p2_to_affine, @@ -52,6 +55,19 @@ const SigVariant = createSigVariant( c.blst_p2_affine_is_inf, c.blst_p1_affine_is_inf, c.blst_p1_in_g1, + // multi_point + c.blst_p2s_add, + c.blst_p2s_mult_pippenger, + c.blst_p2s_mult_pippenger_scratch_sizeof, + c.blst_p2_mult, + c.blst_p2_generator, + c.blst_p2s_to_affine, + c.blst_p1s_add, + c.blst_p1s_mult_pippenger, + c.blst_p1s_mult_pippenger_scratch_sizeof, + c.blst_p1_mult, + c.blst_p1_generator, + c.blst_p1s_to_affine, ); pub const min_sig = struct { @@ -82,4 +98,25 @@ test "test_serde" { try SigVariant.testSerde(); } +// prerequisite for test_multi_point +test "test_type_alignment" { + try SigVariant.testTypeAlignment(); +} + +test "multi_point_test_add_pubkey" { + try SigVariant.testAddPubkey(); +} + +test "multi_point_test_mult_pubkey" { + try SigVariant.testMultPubkey(); +} + +test "multi_point_test_add_signature" { + try SigVariant.testAddSig(); +} + +test "multi_point_test_mult_signature" { + try SigVariant.testMultSig(); +} + // TODO test_multi_point