diff --git a/crypto/fipsmodule/bn/bn_test.cc b/crypto/fipsmodule/bn/bn_test.cc index 2f4c642ed12..b837b39b54c 100644 --- a/crypto/fipsmodule/bn/bn_test.cc +++ b/crypto/fipsmodule/bn/bn_test.cc @@ -856,6 +856,9 @@ static void TestModExp2(BIGNUMFileTest *t, BN_CTX *ctx) { bssl::UniquePtr ret2(BN_new()); ASSERT_TRUE(ret2); + ASSERT_TRUE(BN_nnmod(a1.get(), a1.get(), m1.get(), ctx)); + ASSERT_TRUE(BN_nnmod(a2.get(), a2.get(), m2.get(), ctx)); + ASSERT_TRUE(BN_mod_exp_mont_consttime_x2(ret1.get(), a1.get(), e1.get(), m1.get(), NULL, ret2.get(), a2.get(), e2.get(), m2.get(), NULL, ctx)); diff --git a/fuzz/bn_mod_exp.cc b/fuzz/bn_mod_exp.cc index 0bfa5a85c72..f34708fb5a8 100644 --- a/fuzz/bn_mod_exp.cc +++ b/fuzz/bn_mod_exp.cc @@ -58,8 +58,8 @@ int mod_exp(BIGNUM *r, const BIGNUM *a, const BIGNUM *p, const BIGNUM *m, } extern "C" int LLVMFuzzerTestOneInput(const uint8_t *buf, size_t len) { - CBS cbs, child0, child1, child2; - uint8_t sign; + CBS cbs, child0, child1, child2, child3, child4, child5; + uint8_t sign, sign1; CBS_init(&cbs, buf, len); if (!CBS_get_u16_length_prefixed(&cbs, &child0) || !CBS_get_u8(&child0, &sign) || @@ -67,7 +67,14 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t *buf, size_t len) { !CBS_get_u16_length_prefixed(&cbs, &child1) || CBS_len(&child1) == 0 || !CBS_get_u16_length_prefixed(&cbs, &child2) || - CBS_len(&child2) == 0) { + CBS_len(&child2) == 0 || + !CBS_get_u16_length_prefixed(&cbs, &child3) || + !CBS_get_u8(&child3, &sign1) || + CBS_len(&child3) == 0 || + !CBS_get_u16_length_prefixed(&cbs, &child4) || + CBS_len(&child4) == 0 || + !CBS_get_u16_length_prefixed(&cbs, &child5) || + CBS_len(&child5) == 0) { return 0; } @@ -76,7 +83,10 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t *buf, size_t len) { // fuzzers to spend a lot of time exploring timeouts. if (CBS_len(&child0) > 512 || CBS_len(&child1) > 512 || - CBS_len(&child2) > 512) { + CBS_len(&child2) > 512 || + CBS_len(&child3) > 512 || + CBS_len(&child4) > 512 || + CBS_len(&child5) > 512) { return 0; } @@ -88,16 +98,28 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t *buf, size_t len) { bssl::UniquePtr modulus( BN_bin2bn(CBS_data(&child2), CBS_len(&child2), nullptr)); - if (BN_is_zero(modulus.get())) { + bssl::UniquePtr base1( + BN_bin2bn(CBS_data(&child3), CBS_len(&child3), nullptr)); + BN_set_negative(base.get(), sign % 2); + bssl::UniquePtr power1( + BN_bin2bn(CBS_data(&child4), CBS_len(&child4), nullptr)); + bssl::UniquePtr modulus1( + BN_bin2bn(CBS_data(&child5), CBS_len(&child5), nullptr)); + + if (BN_is_zero(modulus.get()) || BN_is_zero(modulus1.get())) { return 0; } bssl::UniquePtr ctx(BN_CTX_new()); bssl::UniquePtr result(BN_new()); bssl::UniquePtr expected(BN_new()); + bssl::UniquePtr result1(BN_new()); + bssl::UniquePtr expected1(BN_new()); CHECK(ctx); CHECK(result); CHECK(expected); + CHECK(result1); + CHECK(expected1); CHECK(mod_exp(expected.get(), base.get(), power.get(), modulus.get(), ctx.get())); @@ -117,6 +139,29 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t *buf, size_t len) { CHECK(BN_mod_exp_mont_consttime(result.get(), base.get(), power.get(), modulus.get(), ctx.get(), mont.get())); CHECK(BN_cmp(result.get(), expected.get()) == 0); + + // If modulus1 is also odd, use the AVX-512 implementation. + if (BN_is_odd(modulus1.get())) { + bssl::UniquePtr mont1( + BN_MONT_CTX_new_for_modulus(modulus1.get(), ctx.get())); + CHECK(mont1); + + // Calculate the expected result for the second set of inputs. + CHECK(mod_exp(expected1.get(), base1.get(), power1.get(), modulus1.get(), + ctx.get())); + + // Normalize them. + CHECK(BN_nnmod(base1.get(), base1.get(), modulus1.get(), ctx.get())); + + // Run the x2 implementation and compare the results. + CHECK(BN_mod_exp_mont_consttime_x2(result.get(), base.get(), power.get(), + modulus.get(), mont.get(), + result1.get(), base1.get(), power1.get(), + modulus1.get(), mont1.get(), + ctx.get())); + CHECK(BN_cmp(result.get(), expected.get()) == 0); + CHECK(BN_cmp(result1.get(), expected1.get()) == 0); + } } uint8_t *data = (uint8_t *)OPENSSL_malloc(BN_num_bytes(result.get()));