From 88bc659c4f6212112da5866521e8a88869f32c23 Mon Sep 17 00:00:00 2001 From: Tobias Dammers Date: Thu, 15 Jun 2023 09:31:07 +0200 Subject: [PATCH 1/9] DirectSerialise --- .../cardano-crypto-class.cabal | 1 + .../src/Cardano/Crypto/DSIGN/Ed25519.hs | 55 +++++++- .../src/Cardano/Crypto/DirectSerialise.hs | 41 ++++++ .../src/Cardano/Crypto/KES/CompactSingle.hs | 18 ++- .../src/Cardano/Crypto/KES/CompactSum.hs | 60 +++++++- .../src/Cardano/Crypto/KES/Mock.hs | 36 +++++ .../src/Cardano/Crypto/KES/Simple.hs | 20 +++ .../src/Cardano/Crypto/KES/Single.hs | 19 ++- .../src/Cardano/Crypto/KES/Sum.hs | 59 +++++++- .../src/Cardano/Crypto/Libsodium/Memory.hs | 1 + .../Crypto/Libsodium/Memory/Internal.hs | 36 ++++- cardano-crypto-tests/src/Test/Crypto/DSIGN.hs | 57 ++++++++ cardano-crypto-tests/src/Test/Crypto/KES.hs | 130 +++++++++++++++++- cardano-crypto-tests/src/Test/Crypto/Util.hs | 66 ++++++++- 14 files changed, 580 insertions(+), 19 deletions(-) create mode 100644 cardano-crypto-class/src/Cardano/Crypto/DirectSerialise.hs diff --git a/cardano-crypto-class/cardano-crypto-class.cabal b/cardano-crypto-class/cardano-crypto-class.cabal index 208e000f6..f938fabcc 100644 --- a/cardano-crypto-class/cardano-crypto-class.cabal +++ b/cardano-crypto-class/cardano-crypto-class.cabal @@ -39,6 +39,7 @@ library import: base, project-config hs-source-dirs: src exposed-modules: + Cardano.Crypto.DirectSerialise Cardano.Crypto.DSIGN Cardano.Crypto.DSIGN.Class Cardano.Crypto.DSIGN.Ed25519 diff --git a/cardano-crypto-class/src/Cardano/Crypto/DSIGN/Ed25519.hs b/cardano-crypto-class/src/Cardano/Crypto/DSIGN/Ed25519.hs index 01aa4a4db..26ac94750 100644 --- a/cardano-crypto-class/src/Cardano/Crypto/DSIGN/Ed25519.hs +++ b/cardano-crypto-class/src/Cardano/Crypto/DSIGN/Ed25519.hs @@ -3,6 +3,7 @@ {-# LANGUAGE DerivingVia #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} @@ -66,14 +67,17 @@ import Cardano.Crypto.Libsodium.MLockedSeed import Cardano.Crypto.PinnedSizedBytes ( PinnedSizedBytes , psbUseAsSizedPtr + , psbUseAsCPtrLen , psbToByteString , psbFromByteStringCheck + , psbCreate , psbCreateSized , psbCreateSizedResult ) import Cardano.Crypto.Seed import Cardano.Crypto.Util (SignableRepresentation(..)) import Cardano.Foreign +import Cardano.Crypto.DirectSerialise @@ -261,7 +265,7 @@ instance DSIGNMAlgorithm Ed25519DSIGN where stToIO $ do cOrError $ unsafeIOToST $ c_crypto_sign_ed25519_sk_to_pk pkPtr skPtr - throwOnErrno "deriveVerKeyDSIGNM @Ed25519DSIGN" "c_crypto_sign_ed25519_sk_to_pk" maybeErrno + throwOnErrno "deriveVerKeyDSIGN @Ed25519DSIGN" "c_crypto_sign_ed25519_sk_to_pk" maybeErrno return psb @@ -365,3 +369,52 @@ instance TypeError ('Text "CBOR encoding would violate mlocking guarantees") instance TypeError ('Text "CBOR decoding would violate mlocking guarantees") => FromCBOR (SignKeyDSIGNM Ed25519DSIGN) where fromCBOR = error "unsupported" + +instance ( MonadThrow m + , MonadST m + ) => DirectSerialise m (SignKeyDSIGNM Ed25519DSIGN) where + -- /Note:/ We only serialize the 32-byte seed, not the full 64-byte key. The + -- latter contains both the seed and the 32-byte verification key, which is + -- convenient, but redundant, since we can always reconstruct it from the + -- seed. This is also reflected in the 'SizeSignKeyDSIGNM', which equals + -- 'SeedSizeDSIGNM' == 32, rather than reporting the in-memory size of 64. + directSerialise push sk = do + bracket + (getSeedDSIGNM (Proxy @Ed25519DSIGN) sk) + mlockedSeedFinalize + (\seed -> mlockedSeedUseAsCPtr seed $ \ptr -> + push + (castPtr ptr) + (fromIntegral $ seedSizeDSIGN (Proxy @Ed25519DSIGN))) + +instance ( MonadThrow m + , MonadST m + ) => DirectDeserialise m (SignKeyDSIGNM Ed25519DSIGN) where + -- /Note:/ We only serialize the 32-byte seed, not the full 64-byte key. See + -- the DirectSerialise m instance above. + directDeserialise pull = do + bracket + mlockedSeedNew + mlockedSeedFinalize + (\seed -> do + mlockedSeedUseAsCPtr seed $ \ptr -> do + pull + (castPtr ptr) + (fromIntegral $ seedSizeDSIGN (Proxy @Ed25519DSIGN)) + genKeyDSIGNM seed + ) + +instance MonadST m => DirectSerialise m (VerKeyDSIGN Ed25519DSIGN) where + directSerialise push (VerKeyEd25519DSIGN psb) = do + psbUseAsCPtrLen psb $ \ptr _ -> + push + (castPtr ptr) + (fromIntegral $ sizeVerKeyDSIGN (Proxy @Ed25519DSIGN)) + +instance MonadST m => DirectDeserialise m (VerKeyDSIGN Ed25519DSIGN) where + directDeserialise pull = do + psb <- psbCreate $ \ptr -> + pull + (castPtr ptr) + (fromIntegral $ sizeVerKeyDSIGN (Proxy @Ed25519DSIGN)) + return $! VerKeyEd25519DSIGN $! psb diff --git a/cardano-crypto-class/src/Cardano/Crypto/DirectSerialise.hs b/cardano-crypto-class/src/Cardano/Crypto/DirectSerialise.hs new file mode 100644 index 000000000..d879778ae --- /dev/null +++ b/cardano-crypto-class/src/Cardano/Crypto/DirectSerialise.hs @@ -0,0 +1,41 @@ +{-# LANGUAGE MultiParamTypeClasses #-} + +-- | Direct (de-)serialisation to / from raw memory. +-- +-- The purpose of the typeclasses in this module is to abstract over data +-- structures that can expose the data they store as one or more raw 'Ptr's, +-- without any additional memory copying or conversion to intermediate data +-- structures. +-- +-- This is useful for transmitting data like KES SignKeys over a socket +-- connection: by accessing the memory directly and copying it into or out of +-- a file descriptor, without going through an intermediate @ByteString@ +-- representation (or other data structure that resides in the GHC heap), we +-- can more easily assure that the data is never written to disk, including +-- swap, which is an important requirement for KES. +module Cardano.Crypto.DirectSerialise +where + +import Foreign.Ptr +import Foreign.C.Types + +-- | Direct deserialization from raw memory. +-- +-- @directDeserialise f@ should allocate a new value of type 'a', and +-- call @f@ with a pointer to the raw memory to be filled. @f@ may be called +-- multiple times, for data structures that store their data in multiple +-- non-contiguous blocks of memory. +-- +-- The order in which memory blocks are visited matters. +class DirectDeserialise m a where + directDeserialise :: (Ptr CChar -> CSize -> m ()) -> m a + +-- | Direct serialization to raw memory. +-- +-- @directSerialise f x@ should call @f@ to expose the raw memory underyling +-- @x@. For data types that store their data in multiple non-contiguous blocks +-- of memory, @f@ may be called multiple times, once for each block. +-- +-- The order in which memory blocks are visited matters. +class DirectSerialise m a where + directSerialise :: (Ptr CChar -> CSize -> m ()) -> a -> m () diff --git a/cardano-crypto-class/src/Cardano/Crypto/KES/CompactSingle.hs b/cardano-crypto-class/src/Cardano/Crypto/KES/CompactSingle.hs index cce1102f1..4c7fa2b19 100644 --- a/cardano-crypto-class/src/Cardano/Crypto/KES/CompactSingle.hs +++ b/cardano-crypto-class/src/Cardano/Crypto/KES/CompactSingle.hs @@ -60,7 +60,7 @@ import Cardano.Binary (FromCBOR (..), ToCBOR (..)) import Cardano.Crypto.Hash.Class import Cardano.Crypto.DSIGN.Class as DSIGN import Cardano.Crypto.KES.Class - +import Cardano.Crypto.DirectSerialise -- | A standard signature scheme is a forward-secure signature scheme with a -- single time period. @@ -227,3 +227,19 @@ instance (DSIGNMAlgorithm d, KnownNat (SizeSigKES (CompactSingleKES d))) => From slice :: Word -> Word -> ByteString -> ByteString slice offset size = BS.take (fromIntegral size) . BS.drop (fromIntegral offset) + +-- +-- Direct ser/deser +-- + +instance (DirectSerialise m (SignKeyDSIGNM d)) => DirectSerialise m (SignKeyKES (CompactSingleKES d)) where + directSerialise push (SignKeyCompactSingleKES sk) = directSerialise push sk + +instance (Monad m, DirectDeserialise m (SignKeyDSIGNM d)) => DirectDeserialise m (SignKeyKES (CompactSingleKES d)) where + directDeserialise pull = SignKeyCompactSingleKES <$!> directDeserialise pull + +instance (DirectSerialise m (VerKeyDSIGN d)) => DirectSerialise m (VerKeyKES (CompactSingleKES d)) where + directSerialise push (VerKeyCompactSingleKES sk) = directSerialise push sk + +instance (Monad m, DirectDeserialise m (VerKeyDSIGN d)) => DirectDeserialise m (VerKeyKES (CompactSingleKES d)) where + directDeserialise pull = VerKeyCompactSingleKES <$!> directDeserialise pull diff --git a/cardano-crypto-class/src/Cardano/Crypto/KES/CompactSum.hs b/cardano-crypto-class/src/Cardano/Crypto/KES/CompactSum.hs index ce37acbe8..a6649aed1 100644 --- a/cardano-crypto-class/src/Cardano/Crypto/KES/CompactSum.hs +++ b/cardano-crypto-class/src/Cardano/Crypto/KES/CompactSum.hs @@ -6,6 +6,7 @@ {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} @@ -86,7 +87,7 @@ module Cardano.Crypto.KES.CompactSum ( import Data.Proxy (Proxy(..)) import GHC.Generics (Generic) import qualified Data.ByteString as BS -import Control.Monad (guard) +import Control.Monad (guard, (<$!>)) import NoThunks.Class (NoThunks, OnlyCheckWhnfNamed (..)) import Cardano.Binary (FromCBOR (..), ToCBOR (..)) @@ -97,10 +98,16 @@ import Cardano.Crypto.KES.CompactSingle (CompactSingleKES) import Cardano.Crypto.Util import Cardano.Crypto.Libsodium.MLockedSeed import Cardano.Crypto.Libsodium +import Cardano.Crypto.Libsodium.Memory +import Cardano.Crypto.DirectSerialise + import Control.Monad.Trans.Maybe (MaybeT (..), runMaybeT) import Control.Monad.Trans (lift) +import Control.Monad.Class.MonadST +import Control.Monad.Class.MonadThrow import Control.DeepSeq (NFData (..)) import GHC.TypeLits (KnownNat, type (+), type (*)) +import Foreign.Ptr (castPtr) -- | A 2^0 period KES type CompactSum0KES d = CompactSingleKES d @@ -461,3 +468,54 @@ instance ( OptimizedKESAlgorithm d ) => FromCBOR (SigKES (CompactSumKES h d)) where fromCBOR = decodeSigKES + + +-- +-- Direct ser/deser +-- + +instance ( DirectSerialise m (SignKeyKES d) + , DirectSerialise m (VerKeyKES d) + , MonadST m + , KESAlgorithm d + ) => DirectSerialise m (SignKeyKES (CompactSumKES h d)) where + directSerialise push (SignKeyCompactSumKES sk r vk0 vk1) = do + directSerialise push sk + mlockedSeedUseAsCPtr r $ \ptr -> + push (castPtr ptr) (fromIntegral $ seedSizeKES (Proxy :: Proxy d)) + directSerialise push vk0 + directSerialise push vk1 + +instance ( DirectDeserialise m (SignKeyKES d) + , DirectDeserialise m (VerKeyKES d) + , MonadST m + , KESAlgorithm d + ) => DirectDeserialise m (SignKeyKES (CompactSumKES h d)) where + directDeserialise pull = do + sk <- directDeserialise pull + + r <- mlockedSeedNew + mlockedSeedUseAsCPtr r $ \ptr -> + pull (castPtr ptr) (fromIntegral $ seedSizeKES (Proxy :: Proxy d)) + + vk0 <- directDeserialise pull + vk1 <- directDeserialise pull + + return $! SignKeyCompactSumKES sk r vk0 vk1 + + +instance (MonadST m, MonadThrow m) + => DirectSerialise m (VerKeyKES (CompactSumKES h d)) where + directSerialise push (VerKeyCompactSumKES h) = + unpackByteStringCStringLen (hashToBytes h) $ \(ptr, len) -> + push (castPtr ptr) (fromIntegral len) + +instance (MonadST m, MonadThrow m, MonadFail m, HashAlgorithm h) + => DirectDeserialise m (VerKeyKES (CompactSumKES h d)) where + directDeserialise pull = do + let len :: Num a => a + len = fromIntegral $ sizeHash (Proxy @h) + allocaBytes len $ \ptr -> do + pull ptr len + bs <- packByteStringCStringLen (ptr, len) + maybe (fail "Invalid hash") return $! VerKeyCompactSumKES <$!> hashFromBytes bs diff --git a/cardano-crypto-class/src/Cardano/Crypto/KES/Mock.hs b/cardano-crypto-class/src/Cardano/Crypto/KES/Mock.hs index 4e2a91516..c04ae9e37 100644 --- a/cardano-crypto-class/src/Cardano/Crypto/KES/Mock.hs +++ b/cardano-crypto-class/src/Cardano/Crypto/KES/Mock.hs @@ -26,6 +26,8 @@ import GHC.TypeNats (Nat, KnownNat, natVal) import NoThunks.Class (NoThunks) import Control.Exception (assert) +import Control.Monad.Class.MonadST (MonadST) +import Control.Monad.Class.MonadThrow (MonadThrow) import Cardano.Binary (FromCBOR (..), ToCBOR (..)) @@ -37,6 +39,12 @@ import Cardano.Crypto.Libsodium.MLockedSeed import Cardano.Crypto.Libsodium ( mlsbAsByteString ) +import Cardano.Crypto.Libsodium.Memory + ( unpackByteStringCStringLen + , packByteStringCStringLen + , allocaBytes + ) +import Cardano.Crypto.DirectSerialise data MockKES (t :: Nat) @@ -194,3 +202,31 @@ instance KnownNat t => ToCBOR (SigKES (MockKES t)) where instance KnownNat t => FromCBOR (SigKES (MockKES t)) where fromCBOR = decodeSigKES + +instance (MonadST m, MonadThrow m, KnownNat t) => DirectSerialise m (SignKeyKES (MockKES t)) where + directSerialise put sk = do + let bs = rawSerialiseSignKeyMockKES sk + unpackByteStringCStringLen bs $ \(cstr, len) -> put cstr (fromIntegral len) + +instance (MonadST m, MonadThrow m, KnownNat t) => DirectDeserialise m (SignKeyKES (MockKES t)) where + directDeserialise pull = do + let len = fromIntegral $ sizeSignKeyKES (Proxy @(MockKES t)) + bs <- allocaBytes len $ \cstr -> do + pull cstr (fromIntegral len) + packByteStringCStringLen (cstr, len) + maybe (error "directDeserialise @(SignKeyKES (MockKES t))") return $ + rawDeserialiseSignKeyMockKES bs + +instance (MonadST m, MonadThrow m, KnownNat t) => DirectSerialise m (VerKeyKES (MockKES t)) where + directSerialise put sk = do + let bs = rawSerialiseVerKeyKES sk + unpackByteStringCStringLen bs $ \(cstr, len) -> put cstr (fromIntegral len) + +instance (MonadST m, MonadThrow m, KnownNat t) => DirectDeserialise m (VerKeyKES (MockKES t)) where + directDeserialise pull = do + let len = fromIntegral $ sizeVerKeyKES (Proxy @(MockKES t)) + bs <- allocaBytes len $ \cstr -> do + pull cstr (fromIntegral len) + packByteStringCStringLen (cstr, len) + maybe (error "directDeserialise @(VerKeyKES (MockKES t))") return $ + rawDeserialiseVerKeyKES bs diff --git a/cardano-crypto-class/src/Cardano/Crypto/KES/Simple.hs b/cardano-crypto-class/src/Cardano/Crypto/KES/Simple.hs index b8bfe2186..8ed2d5640 100644 --- a/cardano-crypto-class/src/Cardano/Crypto/KES/Simple.hs +++ b/cardano-crypto-class/src/Cardano/Crypto/KES/Simple.hs @@ -43,6 +43,7 @@ import Cardano.Crypto.KES.Class import Cardano.Crypto.Libsodium.MLockedSeed import Cardano.Crypto.Libsodium.MLockedBytes import Cardano.Crypto.Util +import Cardano.Crypto.DirectSerialise import Data.Unit.Strict (forceElemsToWHNF) data SimpleKES d (t :: Nat) @@ -249,3 +250,22 @@ instance (DSIGNMAlgorithm d => FromCBOR (SigKES (SimpleKES d t)) where fromCBOR = decodeSigKES +instance (Monad m, DirectSerialise m (VerKeyDSIGN d)) => DirectSerialise m (VerKeyKES (SimpleKES d t)) where + directSerialise push (VerKeySimpleKES vks) = + mapM_ (directSerialise push) vks + +instance (Monad m, DirectDeserialise m (VerKeyDSIGN d), KnownNat t) => DirectDeserialise m (VerKeyKES (SimpleKES d t)) where + directDeserialise pull = do + let duration = fromIntegral (natVal (Proxy :: Proxy t)) + vks <- Vec.replicateM duration (directDeserialise pull) + return $! VerKeySimpleKES $! vks + +instance (Monad m, DirectSerialise m (SignKeyDSIGNM d)) => DirectSerialise m (SignKeyKES (SimpleKES d t)) where + directSerialise push (SignKeySimpleKES sks) = + mapM_ (directSerialise push) sks + +instance (Monad m, DirectDeserialise m (SignKeyDSIGNM d), KnownNat t) => DirectDeserialise m (SignKeyKES (SimpleKES d t)) where + directDeserialise pull = do + let duration = fromIntegral (natVal (Proxy :: Proxy t)) + sks <- Vec.replicateM duration (directDeserialise pull) + return $! SignKeySimpleKES $! sks diff --git a/cardano-crypto-class/src/Cardano/Crypto/KES/Single.hs b/cardano-crypto-class/src/Cardano/Crypto/KES/Single.hs index 2d38fb527..e7f0364df 100644 --- a/cardano-crypto-class/src/Cardano/Crypto/KES/Single.hs +++ b/cardano-crypto-class/src/Cardano/Crypto/KES/Single.hs @@ -50,7 +50,7 @@ import Cardano.Binary (FromCBOR (..), ToCBOR (..)) import Cardano.Crypto.Hash.Class import Cardano.Crypto.DSIGN.Class as DSIGN import Cardano.Crypto.KES.Class - +import Cardano.Crypto.DirectSerialise -- | A standard signature scheme is a forward-secure signature scheme with a -- single time period. @@ -187,4 +187,19 @@ instance DSIGNMAlgorithm d => ToCBOR (SigKES (SingleKES d)) where instance DSIGNMAlgorithm d => FromCBOR (SigKES (SingleKES d)) where fromCBOR = decodeSigKES - {-# INLINE fromCBOR #-} + +-- +-- Direct ser/deser +-- + +instance (DirectSerialise m (SignKeyDSIGNM d)) => DirectSerialise m (SignKeyKES (SingleKES d)) where + directSerialise push (SignKeySingleKES sk) = directSerialise push sk + +instance (Monad m, DirectDeserialise m (SignKeyDSIGNM d)) => DirectDeserialise m (SignKeyKES (SingleKES d)) where + directDeserialise pull = SignKeySingleKES <$!> directDeserialise pull + +instance (DirectSerialise m (VerKeyDSIGN d)) => DirectSerialise m (VerKeyKES (SingleKES d)) where + directSerialise push (VerKeySingleKES sk) = directSerialise push sk + +instance (Monad m, DirectDeserialise m (VerKeyDSIGN d)) => DirectDeserialise m (VerKeyKES (SingleKES d)) where + directDeserialise pull = VerKeySingleKES <$!> directDeserialise pull diff --git a/cardano-crypto-class/src/Cardano/Crypto/KES/Sum.hs b/cardano-crypto-class/src/Cardano/Crypto/KES/Sum.hs index 300962d11..d9e53a948 100644 --- a/cardano-crypto-class/src/Cardano/Crypto/KES/Sum.hs +++ b/cardano-crypto-class/src/Cardano/Crypto/KES/Sum.hs @@ -65,10 +65,15 @@ import Cardano.Crypto.KES.Single (SingleKES) import Cardano.Crypto.Util import Cardano.Crypto.Libsodium.MLockedSeed import Cardano.Crypto.Libsodium +import Cardano.Crypto.Libsodium.Memory +import Cardano.Crypto.DirectSerialise + import Control.Monad.Trans.Maybe (MaybeT (..), runMaybeT) +import Control.Monad.Class.MonadST +import Control.Monad.Class.MonadThrow import Control.DeepSeq (NFData (..)) import GHC.TypeLits (KnownNat, type (+), type (*)) - +import Foreign.Ptr (castPtr) -- | A 2^0 period KES type Sum0KES d = SingleKES d @@ -383,4 +388,54 @@ instance (KESAlgorithm (SumKES h d), SodiumHashAlgorithm h, SizeHash h ~ SeedSiz instance (KESAlgorithm (SumKES h d), SodiumHashAlgorithm h, SizeHash h ~ SeedSizeKES d) => FromCBOR (SigKES (SumKES h d)) where fromCBOR = decodeSigKES - {-# INLINE fromCBOR #-} + + +-- +-- Direct ser/deser +-- + +instance ( DirectSerialise m (SignKeyKES d) + , DirectSerialise m (VerKeyKES d) + , MonadST m + , KESAlgorithm d + ) => DirectSerialise m (SignKeyKES (SumKES h d)) where + directSerialise push (SignKeySumKES sk r vk0 vk1) = do + directSerialise push sk + mlockedSeedUseAsCPtr r $ \ptr -> + push (castPtr ptr) (fromIntegral $ seedSizeKES (Proxy :: Proxy d)) + directSerialise push vk0 + directSerialise push vk1 + +instance ( DirectDeserialise m (SignKeyKES d) + , DirectDeserialise m (VerKeyKES d) + , MonadST m + , KESAlgorithm d + ) => DirectDeserialise m (SignKeyKES (SumKES h d)) where + directDeserialise pull = do + sk <- directDeserialise pull + + r <- mlockedSeedNew + mlockedSeedUseAsCPtr r $ \ptr -> + pull (castPtr ptr) (fromIntegral $ seedSizeKES (Proxy :: Proxy d)) + + vk0 <- directDeserialise pull + vk1 <- directDeserialise pull + + return $! SignKeySumKES sk r vk0 vk1 + + +instance (MonadST m, MonadThrow m) + => DirectSerialise m (VerKeyKES (SumKES h d)) where + directSerialise push (VerKeySumKES h) = + unpackByteStringCStringLen (hashToBytes h) $ \(ptr, len) -> + push (castPtr ptr) (fromIntegral len) + +instance (MonadST m, MonadThrow m, MonadFail m, HashAlgorithm h) + => DirectDeserialise m (VerKeyKES (SumKES h d)) where + directDeserialise pull = do + let len :: Num a => a + len = fromIntegral $ sizeHash (Proxy @h) + allocaBytes len $ \ptr -> do + pull ptr len + bs <- packByteStringCStringLen (ptr, len) + maybe (fail "Invalid hash") return $! VerKeySumKES <$!> hashFromBytes bs diff --git a/cardano-crypto-class/src/Cardano/Crypto/Libsodium/Memory.hs b/cardano-crypto-class/src/Cardano/Crypto/Libsodium/Memory.hs index a4405ef5d..4d681b11a 100644 --- a/cardano-crypto-class/src/Cardano/Crypto/Libsodium/Memory.hs +++ b/cardano-crypto-class/src/Cardano/Crypto/Libsodium/Memory.hs @@ -26,6 +26,7 @@ module Cardano.Crypto.Libsodium.Memory ( allocaBytes, -- * ByteString memory access, generalized to 'MonadST' + unpackByteStringCStringLen, packByteStringCStringLen, ) where diff --git a/cardano-crypto-class/src/Cardano/Crypto/Libsodium/Memory/Internal.hs b/cardano-crypto-class/src/Cardano/Crypto/Libsodium/Memory/Internal.hs index 68b57392f..553c782d8 100644 --- a/cardano-crypto-class/src/Cardano/Crypto/Libsodium/Memory/Internal.hs +++ b/cardano-crypto-class/src/Cardano/Crypto/Libsodium/Memory/Internal.hs @@ -34,6 +34,7 @@ module Cardano.Crypto.Libsodium.Memory.Internal ( allocaBytes, -- * ByteString memory access, generalized to 'MonadST' + unpackByteStringCStringLen, packByteStringCStringLen, -- * Helper @@ -46,11 +47,13 @@ import Control.Monad (when, void) import Control.Monad.Class.MonadST import Control.Monad.Class.MonadThrow (MonadThrow (bracket)) import Control.Monad.ST (RealWorld, ST) -import Control.Monad.ST.Unsafe (unsafeIOToST, unsafeSTToIO) +import Control.Monad.ST.Unsafe (unsafeIOToST) import Data.ByteString (ByteString) import qualified Data.ByteString as BS +import qualified Data.ByteString.Unsafe as BS import Data.Coerce (coerce) import Data.Typeable +import Data.Word (Word8) import Debug.Trace (traceShowM) import Foreign.C.Error (errnoToIOError, getErrno) import Foreign.C.String (CStringLen) @@ -61,7 +64,7 @@ import Foreign.ForeignPtr.Unsafe (unsafeForeignPtrToPtr) import qualified Foreign.Marshal.Alloc as Foreign import Foreign.Marshal.Utils (fillBytes) import Foreign.Ptr (Ptr, nullPtr, castPtr) -import Foreign.Storable (Storable (peek), sizeOf, alignment) +import Foreign.Storable (Storable (peek), sizeOf, alignment, pokeByteOff) import GHC.IO.Exception (ioException) import GHC.TypeLits (KnownNat, natVal) import NoThunks.Class (NoThunks, OnlyCheckWhnfNamed (..)) @@ -188,9 +191,29 @@ zeroMem ptr size = unsafeIOToMonadST . void $ c_memset (castPtr ptr) 0 size copyMem :: MonadST m => Ptr a -> Ptr a -> CSize -> m () copyMem dst src size = unsafeIOToMonadST . void $ c_memcpy (castPtr dst) (castPtr src) size -allocaBytes :: Int -> (Ptr a -> ST s b) -> ST s b -allocaBytes size f = - unsafeIOToST $ Foreign.allocaBytes size (unsafeSTToIO . f) +allocaBytes :: (MonadThrow m, MonadST m) => Int -> (Ptr a -> m b) -> m b +allocaBytes size = + bracket + (mallocBytes size) + free + +mallocBytes :: MonadST m => Int -> m (Ptr a) +mallocBytes size = + unsafeIOToMonadST $ Foreign.mallocBytes size + +free :: MonadST m => Ptr a -> m () +free = unsafeIOToMonadST . Foreign.free + +-- | Unpacks a ByteString into a temporary buffer and runs the provided 'ST' +-- function on it. +unpackByteStringCStringLen :: (MonadThrow m, MonadST m) => ByteString -> (CStringLen -> m a) -> m a +unpackByteStringCStringLen bs f = do + let len = BS.length bs + allocaBytes (len + 1) $ \buf -> do + unsafeIOToMonadST $ BS.unsafeUseAsCString bs $ \ptr -> do + copyMem buf ptr (fromIntegral len) + pokeByteOff buf len (0 :: Word8) + f (buf, len) packByteStringCStringLen :: MonadST m => CStringLen -> m ByteString packByteStringCStringLen = @@ -258,7 +281,6 @@ mlockedAllocaWith :: -> (Ptr a -> m b) -> m b mlockedAllocaWith allocator size = - bracket alloc free . flip withMLockedForeignPtr + bracket alloc finalizeMLockedForeignPtr . flip withMLockedForeignPtr where alloc = mlAllocate allocator size - free = finalizeMLockedForeignPtr diff --git a/cardano-crypto-tests/src/Test/Crypto/DSIGN.hs b/cardano-crypto-tests/src/Test/Crypto/DSIGN.hs index 1c81dcd46..b7bb0e759 100644 --- a/cardano-crypto-tests/src/Test/Crypto/DSIGN.hs +++ b/cardano-crypto-tests/src/Test/Crypto/DSIGN.hs @@ -27,6 +27,7 @@ import Test.QuickCheck ( forAllShow, forAllShrinkShow, ioProperty, + counterexample, ) import Test.Tasty (TestTree, testGroup, adjustOption) import Test.Tasty.QuickCheck (testProperty, QuickCheckTests) @@ -94,6 +95,7 @@ import Cardano.Crypto.DSIGN ( ) import Cardano.Binary (FromCBOR, ToCBOR) import Cardano.Crypto.PinnedSizedBytes (PinnedSizedBytes) +import Cardano.Crypto.DirectSerialise import Test.Crypto.Util ( Message, prop_raw_serialise, @@ -111,6 +113,9 @@ import Test.Crypto.Util ( showBadInputFor, Lock, withLock, + directSerialiseToBS, + directDeserialiseFromBS, + hexBS, ) import Cardano.Crypto.Libsodium.MLockedSeed @@ -362,6 +367,10 @@ testDSIGNMAlgorithm , FromCBOR (SigDSIGN v) , ContextDSIGN v ~ () , Signable v Message + , DirectSerialise IO (SignKeyDSIGNM v) + , DirectDeserialise IO (SignKeyDSIGNM v) + , DirectSerialise IO (VerKeyDSIGN v) + , DirectDeserialise IO (VerKeyDSIGN v) ) => Lock -> Proxy v @@ -451,6 +460,36 @@ testDSIGNMAlgorithm lock _ n = sig :: SigDSIGN v <- signDSIGNM () msg sk return $ prop_cbor_direct_vs_class encodeSigDSIGN sig ] + , testGroup "DirectSerialise" + [ testProperty "VerKey" $ + ioPropertyWithSK @v lock $ \sk -> do + vk :: VerKeyDSIGN v <- deriveVerKeyDSIGNM sk + serialized <- directSerialiseToBS (fromIntegral $ sizeVerKeyDSIGN (Proxy @v)) vk + vk' <- directDeserialiseFromBS serialized + return $ vk === vk' + , testProperty "SignKey" $ + ioPropertyWithSK @v lock $ \sk -> do + serialized <- directSerialiseToBS (fromIntegral $ sizeSignKeyDSIGN (Proxy @v)) sk + sk' <- directDeserialiseFromBS serialized + equals <- sk ==! sk' + forgetSignKeyDSIGNM sk' + return $ + counterexample ("Serialized: " ++ hexBS serialized ++ " (length: " ++ show (BS.length serialized) ++ ")") $ + equals + ] + , testGroup "DirectSerialise matches raw" + [ testProperty "VerKey" $ + ioPropertyWithSK @v lock $ \sk -> do + vk :: VerKeyDSIGN v <- deriveVerKeyDSIGNM sk + direct <- directSerialiseToBS (fromIntegral $ sizeVerKeyDSIGN (Proxy @v)) vk + let raw = rawSerialiseVerKeyDSIGN vk + return $ direct === raw + , testProperty "SignKey" $ + ioPropertyWithSK @v lock $ \sk -> do + direct <- directSerialiseToBS (fromIntegral $ sizeSignKeyDSIGN (Proxy @v)) sk + raw <- rawSerialiseSignKeyDSIGNM sk + return $ direct === raw + ] ] , testGroup "verify" @@ -477,6 +516,24 @@ testDSIGNMAlgorithm lock _ n = ioPropertyWithSK @v lock $ prop_no_thunks_IO . return , testProperty "Sig" $ \(msg :: Message) -> ioPropertyWithSK @v lock $ prop_no_thunks_IO . signDSIGNM () msg + , testProperty "SignKey DirectSerialise" $ + ioPropertyWithSK @v lock $ \sk -> do + direct <- directSerialiseToBS (fromIntegral $ sizeSignKeyDSIGN (Proxy @v)) sk + prop_no_thunks_IO (return $! direct) + , testProperty "SignKey DirectDeserialise" $ + ioPropertyWithSK @v lock $ \sk -> do + direct <- directSerialiseToBS (fromIntegral $ sizeSignKeyDSIGN (Proxy @v)) sk + prop_no_thunks_IO (directDeserialiseFromBS @IO @(SignKeyDSIGNM v) $! direct) + , testProperty "VerKey DirectSerialise" $ + ioPropertyWithSK @v lock $ \sk -> do + vk <- deriveVerKeyDSIGNM sk + direct <- directSerialiseToBS (fromIntegral $ sizeVerKeyDSIGN (Proxy @v)) vk + prop_no_thunks_IO (return $! direct) + , testProperty "VerKey DirectDeserialise" $ + ioPropertyWithSK @v lock $ \sk -> do + vk <- deriveVerKeyDSIGNM sk + direct <- directSerialiseToBS (fromIntegral $ sizeVerKeyDSIGN (Proxy @v)) vk + prop_no_thunks_IO (directDeserialiseFromBS @IO @(VerKeyDSIGN v) $! direct) ] ] diff --git a/cardano-crypto-tests/src/Test/Crypto/KES.hs b/cardano-crypto-tests/src/Test/Crypto/KES.hs index fbe1612e7..0e521edfd 100644 --- a/cardano-crypto-tests/src/Test/Crypto/KES.hs +++ b/cardano-crypto-tests/src/Test/Crypto/KES.hs @@ -29,21 +29,29 @@ import Data.List (foldl') import qualified Data.ByteString as BS import Data.Set (Set) import qualified Data.Set as Set -import Foreign.Ptr (WordPtr) +import Foreign.Ptr (WordPtr, plusPtr) import Data.IORef -import GHC.TypeNats (KnownNat) +import GHC.TypeNats (KnownNat, natVal) -import Control.Tracer +import Control.Concurrent.MVar (newMVar, takeMVar, putMVar) +import Control.Monad (void, when) +import Control.Monad.Class.MonadST import Control.Monad.Class.MonadThrow import Control.Monad.IO.Class (liftIO) -import Control.Monad (void) +import Control.Tracer import Cardano.Crypto.DSIGN hiding (Signable) import Cardano.Crypto.Hash import Cardano.Crypto.KES +import Cardano.Crypto.DirectSerialise (DirectSerialise, directSerialise, DirectDeserialise) import Cardano.Crypto.Util (SignableRepresentation(..)) import Cardano.Crypto.Libsodium import Cardano.Crypto.Libsodium.MLockedSeed +import Cardano.Crypto.Libsodium.Memory + ( copyMem + , allocaBytes + , packByteStringCStringLen + ) import Cardano.Crypto.PinnedSizedBytes import Test.QuickCheck @@ -67,6 +75,8 @@ import Test.Crypto.Util ( noExceptionsThrown, Lock, withLock, + directSerialiseToBS, + directDeserialiseFromBS, ) import Test.Crypto.EqST import Test.Crypto.Instances (withMLockedSeedFromPSB) @@ -198,6 +208,10 @@ testKESAlgorithm , Signable v ~ SignableRepresentation , ContextKES v ~ () , UnsoundKESAlgorithm v + , DirectSerialise IO (SignKeyKES v) + , DirectSerialise IO (VerKeyKES v) + , DirectDeserialise IO (SignKeyKES v) + , DirectDeserialise IO (VerKeyKES v) ) => Lock -> String @@ -225,9 +239,32 @@ testKESAlgorithm lock n = , testProperty "Sig" $ \seedPSB (msg :: Message) -> ioProperty $ withLock lock $ fmap conjoin $ withAllUpdatesKES @v seedPSB $ \t sk -> do prop_no_thunks_IO (signKES () t msg sk) + + , testProperty "VerKey DirectSerialise" $ + ioPropertyWithSK @v lock $ \sk -> do + vk :: VerKeyKES v <- deriveVerKeyKES sk + direct <- directSerialiseToBS (fromIntegral $ sizeVerKeyKES (Proxy @v)) vk + prop_no_thunks_IO (return $! direct) + , testProperty "SignKey DirectSerialise" $ + ioPropertyWithSK @v lock $ \sk -> do + direct <- directSerialiseToBS (fromIntegral $ sizeSignKeyKES (Proxy @v)) sk + prop_no_thunks_IO (return $! direct) + , testProperty "VerKey DirectDeserialise" $ + ioPropertyWithSK @v lock $ \sk -> do + vk :: VerKeyKES v <- deriveVerKeyKES sk + direct <- directSerialiseToBS (fromIntegral $ sizeVerKeyKES (Proxy @v)) $! vk + prop_no_thunks_IO (directDeserialiseFromBS @IO @(VerKeyKES v) $! direct) + , testProperty "SignKey DirectDeserialise" $ + ioPropertyWithSK @v lock $ \sk -> do + direct <- directSerialiseToBS (fromIntegral $ sizeSignKeyKES (Proxy @v)) sk + bracket + (directDeserialiseFromBS @IO @(SignKeyKES v) $! direct) + forgetSignKeyKES + (prop_no_thunks_IO . return) ] , testProperty "same VerKey " $ prop_deriveVerKeyKES @v + , testProperty "no forgotten chunks in signkey" $ prop_noErasedBlocksInKey (Proxy @v) , testGroup "serialisation" [ testGroup "raw ser only" @@ -313,6 +350,38 @@ testKESAlgorithm lock n = sig :: SigKES v <- signKES () 0 msg sk return $ prop_cbor_direct_vs_class encodeSigKES sig ] + + , testGroup "DirectSerialise" + [ testProperty "VerKey" $ + ioPropertyWithSK @v lock $ \sk -> do + vk :: VerKeyKES v <- deriveVerKeyKES sk + serialized <- directSerialiseToBS (fromIntegral $ sizeVerKeyKES (Proxy @v)) vk + vk' <- directDeserialiseFromBS serialized + return $ vk === vk' + , testProperty "SignKey" $ + ioPropertyWithSK @v lock $ \sk -> do + serialized <- directSerialiseToBS (fromIntegral $ sizeSignKeyKES (Proxy @v)) sk + equals <- bracket + (directDeserialiseFromBS serialized) + forgetSignKeyKES + (\sk' -> sk ==! sk') + return $ + counterexample ("Serialized: " ++ hexBS serialized ++ " (length: " ++ show (BS.length serialized) ++ ")") $ + equals + ] + , testGroup "DirectSerialise matches raw" + [ testProperty "VerKey" $ + ioPropertyWithSK @v lock $ \sk -> do + vk :: VerKeyKES v <- deriveVerKeyKES sk + direct <- directSerialiseToBS (fromIntegral $ sizeVerKeyKES (Proxy @v)) vk + let raw = rawSerialiseVerKeyKES vk + return $ direct === raw + , testProperty "SignKey" $ + ioPropertyWithSK @v lock $ \sk -> do + direct <- directSerialiseToBS (fromIntegral $ sizeSignKeyKES (Proxy @v)) sk + raw <- rawSerialiseSignKeyKES sk + return $ direct === raw + ] ] , testGroup "verify" @@ -676,3 +745,56 @@ withAllUpdatesKES seedPSB f = withMLockedSeedFromPSB seedPSB $ \seed -> do xs <- go sk' (t + 1) return $ x:xs +withNullSeed :: forall m n a. (MonadThrow m, MonadST m, KnownNat n) => (MLockedSeed n -> m a) -> m a +withNullSeed = bracket + (MLockedSeed <$> mlsbFromByteString (BS.replicate (fromIntegral $ natVal (Proxy @n)) 0)) + mlockedSeedFinalize + +withNullSK :: forall m v a. (KESAlgorithm v, MonadThrow m, MonadST m) + => (SignKeyKES v -> m a) -> m a +withNullSK = bracket + (withNullSeed genKeyKES) + forgetSignKeyKES + + +-- | This test detects whether a sign key contains references to pool-allocated +-- blocks of memory that have been forgotten by the time the key is complete. +-- We do this based on the fact that the pooled allocator erases memory blocks +-- by overwriting them with series of 0xff bytes; thus we cut the serialized +-- key up into chunks of 16 bytes, and if any of those chunks is entirely +-- filled with 0xff bytes, we assume that we're looking at erased memory. +prop_noErasedBlocksInKey + :: forall v. + UnsoundKESAlgorithm v + => DirectSerialise IO (SignKeyKES v) + => Proxy v + -> Property +prop_noErasedBlocksInKey kesAlgorithm = + ioProperty . withNullSK @IO @v $ \sk -> do + let size :: Int = fromIntegral $ sizeSignKeyKES kesAlgorithm + serialized <- allocaBytes size $ \ptr -> do + positionVar <- newMVar (0 :: Int) + directSerialise (\buf nCSize -> do + let n = fromIntegral nCSize :: Int + bracket + (takeMVar positionVar) + (putMVar positionVar . (+ n)) + (\position -> do + when (n + position > size) (error "Buffer size exceeded") + copyMem (plusPtr ptr position) buf (fromIntegral n) + ) + ) + sk + packByteStringCStringLen (ptr, size) + forgetSignKeyKES sk + return $ counterexample (hexBS serialized) $ not (hasLongRunOfFF serialized) + +hasLongRunOfFF :: ByteString -> Bool +hasLongRunOfFF bs + | BS.length bs < 16 + = False + | otherwise + = let first16 = BS.take 16 bs + remainder = BS.drop 16 bs + in (BS.all (== 0xFF) first16) || hasLongRunOfFF remainder + diff --git a/cardano-crypto-tests/src/Test/Crypto/Util.hs b/cardano-crypto-tests/src/Test/Crypto/Util.hs index c0c7d7441..5b61969f9 100644 --- a/cardano-crypto-tests/src/Test/Crypto/Util.hs +++ b/cardano-crypto-tests/src/Test/Crypto/Util.hs @@ -59,6 +59,10 @@ module Test.Crypto.Util , noExceptionsThrown , doesNotThrow + -- * Direct ser/deser helpers + , directSerialiseToBS + , directDeserialiseFromBS + -- * Error handling , eitherShowError @@ -95,12 +99,19 @@ import Codec.CBOR.Write ( ) import Cardano.Crypto.Seed (Seed, mkSeedFromBytes) import Cardano.Crypto.Util (SignableRepresentation(..)) +import Cardano.Crypto.DirectSerialise import Crypto.Random ( ChaChaDRG , MonadPseudoRandom , drgNewTest , withDRG ) +import Cardano.Crypto.Libsodium.Memory + ( unpackByteStringCStringLen + , packByteStringCStringLen + , allocaBytes + , copyMem + ) import Data.ByteString (ByteString) import qualified Data.ByteString as BS import qualified Data.ByteString.Char8 as BS8 @@ -130,7 +141,19 @@ import qualified Test.QuickCheck.Gen as Gen import Control.Monad (guard, when) import GHC.TypeLits (Nat, KnownNat, natVal) import Formatting.Buildable (Buildable (..), build) -import Control.Concurrent.Class.MonadMVar (MVar, withMVar, newMVar) +import Foreign.Ptr (Ptr, plusPtr) +import Foreign.C.Types (CChar, CSize) +import Control.Monad.Class.MonadST (MonadST) +import Control.Monad.Class.MonadThrow (MonadThrow) +import Control.Concurrent.Class.MonadMVar + ( MVar + , withMVar + , newMVar + , putMVar + , takeMVar + , newMVar + , MonadMVar + ) import GHC.Stack (HasCallStack) -------------------------------------------------------------------------------- @@ -375,3 +398,44 @@ mkLock = Lock <$> newMVar () eitherShowError :: (HasCallStack, Show e) => Either e a -> IO a eitherShowError (Left e) = error (show e) eitherShowError (Right a) = return a + +-------------------------------------------------------------------------------- +-- Helpers for direct ser/deser +-------------------------------------------------------------------------------- + +directSerialiseToBS :: forall m a. + DirectSerialise m a + => MonadST m + => MonadThrow m + => MonadMVar m + => Int -> a -> m ByteString +directSerialiseToBS dstsize val = do + allocaBytes dstsize $ \dst -> do + posVar <- newMVar 0 + let pusher :: Ptr CChar -> CSize -> m () + pusher src srcsize = do + pos <- takeMVar posVar + let pos' = pos + fromIntegral srcsize + when (pos' > dstsize) (error "Buffer overrun") + copyMem (plusPtr dst pos) src (fromIntegral srcsize) + putMVar posVar pos' + directSerialise pusher val + packByteStringCStringLen (dst, fromIntegral dstsize) + +directDeserialiseFromBS :: forall m a. + DirectDeserialise m a + => MonadST m + => MonadThrow m + => MonadMVar m + => ByteString -> m a +directDeserialiseFromBS bs = do + unpackByteStringCStringLen bs $ \(src, srcsize) -> do + posVar <- newMVar 0 + let puller :: Ptr CChar -> CSize -> m () + puller dst dstsize = do + pos <- takeMVar posVar + let pos' = pos + fromIntegral dstsize + when (pos' > srcsize) (error "Buffer overrun") + copyMem dst (plusPtr src pos) (fromIntegral dstsize) + putMVar posVar pos' + directDeserialise puller From c3fa788738d256c1eedbfd44846cc61cb8debb79 Mon Sep 17 00:00:00 2001 From: Tobias Dammers Date: Mon, 21 Aug 2023 17:46:00 +0200 Subject: [PATCH 2/9] Cryptographic RNG for MLockedSeed --- .../src/Cardano/Crypto/Libsodium/C.hs | 5 +++++ .../Cardano/Crypto/Libsodium/MLockedSeed.hs | 21 +++++++++++++++++-- 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/cardano-crypto-class/src/Cardano/Crypto/Libsodium/C.hs b/cardano-crypto-class/src/Cardano/Crypto/Libsodium/C.hs index 1ce55c66b..3fe9623ec 100644 --- a/cardano-crypto-class/src/Cardano/Crypto/Libsodium/C.hs +++ b/cardano-crypto-class/src/Cardano/Crypto/Libsodium/C.hs @@ -29,6 +29,8 @@ module Cardano.Crypto.Libsodium.C ( c_crypto_sign_ed25519_detached, c_crypto_sign_ed25519_verify_detached, c_crypto_sign_ed25519_sk_to_pk, + -- * RNG + c_sodium_randombytes_buf, -- * Helpers c_sodium_compare, -- * Constants @@ -182,3 +184,6 @@ foreign import capi unsafe "sodium.h crypto_sign_ed25519_sk_to_pk" c_crypto_sign -- -- foreign import capi unsafe "sodium.h sodium_compare" c_sodium_compare :: Ptr a -> Ptr a -> CSize -> IO Int + +-- | @void randombytes_buf(void * const buf, const size_t size);@ +foreign import capi unsafe "sodium/randombytes.h randombytes_buf" c_sodium_randombytes_buf :: Ptr a -> CSize -> IO () diff --git a/cardano-crypto-class/src/Cardano/Crypto/Libsodium/MLockedSeed.hs b/cardano-crypto-class/src/Cardano/Crypto/Libsodium/MLockedSeed.hs index 5fb8c600d..cb9520b21 100644 --- a/cardano-crypto-class/src/Cardano/Crypto/Libsodium/MLockedSeed.hs +++ b/cardano-crypto-class/src/Cardano/Crypto/Libsodium/MLockedSeed.hs @@ -2,7 +2,8 @@ {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE MultiParamTypeClasses #-} -{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} module Cardano.Crypto.Libsodium.MLockedSeed where @@ -20,12 +21,16 @@ import Cardano.Crypto.Libsodium.Memory ( MLockedAllocator, mlockedMalloc, ) +import Cardano.Crypto.Libsodium.C ( + c_sodium_randombytes_buf, + ) import Cardano.Foreign (SizedPtr) import Control.DeepSeq (NFData) import Control.Monad.Class.MonadST (MonadST) +import Data.Proxy (Proxy (..)) import Data.Word (Word8) import Foreign.Ptr (Ptr) -import GHC.TypeNats (KnownNat) +import GHC.TypeNats (KnownNat, natVal) import NoThunks.Class (NoThunks) -- | A seed of size @n@, stored in mlocked memory. This is required to prevent @@ -66,6 +71,18 @@ mlockedSeedNewZeroWith :: (KnownNat n, MonadST m) => MLockedAllocator m -> m (ML mlockedSeedNewZeroWith allocator = MLockedSeed <$> mlsbNewZeroWith allocator +mlockedSeedNewRandom :: forall n. (KnownNat n) => IO (MLockedSeed n) +mlockedSeedNewRandom = mlockedSeedNewRandomWith mlockedMalloc + +mlockedSeedNewRandomWith :: forall n. (KnownNat n) => MLockedAllocator IO -> IO (MLockedSeed n) +mlockedSeedNewRandomWith allocator = do + mls <- MLockedSeed <$> mlsbNewZeroWith allocator + mlockedSeedUseAsCPtr mls $ \dst -> do + c_sodium_randombytes_buf dst size + return mls + where + size = fromIntegral $ natVal (Proxy @n) + mlockedSeedFinalize :: (MonadST m) => MLockedSeed n -> m () mlockedSeedFinalize = mlsbFinalize . mlockedSeedMLSB From 6207d70d9422d282fc2a21a42800a64aebf004aa Mon Sep 17 00:00:00 2001 From: Tobias Dammers Date: Tue, 29 Aug 2023 12:12:49 +0200 Subject: [PATCH 3/9] Fix incorrect usage of mlocked memory in MockKES --- cardano-crypto-class/src/Cardano/Crypto/KES/Mock.hs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/cardano-crypto-class/src/Cardano/Crypto/KES/Mock.hs b/cardano-crypto-class/src/Cardano/Crypto/KES/Mock.hs index c04ae9e37..923b098a8 100644 --- a/cardano-crypto-class/src/Cardano/Crypto/KES/Mock.hs +++ b/cardano-crypto-class/src/Cardano/Crypto/KES/Mock.hs @@ -37,7 +37,7 @@ import Cardano.Crypto.KES.Class import Cardano.Crypto.Util import Cardano.Crypto.Libsodium.MLockedSeed import Cardano.Crypto.Libsodium - ( mlsbAsByteString + ( mlsbToByteString ) import Cardano.Crypto.Libsodium.Memory ( unpackByteStringCStringLen @@ -159,7 +159,8 @@ instance KnownNat t => KESAlgorithm (MockKES t) where -- genKeyKESWith _allocator seed = do - let vk = VerKeyMockKES (runMonadRandomWithSeed (mkSeedFromBytes . mlsbAsByteString . mlockedSeedMLSB $ seed) getRandomWord64) + seedBS <- mlsbToByteString . mlockedSeedMLSB $ seed + let vk = VerKeyMockKES (runMonadRandomWithSeed (mkSeedFromBytes seedBS) getRandomWord64) return $! SignKeyMockKES vk 0 forgetSignKeyKESWith _ = const $ return () From c4c59d0db2fcc89f4d6428d3d82c68a9f38d58ab Mon Sep 17 00:00:00 2001 From: Tobias Dammers Date: Tue, 31 Oct 2023 11:06:00 +0100 Subject: [PATCH 4/9] Address review comment --- .../src/Cardano/Crypto/DSIGN/Ed25519.hs | 14 ++++------- .../src/Cardano/Crypto/DirectSerialise.hs | 12 ++++----- .../src/Cardano/Crypto/KES/CompactSingle.hs | 8 +++--- .../src/Cardano/Crypto/KES/CompactSum.hs | 25 ++++++++----------- .../src/Cardano/Crypto/KES/Mock.hs | 10 +++----- .../src/Cardano/Crypto/KES/Simple.hs | 8 +++--- .../src/Cardano/Crypto/KES/Single.hs | 8 +++--- .../src/Cardano/Crypto/KES/Sum.hs | 25 ++++++++----------- cardano-crypto-tests/src/Test/Crypto/DSIGN.hs | 8 +++--- cardano-crypto-tests/src/Test/Crypto/KES.hs | 10 ++++---- cardano-crypto-tests/src/Test/Crypto/Util.hs | 4 +-- 11 files changed, 58 insertions(+), 74 deletions(-) diff --git a/cardano-crypto-class/src/Cardano/Crypto/DSIGN/Ed25519.hs b/cardano-crypto-class/src/Cardano/Crypto/DSIGN/Ed25519.hs index 26ac94750..622d0eae1 100644 --- a/cardano-crypto-class/src/Cardano/Crypto/DSIGN/Ed25519.hs +++ b/cardano-crypto-class/src/Cardano/Crypto/DSIGN/Ed25519.hs @@ -370,9 +370,7 @@ instance TypeError ('Text "CBOR decoding would violate mlocking guarantees") => FromCBOR (SignKeyDSIGNM Ed25519DSIGN) where fromCBOR = error "unsupported" -instance ( MonadThrow m - , MonadST m - ) => DirectSerialise m (SignKeyDSIGNM Ed25519DSIGN) where +instance DirectSerialise (SignKeyDSIGNM Ed25519DSIGN) where -- /Note:/ We only serialize the 32-byte seed, not the full 64-byte key. The -- latter contains both the seed and the 32-byte verification key, which is -- convenient, but redundant, since we can always reconstruct it from the @@ -387,11 +385,9 @@ instance ( MonadThrow m (castPtr ptr) (fromIntegral $ seedSizeDSIGN (Proxy @Ed25519DSIGN))) -instance ( MonadThrow m - , MonadST m - ) => DirectDeserialise m (SignKeyDSIGNM Ed25519DSIGN) where +instance DirectDeserialise (SignKeyDSIGNM Ed25519DSIGN) where -- /Note:/ We only serialize the 32-byte seed, not the full 64-byte key. See - -- the DirectSerialise m instance above. + -- the DirectSerialise instance above. directDeserialise pull = do bracket mlockedSeedNew @@ -404,14 +400,14 @@ instance ( MonadThrow m genKeyDSIGNM seed ) -instance MonadST m => DirectSerialise m (VerKeyDSIGN Ed25519DSIGN) where +instance DirectSerialise (VerKeyDSIGN Ed25519DSIGN) where directSerialise push (VerKeyEd25519DSIGN psb) = do psbUseAsCPtrLen psb $ \ptr _ -> push (castPtr ptr) (fromIntegral $ sizeVerKeyDSIGN (Proxy @Ed25519DSIGN)) -instance MonadST m => DirectDeserialise m (VerKeyDSIGN Ed25519DSIGN) where +instance DirectDeserialise (VerKeyDSIGN Ed25519DSIGN) where directDeserialise pull = do psb <- psbCreate $ \ptr -> pull diff --git a/cardano-crypto-class/src/Cardano/Crypto/DirectSerialise.hs b/cardano-crypto-class/src/Cardano/Crypto/DirectSerialise.hs index d879778ae..7b2c5e787 100644 --- a/cardano-crypto-class/src/Cardano/Crypto/DirectSerialise.hs +++ b/cardano-crypto-class/src/Cardano/Crypto/DirectSerialise.hs @@ -1,5 +1,3 @@ -{-# LANGUAGE MultiParamTypeClasses #-} - -- | Direct (de-)serialisation to / from raw memory. -- -- The purpose of the typeclasses in this module is to abstract over data @@ -18,6 +16,8 @@ where import Foreign.Ptr import Foreign.C.Types +import Control.Monad.Class.MonadThrow (MonadThrow) +import Control.Monad.Class.MonadST (MonadST) -- | Direct deserialization from raw memory. -- @@ -27,8 +27,8 @@ import Foreign.C.Types -- non-contiguous blocks of memory. -- -- The order in which memory blocks are visited matters. -class DirectDeserialise m a where - directDeserialise :: (Ptr CChar -> CSize -> m ()) -> m a +class DirectDeserialise a where + directDeserialise :: (MonadST m, MonadThrow m) => (Ptr CChar -> CSize -> m ()) -> m a -- | Direct serialization to raw memory. -- @@ -37,5 +37,5 @@ class DirectDeserialise m a where -- of memory, @f@ may be called multiple times, once for each block. -- -- The order in which memory blocks are visited matters. -class DirectSerialise m a where - directSerialise :: (Ptr CChar -> CSize -> m ()) -> a -> m () +class DirectSerialise a where + directSerialise :: (MonadST m, MonadThrow m) => (Ptr CChar -> CSize -> m ()) -> a -> m () diff --git a/cardano-crypto-class/src/Cardano/Crypto/KES/CompactSingle.hs b/cardano-crypto-class/src/Cardano/Crypto/KES/CompactSingle.hs index 4c7fa2b19..6bc744e9a 100644 --- a/cardano-crypto-class/src/Cardano/Crypto/KES/CompactSingle.hs +++ b/cardano-crypto-class/src/Cardano/Crypto/KES/CompactSingle.hs @@ -232,14 +232,14 @@ slice offset size = BS.take (fromIntegral size) -- Direct ser/deser -- -instance (DirectSerialise m (SignKeyDSIGNM d)) => DirectSerialise m (SignKeyKES (CompactSingleKES d)) where +instance (DirectSerialise (SignKeyDSIGNM d)) => DirectSerialise (SignKeyKES (CompactSingleKES d)) where directSerialise push (SignKeyCompactSingleKES sk) = directSerialise push sk -instance (Monad m, DirectDeserialise m (SignKeyDSIGNM d)) => DirectDeserialise m (SignKeyKES (CompactSingleKES d)) where +instance (DirectDeserialise (SignKeyDSIGNM d)) => DirectDeserialise (SignKeyKES (CompactSingleKES d)) where directDeserialise pull = SignKeyCompactSingleKES <$!> directDeserialise pull -instance (DirectSerialise m (VerKeyDSIGN d)) => DirectSerialise m (VerKeyKES (CompactSingleKES d)) where +instance (DirectSerialise (VerKeyDSIGN d)) => DirectSerialise (VerKeyKES (CompactSingleKES d)) where directSerialise push (VerKeyCompactSingleKES sk) = directSerialise push sk -instance (Monad m, DirectDeserialise m (VerKeyDSIGN d)) => DirectDeserialise m (VerKeyKES (CompactSingleKES d)) where +instance (DirectDeserialise (VerKeyDSIGN d)) => DirectDeserialise (VerKeyKES (CompactSingleKES d)) where directDeserialise pull = VerKeyCompactSingleKES <$!> directDeserialise pull diff --git a/cardano-crypto-class/src/Cardano/Crypto/KES/CompactSum.hs b/cardano-crypto-class/src/Cardano/Crypto/KES/CompactSum.hs index a6649aed1..47bc62ae6 100644 --- a/cardano-crypto-class/src/Cardano/Crypto/KES/CompactSum.hs +++ b/cardano-crypto-class/src/Cardano/Crypto/KES/CompactSum.hs @@ -103,8 +103,6 @@ import Cardano.Crypto.DirectSerialise import Control.Monad.Trans.Maybe (MaybeT (..), runMaybeT) import Control.Monad.Trans (lift) -import Control.Monad.Class.MonadST -import Control.Monad.Class.MonadThrow import Control.DeepSeq (NFData (..)) import GHC.TypeLits (KnownNat, type (+), type (*)) import Foreign.Ptr (castPtr) @@ -474,11 +472,10 @@ instance ( OptimizedKESAlgorithm d -- Direct ser/deser -- -instance ( DirectSerialise m (SignKeyKES d) - , DirectSerialise m (VerKeyKES d) - , MonadST m +instance ( DirectSerialise (SignKeyKES d) + , DirectSerialise (VerKeyKES d) , KESAlgorithm d - ) => DirectSerialise m (SignKeyKES (CompactSumKES h d)) where + ) => DirectSerialise (SignKeyKES (CompactSumKES h d)) where directSerialise push (SignKeyCompactSumKES sk r vk0 vk1) = do directSerialise push sk mlockedSeedUseAsCPtr r $ \ptr -> @@ -486,11 +483,10 @@ instance ( DirectSerialise m (SignKeyKES d) directSerialise push vk0 directSerialise push vk1 -instance ( DirectDeserialise m (SignKeyKES d) - , DirectDeserialise m (VerKeyKES d) - , MonadST m +instance ( DirectDeserialise (SignKeyKES d) + , DirectDeserialise (VerKeyKES d) , KESAlgorithm d - ) => DirectDeserialise m (SignKeyKES (CompactSumKES h d)) where + ) => DirectDeserialise (SignKeyKES (CompactSumKES h d)) where directDeserialise pull = do sk <- directDeserialise pull @@ -504,18 +500,17 @@ instance ( DirectDeserialise m (SignKeyKES d) return $! SignKeyCompactSumKES sk r vk0 vk1 -instance (MonadST m, MonadThrow m) - => DirectSerialise m (VerKeyKES (CompactSumKES h d)) where +instance DirectSerialise (VerKeyKES (CompactSumKES h d)) where directSerialise push (VerKeyCompactSumKES h) = unpackByteStringCStringLen (hashToBytes h) $ \(ptr, len) -> push (castPtr ptr) (fromIntegral len) -instance (MonadST m, MonadThrow m, MonadFail m, HashAlgorithm h) - => DirectDeserialise m (VerKeyKES (CompactSumKES h d)) where +instance (HashAlgorithm h) + => DirectDeserialise (VerKeyKES (CompactSumKES h d)) where directDeserialise pull = do let len :: Num a => a len = fromIntegral $ sizeHash (Proxy @h) allocaBytes len $ \ptr -> do pull ptr len bs <- packByteStringCStringLen (ptr, len) - maybe (fail "Invalid hash") return $! VerKeyCompactSumKES <$!> hashFromBytes bs + maybe (error "Invalid hash") return $! VerKeyCompactSumKES <$!> hashFromBytes bs diff --git a/cardano-crypto-class/src/Cardano/Crypto/KES/Mock.hs b/cardano-crypto-class/src/Cardano/Crypto/KES/Mock.hs index 923b098a8..5748cb7bd 100644 --- a/cardano-crypto-class/src/Cardano/Crypto/KES/Mock.hs +++ b/cardano-crypto-class/src/Cardano/Crypto/KES/Mock.hs @@ -26,8 +26,6 @@ import GHC.TypeNats (Nat, KnownNat, natVal) import NoThunks.Class (NoThunks) import Control.Exception (assert) -import Control.Monad.Class.MonadST (MonadST) -import Control.Monad.Class.MonadThrow (MonadThrow) import Cardano.Binary (FromCBOR (..), ToCBOR (..)) @@ -204,12 +202,12 @@ instance KnownNat t => ToCBOR (SigKES (MockKES t)) where instance KnownNat t => FromCBOR (SigKES (MockKES t)) where fromCBOR = decodeSigKES -instance (MonadST m, MonadThrow m, KnownNat t) => DirectSerialise m (SignKeyKES (MockKES t)) where +instance (KnownNat t) => DirectSerialise (SignKeyKES (MockKES t)) where directSerialise put sk = do let bs = rawSerialiseSignKeyMockKES sk unpackByteStringCStringLen bs $ \(cstr, len) -> put cstr (fromIntegral len) -instance (MonadST m, MonadThrow m, KnownNat t) => DirectDeserialise m (SignKeyKES (MockKES t)) where +instance (KnownNat t) => DirectDeserialise (SignKeyKES (MockKES t)) where directDeserialise pull = do let len = fromIntegral $ sizeSignKeyKES (Proxy @(MockKES t)) bs <- allocaBytes len $ \cstr -> do @@ -218,12 +216,12 @@ instance (MonadST m, MonadThrow m, KnownNat t) => DirectDeserialise m (SignKeyKE maybe (error "directDeserialise @(SignKeyKES (MockKES t))") return $ rawDeserialiseSignKeyMockKES bs -instance (MonadST m, MonadThrow m, KnownNat t) => DirectSerialise m (VerKeyKES (MockKES t)) where +instance (KnownNat t) => DirectSerialise (VerKeyKES (MockKES t)) where directSerialise put sk = do let bs = rawSerialiseVerKeyKES sk unpackByteStringCStringLen bs $ \(cstr, len) -> put cstr (fromIntegral len) -instance (MonadST m, MonadThrow m, KnownNat t) => DirectDeserialise m (VerKeyKES (MockKES t)) where +instance (KnownNat t) => DirectDeserialise (VerKeyKES (MockKES t)) where directDeserialise pull = do let len = fromIntegral $ sizeVerKeyKES (Proxy @(MockKES t)) bs <- allocaBytes len $ \cstr -> do diff --git a/cardano-crypto-class/src/Cardano/Crypto/KES/Simple.hs b/cardano-crypto-class/src/Cardano/Crypto/KES/Simple.hs index 8ed2d5640..afb15d202 100644 --- a/cardano-crypto-class/src/Cardano/Crypto/KES/Simple.hs +++ b/cardano-crypto-class/src/Cardano/Crypto/KES/Simple.hs @@ -250,21 +250,21 @@ instance (DSIGNMAlgorithm d => FromCBOR (SigKES (SimpleKES d t)) where fromCBOR = decodeSigKES -instance (Monad m, DirectSerialise m (VerKeyDSIGN d)) => DirectSerialise m (VerKeyKES (SimpleKES d t)) where +instance (DirectSerialise (VerKeyDSIGN d)) => DirectSerialise (VerKeyKES (SimpleKES d t)) where directSerialise push (VerKeySimpleKES vks) = mapM_ (directSerialise push) vks -instance (Monad m, DirectDeserialise m (VerKeyDSIGN d), KnownNat t) => DirectDeserialise m (VerKeyKES (SimpleKES d t)) where +instance (DirectDeserialise (VerKeyDSIGN d), KnownNat t) => DirectDeserialise (VerKeyKES (SimpleKES d t)) where directDeserialise pull = do let duration = fromIntegral (natVal (Proxy :: Proxy t)) vks <- Vec.replicateM duration (directDeserialise pull) return $! VerKeySimpleKES $! vks -instance (Monad m, DirectSerialise m (SignKeyDSIGNM d)) => DirectSerialise m (SignKeyKES (SimpleKES d t)) where +instance (DirectSerialise (SignKeyDSIGNM d)) => DirectSerialise (SignKeyKES (SimpleKES d t)) where directSerialise push (SignKeySimpleKES sks) = mapM_ (directSerialise push) sks -instance (Monad m, DirectDeserialise m (SignKeyDSIGNM d), KnownNat t) => DirectDeserialise m (SignKeyKES (SimpleKES d t)) where +instance (DirectDeserialise (SignKeyDSIGNM d), KnownNat t) => DirectDeserialise (SignKeyKES (SimpleKES d t)) where directDeserialise pull = do let duration = fromIntegral (natVal (Proxy :: Proxy t)) sks <- Vec.replicateM duration (directDeserialise pull) diff --git a/cardano-crypto-class/src/Cardano/Crypto/KES/Single.hs b/cardano-crypto-class/src/Cardano/Crypto/KES/Single.hs index e7f0364df..4b9da7d26 100644 --- a/cardano-crypto-class/src/Cardano/Crypto/KES/Single.hs +++ b/cardano-crypto-class/src/Cardano/Crypto/KES/Single.hs @@ -192,14 +192,14 @@ instance DSIGNMAlgorithm d => FromCBOR (SigKES (SingleKES d)) where -- Direct ser/deser -- -instance (DirectSerialise m (SignKeyDSIGNM d)) => DirectSerialise m (SignKeyKES (SingleKES d)) where +instance (DirectSerialise (SignKeyDSIGNM d)) => DirectSerialise (SignKeyKES (SingleKES d)) where directSerialise push (SignKeySingleKES sk) = directSerialise push sk -instance (Monad m, DirectDeserialise m (SignKeyDSIGNM d)) => DirectDeserialise m (SignKeyKES (SingleKES d)) where +instance (DirectDeserialise (SignKeyDSIGNM d)) => DirectDeserialise (SignKeyKES (SingleKES d)) where directDeserialise pull = SignKeySingleKES <$!> directDeserialise pull -instance (DirectSerialise m (VerKeyDSIGN d)) => DirectSerialise m (VerKeyKES (SingleKES d)) where +instance (DirectSerialise (VerKeyDSIGN d)) => DirectSerialise (VerKeyKES (SingleKES d)) where directSerialise push (VerKeySingleKES sk) = directSerialise push sk -instance (Monad m, DirectDeserialise m (VerKeyDSIGN d)) => DirectDeserialise m (VerKeyKES (SingleKES d)) where +instance (DirectDeserialise (VerKeyDSIGN d)) => DirectDeserialise (VerKeyKES (SingleKES d)) where directDeserialise pull = VerKeySingleKES <$!> directDeserialise pull diff --git a/cardano-crypto-class/src/Cardano/Crypto/KES/Sum.hs b/cardano-crypto-class/src/Cardano/Crypto/KES/Sum.hs index d9e53a948..0288bc46c 100644 --- a/cardano-crypto-class/src/Cardano/Crypto/KES/Sum.hs +++ b/cardano-crypto-class/src/Cardano/Crypto/KES/Sum.hs @@ -69,8 +69,6 @@ import Cardano.Crypto.Libsodium.Memory import Cardano.Crypto.DirectSerialise import Control.Monad.Trans.Maybe (MaybeT (..), runMaybeT) -import Control.Monad.Class.MonadST -import Control.Monad.Class.MonadThrow import Control.DeepSeq (NFData (..)) import GHC.TypeLits (KnownNat, type (+), type (*)) import Foreign.Ptr (castPtr) @@ -394,11 +392,10 @@ instance (KESAlgorithm (SumKES h d), SodiumHashAlgorithm h, SizeHash h ~ SeedSiz -- Direct ser/deser -- -instance ( DirectSerialise m (SignKeyKES d) - , DirectSerialise m (VerKeyKES d) - , MonadST m +instance ( DirectSerialise (SignKeyKES d) + , DirectSerialise (VerKeyKES d) , KESAlgorithm d - ) => DirectSerialise m (SignKeyKES (SumKES h d)) where + ) => DirectSerialise (SignKeyKES (SumKES h d)) where directSerialise push (SignKeySumKES sk r vk0 vk1) = do directSerialise push sk mlockedSeedUseAsCPtr r $ \ptr -> @@ -406,11 +403,10 @@ instance ( DirectSerialise m (SignKeyKES d) directSerialise push vk0 directSerialise push vk1 -instance ( DirectDeserialise m (SignKeyKES d) - , DirectDeserialise m (VerKeyKES d) - , MonadST m +instance ( DirectDeserialise (SignKeyKES d) + , DirectDeserialise (VerKeyKES d) , KESAlgorithm d - ) => DirectDeserialise m (SignKeyKES (SumKES h d)) where + ) => DirectDeserialise (SignKeyKES (SumKES h d)) where directDeserialise pull = do sk <- directDeserialise pull @@ -424,18 +420,17 @@ instance ( DirectDeserialise m (SignKeyKES d) return $! SignKeySumKES sk r vk0 vk1 -instance (MonadST m, MonadThrow m) - => DirectSerialise m (VerKeyKES (SumKES h d)) where +instance DirectSerialise (VerKeyKES (SumKES h d)) where directSerialise push (VerKeySumKES h) = unpackByteStringCStringLen (hashToBytes h) $ \(ptr, len) -> push (castPtr ptr) (fromIntegral len) -instance (MonadST m, MonadThrow m, MonadFail m, HashAlgorithm h) - => DirectDeserialise m (VerKeyKES (SumKES h d)) where +instance (HashAlgorithm h) + => DirectDeserialise (VerKeyKES (SumKES h d)) where directDeserialise pull = do let len :: Num a => a len = fromIntegral $ sizeHash (Proxy @h) allocaBytes len $ \ptr -> do pull ptr len bs <- packByteStringCStringLen (ptr, len) - maybe (fail "Invalid hash") return $! VerKeySumKES <$!> hashFromBytes bs + maybe (error "Invalid hash") return $! VerKeySumKES <$!> hashFromBytes bs diff --git a/cardano-crypto-tests/src/Test/Crypto/DSIGN.hs b/cardano-crypto-tests/src/Test/Crypto/DSIGN.hs index b7bb0e759..102bdedf4 100644 --- a/cardano-crypto-tests/src/Test/Crypto/DSIGN.hs +++ b/cardano-crypto-tests/src/Test/Crypto/DSIGN.hs @@ -367,10 +367,10 @@ testDSIGNMAlgorithm , FromCBOR (SigDSIGN v) , ContextDSIGN v ~ () , Signable v Message - , DirectSerialise IO (SignKeyDSIGNM v) - , DirectDeserialise IO (SignKeyDSIGNM v) - , DirectSerialise IO (VerKeyDSIGN v) - , DirectDeserialise IO (VerKeyDSIGN v) + , DirectSerialise (SignKeyDSIGNM v) + , DirectDeserialise (SignKeyDSIGNM v) + , DirectSerialise (VerKeyDSIGN v) + , DirectDeserialise (VerKeyDSIGN v) ) => Lock -> Proxy v diff --git a/cardano-crypto-tests/src/Test/Crypto/KES.hs b/cardano-crypto-tests/src/Test/Crypto/KES.hs index 0e521edfd..6b3b8e558 100644 --- a/cardano-crypto-tests/src/Test/Crypto/KES.hs +++ b/cardano-crypto-tests/src/Test/Crypto/KES.hs @@ -208,10 +208,10 @@ testKESAlgorithm , Signable v ~ SignableRepresentation , ContextKES v ~ () , UnsoundKESAlgorithm v - , DirectSerialise IO (SignKeyKES v) - , DirectSerialise IO (VerKeyKES v) - , DirectDeserialise IO (SignKeyKES v) - , DirectDeserialise IO (VerKeyKES v) + , DirectSerialise (SignKeyKES v) + , DirectSerialise (VerKeyKES v) + , DirectDeserialise (SignKeyKES v) + , DirectDeserialise (VerKeyKES v) ) => Lock -> String @@ -766,7 +766,7 @@ withNullSK = bracket prop_noErasedBlocksInKey :: forall v. UnsoundKESAlgorithm v - => DirectSerialise IO (SignKeyKES v) + => DirectSerialise (SignKeyKES v) => Proxy v -> Property prop_noErasedBlocksInKey kesAlgorithm = diff --git a/cardano-crypto-tests/src/Test/Crypto/Util.hs b/cardano-crypto-tests/src/Test/Crypto/Util.hs index 5b61969f9..a23740e64 100644 --- a/cardano-crypto-tests/src/Test/Crypto/Util.hs +++ b/cardano-crypto-tests/src/Test/Crypto/Util.hs @@ -404,7 +404,7 @@ eitherShowError (Right a) = return a -------------------------------------------------------------------------------- directSerialiseToBS :: forall m a. - DirectSerialise m a + DirectSerialise a => MonadST m => MonadThrow m => MonadMVar m @@ -423,7 +423,7 @@ directSerialiseToBS dstsize val = do packByteStringCStringLen (dst, fromIntegral dstsize) directDeserialiseFromBS :: forall m a. - DirectDeserialise m a + DirectDeserialise a => MonadST m => MonadThrow m => MonadMVar m From f1892f49eed33b955b02a732d2221e1817a2f89f Mon Sep 17 00:00:00 2001 From: Tobias Dammers Date: Tue, 31 Oct 2023 20:53:16 +0100 Subject: [PATCH 5/9] Generalized withForeignPtr and improved allocaBytes --- .../Crypto/Libsodium/Memory/Internal.hs | 45 ++++++++++++------- 1 file changed, 28 insertions(+), 17 deletions(-) diff --git a/cardano-crypto-class/src/Cardano/Crypto/Libsodium/Memory/Internal.hs b/cardano-crypto-class/src/Cardano/Crypto/Libsodium/Memory/Internal.hs index 553c782d8..986698fe6 100644 --- a/cardano-crypto-class/src/Cardano/Crypto/Libsodium/Memory/Internal.hs +++ b/cardano-crypto-class/src/Cardano/Crypto/Libsodium/Memory/Internal.hs @@ -28,6 +28,10 @@ module Cardano.Crypto.Libsodium.Memory.Internal ( mlockedAllocForeignPtrWith, mlockedAllocForeignPtrBytesWith, + -- * 'ForeignPtr' operations, generalized to 'MonadST' + mallocForeignPtrBytes, + withForeignPtr, + -- * Unmanaged memory, generalized to 'MonadST' zeroMem, copyMem, @@ -47,6 +51,7 @@ import Control.Monad (when, void) import Control.Monad.Class.MonadST import Control.Monad.Class.MonadThrow (MonadThrow (bracket)) import Control.Monad.ST (RealWorld, ST) +import Control.Monad.Primitive (touch) import Control.Monad.ST.Unsafe (unsafeIOToST) import Data.ByteString (ByteString) import qualified Data.ByteString as BS @@ -58,10 +63,11 @@ import Debug.Trace (traceShowM) import Foreign.C.Error (errnoToIOError, getErrno) import Foreign.C.String (CStringLen) import Foreign.C.Types (CSize (..)) -import Foreign.Concurrent (newForeignPtr) -import Foreign.ForeignPtr (ForeignPtr, finalizeForeignPtr, touchForeignPtr) +import qualified Foreign.Concurrent as Foreign +import qualified Foreign.ForeignPtr as Foreign hiding (newForeignPtr) +import qualified Foreign.ForeignPtr.Unsafe as Foreign +import Foreign.ForeignPtr (ForeignPtr) import Foreign.ForeignPtr.Unsafe (unsafeForeignPtrToPtr) -import qualified Foreign.Marshal.Alloc as Foreign import Foreign.Marshal.Utils (fillBytes) import Foreign.Ptr (Ptr, nullPtr, castPtr) import Foreign.Storable (Storable (peek), sizeOf, alignment, pokeByteOff) @@ -84,11 +90,11 @@ instance NFData (MLockedForeignPtr a) where withMLockedForeignPtr :: MonadST m => MLockedForeignPtr a -> (Ptr a -> m b) -> m b withMLockedForeignPtr (SFP fptr) f = do r <- f (unsafeForeignPtrToPtr fptr) - r <$ unsafeIOToMonadST (touchForeignPtr fptr) + r <$ unsafeIOToMonadST (Foreign.touchForeignPtr fptr) finalizeMLockedForeignPtr :: MonadST m => MLockedForeignPtr a -> m () finalizeMLockedForeignPtr (SFP fptr) = - unsafeIOToMonadST $ finalizeForeignPtr fptr + unsafeIOToMonadST $ Foreign.finalizeForeignPtr fptr {-# WARNING traceMLockedForeignPtr "Do not use traceMLockedForeignPtr in production" #-} @@ -106,7 +112,7 @@ makeMLockedPool = do (max 1 . fromIntegral $ 4096 `div` natVal (Proxy @n) `div` 64) (\size -> unsafeIOToST $ mask_ $ do ptr <- sodiumMalloc (fromIntegral size) - newForeignPtr ptr (sodiumFree ptr (fromIntegral size)) + Foreign.newForeignPtr ptr (sodiumFree ptr (fromIntegral size)) ) (\ptr -> do eraseMem (Proxy @n) ptr @@ -162,7 +168,7 @@ mlockedMallocIO size = SFP <$> do | otherwise -> do mask_ $ do ptr <- sodiumMalloc size - newForeignPtr ptr $ do + Foreign.newForeignPtr ptr $ do sodiumFree ptr size sodiumMalloc :: CSize -> IO (Ptr a) @@ -191,18 +197,23 @@ zeroMem ptr size = unsafeIOToMonadST . void $ c_memset (castPtr ptr) 0 size copyMem :: MonadST m => Ptr a -> Ptr a -> CSize -> m () copyMem dst src size = unsafeIOToMonadST . void $ c_memcpy (castPtr dst) (castPtr src) size -allocaBytes :: (MonadThrow m, MonadST m) => Int -> (Ptr a -> m b) -> m b -allocaBytes size = - bracket - (mallocBytes size) - free +mallocForeignPtrBytes :: (MonadST m) => Int -> m (ForeignPtr a) +mallocForeignPtrBytes size = + unsafeIOToMonadST (Foreign.mallocForeignPtrBytes size) -mallocBytes :: MonadST m => Int -> m (Ptr a) -mallocBytes size = - unsafeIOToMonadST $ Foreign.mallocBytes size +-- | 'Foreign.withForeignPtr', generalized to 'MonadST'. +-- Caveat: if the monadic action passed to 'withForeignPtr' does not terminate +-- (e.g., 'forever'), the 'ForeignPtr' finalizer may run prematurely. +withForeignPtr :: (MonadST m) => ForeignPtr a -> (Ptr a -> m b) -> m b +withForeignPtr fptr f = do + result <- f $ Foreign.unsafeForeignPtrToPtr fptr + stToIO $ touch fptr + return result -free :: MonadST m => Ptr a -> m () -free = unsafeIOToMonadST . Foreign.free +allocaBytes :: (MonadThrow m, MonadST m) => Int -> (Ptr a -> m b) -> m b +allocaBytes size action = do + fptr <- mallocForeignPtrBytes size + withForeignPtr fptr action -- | Unpacks a ByteString into a temporary buffer and runs the provided 'ST' -- function on it. From d1e8df95216c004820ad44cd5537c5767f141337 Mon Sep 17 00:00:00 2001 From: Tobias Dammers Date: Tue, 31 Oct 2023 21:07:03 +0100 Subject: [PATCH 6/9] Remove zero-termination on 'useByteStringAsCStringLen' --- .../src/Cardano/Crypto/Libsodium/Memory/Internal.hs | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/cardano-crypto-class/src/Cardano/Crypto/Libsodium/Memory/Internal.hs b/cardano-crypto-class/src/Cardano/Crypto/Libsodium/Memory/Internal.hs index 986698fe6..9860caccc 100644 --- a/cardano-crypto-class/src/Cardano/Crypto/Libsodium/Memory/Internal.hs +++ b/cardano-crypto-class/src/Cardano/Crypto/Libsodium/Memory/Internal.hs @@ -58,7 +58,6 @@ import qualified Data.ByteString as BS import qualified Data.ByteString.Unsafe as BS import Data.Coerce (coerce) import Data.Typeable -import Data.Word (Word8) import Debug.Trace (traceShowM) import Foreign.C.Error (errnoToIOError, getErrno) import Foreign.C.String (CStringLen) @@ -70,7 +69,7 @@ import Foreign.ForeignPtr (ForeignPtr) import Foreign.ForeignPtr.Unsafe (unsafeForeignPtrToPtr) import Foreign.Marshal.Utils (fillBytes) import Foreign.Ptr (Ptr, nullPtr, castPtr) -import Foreign.Storable (Storable (peek), sizeOf, alignment, pokeByteOff) +import Foreign.Storable (Storable (peek), sizeOf, alignment) import GHC.IO.Exception (ioException) import GHC.TypeLits (KnownNat, natVal) import NoThunks.Class (NoThunks, OnlyCheckWhnfNamed (..)) @@ -220,10 +219,9 @@ allocaBytes size action = do unpackByteStringCStringLen :: (MonadThrow m, MonadST m) => ByteString -> (CStringLen -> m a) -> m a unpackByteStringCStringLen bs f = do let len = BS.length bs - allocaBytes (len + 1) $ \buf -> do + allocaBytes len $ \buf -> do unsafeIOToMonadST $ BS.unsafeUseAsCString bs $ \ptr -> do copyMem buf ptr (fromIntegral len) - pokeByteOff buf len (0 :: Word8) f (buf, len) packByteStringCStringLen :: MonadST m => CStringLen -> m ByteString From e2d86bb549c48d0092f8299e8a84accf8cc0bd2e Mon Sep 17 00:00:00 2001 From: Tobias Dammers Date: Tue, 31 Oct 2023 21:45:08 +0100 Subject: [PATCH 7/9] Use ForeignPtr instead of allocaBytes for direct serialization --- .../src/Cardano/Crypto/KES/CompactSum.hs | 10 ++++++---- .../src/Cardano/Crypto/KES/Mock.hs | 20 +++++++++++-------- .../src/Cardano/Crypto/KES/Sum.hs | 10 ++++++---- .../src/Cardano/Crypto/Libsodium/Memory.hs | 4 ++++ 4 files changed, 28 insertions(+), 16 deletions(-) diff --git a/cardano-crypto-class/src/Cardano/Crypto/KES/CompactSum.hs b/cardano-crypto-class/src/Cardano/Crypto/KES/CompactSum.hs index 47bc62ae6..a7a769541 100644 --- a/cardano-crypto-class/src/Cardano/Crypto/KES/CompactSum.hs +++ b/cardano-crypto-class/src/Cardano/Crypto/KES/CompactSum.hs @@ -87,6 +87,7 @@ module Cardano.Crypto.KES.CompactSum ( import Data.Proxy (Proxy(..)) import GHC.Generics (Generic) import qualified Data.ByteString as BS +import qualified Data.ByteString.Internal as BS import Control.Monad (guard, (<$!>)) import NoThunks.Class (NoThunks, OnlyCheckWhnfNamed (..)) @@ -510,7 +511,8 @@ instance (HashAlgorithm h) directDeserialise pull = do let len :: Num a => a len = fromIntegral $ sizeHash (Proxy @h) - allocaBytes len $ \ptr -> do - pull ptr len - bs <- packByteStringCStringLen (ptr, len) - maybe (error "Invalid hash") return $! VerKeyCompactSumKES <$!> hashFromBytes bs + fptr <- mallocForeignPtrBytes len + withForeignPtr fptr $ \ptr -> do + pull (castPtr ptr) len + let bs = BS.fromForeignPtr0 fptr len + maybe (error "Invalid hash") return $! VerKeyCompactSumKES <$!> hashFromBytes bs diff --git a/cardano-crypto-class/src/Cardano/Crypto/KES/Mock.hs b/cardano-crypto-class/src/Cardano/Crypto/KES/Mock.hs index 5748cb7bd..a3239d969 100644 --- a/cardano-crypto-class/src/Cardano/Crypto/KES/Mock.hs +++ b/cardano-crypto-class/src/Cardano/Crypto/KES/Mock.hs @@ -24,6 +24,8 @@ import Data.Proxy (Proxy(..)) import GHC.Generics (Generic) import GHC.TypeNats (Nat, KnownNat, natVal) import NoThunks.Class (NoThunks) +import qualified Data.ByteString.Internal as BS +import Foreign.Ptr (castPtr) import Control.Exception (assert) @@ -39,8 +41,8 @@ import Cardano.Crypto.Libsodium ) import Cardano.Crypto.Libsodium.Memory ( unpackByteStringCStringLen - , packByteStringCStringLen - , allocaBytes + , mallocForeignPtrBytes + , withForeignPtr ) import Cardano.Crypto.DirectSerialise @@ -210,9 +212,10 @@ instance (KnownNat t) => DirectSerialise (SignKeyKES (MockKES t)) where instance (KnownNat t) => DirectDeserialise (SignKeyKES (MockKES t)) where directDeserialise pull = do let len = fromIntegral $ sizeSignKeyKES (Proxy @(MockKES t)) - bs <- allocaBytes len $ \cstr -> do - pull cstr (fromIntegral len) - packByteStringCStringLen (cstr, len) + fptr <- mallocForeignPtrBytes len + withForeignPtr fptr $ \ptr -> + pull (castPtr ptr) (fromIntegral len) + let bs = BS.fromForeignPtr0 fptr len maybe (error "directDeserialise @(SignKeyKES (MockKES t))") return $ rawDeserialiseSignKeyMockKES bs @@ -224,8 +227,9 @@ instance (KnownNat t) => DirectSerialise (VerKeyKES (MockKES t)) where instance (KnownNat t) => DirectDeserialise (VerKeyKES (MockKES t)) where directDeserialise pull = do let len = fromIntegral $ sizeVerKeyKES (Proxy @(MockKES t)) - bs <- allocaBytes len $ \cstr -> do - pull cstr (fromIntegral len) - packByteStringCStringLen (cstr, len) + fptr <- mallocForeignPtrBytes len + withForeignPtr fptr $ \ptr -> + pull (castPtr ptr) (fromIntegral len) + let bs = BS.fromForeignPtr0 fptr len maybe (error "directDeserialise @(VerKeyKES (MockKES t))") return $ rawDeserialiseVerKeyKES bs diff --git a/cardano-crypto-class/src/Cardano/Crypto/KES/Sum.hs b/cardano-crypto-class/src/Cardano/Crypto/KES/Sum.hs index 0288bc46c..7f9add51d 100644 --- a/cardano-crypto-class/src/Cardano/Crypto/KES/Sum.hs +++ b/cardano-crypto-class/src/Cardano/Crypto/KES/Sum.hs @@ -54,6 +54,7 @@ module Cardano.Crypto.KES.Sum ( import Data.Proxy (Proxy(..)) import GHC.Generics (Generic) import qualified Data.ByteString as BS +import qualified Data.ByteString.Internal as BS import Control.Monad (guard, (<$!>)) import NoThunks.Class (NoThunks, OnlyCheckWhnfNamed (..)) @@ -430,7 +431,8 @@ instance (HashAlgorithm h) directDeserialise pull = do let len :: Num a => a len = fromIntegral $ sizeHash (Proxy @h) - allocaBytes len $ \ptr -> do - pull ptr len - bs <- packByteStringCStringLen (ptr, len) - maybe (error "Invalid hash") return $! VerKeySumKES <$!> hashFromBytes bs + fptr <- mallocForeignPtrBytes len + withForeignPtr fptr $ \ptr -> do + pull (castPtr ptr) len + let bs = BS.fromForeignPtr0 fptr len + maybe (error "Invalid hash") return $! VerKeySumKES <$!> hashFromBytes bs diff --git a/cardano-crypto-class/src/Cardano/Crypto/Libsodium/Memory.hs b/cardano-crypto-class/src/Cardano/Crypto/Libsodium/Memory.hs index 4d681b11a..696c4a230 100644 --- a/cardano-crypto-class/src/Cardano/Crypto/Libsodium/Memory.hs +++ b/cardano-crypto-class/src/Cardano/Crypto/Libsodium/Memory.hs @@ -25,6 +25,10 @@ module Cardano.Crypto.Libsodium.Memory ( copyMem, allocaBytes, + -- * 'ForeignPtr' operations, generalized to 'MonadST' + mallocForeignPtrBytes, + withForeignPtr, + -- * ByteString memory access, generalized to 'MonadST' unpackByteStringCStringLen, packByteStringCStringLen, From 3f0d62ff870e16ec8f94dc39073b8dcc8fa0307e Mon Sep 17 00:00:00 2001 From: Tobias Dammers Date: Wed, 1 Nov 2023 10:38:07 +0100 Subject: [PATCH 8/9] Fix build failure on GHC 8.10 --- cardano-crypto-class/src/Cardano/Crypto/KES/CompactSum.hs | 2 +- cardano-crypto-class/src/Cardano/Crypto/KES/Mock.hs | 4 ++-- cardano-crypto-class/src/Cardano/Crypto/KES/Sum.hs | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/cardano-crypto-class/src/Cardano/Crypto/KES/CompactSum.hs b/cardano-crypto-class/src/Cardano/Crypto/KES/CompactSum.hs index a7a769541..6bb51c478 100644 --- a/cardano-crypto-class/src/Cardano/Crypto/KES/CompactSum.hs +++ b/cardano-crypto-class/src/Cardano/Crypto/KES/CompactSum.hs @@ -514,5 +514,5 @@ instance (HashAlgorithm h) fptr <- mallocForeignPtrBytes len withForeignPtr fptr $ \ptr -> do pull (castPtr ptr) len - let bs = BS.fromForeignPtr0 fptr len + let bs = BS.fromForeignPtr fptr 0 len maybe (error "Invalid hash") return $! VerKeyCompactSumKES <$!> hashFromBytes bs diff --git a/cardano-crypto-class/src/Cardano/Crypto/KES/Mock.hs b/cardano-crypto-class/src/Cardano/Crypto/KES/Mock.hs index a3239d969..3c6941de5 100644 --- a/cardano-crypto-class/src/Cardano/Crypto/KES/Mock.hs +++ b/cardano-crypto-class/src/Cardano/Crypto/KES/Mock.hs @@ -215,7 +215,7 @@ instance (KnownNat t) => DirectDeserialise (SignKeyKES (MockKES t)) where fptr <- mallocForeignPtrBytes len withForeignPtr fptr $ \ptr -> pull (castPtr ptr) (fromIntegral len) - let bs = BS.fromForeignPtr0 fptr len + let bs = BS.fromForeignPtr fptr 0 len maybe (error "directDeserialise @(SignKeyKES (MockKES t))") return $ rawDeserialiseSignKeyMockKES bs @@ -230,6 +230,6 @@ instance (KnownNat t) => DirectDeserialise (VerKeyKES (MockKES t)) where fptr <- mallocForeignPtrBytes len withForeignPtr fptr $ \ptr -> pull (castPtr ptr) (fromIntegral len) - let bs = BS.fromForeignPtr0 fptr len + let bs = BS.fromForeignPtr fptr 0 len maybe (error "directDeserialise @(VerKeyKES (MockKES t))") return $ rawDeserialiseVerKeyKES bs diff --git a/cardano-crypto-class/src/Cardano/Crypto/KES/Sum.hs b/cardano-crypto-class/src/Cardano/Crypto/KES/Sum.hs index 7f9add51d..b2206419e 100644 --- a/cardano-crypto-class/src/Cardano/Crypto/KES/Sum.hs +++ b/cardano-crypto-class/src/Cardano/Crypto/KES/Sum.hs @@ -434,5 +434,5 @@ instance (HashAlgorithm h) fptr <- mallocForeignPtrBytes len withForeignPtr fptr $ \ptr -> do pull (castPtr ptr) len - let bs = BS.fromForeignPtr0 fptr len + let bs = BS.fromForeignPtr fptr 0 len maybe (error "Invalid hash") return $! VerKeySumKES <$!> hashFromBytes bs From c1680aeacb1803d8cc19cd0a64814a3af4c1f2ba Mon Sep 17 00:00:00 2001 From: Tobias Dammers Date: Thu, 2 Nov 2023 16:28:34 +0100 Subject: [PATCH 9/9] Tag generalized ForeignPtr with the the ST context This prevents a ForeignPtr created in one ST context to leak into another. --- .../src/Cardano/Crypto/KES/CompactSum.hs | 2 +- .../src/Cardano/Crypto/KES/Mock.hs | 5 +++-- .../src/Cardano/Crypto/KES/Sum.hs | 2 +- .../src/Cardano/Crypto/Libsodium/Memory.hs | 1 + .../Crypto/Libsodium/Memory/Internal.hs | 18 ++++++++++++------ 5 files changed, 18 insertions(+), 10 deletions(-) diff --git a/cardano-crypto-class/src/Cardano/Crypto/KES/CompactSum.hs b/cardano-crypto-class/src/Cardano/Crypto/KES/CompactSum.hs index 6bb51c478..96726f7aa 100644 --- a/cardano-crypto-class/src/Cardano/Crypto/KES/CompactSum.hs +++ b/cardano-crypto-class/src/Cardano/Crypto/KES/CompactSum.hs @@ -514,5 +514,5 @@ instance (HashAlgorithm h) fptr <- mallocForeignPtrBytes len withForeignPtr fptr $ \ptr -> do pull (castPtr ptr) len - let bs = BS.fromForeignPtr fptr 0 len + let bs = BS.fromForeignPtr (unsafeRawForeignPtr fptr) 0 len maybe (error "Invalid hash") return $! VerKeyCompactSumKES <$!> hashFromBytes bs diff --git a/cardano-crypto-class/src/Cardano/Crypto/KES/Mock.hs b/cardano-crypto-class/src/Cardano/Crypto/KES/Mock.hs index 3c6941de5..43b503e21 100644 --- a/cardano-crypto-class/src/Cardano/Crypto/KES/Mock.hs +++ b/cardano-crypto-class/src/Cardano/Crypto/KES/Mock.hs @@ -41,6 +41,7 @@ import Cardano.Crypto.Libsodium ) import Cardano.Crypto.Libsodium.Memory ( unpackByteStringCStringLen + , ForeignPtr (..) , mallocForeignPtrBytes , withForeignPtr ) @@ -215,7 +216,7 @@ instance (KnownNat t) => DirectDeserialise (SignKeyKES (MockKES t)) where fptr <- mallocForeignPtrBytes len withForeignPtr fptr $ \ptr -> pull (castPtr ptr) (fromIntegral len) - let bs = BS.fromForeignPtr fptr 0 len + let bs = BS.fromForeignPtr (unsafeRawForeignPtr fptr) 0 len maybe (error "directDeserialise @(SignKeyKES (MockKES t))") return $ rawDeserialiseSignKeyMockKES bs @@ -230,6 +231,6 @@ instance (KnownNat t) => DirectDeserialise (VerKeyKES (MockKES t)) where fptr <- mallocForeignPtrBytes len withForeignPtr fptr $ \ptr -> pull (castPtr ptr) (fromIntegral len) - let bs = BS.fromForeignPtr fptr 0 len + let bs = BS.fromForeignPtr (unsafeRawForeignPtr fptr) 0 len maybe (error "directDeserialise @(VerKeyKES (MockKES t))") return $ rawDeserialiseVerKeyKES bs diff --git a/cardano-crypto-class/src/Cardano/Crypto/KES/Sum.hs b/cardano-crypto-class/src/Cardano/Crypto/KES/Sum.hs index b2206419e..d78be36a8 100644 --- a/cardano-crypto-class/src/Cardano/Crypto/KES/Sum.hs +++ b/cardano-crypto-class/src/Cardano/Crypto/KES/Sum.hs @@ -434,5 +434,5 @@ instance (HashAlgorithm h) fptr <- mallocForeignPtrBytes len withForeignPtr fptr $ \ptr -> do pull (castPtr ptr) len - let bs = BS.fromForeignPtr fptr 0 len + let bs = BS.fromForeignPtr (unsafeRawForeignPtr fptr) 0 len maybe (error "Invalid hash") return $! VerKeySumKES <$!> hashFromBytes bs diff --git a/cardano-crypto-class/src/Cardano/Crypto/Libsodium/Memory.hs b/cardano-crypto-class/src/Cardano/Crypto/Libsodium/Memory.hs index 696c4a230..cd927cb42 100644 --- a/cardano-crypto-class/src/Cardano/Crypto/Libsodium/Memory.hs +++ b/cardano-crypto-class/src/Cardano/Crypto/Libsodium/Memory.hs @@ -26,6 +26,7 @@ module Cardano.Crypto.Libsodium.Memory ( allocaBytes, -- * 'ForeignPtr' operations, generalized to 'MonadST' + ForeignPtr (..), mallocForeignPtrBytes, withForeignPtr, diff --git a/cardano-crypto-class/src/Cardano/Crypto/Libsodium/Memory/Internal.hs b/cardano-crypto-class/src/Cardano/Crypto/Libsodium/Memory/Internal.hs index 9860caccc..ad61bd282 100644 --- a/cardano-crypto-class/src/Cardano/Crypto/Libsodium/Memory/Internal.hs +++ b/cardano-crypto-class/src/Cardano/Crypto/Libsodium/Memory/Internal.hs @@ -29,6 +29,7 @@ module Cardano.Crypto.Libsodium.Memory.Internal ( mlockedAllocForeignPtrBytesWith, -- * 'ForeignPtr' operations, generalized to 'MonadST' + ForeignPtr (..), mallocForeignPtrBytes, withForeignPtr, @@ -65,7 +66,6 @@ import Foreign.C.Types (CSize (..)) import qualified Foreign.Concurrent as Foreign import qualified Foreign.ForeignPtr as Foreign hiding (newForeignPtr) import qualified Foreign.ForeignPtr.Unsafe as Foreign -import Foreign.ForeignPtr (ForeignPtr) import Foreign.ForeignPtr.Unsafe (unsafeForeignPtrToPtr) import Foreign.Marshal.Utils (fillBytes) import Foreign.Ptr (Ptr, nullPtr, castPtr) @@ -74,13 +74,14 @@ import GHC.IO.Exception (ioException) import GHC.TypeLits (KnownNat, natVal) import NoThunks.Class (NoThunks, OnlyCheckWhnfNamed (..)) import System.IO.Unsafe (unsafePerformIO) +import Data.Kind import Cardano.Crypto.Libsodium.C import Cardano.Foreign (c_memset, c_memcpy, SizedPtr (..)) import Cardano.Memory.Pool (initPool, grabNextBlock, Pool) -- | Foreign pointer to securely allocated memory. -newtype MLockedForeignPtr a = SFP { _unwrapMLockedForeignPtr :: ForeignPtr a } +newtype MLockedForeignPtr a = SFP { _unwrapMLockedForeignPtr :: Foreign.ForeignPtr a } deriving NoThunks via OnlyCheckWhnfNamed "MLockedForeignPtr" (MLockedForeignPtr a) instance NFData (MLockedForeignPtr a) where @@ -196,15 +197,20 @@ zeroMem ptr size = unsafeIOToMonadST . void $ c_memset (castPtr ptr) 0 size copyMem :: MonadST m => Ptr a -> Ptr a -> CSize -> m () copyMem dst src size = unsafeIOToMonadST . void $ c_memcpy (castPtr dst) (castPtr src) size -mallocForeignPtrBytes :: (MonadST m) => Int -> m (ForeignPtr a) +-- | A 'ForeignPtr' type, generalized to 'MonadST'. The type is tagged with +-- the correct Monad @m@ in order to ensure that foreign pointers created in +-- one ST context can only be used within the same ST context. +newtype ForeignPtr (m :: Type -> Type) a = ForeignPtr { unsafeRawForeignPtr :: Foreign.ForeignPtr a } + +mallocForeignPtrBytes :: (MonadST m) => Int -> m (ForeignPtr m a) mallocForeignPtrBytes size = - unsafeIOToMonadST (Foreign.mallocForeignPtrBytes size) + ForeignPtr <$> unsafeIOToMonadST (Foreign.mallocForeignPtrBytes size) -- | 'Foreign.withForeignPtr', generalized to 'MonadST'. -- Caveat: if the monadic action passed to 'withForeignPtr' does not terminate -- (e.g., 'forever'), the 'ForeignPtr' finalizer may run prematurely. -withForeignPtr :: (MonadST m) => ForeignPtr a -> (Ptr a -> m b) -> m b -withForeignPtr fptr f = do +withForeignPtr :: (MonadST m) => ForeignPtr m a -> (Ptr a -> m b) -> m b +withForeignPtr (ForeignPtr fptr) f = do result <- f $ Foreign.unsafeForeignPtrToPtr fptr stToIO $ touch fptr return result