diff --git a/cardano-crypto-class/src/Cardano/Crypto/EllipticCurve/BLS12_381.hs b/cardano-crypto-class/src/Cardano/Crypto/EllipticCurve/BLS12_381.hs index 7f4e83df0..cd9121bec 100644 --- a/cardano-crypto-class/src/Cardano/Crypto/EllipticCurve/BLS12_381.hs +++ b/cardano-crypto-class/src/Cardano/Crypto/EllipticCurve/BLS12_381.hs @@ -23,6 +23,7 @@ module Cardano.Crypto.EllipticCurve.BLS12_381 ( blsMult, blsCneg, blsNeg, + blsMSM, blsCompress, blsSerialize, blsUncompress, diff --git a/cardano-crypto-class/src/Cardano/Crypto/EllipticCurve/BLS12_381/Internal.hs b/cardano-crypto-class/src/Cardano/Crypto/EllipticCurve/BLS12_381/Internal.hs index e7aacdbf0..e813db81c 100644 --- a/cardano-crypto-class/src/Cardano/Crypto/EllipticCurve/BLS12_381/Internal.hs +++ b/cardano-crypto-class/src/Cardano/Crypto/EllipticCurve/BLS12_381/Internal.hs @@ -1,4 +1,5 @@ {-# LANGUAGE BangPatterns #-} +{-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE ForeignFunctionInterface #-} @@ -6,6 +7,9 @@ {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} +#if MIN_VERSION_base(4,20,0) +{-# OPTIONS_GHC -Wno-x-data-list-nonempty-unzip #-} +#endif module Cardano.Crypto.EllipticCurve.BLS12_381.Internal ( -- * Unsafe Types @@ -54,6 +58,8 @@ module Cardano.Crypto.EllipticCurve.BLS12_381.Internal ( c_blst_add_or_double, c_blst_mult, c_blst_cneg, + c_blst_scratch_sizeof, + c_blst_mult_pippenger, c_blst_hash, c_blst_compress, c_blst_serialize, @@ -129,6 +135,7 @@ module Cardano.Crypto.EllipticCurve.BLS12_381.Internal ( blsMult, blsCneg, blsNeg, + blsMSM, blsCompress, blsSerialize, blsUncompress, @@ -165,11 +172,19 @@ import Data.ByteString (ByteString) import qualified Data.ByteString as BS import qualified Data.ByteString.Internal as BSI import qualified Data.ByteString.Unsafe as BSU +import qualified Data.List.NonEmpty as NonEmpty + +#if MIN_VERSION_base(4,22,0) +import qualified Data.Functor (unzip) +#endif + import Data.Proxy (Proxy (..)) import Data.Void +import Foreign (poke, sizeOf) import Foreign.C.String import Foreign.C.Types import Foreign.ForeignPtr +import Foreign.Marshal (advancePtr) import Foreign.Marshal.Alloc (allocaBytes) import Foreign.Marshal.Utils (copyBytes) import Foreign.Ptr (Ptr, castPtr, nullPtr, plusPtr) @@ -189,10 +204,14 @@ type Point1Ptr = PointPtr Curve1 type Point2Ptr = PointPtr Curve2 newtype AffinePtr curve = AffinePtr (Ptr Void) +newtype AffinePtrVector curve = AffinePtrVector (Ptr Void) type Affine1Ptr = AffinePtr Curve1 type Affine2Ptr = AffinePtr Curve2 +type Affine1PtrVector = AffinePtrVector Curve1 +type Affine2PtrVector = AffinePtrVector Curve2 + newtype PTPtr = PTPtr (Ptr Void) unsafePointFromPointPtr :: PointPtr curve -> Point curve @@ -288,6 +307,22 @@ withNewAffine_ = fmap fst . withNewAffine withNewAffine' :: BLS curve => (AffinePtr curve -> IO a) -> IO (Affine curve) withNewAffine' = fmap snd . withNewAffine +withAffineVector :: NonEmpty.NonEmpty (Affine curve) -> (AffinePtrVector curve -> IO a) -> IO a +withAffineVector affines go = do + let numAffines = NonEmpty.length affines + sizeReference = sizeOf (undefined :: Ptr ()) + allocaBytes (numAffines * sizeReference) $ \ptr -> + -- The accumulate function ensures that each `withAffine` call is properly nested. + -- This guarantees that the foreign pointers remain valid while we populate `ptr`. + -- If we instead used `zipWithM_` for example, the pointers could be finalized too early. + -- By nesting `withAffine` calls in `accumulate`, we ensure they stay in scope until `go` is executed. + let accumulate [] = go (AffinePtrVector (castPtr ptr)) + accumulate ((ix, affine) : rest) = + withAffine affine $ \(AffinePtr aPtr) -> do + poke (ptr `advancePtr` ix) aPtr + accumulate rest + in accumulate (zip [0 ..] (NonEmpty.toList affines)) + withPT :: PT -> (PTPtr -> IO a) -> IO a withPT (PT pt) go = withForeignPtr pt (go . PTPtr) @@ -317,6 +352,10 @@ class BLS curve where c_blst_mult :: PointPtr curve -> PointPtr curve -> ScalarPtr -> CSize -> IO () c_blst_cneg :: PointPtr curve -> Bool -> IO () + c_blst_scratch_sizeof :: Proxy curve -> CSize -> CSize + c_blst_mult_pippenger :: + PointPtr curve -> AffinePtrVector curve -> CSize -> ScalarPtrVector -> CSize -> ScratchPtr -> IO () + c_blst_hash :: PointPtr curve -> Ptr CChar -> CSize -> Ptr CChar -> CSize -> Ptr CChar -> CSize -> IO () c_blst_compress :: Ptr CChar -> PointPtr curve -> IO () @@ -345,6 +384,9 @@ instance BLS Curve1 where c_blst_mult = c_blst_p1_mult c_blst_cneg = c_blst_p1_cneg + c_blst_scratch_sizeof _ = c_blst_p1s_mult_pippenger_scratch_sizeof + c_blst_mult_pippenger = c_blst_p1s_mult_pippenger + c_blst_hash = c_blst_hash_to_g1 c_blst_compress = c_blst_p1_compress c_blst_serialize = c_blst_p1_serialize @@ -373,6 +415,9 @@ instance BLS Curve2 where c_blst_mult = c_blst_p2_mult c_blst_cneg = c_blst_p2_cneg + c_blst_scratch_sizeof _ = c_blst_p2s_mult_pippenger_scratch_sizeof + c_blst_mult_pippenger = c_blst_p2s_mult_pippenger + c_blst_hash = c_blst_hash_to_g2 c_blst_compress = c_blst_p2_compress c_blst_serialize = c_blst_p2_serialize @@ -428,6 +473,22 @@ withNewScalar_ = fmap fst . withNewScalar withNewScalar' :: (ScalarPtr -> IO a) -> IO Scalar withNewScalar' = fmap snd . withNewScalar +withScalarVector :: NonEmpty.NonEmpty Scalar -> (ScalarPtrVector -> IO a) -> IO a +withScalarVector scalars go = do + let numScalars = NonEmpty.length scalars + sizeReference = sizeOf (undefined :: Ptr ()) + allocaBytes (numScalars * sizeReference) $ \ptr -> + -- The accumulate function ensures that each `withScalar` call is properly nested. + -- This guarantees that the foreign pointers remain valid while we populate `ptr`. + -- If we instead used `zipWithM_` for example, the pointers could be finalized too early. + -- By nesting `withScalar` calls in `accumulate`, we ensure they stay in scope until `go` is executed. + let accumulate [] = go (ScalarPtrVector (castPtr ptr)) + accumulate ((ix, scalar) : rest) = + withScalar scalar $ \(ScalarPtr sPtr) -> do + poke (ptr `advancePtr` ix) sPtr + accumulate rest + in accumulate (zip [0 ..] (NonEmpty.toList scalars)) + cloneScalar :: Scalar -> IO Scalar cloneScalar (Scalar a) = do b <- mallocForeignPtrBytes sizeScalar @@ -512,7 +573,9 @@ scalarFromInteger n = do ---- Unsafe types newtype ScalarPtr = ScalarPtr (Ptr Void) +newtype ScalarPtrVector = ScalarPtrVector (Ptr Void) newtype FrPtr = FrPtr (Ptr Void) +newtype ScratchPtr = ScratchPtr (Ptr Void) ---- Raw Scalar / Fr functions @@ -555,6 +618,12 @@ foreign import ccall "blst_p1_generator" c_blst_p1_generator :: Point1Ptr foreign import ccall "blst_p1_is_equal" c_blst_p1_is_equal :: Point1Ptr -> Point1Ptr -> IO Bool foreign import ccall "blst_p1_is_inf" c_blst_p1_is_inf :: Point1Ptr -> IO Bool +foreign import ccall "blst_p1s_mult_pippenger_scratch_sizeof" + c_blst_p1s_mult_pippenger_scratch_sizeof :: CSize -> CSize +foreign import ccall "blst_p1s_mult_pippenger" + c_blst_p1s_mult_pippenger :: + Point1Ptr -> Affine1PtrVector -> CSize -> ScalarPtrVector -> CSize -> ScratchPtr -> IO () + ---- Raw Point2 functions foreign import ccall "size_blst_p2" c_size_blst_p2 :: CSize @@ -582,6 +651,12 @@ foreign import ccall "blst_p2_generator" c_blst_p2_generator :: Point2Ptr foreign import ccall "blst_p2_is_equal" c_blst_p2_is_equal :: Point2Ptr -> Point2Ptr -> IO Bool foreign import ccall "blst_p2_is_inf" c_blst_p2_is_inf :: Point2Ptr -> IO Bool +foreign import ccall "blst_p2s_mult_pippenger_scratch_sizeof" + c_blst_p2s_mult_pippenger_scratch_sizeof :: CSize -> CSize +foreign import ccall "blst_p2s_mult_pippenger" + c_blst_p2s_mult_pippenger :: + Point2Ptr -> Affine2PtrVector -> CSize -> ScalarPtrVector -> CSize -> ScratchPtr -> IO () + ---- Affine operations foreign import ccall "size_blst_affine1" c_size_blst_affine1 :: CSize @@ -824,7 +899,8 @@ blsZero = error $ "Unexpected failure deserialising point at infinity on BLS12_381.G1: " ++ show err Right infinity -> infinity -- The zero point on this curve is chosen to be the point at infinity. - ---- Scalar / Fr operations + +---- Scalar / Fr operations scalarFromFr :: Fr -> IO Scalar scalarFromFr fr = @@ -875,6 +951,51 @@ scalarCanonical scalar = unsafePerformIO $ withScalar scalar c_blst_scalar_fr_check +---- MSM operations + +-- | A small convenience helper for unzipping a 'NonEmpty' list of @(p, i)@ pairs +-- into two 'NonEmpty' lists. We dispatch to 'Data.Functor.unzip' when base >= 4.22, +-- and to 'NonEmpty.unzip' otherwise. Having this in one place under CPP avoids +-- duplicating large code blocks (and also keeps Fourmolu happy). +unzipPointsAndScalars :: + NonEmpty.NonEmpty (p, i) -> + (NonEmpty.NonEmpty p, NonEmpty.NonEmpty i) +#if MIN_VERSION_base(4,22,0) +unzipPointsAndScalars = Data.Functor.unzip +#else +unzipPointsAndScalars = NonEmpty.unzip +#endif + +-- | Multi-scalar multiplication using the Pippenger algorithm. +-- The scalar will be brought into the range of modular arithmetic +-- by means of a modulo operation over the 'scalarPeriod'. +-- Negative number will also be brought to the range +-- [0, 'scalarPeriod' - 1] via modular reduction. +blsMSM :: forall curve. BLS curve => NonEmpty.NonEmpty (Point curve, Integer) -> Point curve +blsMSM psAndSs = + unsafePerformIO $ do + let (points, scalarsAsInt) = unzipPointsAndScalars psAndSs + numPoints = length points + nonEmptyAffinePoints = fmap toAffine points + nonEmptyScalars <- mapM scalarFromInteger scalarsAsInt + + withAffineVector nonEmptyAffinePoints $ \affineVectorPtr -> do + withScalarVector nonEmptyScalars $ \scalarVectorPtr -> do + let numPoints' :: CSize + numPoints' = fromIntegral numPoints + scratchSize :: Int + scratchSize = fromIntegral @CSize @Int $ c_blst_scratch_sizeof (Proxy @curve) numPoints' + + allocaBytes scratchSize $ \scratchPtr -> do + withNewPoint' $ \resultPtr -> do + c_blst_mult_pippenger + resultPtr + affineVectorPtr + (fromIntegral numPoints) + scalarVectorPtr + 255 + (ScratchPtr scratchPtr) + ---- PT operations ptMult :: PT -> PT -> PT diff --git a/cardano-crypto-tests/src/Test/Crypto/EllipticCurve.hs b/cardano-crypto-tests/src/Test/Crypto/EllipticCurve.hs index 032d0c8ab..622032258 100644 --- a/cardano-crypto-tests/src/Test/Crypto/EllipticCurve.hs +++ b/cardano-crypto-tests/src/Test/Crypto/EllipticCurve.hs @@ -1,8 +1,10 @@ +{-# LANGUAGE InstanceSigs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} {-# OPTIONS_GHC -Wno-orphans #-} +{-# OPTIONS_GHC -Wno-redundant-constraints #-} module Test.Crypto.EllipticCurve where @@ -19,11 +21,13 @@ import qualified Data.ByteString as BS import qualified Data.ByteString.Base16 as Base16 import qualified Data.ByteString.Char8 as BS8 import qualified Data.Foldable as F (foldl') +import qualified Data.List.NonEmpty as NonEmpty import Data.Proxy (Proxy (..)) import System.IO.Unsafe (unsafePerformIO) import Test.Crypto.Instances () import Test.QuickCheck ( Arbitrary (..), + NonEmptyList (..), Property, choose, chooseAny, @@ -132,6 +136,10 @@ testBLSCurve name _ = BLS.blsMult (BLS.blsMult a b) c === BLS.blsMult (BLS.blsMult a c) b , testProperty "scalar mult distributive left" $ \(a :: BLS.Point curve) (BigInteger b) (BigInteger c) -> BLS.blsMult a (b + c) === BLS.blsAddOrDouble (BLS.blsMult a b) (BLS.blsMult a c) + , testProperty "MSM matches naive approach" $ \(NonEmpty (psAndSs :: [(BLS.Point curve, BigInteger)])) -> + let pairs = NonEmpty.fromList [(p, i) | (p, BigInteger i) <- psAndSs] + in BLS.blsMSM pairs + === foldr (\(p, s) acc -> BLS.blsAddOrDouble acc (BLS.blsMult p s)) (BLS.blsZero @curve) pairs , testProperty "scalar mult distributive right" $ \(a :: BLS.Point curve) (b :: BLS.Point curve) (BigInteger c) -> BLS.blsMult (BLS.blsAddOrDouble a b) c === BLS.blsAddOrDouble (BLS.blsMult a c) (BLS.blsMult b c) , testProperty "mult by zero is inf" $ \(a :: BLS.Point curve) ->