diff --git a/cardano-crypto-class/cardano-crypto-class.cabal b/cardano-crypto-class/cardano-crypto-class.cabal index 208e000f6..f938fabcc 100644 --- a/cardano-crypto-class/cardano-crypto-class.cabal +++ b/cardano-crypto-class/cardano-crypto-class.cabal @@ -39,6 +39,7 @@ library import: base, project-config hs-source-dirs: src exposed-modules: + Cardano.Crypto.DirectSerialise Cardano.Crypto.DSIGN Cardano.Crypto.DSIGN.Class Cardano.Crypto.DSIGN.Ed25519 diff --git a/cardano-crypto-class/src/Cardano/Crypto/DSIGN/Ed25519.hs b/cardano-crypto-class/src/Cardano/Crypto/DSIGN/Ed25519.hs index 01aa4a4db..622d0eae1 100644 --- a/cardano-crypto-class/src/Cardano/Crypto/DSIGN/Ed25519.hs +++ b/cardano-crypto-class/src/Cardano/Crypto/DSIGN/Ed25519.hs @@ -3,6 +3,7 @@ {-# LANGUAGE DerivingVia #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} @@ -66,14 +67,17 @@ import Cardano.Crypto.Libsodium.MLockedSeed import Cardano.Crypto.PinnedSizedBytes ( PinnedSizedBytes , psbUseAsSizedPtr + , psbUseAsCPtrLen , psbToByteString , psbFromByteStringCheck + , psbCreate , psbCreateSized , psbCreateSizedResult ) import Cardano.Crypto.Seed import Cardano.Crypto.Util (SignableRepresentation(..)) import Cardano.Foreign +import Cardano.Crypto.DirectSerialise @@ -261,7 +265,7 @@ instance DSIGNMAlgorithm Ed25519DSIGN where stToIO $ do cOrError $ unsafeIOToST $ c_crypto_sign_ed25519_sk_to_pk pkPtr skPtr - throwOnErrno "deriveVerKeyDSIGNM @Ed25519DSIGN" "c_crypto_sign_ed25519_sk_to_pk" maybeErrno + throwOnErrno "deriveVerKeyDSIGN @Ed25519DSIGN" "c_crypto_sign_ed25519_sk_to_pk" maybeErrno return psb @@ -365,3 +369,48 @@ instance TypeError ('Text "CBOR encoding would violate mlocking guarantees") instance TypeError ('Text "CBOR decoding would violate mlocking guarantees") => FromCBOR (SignKeyDSIGNM Ed25519DSIGN) where fromCBOR = error "unsupported" + +instance DirectSerialise (SignKeyDSIGNM Ed25519DSIGN) where + -- /Note:/ We only serialize the 32-byte seed, not the full 64-byte key. The + -- latter contains both the seed and the 32-byte verification key, which is + -- convenient, but redundant, since we can always reconstruct it from the + -- seed. This is also reflected in the 'SizeSignKeyDSIGNM', which equals + -- 'SeedSizeDSIGNM' == 32, rather than reporting the in-memory size of 64. + directSerialise push sk = do + bracket + (getSeedDSIGNM (Proxy @Ed25519DSIGN) sk) + mlockedSeedFinalize + (\seed -> mlockedSeedUseAsCPtr seed $ \ptr -> + push + (castPtr ptr) + (fromIntegral $ seedSizeDSIGN (Proxy @Ed25519DSIGN))) + +instance DirectDeserialise (SignKeyDSIGNM Ed25519DSIGN) where + -- /Note:/ We only serialize the 32-byte seed, not the full 64-byte key. See + -- the DirectSerialise instance above. + directDeserialise pull = do + bracket + mlockedSeedNew + mlockedSeedFinalize + (\seed -> do + mlockedSeedUseAsCPtr seed $ \ptr -> do + pull + (castPtr ptr) + (fromIntegral $ seedSizeDSIGN (Proxy @Ed25519DSIGN)) + genKeyDSIGNM seed + ) + +instance DirectSerialise (VerKeyDSIGN Ed25519DSIGN) where + directSerialise push (VerKeyEd25519DSIGN psb) = do + psbUseAsCPtrLen psb $ \ptr _ -> + push + (castPtr ptr) + (fromIntegral $ sizeVerKeyDSIGN (Proxy @Ed25519DSIGN)) + +instance DirectDeserialise (VerKeyDSIGN Ed25519DSIGN) where + directDeserialise pull = do + psb <- psbCreate $ \ptr -> + pull + (castPtr ptr) + (fromIntegral $ sizeVerKeyDSIGN (Proxy @Ed25519DSIGN)) + return $! VerKeyEd25519DSIGN $! psb diff --git a/cardano-crypto-class/src/Cardano/Crypto/DirectSerialise.hs b/cardano-crypto-class/src/Cardano/Crypto/DirectSerialise.hs new file mode 100644 index 000000000..7b2c5e787 --- /dev/null +++ b/cardano-crypto-class/src/Cardano/Crypto/DirectSerialise.hs @@ -0,0 +1,41 @@ +-- | Direct (de-)serialisation to / from raw memory. +-- +-- The purpose of the typeclasses in this module is to abstract over data +-- structures that can expose the data they store as one or more raw 'Ptr's, +-- without any additional memory copying or conversion to intermediate data +-- structures. +-- +-- This is useful for transmitting data like KES SignKeys over a socket +-- connection: by accessing the memory directly and copying it into or out of +-- a file descriptor, without going through an intermediate @ByteString@ +-- representation (or other data structure that resides in the GHC heap), we +-- can more easily assure that the data is never written to disk, including +-- swap, which is an important requirement for KES. +module Cardano.Crypto.DirectSerialise +where + +import Foreign.Ptr +import Foreign.C.Types +import Control.Monad.Class.MonadThrow (MonadThrow) +import Control.Monad.Class.MonadST (MonadST) + +-- | Direct deserialization from raw memory. +-- +-- @directDeserialise f@ should allocate a new value of type 'a', and +-- call @f@ with a pointer to the raw memory to be filled. @f@ may be called +-- multiple times, for data structures that store their data in multiple +-- non-contiguous blocks of memory. +-- +-- The order in which memory blocks are visited matters. +class DirectDeserialise a where + directDeserialise :: (MonadST m, MonadThrow m) => (Ptr CChar -> CSize -> m ()) -> m a + +-- | Direct serialization to raw memory. +-- +-- @directSerialise f x@ should call @f@ to expose the raw memory underyling +-- @x@. For data types that store their data in multiple non-contiguous blocks +-- of memory, @f@ may be called multiple times, once for each block. +-- +-- The order in which memory blocks are visited matters. +class DirectSerialise a where + directSerialise :: (MonadST m, MonadThrow m) => (Ptr CChar -> CSize -> m ()) -> a -> m () diff --git a/cardano-crypto-class/src/Cardano/Crypto/KES/CompactSingle.hs b/cardano-crypto-class/src/Cardano/Crypto/KES/CompactSingle.hs index cce1102f1..6bc744e9a 100644 --- a/cardano-crypto-class/src/Cardano/Crypto/KES/CompactSingle.hs +++ b/cardano-crypto-class/src/Cardano/Crypto/KES/CompactSingle.hs @@ -60,7 +60,7 @@ import Cardano.Binary (FromCBOR (..), ToCBOR (..)) import Cardano.Crypto.Hash.Class import Cardano.Crypto.DSIGN.Class as DSIGN import Cardano.Crypto.KES.Class - +import Cardano.Crypto.DirectSerialise -- | A standard signature scheme is a forward-secure signature scheme with a -- single time period. @@ -227,3 +227,19 @@ instance (DSIGNMAlgorithm d, KnownNat (SizeSigKES (CompactSingleKES d))) => From slice :: Word -> Word -> ByteString -> ByteString slice offset size = BS.take (fromIntegral size) . BS.drop (fromIntegral offset) + +-- +-- Direct ser/deser +-- + +instance (DirectSerialise (SignKeyDSIGNM d)) => DirectSerialise (SignKeyKES (CompactSingleKES d)) where + directSerialise push (SignKeyCompactSingleKES sk) = directSerialise push sk + +instance (DirectDeserialise (SignKeyDSIGNM d)) => DirectDeserialise (SignKeyKES (CompactSingleKES d)) where + directDeserialise pull = SignKeyCompactSingleKES <$!> directDeserialise pull + +instance (DirectSerialise (VerKeyDSIGN d)) => DirectSerialise (VerKeyKES (CompactSingleKES d)) where + directSerialise push (VerKeyCompactSingleKES sk) = directSerialise push sk + +instance (DirectDeserialise (VerKeyDSIGN d)) => DirectDeserialise (VerKeyKES (CompactSingleKES d)) where + directDeserialise pull = VerKeyCompactSingleKES <$!> directDeserialise pull diff --git a/cardano-crypto-class/src/Cardano/Crypto/KES/CompactSum.hs b/cardano-crypto-class/src/Cardano/Crypto/KES/CompactSum.hs index ce37acbe8..96726f7aa 100644 --- a/cardano-crypto-class/src/Cardano/Crypto/KES/CompactSum.hs +++ b/cardano-crypto-class/src/Cardano/Crypto/KES/CompactSum.hs @@ -6,6 +6,7 @@ {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} @@ -86,7 +87,8 @@ module Cardano.Crypto.KES.CompactSum ( import Data.Proxy (Proxy(..)) import GHC.Generics (Generic) import qualified Data.ByteString as BS -import Control.Monad (guard) +import qualified Data.ByteString.Internal as BS +import Control.Monad (guard, (<$!>)) import NoThunks.Class (NoThunks, OnlyCheckWhnfNamed (..)) import Cardano.Binary (FromCBOR (..), ToCBOR (..)) @@ -97,10 +99,14 @@ import Cardano.Crypto.KES.CompactSingle (CompactSingleKES) import Cardano.Crypto.Util import Cardano.Crypto.Libsodium.MLockedSeed import Cardano.Crypto.Libsodium +import Cardano.Crypto.Libsodium.Memory +import Cardano.Crypto.DirectSerialise + import Control.Monad.Trans.Maybe (MaybeT (..), runMaybeT) import Control.Monad.Trans (lift) import Control.DeepSeq (NFData (..)) import GHC.TypeLits (KnownNat, type (+), type (*)) +import Foreign.Ptr (castPtr) -- | A 2^0 period KES type CompactSum0KES d = CompactSingleKES d @@ -461,3 +467,52 @@ instance ( OptimizedKESAlgorithm d ) => FromCBOR (SigKES (CompactSumKES h d)) where fromCBOR = decodeSigKES + + +-- +-- Direct ser/deser +-- + +instance ( DirectSerialise (SignKeyKES d) + , DirectSerialise (VerKeyKES d) + , KESAlgorithm d + ) => DirectSerialise (SignKeyKES (CompactSumKES h d)) where + directSerialise push (SignKeyCompactSumKES sk r vk0 vk1) = do + directSerialise push sk + mlockedSeedUseAsCPtr r $ \ptr -> + push (castPtr ptr) (fromIntegral $ seedSizeKES (Proxy :: Proxy d)) + directSerialise push vk0 + directSerialise push vk1 + +instance ( DirectDeserialise (SignKeyKES d) + , DirectDeserialise (VerKeyKES d) + , KESAlgorithm d + ) => DirectDeserialise (SignKeyKES (CompactSumKES h d)) where + directDeserialise pull = do + sk <- directDeserialise pull + + r <- mlockedSeedNew + mlockedSeedUseAsCPtr r $ \ptr -> + pull (castPtr ptr) (fromIntegral $ seedSizeKES (Proxy :: Proxy d)) + + vk0 <- directDeserialise pull + vk1 <- directDeserialise pull + + return $! SignKeyCompactSumKES sk r vk0 vk1 + + +instance DirectSerialise (VerKeyKES (CompactSumKES h d)) where + directSerialise push (VerKeyCompactSumKES h) = + unpackByteStringCStringLen (hashToBytes h) $ \(ptr, len) -> + push (castPtr ptr) (fromIntegral len) + +instance (HashAlgorithm h) + => DirectDeserialise (VerKeyKES (CompactSumKES h d)) where + directDeserialise pull = do + let len :: Num a => a + len = fromIntegral $ sizeHash (Proxy @h) + fptr <- mallocForeignPtrBytes len + withForeignPtr fptr $ \ptr -> do + pull (castPtr ptr) len + let bs = BS.fromForeignPtr (unsafeRawForeignPtr fptr) 0 len + maybe (error "Invalid hash") return $! VerKeyCompactSumKES <$!> hashFromBytes bs diff --git a/cardano-crypto-class/src/Cardano/Crypto/KES/Mock.hs b/cardano-crypto-class/src/Cardano/Crypto/KES/Mock.hs index 4e2a91516..43b503e21 100644 --- a/cardano-crypto-class/src/Cardano/Crypto/KES/Mock.hs +++ b/cardano-crypto-class/src/Cardano/Crypto/KES/Mock.hs @@ -24,6 +24,8 @@ import Data.Proxy (Proxy(..)) import GHC.Generics (Generic) import GHC.TypeNats (Nat, KnownNat, natVal) import NoThunks.Class (NoThunks) +import qualified Data.ByteString.Internal as BS +import Foreign.Ptr (castPtr) import Control.Exception (assert) @@ -35,8 +37,15 @@ import Cardano.Crypto.KES.Class import Cardano.Crypto.Util import Cardano.Crypto.Libsodium.MLockedSeed import Cardano.Crypto.Libsodium - ( mlsbAsByteString + ( mlsbToByteString ) +import Cardano.Crypto.Libsodium.Memory + ( unpackByteStringCStringLen + , ForeignPtr (..) + , mallocForeignPtrBytes + , withForeignPtr + ) +import Cardano.Crypto.DirectSerialise data MockKES (t :: Nat) @@ -151,7 +160,8 @@ instance KnownNat t => KESAlgorithm (MockKES t) where -- genKeyKESWith _allocator seed = do - let vk = VerKeyMockKES (runMonadRandomWithSeed (mkSeedFromBytes . mlsbAsByteString . mlockedSeedMLSB $ seed) getRandomWord64) + seedBS <- mlsbToByteString . mlockedSeedMLSB $ seed + let vk = VerKeyMockKES (runMonadRandomWithSeed (mkSeedFromBytes seedBS) getRandomWord64) return $! SignKeyMockKES vk 0 forgetSignKeyKESWith _ = const $ return () @@ -194,3 +204,33 @@ instance KnownNat t => ToCBOR (SigKES (MockKES t)) where instance KnownNat t => FromCBOR (SigKES (MockKES t)) where fromCBOR = decodeSigKES + +instance (KnownNat t) => DirectSerialise (SignKeyKES (MockKES t)) where + directSerialise put sk = do + let bs = rawSerialiseSignKeyMockKES sk + unpackByteStringCStringLen bs $ \(cstr, len) -> put cstr (fromIntegral len) + +instance (KnownNat t) => DirectDeserialise (SignKeyKES (MockKES t)) where + directDeserialise pull = do + let len = fromIntegral $ sizeSignKeyKES (Proxy @(MockKES t)) + fptr <- mallocForeignPtrBytes len + withForeignPtr fptr $ \ptr -> + pull (castPtr ptr) (fromIntegral len) + let bs = BS.fromForeignPtr (unsafeRawForeignPtr fptr) 0 len + maybe (error "directDeserialise @(SignKeyKES (MockKES t))") return $ + rawDeserialiseSignKeyMockKES bs + +instance (KnownNat t) => DirectSerialise (VerKeyKES (MockKES t)) where + directSerialise put sk = do + let bs = rawSerialiseVerKeyKES sk + unpackByteStringCStringLen bs $ \(cstr, len) -> put cstr (fromIntegral len) + +instance (KnownNat t) => DirectDeserialise (VerKeyKES (MockKES t)) where + directDeserialise pull = do + let len = fromIntegral $ sizeVerKeyKES (Proxy @(MockKES t)) + fptr <- mallocForeignPtrBytes len + withForeignPtr fptr $ \ptr -> + pull (castPtr ptr) (fromIntegral len) + let bs = BS.fromForeignPtr (unsafeRawForeignPtr fptr) 0 len + maybe (error "directDeserialise @(VerKeyKES (MockKES t))") return $ + rawDeserialiseVerKeyKES bs diff --git a/cardano-crypto-class/src/Cardano/Crypto/KES/Simple.hs b/cardano-crypto-class/src/Cardano/Crypto/KES/Simple.hs index b8bfe2186..afb15d202 100644 --- a/cardano-crypto-class/src/Cardano/Crypto/KES/Simple.hs +++ b/cardano-crypto-class/src/Cardano/Crypto/KES/Simple.hs @@ -43,6 +43,7 @@ import Cardano.Crypto.KES.Class import Cardano.Crypto.Libsodium.MLockedSeed import Cardano.Crypto.Libsodium.MLockedBytes import Cardano.Crypto.Util +import Cardano.Crypto.DirectSerialise import Data.Unit.Strict (forceElemsToWHNF) data SimpleKES d (t :: Nat) @@ -249,3 +250,22 @@ instance (DSIGNMAlgorithm d => FromCBOR (SigKES (SimpleKES d t)) where fromCBOR = decodeSigKES +instance (DirectSerialise (VerKeyDSIGN d)) => DirectSerialise (VerKeyKES (SimpleKES d t)) where + directSerialise push (VerKeySimpleKES vks) = + mapM_ (directSerialise push) vks + +instance (DirectDeserialise (VerKeyDSIGN d), KnownNat t) => DirectDeserialise (VerKeyKES (SimpleKES d t)) where + directDeserialise pull = do + let duration = fromIntegral (natVal (Proxy :: Proxy t)) + vks <- Vec.replicateM duration (directDeserialise pull) + return $! VerKeySimpleKES $! vks + +instance (DirectSerialise (SignKeyDSIGNM d)) => DirectSerialise (SignKeyKES (SimpleKES d t)) where + directSerialise push (SignKeySimpleKES sks) = + mapM_ (directSerialise push) sks + +instance (DirectDeserialise (SignKeyDSIGNM d), KnownNat t) => DirectDeserialise (SignKeyKES (SimpleKES d t)) where + directDeserialise pull = do + let duration = fromIntegral (natVal (Proxy :: Proxy t)) + sks <- Vec.replicateM duration (directDeserialise pull) + return $! SignKeySimpleKES $! sks diff --git a/cardano-crypto-class/src/Cardano/Crypto/KES/Single.hs b/cardano-crypto-class/src/Cardano/Crypto/KES/Single.hs index 2d38fb527..4b9da7d26 100644 --- a/cardano-crypto-class/src/Cardano/Crypto/KES/Single.hs +++ b/cardano-crypto-class/src/Cardano/Crypto/KES/Single.hs @@ -50,7 +50,7 @@ import Cardano.Binary (FromCBOR (..), ToCBOR (..)) import Cardano.Crypto.Hash.Class import Cardano.Crypto.DSIGN.Class as DSIGN import Cardano.Crypto.KES.Class - +import Cardano.Crypto.DirectSerialise -- | A standard signature scheme is a forward-secure signature scheme with a -- single time period. @@ -187,4 +187,19 @@ instance DSIGNMAlgorithm d => ToCBOR (SigKES (SingleKES d)) where instance DSIGNMAlgorithm d => FromCBOR (SigKES (SingleKES d)) where fromCBOR = decodeSigKES - {-# INLINE fromCBOR #-} + +-- +-- Direct ser/deser +-- + +instance (DirectSerialise (SignKeyDSIGNM d)) => DirectSerialise (SignKeyKES (SingleKES d)) where + directSerialise push (SignKeySingleKES sk) = directSerialise push sk + +instance (DirectDeserialise (SignKeyDSIGNM d)) => DirectDeserialise (SignKeyKES (SingleKES d)) where + directDeserialise pull = SignKeySingleKES <$!> directDeserialise pull + +instance (DirectSerialise (VerKeyDSIGN d)) => DirectSerialise (VerKeyKES (SingleKES d)) where + directSerialise push (VerKeySingleKES sk) = directSerialise push sk + +instance (DirectDeserialise (VerKeyDSIGN d)) => DirectDeserialise (VerKeyKES (SingleKES d)) where + directDeserialise pull = VerKeySingleKES <$!> directDeserialise pull diff --git a/cardano-crypto-class/src/Cardano/Crypto/KES/Sum.hs b/cardano-crypto-class/src/Cardano/Crypto/KES/Sum.hs index 300962d11..d78be36a8 100644 --- a/cardano-crypto-class/src/Cardano/Crypto/KES/Sum.hs +++ b/cardano-crypto-class/src/Cardano/Crypto/KES/Sum.hs @@ -54,6 +54,7 @@ module Cardano.Crypto.KES.Sum ( import Data.Proxy (Proxy(..)) import GHC.Generics (Generic) import qualified Data.ByteString as BS +import qualified Data.ByteString.Internal as BS import Control.Monad (guard, (<$!>)) import NoThunks.Class (NoThunks, OnlyCheckWhnfNamed (..)) @@ -65,10 +66,13 @@ import Cardano.Crypto.KES.Single (SingleKES) import Cardano.Crypto.Util import Cardano.Crypto.Libsodium.MLockedSeed import Cardano.Crypto.Libsodium +import Cardano.Crypto.Libsodium.Memory +import Cardano.Crypto.DirectSerialise + import Control.Monad.Trans.Maybe (MaybeT (..), runMaybeT) import Control.DeepSeq (NFData (..)) import GHC.TypeLits (KnownNat, type (+), type (*)) - +import Foreign.Ptr (castPtr) -- | A 2^0 period KES type Sum0KES d = SingleKES d @@ -383,4 +387,52 @@ instance (KESAlgorithm (SumKES h d), SodiumHashAlgorithm h, SizeHash h ~ SeedSiz instance (KESAlgorithm (SumKES h d), SodiumHashAlgorithm h, SizeHash h ~ SeedSizeKES d) => FromCBOR (SigKES (SumKES h d)) where fromCBOR = decodeSigKES - {-# INLINE fromCBOR #-} + + +-- +-- Direct ser/deser +-- + +instance ( DirectSerialise (SignKeyKES d) + , DirectSerialise (VerKeyKES d) + , KESAlgorithm d + ) => DirectSerialise (SignKeyKES (SumKES h d)) where + directSerialise push (SignKeySumKES sk r vk0 vk1) = do + directSerialise push sk + mlockedSeedUseAsCPtr r $ \ptr -> + push (castPtr ptr) (fromIntegral $ seedSizeKES (Proxy :: Proxy d)) + directSerialise push vk0 + directSerialise push vk1 + +instance ( DirectDeserialise (SignKeyKES d) + , DirectDeserialise (VerKeyKES d) + , KESAlgorithm d + ) => DirectDeserialise (SignKeyKES (SumKES h d)) where + directDeserialise pull = do + sk <- directDeserialise pull + + r <- mlockedSeedNew + mlockedSeedUseAsCPtr r $ \ptr -> + pull (castPtr ptr) (fromIntegral $ seedSizeKES (Proxy :: Proxy d)) + + vk0 <- directDeserialise pull + vk1 <- directDeserialise pull + + return $! SignKeySumKES sk r vk0 vk1 + + +instance DirectSerialise (VerKeyKES (SumKES h d)) where + directSerialise push (VerKeySumKES h) = + unpackByteStringCStringLen (hashToBytes h) $ \(ptr, len) -> + push (castPtr ptr) (fromIntegral len) + +instance (HashAlgorithm h) + => DirectDeserialise (VerKeyKES (SumKES h d)) where + directDeserialise pull = do + let len :: Num a => a + len = fromIntegral $ sizeHash (Proxy @h) + fptr <- mallocForeignPtrBytes len + withForeignPtr fptr $ \ptr -> do + pull (castPtr ptr) len + let bs = BS.fromForeignPtr (unsafeRawForeignPtr fptr) 0 len + maybe (error "Invalid hash") return $! VerKeySumKES <$!> hashFromBytes bs diff --git a/cardano-crypto-class/src/Cardano/Crypto/Libsodium/C.hs b/cardano-crypto-class/src/Cardano/Crypto/Libsodium/C.hs index 1ce55c66b..3fe9623ec 100644 --- a/cardano-crypto-class/src/Cardano/Crypto/Libsodium/C.hs +++ b/cardano-crypto-class/src/Cardano/Crypto/Libsodium/C.hs @@ -29,6 +29,8 @@ module Cardano.Crypto.Libsodium.C ( c_crypto_sign_ed25519_detached, c_crypto_sign_ed25519_verify_detached, c_crypto_sign_ed25519_sk_to_pk, + -- * RNG + c_sodium_randombytes_buf, -- * Helpers c_sodium_compare, -- * Constants @@ -182,3 +184,6 @@ foreign import capi unsafe "sodium.h crypto_sign_ed25519_sk_to_pk" c_crypto_sign -- -- foreign import capi unsafe "sodium.h sodium_compare" c_sodium_compare :: Ptr a -> Ptr a -> CSize -> IO Int + +-- | @void randombytes_buf(void * const buf, const size_t size);@ +foreign import capi unsafe "sodium/randombytes.h randombytes_buf" c_sodium_randombytes_buf :: Ptr a -> CSize -> IO () diff --git a/cardano-crypto-class/src/Cardano/Crypto/Libsodium/MLockedSeed.hs b/cardano-crypto-class/src/Cardano/Crypto/Libsodium/MLockedSeed.hs index 5fb8c600d..cb9520b21 100644 --- a/cardano-crypto-class/src/Cardano/Crypto/Libsodium/MLockedSeed.hs +++ b/cardano-crypto-class/src/Cardano/Crypto/Libsodium/MLockedSeed.hs @@ -2,7 +2,8 @@ {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE MultiParamTypeClasses #-} -{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} module Cardano.Crypto.Libsodium.MLockedSeed where @@ -20,12 +21,16 @@ import Cardano.Crypto.Libsodium.Memory ( MLockedAllocator, mlockedMalloc, ) +import Cardano.Crypto.Libsodium.C ( + c_sodium_randombytes_buf, + ) import Cardano.Foreign (SizedPtr) import Control.DeepSeq (NFData) import Control.Monad.Class.MonadST (MonadST) +import Data.Proxy (Proxy (..)) import Data.Word (Word8) import Foreign.Ptr (Ptr) -import GHC.TypeNats (KnownNat) +import GHC.TypeNats (KnownNat, natVal) import NoThunks.Class (NoThunks) -- | A seed of size @n@, stored in mlocked memory. This is required to prevent @@ -66,6 +71,18 @@ mlockedSeedNewZeroWith :: (KnownNat n, MonadST m) => MLockedAllocator m -> m (ML mlockedSeedNewZeroWith allocator = MLockedSeed <$> mlsbNewZeroWith allocator +mlockedSeedNewRandom :: forall n. (KnownNat n) => IO (MLockedSeed n) +mlockedSeedNewRandom = mlockedSeedNewRandomWith mlockedMalloc + +mlockedSeedNewRandomWith :: forall n. (KnownNat n) => MLockedAllocator IO -> IO (MLockedSeed n) +mlockedSeedNewRandomWith allocator = do + mls <- MLockedSeed <$> mlsbNewZeroWith allocator + mlockedSeedUseAsCPtr mls $ \dst -> do + c_sodium_randombytes_buf dst size + return mls + where + size = fromIntegral $ natVal (Proxy @n) + mlockedSeedFinalize :: (MonadST m) => MLockedSeed n -> m () mlockedSeedFinalize = mlsbFinalize . mlockedSeedMLSB diff --git a/cardano-crypto-class/src/Cardano/Crypto/Libsodium/Memory.hs b/cardano-crypto-class/src/Cardano/Crypto/Libsodium/Memory.hs index a4405ef5d..cd927cb42 100644 --- a/cardano-crypto-class/src/Cardano/Crypto/Libsodium/Memory.hs +++ b/cardano-crypto-class/src/Cardano/Crypto/Libsodium/Memory.hs @@ -25,7 +25,13 @@ module Cardano.Crypto.Libsodium.Memory ( copyMem, allocaBytes, + -- * 'ForeignPtr' operations, generalized to 'MonadST' + ForeignPtr (..), + mallocForeignPtrBytes, + withForeignPtr, + -- * ByteString memory access, generalized to 'MonadST' + unpackByteStringCStringLen, packByteStringCStringLen, ) where diff --git a/cardano-crypto-class/src/Cardano/Crypto/Libsodium/Memory/Internal.hs b/cardano-crypto-class/src/Cardano/Crypto/Libsodium/Memory/Internal.hs index 68b57392f..ad61bd282 100644 --- a/cardano-crypto-class/src/Cardano/Crypto/Libsodium/Memory/Internal.hs +++ b/cardano-crypto-class/src/Cardano/Crypto/Libsodium/Memory/Internal.hs @@ -28,12 +28,18 @@ module Cardano.Crypto.Libsodium.Memory.Internal ( mlockedAllocForeignPtrWith, mlockedAllocForeignPtrBytesWith, + -- * 'ForeignPtr' operations, generalized to 'MonadST' + ForeignPtr (..), + mallocForeignPtrBytes, + withForeignPtr, + -- * Unmanaged memory, generalized to 'MonadST' zeroMem, copyMem, allocaBytes, -- * ByteString memory access, generalized to 'MonadST' + unpackByteStringCStringLen, packByteStringCStringLen, -- * Helper @@ -46,19 +52,21 @@ import Control.Monad (when, void) import Control.Monad.Class.MonadST import Control.Monad.Class.MonadThrow (MonadThrow (bracket)) import Control.Monad.ST (RealWorld, ST) -import Control.Monad.ST.Unsafe (unsafeIOToST, unsafeSTToIO) +import Control.Monad.Primitive (touch) +import Control.Monad.ST.Unsafe (unsafeIOToST) import Data.ByteString (ByteString) import qualified Data.ByteString as BS +import qualified Data.ByteString.Unsafe as BS import Data.Coerce (coerce) import Data.Typeable import Debug.Trace (traceShowM) import Foreign.C.Error (errnoToIOError, getErrno) import Foreign.C.String (CStringLen) import Foreign.C.Types (CSize (..)) -import Foreign.Concurrent (newForeignPtr) -import Foreign.ForeignPtr (ForeignPtr, finalizeForeignPtr, touchForeignPtr) +import qualified Foreign.Concurrent as Foreign +import qualified Foreign.ForeignPtr as Foreign hiding (newForeignPtr) +import qualified Foreign.ForeignPtr.Unsafe as Foreign import Foreign.ForeignPtr.Unsafe (unsafeForeignPtrToPtr) -import qualified Foreign.Marshal.Alloc as Foreign import Foreign.Marshal.Utils (fillBytes) import Foreign.Ptr (Ptr, nullPtr, castPtr) import Foreign.Storable (Storable (peek), sizeOf, alignment) @@ -66,13 +74,14 @@ import GHC.IO.Exception (ioException) import GHC.TypeLits (KnownNat, natVal) import NoThunks.Class (NoThunks, OnlyCheckWhnfNamed (..)) import System.IO.Unsafe (unsafePerformIO) +import Data.Kind import Cardano.Crypto.Libsodium.C import Cardano.Foreign (c_memset, c_memcpy, SizedPtr (..)) import Cardano.Memory.Pool (initPool, grabNextBlock, Pool) -- | Foreign pointer to securely allocated memory. -newtype MLockedForeignPtr a = SFP { _unwrapMLockedForeignPtr :: ForeignPtr a } +newtype MLockedForeignPtr a = SFP { _unwrapMLockedForeignPtr :: Foreign.ForeignPtr a } deriving NoThunks via OnlyCheckWhnfNamed "MLockedForeignPtr" (MLockedForeignPtr a) instance NFData (MLockedForeignPtr a) where @@ -81,11 +90,11 @@ instance NFData (MLockedForeignPtr a) where withMLockedForeignPtr :: MonadST m => MLockedForeignPtr a -> (Ptr a -> m b) -> m b withMLockedForeignPtr (SFP fptr) f = do r <- f (unsafeForeignPtrToPtr fptr) - r <$ unsafeIOToMonadST (touchForeignPtr fptr) + r <$ unsafeIOToMonadST (Foreign.touchForeignPtr fptr) finalizeMLockedForeignPtr :: MonadST m => MLockedForeignPtr a -> m () finalizeMLockedForeignPtr (SFP fptr) = - unsafeIOToMonadST $ finalizeForeignPtr fptr + unsafeIOToMonadST $ Foreign.finalizeForeignPtr fptr {-# WARNING traceMLockedForeignPtr "Do not use traceMLockedForeignPtr in production" #-} @@ -103,7 +112,7 @@ makeMLockedPool = do (max 1 . fromIntegral $ 4096 `div` natVal (Proxy @n) `div` 64) (\size -> unsafeIOToST $ mask_ $ do ptr <- sodiumMalloc (fromIntegral size) - newForeignPtr ptr (sodiumFree ptr (fromIntegral size)) + Foreign.newForeignPtr ptr (sodiumFree ptr (fromIntegral size)) ) (\ptr -> do eraseMem (Proxy @n) ptr @@ -159,7 +168,7 @@ mlockedMallocIO size = SFP <$> do | otherwise -> do mask_ $ do ptr <- sodiumMalloc size - newForeignPtr ptr $ do + Foreign.newForeignPtr ptr $ do sodiumFree ptr size sodiumMalloc :: CSize -> IO (Ptr a) @@ -188,9 +197,38 @@ zeroMem ptr size = unsafeIOToMonadST . void $ c_memset (castPtr ptr) 0 size copyMem :: MonadST m => Ptr a -> Ptr a -> CSize -> m () copyMem dst src size = unsafeIOToMonadST . void $ c_memcpy (castPtr dst) (castPtr src) size -allocaBytes :: Int -> (Ptr a -> ST s b) -> ST s b -allocaBytes size f = - unsafeIOToST $ Foreign.allocaBytes size (unsafeSTToIO . f) +-- | A 'ForeignPtr' type, generalized to 'MonadST'. The type is tagged with +-- the correct Monad @m@ in order to ensure that foreign pointers created in +-- one ST context can only be used within the same ST context. +newtype ForeignPtr (m :: Type -> Type) a = ForeignPtr { unsafeRawForeignPtr :: Foreign.ForeignPtr a } + +mallocForeignPtrBytes :: (MonadST m) => Int -> m (ForeignPtr m a) +mallocForeignPtrBytes size = + ForeignPtr <$> unsafeIOToMonadST (Foreign.mallocForeignPtrBytes size) + +-- | 'Foreign.withForeignPtr', generalized to 'MonadST'. +-- Caveat: if the monadic action passed to 'withForeignPtr' does not terminate +-- (e.g., 'forever'), the 'ForeignPtr' finalizer may run prematurely. +withForeignPtr :: (MonadST m) => ForeignPtr m a -> (Ptr a -> m b) -> m b +withForeignPtr (ForeignPtr fptr) f = do + result <- f $ Foreign.unsafeForeignPtrToPtr fptr + stToIO $ touch fptr + return result + +allocaBytes :: (MonadThrow m, MonadST m) => Int -> (Ptr a -> m b) -> m b +allocaBytes size action = do + fptr <- mallocForeignPtrBytes size + withForeignPtr fptr action + +-- | Unpacks a ByteString into a temporary buffer and runs the provided 'ST' +-- function on it. +unpackByteStringCStringLen :: (MonadThrow m, MonadST m) => ByteString -> (CStringLen -> m a) -> m a +unpackByteStringCStringLen bs f = do + let len = BS.length bs + allocaBytes len $ \buf -> do + unsafeIOToMonadST $ BS.unsafeUseAsCString bs $ \ptr -> do + copyMem buf ptr (fromIntegral len) + f (buf, len) packByteStringCStringLen :: MonadST m => CStringLen -> m ByteString packByteStringCStringLen = @@ -258,7 +296,6 @@ mlockedAllocaWith :: -> (Ptr a -> m b) -> m b mlockedAllocaWith allocator size = - bracket alloc free . flip withMLockedForeignPtr + bracket alloc finalizeMLockedForeignPtr . flip withMLockedForeignPtr where alloc = mlAllocate allocator size - free = finalizeMLockedForeignPtr diff --git a/cardano-crypto-tests/src/Test/Crypto/DSIGN.hs b/cardano-crypto-tests/src/Test/Crypto/DSIGN.hs index 1c81dcd46..102bdedf4 100644 --- a/cardano-crypto-tests/src/Test/Crypto/DSIGN.hs +++ b/cardano-crypto-tests/src/Test/Crypto/DSIGN.hs @@ -27,6 +27,7 @@ import Test.QuickCheck ( forAllShow, forAllShrinkShow, ioProperty, + counterexample, ) import Test.Tasty (TestTree, testGroup, adjustOption) import Test.Tasty.QuickCheck (testProperty, QuickCheckTests) @@ -94,6 +95,7 @@ import Cardano.Crypto.DSIGN ( ) import Cardano.Binary (FromCBOR, ToCBOR) import Cardano.Crypto.PinnedSizedBytes (PinnedSizedBytes) +import Cardano.Crypto.DirectSerialise import Test.Crypto.Util ( Message, prop_raw_serialise, @@ -111,6 +113,9 @@ import Test.Crypto.Util ( showBadInputFor, Lock, withLock, + directSerialiseToBS, + directDeserialiseFromBS, + hexBS, ) import Cardano.Crypto.Libsodium.MLockedSeed @@ -362,6 +367,10 @@ testDSIGNMAlgorithm , FromCBOR (SigDSIGN v) , ContextDSIGN v ~ () , Signable v Message + , DirectSerialise (SignKeyDSIGNM v) + , DirectDeserialise (SignKeyDSIGNM v) + , DirectSerialise (VerKeyDSIGN v) + , DirectDeserialise (VerKeyDSIGN v) ) => Lock -> Proxy v @@ -451,6 +460,36 @@ testDSIGNMAlgorithm lock _ n = sig :: SigDSIGN v <- signDSIGNM () msg sk return $ prop_cbor_direct_vs_class encodeSigDSIGN sig ] + , testGroup "DirectSerialise" + [ testProperty "VerKey" $ + ioPropertyWithSK @v lock $ \sk -> do + vk :: VerKeyDSIGN v <- deriveVerKeyDSIGNM sk + serialized <- directSerialiseToBS (fromIntegral $ sizeVerKeyDSIGN (Proxy @v)) vk + vk' <- directDeserialiseFromBS serialized + return $ vk === vk' + , testProperty "SignKey" $ + ioPropertyWithSK @v lock $ \sk -> do + serialized <- directSerialiseToBS (fromIntegral $ sizeSignKeyDSIGN (Proxy @v)) sk + sk' <- directDeserialiseFromBS serialized + equals <- sk ==! sk' + forgetSignKeyDSIGNM sk' + return $ + counterexample ("Serialized: " ++ hexBS serialized ++ " (length: " ++ show (BS.length serialized) ++ ")") $ + equals + ] + , testGroup "DirectSerialise matches raw" + [ testProperty "VerKey" $ + ioPropertyWithSK @v lock $ \sk -> do + vk :: VerKeyDSIGN v <- deriveVerKeyDSIGNM sk + direct <- directSerialiseToBS (fromIntegral $ sizeVerKeyDSIGN (Proxy @v)) vk + let raw = rawSerialiseVerKeyDSIGN vk + return $ direct === raw + , testProperty "SignKey" $ + ioPropertyWithSK @v lock $ \sk -> do + direct <- directSerialiseToBS (fromIntegral $ sizeSignKeyDSIGN (Proxy @v)) sk + raw <- rawSerialiseSignKeyDSIGNM sk + return $ direct === raw + ] ] , testGroup "verify" @@ -477,6 +516,24 @@ testDSIGNMAlgorithm lock _ n = ioPropertyWithSK @v lock $ prop_no_thunks_IO . return , testProperty "Sig" $ \(msg :: Message) -> ioPropertyWithSK @v lock $ prop_no_thunks_IO . signDSIGNM () msg + , testProperty "SignKey DirectSerialise" $ + ioPropertyWithSK @v lock $ \sk -> do + direct <- directSerialiseToBS (fromIntegral $ sizeSignKeyDSIGN (Proxy @v)) sk + prop_no_thunks_IO (return $! direct) + , testProperty "SignKey DirectDeserialise" $ + ioPropertyWithSK @v lock $ \sk -> do + direct <- directSerialiseToBS (fromIntegral $ sizeSignKeyDSIGN (Proxy @v)) sk + prop_no_thunks_IO (directDeserialiseFromBS @IO @(SignKeyDSIGNM v) $! direct) + , testProperty "VerKey DirectSerialise" $ + ioPropertyWithSK @v lock $ \sk -> do + vk <- deriveVerKeyDSIGNM sk + direct <- directSerialiseToBS (fromIntegral $ sizeVerKeyDSIGN (Proxy @v)) vk + prop_no_thunks_IO (return $! direct) + , testProperty "VerKey DirectDeserialise" $ + ioPropertyWithSK @v lock $ \sk -> do + vk <- deriveVerKeyDSIGNM sk + direct <- directSerialiseToBS (fromIntegral $ sizeVerKeyDSIGN (Proxy @v)) vk + prop_no_thunks_IO (directDeserialiseFromBS @IO @(VerKeyDSIGN v) $! direct) ] ] diff --git a/cardano-crypto-tests/src/Test/Crypto/KES.hs b/cardano-crypto-tests/src/Test/Crypto/KES.hs index fbe1612e7..6b3b8e558 100644 --- a/cardano-crypto-tests/src/Test/Crypto/KES.hs +++ b/cardano-crypto-tests/src/Test/Crypto/KES.hs @@ -29,21 +29,29 @@ import Data.List (foldl') import qualified Data.ByteString as BS import Data.Set (Set) import qualified Data.Set as Set -import Foreign.Ptr (WordPtr) +import Foreign.Ptr (WordPtr, plusPtr) import Data.IORef -import GHC.TypeNats (KnownNat) +import GHC.TypeNats (KnownNat, natVal) -import Control.Tracer +import Control.Concurrent.MVar (newMVar, takeMVar, putMVar) +import Control.Monad (void, when) +import Control.Monad.Class.MonadST import Control.Monad.Class.MonadThrow import Control.Monad.IO.Class (liftIO) -import Control.Monad (void) +import Control.Tracer import Cardano.Crypto.DSIGN hiding (Signable) import Cardano.Crypto.Hash import Cardano.Crypto.KES +import Cardano.Crypto.DirectSerialise (DirectSerialise, directSerialise, DirectDeserialise) import Cardano.Crypto.Util (SignableRepresentation(..)) import Cardano.Crypto.Libsodium import Cardano.Crypto.Libsodium.MLockedSeed +import Cardano.Crypto.Libsodium.Memory + ( copyMem + , allocaBytes + , packByteStringCStringLen + ) import Cardano.Crypto.PinnedSizedBytes import Test.QuickCheck @@ -67,6 +75,8 @@ import Test.Crypto.Util ( noExceptionsThrown, Lock, withLock, + directSerialiseToBS, + directDeserialiseFromBS, ) import Test.Crypto.EqST import Test.Crypto.Instances (withMLockedSeedFromPSB) @@ -198,6 +208,10 @@ testKESAlgorithm , Signable v ~ SignableRepresentation , ContextKES v ~ () , UnsoundKESAlgorithm v + , DirectSerialise (SignKeyKES v) + , DirectSerialise (VerKeyKES v) + , DirectDeserialise (SignKeyKES v) + , DirectDeserialise (VerKeyKES v) ) => Lock -> String @@ -225,9 +239,32 @@ testKESAlgorithm lock n = , testProperty "Sig" $ \seedPSB (msg :: Message) -> ioProperty $ withLock lock $ fmap conjoin $ withAllUpdatesKES @v seedPSB $ \t sk -> do prop_no_thunks_IO (signKES () t msg sk) + + , testProperty "VerKey DirectSerialise" $ + ioPropertyWithSK @v lock $ \sk -> do + vk :: VerKeyKES v <- deriveVerKeyKES sk + direct <- directSerialiseToBS (fromIntegral $ sizeVerKeyKES (Proxy @v)) vk + prop_no_thunks_IO (return $! direct) + , testProperty "SignKey DirectSerialise" $ + ioPropertyWithSK @v lock $ \sk -> do + direct <- directSerialiseToBS (fromIntegral $ sizeSignKeyKES (Proxy @v)) sk + prop_no_thunks_IO (return $! direct) + , testProperty "VerKey DirectDeserialise" $ + ioPropertyWithSK @v lock $ \sk -> do + vk :: VerKeyKES v <- deriveVerKeyKES sk + direct <- directSerialiseToBS (fromIntegral $ sizeVerKeyKES (Proxy @v)) $! vk + prop_no_thunks_IO (directDeserialiseFromBS @IO @(VerKeyKES v) $! direct) + , testProperty "SignKey DirectDeserialise" $ + ioPropertyWithSK @v lock $ \sk -> do + direct <- directSerialiseToBS (fromIntegral $ sizeSignKeyKES (Proxy @v)) sk + bracket + (directDeserialiseFromBS @IO @(SignKeyKES v) $! direct) + forgetSignKeyKES + (prop_no_thunks_IO . return) ] , testProperty "same VerKey " $ prop_deriveVerKeyKES @v + , testProperty "no forgotten chunks in signkey" $ prop_noErasedBlocksInKey (Proxy @v) , testGroup "serialisation" [ testGroup "raw ser only" @@ -313,6 +350,38 @@ testKESAlgorithm lock n = sig :: SigKES v <- signKES () 0 msg sk return $ prop_cbor_direct_vs_class encodeSigKES sig ] + + , testGroup "DirectSerialise" + [ testProperty "VerKey" $ + ioPropertyWithSK @v lock $ \sk -> do + vk :: VerKeyKES v <- deriveVerKeyKES sk + serialized <- directSerialiseToBS (fromIntegral $ sizeVerKeyKES (Proxy @v)) vk + vk' <- directDeserialiseFromBS serialized + return $ vk === vk' + , testProperty "SignKey" $ + ioPropertyWithSK @v lock $ \sk -> do + serialized <- directSerialiseToBS (fromIntegral $ sizeSignKeyKES (Proxy @v)) sk + equals <- bracket + (directDeserialiseFromBS serialized) + forgetSignKeyKES + (\sk' -> sk ==! sk') + return $ + counterexample ("Serialized: " ++ hexBS serialized ++ " (length: " ++ show (BS.length serialized) ++ ")") $ + equals + ] + , testGroup "DirectSerialise matches raw" + [ testProperty "VerKey" $ + ioPropertyWithSK @v lock $ \sk -> do + vk :: VerKeyKES v <- deriveVerKeyKES sk + direct <- directSerialiseToBS (fromIntegral $ sizeVerKeyKES (Proxy @v)) vk + let raw = rawSerialiseVerKeyKES vk + return $ direct === raw + , testProperty "SignKey" $ + ioPropertyWithSK @v lock $ \sk -> do + direct <- directSerialiseToBS (fromIntegral $ sizeSignKeyKES (Proxy @v)) sk + raw <- rawSerialiseSignKeyKES sk + return $ direct === raw + ] ] , testGroup "verify" @@ -676,3 +745,56 @@ withAllUpdatesKES seedPSB f = withMLockedSeedFromPSB seedPSB $ \seed -> do xs <- go sk' (t + 1) return $ x:xs +withNullSeed :: forall m n a. (MonadThrow m, MonadST m, KnownNat n) => (MLockedSeed n -> m a) -> m a +withNullSeed = bracket + (MLockedSeed <$> mlsbFromByteString (BS.replicate (fromIntegral $ natVal (Proxy @n)) 0)) + mlockedSeedFinalize + +withNullSK :: forall m v a. (KESAlgorithm v, MonadThrow m, MonadST m) + => (SignKeyKES v -> m a) -> m a +withNullSK = bracket + (withNullSeed genKeyKES) + forgetSignKeyKES + + +-- | This test detects whether a sign key contains references to pool-allocated +-- blocks of memory that have been forgotten by the time the key is complete. +-- We do this based on the fact that the pooled allocator erases memory blocks +-- by overwriting them with series of 0xff bytes; thus we cut the serialized +-- key up into chunks of 16 bytes, and if any of those chunks is entirely +-- filled with 0xff bytes, we assume that we're looking at erased memory. +prop_noErasedBlocksInKey + :: forall v. + UnsoundKESAlgorithm v + => DirectSerialise (SignKeyKES v) + => Proxy v + -> Property +prop_noErasedBlocksInKey kesAlgorithm = + ioProperty . withNullSK @IO @v $ \sk -> do + let size :: Int = fromIntegral $ sizeSignKeyKES kesAlgorithm + serialized <- allocaBytes size $ \ptr -> do + positionVar <- newMVar (0 :: Int) + directSerialise (\buf nCSize -> do + let n = fromIntegral nCSize :: Int + bracket + (takeMVar positionVar) + (putMVar positionVar . (+ n)) + (\position -> do + when (n + position > size) (error "Buffer size exceeded") + copyMem (plusPtr ptr position) buf (fromIntegral n) + ) + ) + sk + packByteStringCStringLen (ptr, size) + forgetSignKeyKES sk + return $ counterexample (hexBS serialized) $ not (hasLongRunOfFF serialized) + +hasLongRunOfFF :: ByteString -> Bool +hasLongRunOfFF bs + | BS.length bs < 16 + = False + | otherwise + = let first16 = BS.take 16 bs + remainder = BS.drop 16 bs + in (BS.all (== 0xFF) first16) || hasLongRunOfFF remainder + diff --git a/cardano-crypto-tests/src/Test/Crypto/Util.hs b/cardano-crypto-tests/src/Test/Crypto/Util.hs index c0c7d7441..a23740e64 100644 --- a/cardano-crypto-tests/src/Test/Crypto/Util.hs +++ b/cardano-crypto-tests/src/Test/Crypto/Util.hs @@ -59,6 +59,10 @@ module Test.Crypto.Util , noExceptionsThrown , doesNotThrow + -- * Direct ser/deser helpers + , directSerialiseToBS + , directDeserialiseFromBS + -- * Error handling , eitherShowError @@ -95,12 +99,19 @@ import Codec.CBOR.Write ( ) import Cardano.Crypto.Seed (Seed, mkSeedFromBytes) import Cardano.Crypto.Util (SignableRepresentation(..)) +import Cardano.Crypto.DirectSerialise import Crypto.Random ( ChaChaDRG , MonadPseudoRandom , drgNewTest , withDRG ) +import Cardano.Crypto.Libsodium.Memory + ( unpackByteStringCStringLen + , packByteStringCStringLen + , allocaBytes + , copyMem + ) import Data.ByteString (ByteString) import qualified Data.ByteString as BS import qualified Data.ByteString.Char8 as BS8 @@ -130,7 +141,19 @@ import qualified Test.QuickCheck.Gen as Gen import Control.Monad (guard, when) import GHC.TypeLits (Nat, KnownNat, natVal) import Formatting.Buildable (Buildable (..), build) -import Control.Concurrent.Class.MonadMVar (MVar, withMVar, newMVar) +import Foreign.Ptr (Ptr, plusPtr) +import Foreign.C.Types (CChar, CSize) +import Control.Monad.Class.MonadST (MonadST) +import Control.Monad.Class.MonadThrow (MonadThrow) +import Control.Concurrent.Class.MonadMVar + ( MVar + , withMVar + , newMVar + , putMVar + , takeMVar + , newMVar + , MonadMVar + ) import GHC.Stack (HasCallStack) -------------------------------------------------------------------------------- @@ -375,3 +398,44 @@ mkLock = Lock <$> newMVar () eitherShowError :: (HasCallStack, Show e) => Either e a -> IO a eitherShowError (Left e) = error (show e) eitherShowError (Right a) = return a + +-------------------------------------------------------------------------------- +-- Helpers for direct ser/deser +-------------------------------------------------------------------------------- + +directSerialiseToBS :: forall m a. + DirectSerialise a + => MonadST m + => MonadThrow m + => MonadMVar m + => Int -> a -> m ByteString +directSerialiseToBS dstsize val = do + allocaBytes dstsize $ \dst -> do + posVar <- newMVar 0 + let pusher :: Ptr CChar -> CSize -> m () + pusher src srcsize = do + pos <- takeMVar posVar + let pos' = pos + fromIntegral srcsize + when (pos' > dstsize) (error "Buffer overrun") + copyMem (plusPtr dst pos) src (fromIntegral srcsize) + putMVar posVar pos' + directSerialise pusher val + packByteStringCStringLen (dst, fromIntegral dstsize) + +directDeserialiseFromBS :: forall m a. + DirectDeserialise a + => MonadST m + => MonadThrow m + => MonadMVar m + => ByteString -> m a +directDeserialiseFromBS bs = do + unpackByteStringCStringLen bs $ \(src, srcsize) -> do + posVar <- newMVar 0 + let puller :: Ptr CChar -> CSize -> m () + puller dst dstsize = do + pos <- takeMVar posVar + let pos' = pos + fromIntegral dstsize + when (pos' > srcsize) (error "Buffer overrun") + copyMem dst (plusPtr src pos) (fromIntegral dstsize) + putMVar posVar pos' + directDeserialise puller