Skip to content

Commit

Permalink
Add fuzzer coverage for BN_mod_exp_mont_consttime_x2
Browse files Browse the repository at this point in the history
  • Loading branch information
pittma committed Oct 30, 2023
1 parent 64ae838 commit 64b3a80
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 5 deletions.
3 changes: 3 additions & 0 deletions crypto/fipsmodule/bn/bn_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -856,6 +856,9 @@ static void TestModExp2(BIGNUMFileTest *t, BN_CTX *ctx) {
bssl::UniquePtr<BIGNUM> ret2(BN_new());
ASSERT_TRUE(ret2);

ASSERT_TRUE(BN_nnmod(a1.get(), a1.get(), m1.get(), ctx));
ASSERT_TRUE(BN_nnmod(a2.get(), a2.get(), m2.get(), ctx));

ASSERT_TRUE(BN_mod_exp_mont_consttime_x2(ret1.get(), a1.get(), e1.get(), m1.get(), NULL,
ret2.get(), a2.get(), e2.get(), m2.get(), NULL,
ctx));
Expand Down
55 changes: 50 additions & 5 deletions fuzz/bn_mod_exp.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,16 +58,23 @@ int mod_exp(BIGNUM *r, const BIGNUM *a, const BIGNUM *p, const BIGNUM *m,
}

extern "C" int LLVMFuzzerTestOneInput(const uint8_t *buf, size_t len) {
CBS cbs, child0, child1, child2;
uint8_t sign;
CBS cbs, child0, child1, child2, child3, child4, child5;
uint8_t sign, sign1;
CBS_init(&cbs, buf, len);
if (!CBS_get_u16_length_prefixed(&cbs, &child0) ||
!CBS_get_u8(&child0, &sign) ||
CBS_len(&child0) == 0 ||
!CBS_get_u16_length_prefixed(&cbs, &child1) ||
CBS_len(&child1) == 0 ||
!CBS_get_u16_length_prefixed(&cbs, &child2) ||
CBS_len(&child2) == 0) {
CBS_len(&child2) == 0 ||
!CBS_get_u16_length_prefixed(&cbs, &child3) ||
!CBS_get_u8(&child3, &sign1) ||
CBS_len(&child3) == 0 ||
!CBS_get_u16_length_prefixed(&cbs, &child4) ||
CBS_len(&child4) == 0 ||
!CBS_get_u16_length_prefixed(&cbs, &child5) ||
CBS_len(&child5) == 0) {
return 0;
}

Expand All @@ -76,7 +83,10 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t *buf, size_t len) {
// fuzzers to spend a lot of time exploring timeouts.
if (CBS_len(&child0) > 512 ||
CBS_len(&child1) > 512 ||
CBS_len(&child2) > 512) {
CBS_len(&child2) > 512 ||
CBS_len(&child3) > 512 ||
CBS_len(&child4) > 512 ||
CBS_len(&child5) > 512) {
return 0;
}

Expand All @@ -88,16 +98,28 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t *buf, size_t len) {
bssl::UniquePtr<BIGNUM> modulus(
BN_bin2bn(CBS_data(&child2), CBS_len(&child2), nullptr));

if (BN_is_zero(modulus.get())) {
bssl::UniquePtr<BIGNUM> base1(
BN_bin2bn(CBS_data(&child3), CBS_len(&child3), nullptr));
BN_set_negative(base.get(), sign % 2);
bssl::UniquePtr<BIGNUM> power1(
BN_bin2bn(CBS_data(&child4), CBS_len(&child4), nullptr));
bssl::UniquePtr<BIGNUM> modulus1(
BN_bin2bn(CBS_data(&child5), CBS_len(&child5), nullptr));

if (BN_is_zero(modulus.get()) || BN_is_zero(modulus1.get())) {
return 0;
}

bssl::UniquePtr<BN_CTX> ctx(BN_CTX_new());
bssl::UniquePtr<BIGNUM> result(BN_new());
bssl::UniquePtr<BIGNUM> expected(BN_new());
bssl::UniquePtr<BIGNUM> result1(BN_new());
bssl::UniquePtr<BIGNUM> expected1(BN_new());
CHECK(ctx);
CHECK(result);
CHECK(expected);
CHECK(result1);
CHECK(expected1);

CHECK(mod_exp(expected.get(), base.get(), power.get(), modulus.get(),
ctx.get()));
Expand All @@ -117,6 +139,29 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t *buf, size_t len) {
CHECK(BN_mod_exp_mont_consttime(result.get(), base.get(), power.get(),
modulus.get(), ctx.get(), mont.get()));
CHECK(BN_cmp(result.get(), expected.get()) == 0);

// If modulus1 is also odd, use the AVX-512 implementation.
if (BN_is_odd(modulus1.get())) {
bssl::UniquePtr<BN_MONT_CTX> mont1(
BN_MONT_CTX_new_for_modulus(modulus1.get(), ctx.get()));
CHECK(mont1);

// Calculate the expected result for the second set of inputs.
CHECK(mod_exp(expected1.get(), base1.get(), power1.get(), modulus1.get(),
ctx.get()));

// Normalize them.
CHECK(BN_nnmod(base1.get(), base1.get(), modulus1.get(), ctx.get()));

// Run the x2 implementation and compare the results.
CHECK(BN_mod_exp_mont_consttime_x2(result.get(), base.get(), power.get(),
modulus.get(), mont.get(),
result1.get(), base1.get(), power1.get(),
modulus1.get(), mont1.get(),
ctx.get()));
CHECK(BN_cmp(result.get(), expected.get()) == 0);
CHECK(BN_cmp(result1.get(), expected1.get()) == 0);
}
}

uint8_t *data = (uint8_t *)OPENSSL_malloc(BN_num_bytes(result.get()));
Expand Down

0 comments on commit 64b3a80

Please sign in to comment.