Skip to content

Commit

Permalink
Generalize DirectSerialise
Browse files Browse the repository at this point in the history
  • Loading branch information
tdammers committed Jan 31, 2023
1 parent a10fd54 commit 6bea65f
Show file tree
Hide file tree
Showing 15 changed files with 337 additions and 140 deletions.
38 changes: 27 additions & 11 deletions cardano-crypto-class/src/Cardano/Crypto/DSIGN/Ed25519ML.hs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ import Control.Monad.ST.Unsafe (unsafeIOToST)
import Cardano.Binary (FromCBOR (..), ToCBOR (..))

import Cardano.Foreign
import Cardano.Crypto.PinnedSizedBytes
import Cardano.Crypto.Libsodium.C
import Cardano.Crypto.Libsodium (MLockedSizedBytes)
import Cardano.Crypto.MonadSodium
Expand All @@ -49,6 +48,12 @@ import Cardano.Crypto.MonadSodium
, mlsbFinalize
, mlsbCopy
, MEq (..)
, PinnedSizedBytes
, psbUseAsSizedPtr
, psbToByteString
, psbFromByteStringCheck
, psbCreateSizedResult
, psbCreate
)

import Cardano.Crypto.DSIGNM.Class
Expand Down Expand Up @@ -190,8 +195,9 @@ instance (MonadST m, MonadSodium m, MonadThrow m) => DSIGNMAlgorithm m Ed25519DS
deriveVerKeyDSIGNM (SignKeyEd25519DSIGNM sk) =
VerKeyEd25519DSIGNM <$!> do
mlsbUseAsSizedPtr sk $ \skPtr -> do
(psb, maybeErrno) <- withLiftST $ \fromST -> fromST $ do
psbCreateSizedResult $ \pkPtr ->
(psb, maybeErrno) <-
psbCreateSizedResult $ \pkPtr ->
withLiftST $ \fromST -> fromST $ do
cOrError $ unsafeIOToST $
c_crypto_sign_ed25519_sk_to_pk pkPtr skPtr
throwOnErrno "deriveVerKeyDSIGNM @Ed25519DSIGNM" "c_crypto_sign_ed25519_sk_to_pk" maybeErrno
Expand All @@ -202,8 +208,9 @@ instance (MonadST m, MonadSodium m, MonadThrow m) => DSIGNMAlgorithm m Ed25519DS
let bs = getSignableRepresentation a
in SigEd25519DSIGNM <$!> do
mlsbUseAsSizedPtr sk $ \skPtr -> do
(psb, maybeErrno) <- withLiftST $ \fromST -> fromST $ do
psbCreateSizedResult $ \sigPtr -> do
(psb, maybeErrno) <-
psbCreateSizedResult $ \sigPtr -> do
withLiftST $ \fromST -> fromST $ do
cOrError $ unsafeIOToST $ do
BS.useAsCStringLen bs $ \(ptr, len) ->
c_crypto_sign_ed25519_detached sigPtr nullPtr (castPtr ptr) (fromIntegral len) skPtr
Expand Down Expand Up @@ -273,7 +280,10 @@ instance (MonadST m, MonadSodium m, MonadThrow m) => UnsoundDSIGNMAlgorithm m Ed
mlockedSeedFinalize seed
return sk

instance DirectSerialise (SignKeyDSIGNM Ed25519DSIGNM) where
instance ( MonadThrow m
, MonadST m
, MonadSodium m
) => DirectSerialise m (SignKeyDSIGNM Ed25519DSIGNM) where
-- /Note:/ We only serialize the 32-byte seed, not the full 64-byte key. The
-- latter contains both the seed and the 32-byte verification key, which is
-- convenient, but redundant, since we can always reconstruct it from the
Expand All @@ -288,9 +298,12 @@ instance DirectSerialise (SignKeyDSIGNM Ed25519DSIGNM) where
(castPtr ptr)
(fromIntegral $ seedSizeDSIGNM (Proxy @Ed25519DSIGNM)))

instance DirectDeserialise (SignKeyDSIGNM Ed25519DSIGNM) where
instance ( MonadThrow m
, MonadST m
, MonadSodium m
) => DirectDeserialise m (SignKeyDSIGNM Ed25519DSIGNM) where
-- /Note:/ We only serialize the 32-byte seed, not the full 64-byte key. See
-- the DirectSerialise instance above.
-- the DirectSerialise m instance above.
directDeserialise pull = do
bracket
mlockedSeedNew
Expand All @@ -303,14 +316,17 @@ instance DirectDeserialise (SignKeyDSIGNM Ed25519DSIGNM) where
genKeyDSIGNM seed
)

instance DirectSerialise (VerKeyDSIGNM Ed25519DSIGNM) where
instance ( MonadSodium m
) => DirectSerialise m (VerKeyDSIGNM Ed25519DSIGNM) where
directSerialise push (VerKeyEd25519DSIGNM psb) = do
psbUseAsCPtr psb $ \ptr ->
psbUseAsCPtrLen psb $ \ptr _ ->
push
(castPtr ptr)
(fromIntegral $ sizeVerKeyDSIGNM (Proxy @Ed25519DSIGNM))

instance DirectDeserialise (VerKeyDSIGNM Ed25519DSIGNM) where
instance ( MonadThrow m
, MonadSodium m
) => DirectDeserialise m (VerKeyDSIGNM Ed25519DSIGNM) where
directDeserialise pull = do
psb <- psbCreate $ \ptr ->
pull
Expand Down
10 changes: 6 additions & 4 deletions cardano-crypto-class/src/Cardano/Crypto/DirectSerialise.hs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
{-# LANGUAGE MultiParamTypeClasses #-}

-- | Direct (de-)serialisation to / from raw memory.
--
-- The purpose of the typeclasses in this module is to abstract over data
Expand Down Expand Up @@ -25,8 +27,8 @@ import Foreign.C.Types
-- non-contiguous blocks of memory.
--
-- The order in which memory blocks are visited matters.
class DirectDeserialise a where
directDeserialise :: (Ptr CChar -> CSize -> IO ()) -> IO a
class DirectDeserialise m a where
directDeserialise :: (Ptr CChar -> CSize -> m ()) -> m a

-- | Direct serialization to raw memory.
--
Expand All @@ -35,5 +37,5 @@ class DirectDeserialise a where
-- of memory, @f@ may be called multiple times, once for each block.
--
-- The order in which memory blocks are visited matters.
class DirectSerialise a where
directSerialise :: (Ptr CChar -> CSize -> IO ()) -> a -> IO ()
class DirectSerialise m a where
directSerialise :: (Ptr CChar -> CSize -> m ()) -> a -> m ()
12 changes: 6 additions & 6 deletions cardano-crypto-class/src/Cardano/Crypto/KES/CompactSingle.hs
Original file line number Diff line number Diff line change
Expand Up @@ -239,14 +239,14 @@ slice offset size = BS.take (fromIntegral size)
-- Direct ser/deser
--

instance (DirectSerialise (SignKeyDSIGNM d)) => DirectSerialise (SignKeyKES (CompactSingleKES d)) where
instance (DirectSerialise m (SignKeyDSIGNM d)) => DirectSerialise m (SignKeyKES (CompactSingleKES d)) where
directSerialise push (SignKeyCompactSingleKES sk) = directSerialise push sk

instance (DirectDeserialise (SignKeyDSIGNM d)) => DirectDeserialise (SignKeyKES (CompactSingleKES d)) where
directDeserialise pull = SignKeyCompactSingleKES <$> directDeserialise pull
instance (Monad m, DirectDeserialise m (SignKeyDSIGNM d)) => DirectDeserialise m (SignKeyKES (CompactSingleKES d)) where
directDeserialise pull = SignKeyCompactSingleKES <$!> directDeserialise pull

instance (DirectSerialise (VerKeyDSIGNM d)) => DirectSerialise (VerKeyKES (CompactSingleKES d)) where
instance (DirectSerialise m (VerKeyDSIGNM d)) => DirectSerialise m (VerKeyKES (CompactSingleKES d)) where
directSerialise push (VerKeyCompactSingleKES sk) = directSerialise push sk

instance (DirectDeserialise (VerKeyDSIGNM d)) => DirectDeserialise (VerKeyKES (CompactSingleKES d)) where
directDeserialise pull = VerKeyCompactSingleKES <$> directDeserialise pull
instance (Monad m, DirectDeserialise m (VerKeyDSIGNM d)) => DirectDeserialise m (VerKeyKES (CompactSingleKES d)) where
directDeserialise pull = VerKeyCompactSingleKES <$!> directDeserialise pull
47 changes: 24 additions & 23 deletions cardano-crypto-class/src/Cardano/Crypto/KES/CompactSum.hs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ import qualified Data.ByteString as BS
import Control.Monad (guard, (<$!>))
import NoThunks.Class (NoThunks, OnlyCheckWhnfNamed (..))
import Foreign.Ptr (castPtr)
import Foreign.Marshal.Alloc (allocaBytes)

import Cardano.Binary (FromCBOR (..), ToCBOR (..))

Expand All @@ -100,7 +99,7 @@ import Cardano.Crypto.KES.Class
import Cardano.Crypto.KES.CompactSingle (CompactSingleKES)
import Cardano.Crypto.Util
import Cardano.Crypto.MLockedSeed
import qualified Cardano.Crypto.MonadSodium as NaCl
import Cardano.Crypto.MonadSodium
import Control.Monad.Class.MonadST (MonadST)
import Control.Monad.Class.MonadThrow (MonadThrow)
import Control.Monad.Trans.Maybe (MaybeT (..), runMaybeT)
Expand Down Expand Up @@ -153,7 +152,7 @@ instance (NFData (SignKeyKES d), NFData (VerKeyKES d)) =>
rnf (sk, r, vk1, vk2)

instance ( OptimizedKESAlgorithm d
, NaCl.SodiumHashAlgorithm h -- needed for secure forgetting
, SodiumHashAlgorithm h -- needed for secure forgetting
, SizeHash h ~ SeedSizeKES d -- can be relaxed
, NoThunks (VerKeyKES (CompactSumKES h d))
, KnownNat (SizeVerKeyKES (CompactSumKES h d))
Expand Down Expand Up @@ -253,9 +252,9 @@ instance ( OptimizedKESAlgorithm d

instance ( OptimizedKESAlgorithm d
, KESSignAlgorithm m d
, NaCl.SodiumHashAlgorithm h -- needed for secure forgetting
, SodiumHashAlgorithm h -- needed for secure forgetting
, SizeHash h ~ SeedSizeKES d -- can be relaxed
, NaCl.MonadSodium m
, MonadSodium m
, MonadST m -- only needed for unsafe raw ser/deser
, MonadThrow m
, NoThunks (VerKeyKES (CompactSumKES h d))
Expand Down Expand Up @@ -313,7 +312,7 @@ instance ( OptimizedKESAlgorithm d

{-# NOINLINE genKeyKES #-}
genKeyKES r = do
(r0raw, r1raw) <- NaCl.expandHash (Proxy :: Proxy h) (mlockedSeedMLSB r)
(r0raw, r1raw) <- expandHash (Proxy :: Proxy h) (mlockedSeedMLSB r)
let r0 = MLockedSeed r0raw
r1 = MLockedSeed r1raw
sk_0 <- genKeyKES r0
Expand All @@ -333,7 +332,7 @@ instance ( OptimizedKESAlgorithm d

instance ( KESSignAlgorithm m (CompactSumKES h d)
, UnsoundKESSignAlgorithm m d
, NaCl.MonadSodium m
, MonadSodium m
, MonadST m
) => UnsoundKESSignAlgorithm m (CompactSumKES h d) where
--
Expand All @@ -343,7 +342,7 @@ instance ( KESSignAlgorithm m (CompactSumKES h d)
{-# NOINLINE rawSerialiseSignKeyKES #-}
rawSerialiseSignKeyKES (SignKeyCompactSumKES sk r_1 vk_0 vk_1) = do
ssk <- rawSerialiseSignKeyKES sk
sr1 <- NaCl.mlsbToByteString . mlockedSeedMLSB $ r_1
sr1 <- mlsbToByteString . mlockedSeedMLSB $ r_1
return $ mconcat
[ ssk
, sr1
Expand All @@ -355,7 +354,7 @@ instance ( KESSignAlgorithm m (CompactSumKES h d)
rawDeserialiseSignKeyKES b = runMaybeT $ do
guard (BS.length b == fromIntegral size_total)
sk <- MaybeT $ rawDeserialiseSignKeyKES b_sk
r <- MaybeT $ NaCl.mlsbFromByteStringCheck b_r
r <- MaybeT $ mlsbFromByteStringCheck b_r
vk_0 <- MaybeT . return $ rawDeserialiseVerKeyKES b_vk0
vk_1 <- MaybeT . return $ rawDeserialiseVerKeyKES b_vk1
return (SignKeyCompactSumKES sk (MLockedSeed r) vk_0 vk_1)
Expand Down Expand Up @@ -409,7 +408,7 @@ deriving via OnlyCheckWhnfNamed "SignKeyKES (CompactSumKES h d)" (SignKeyKES (Co
instance (KESAlgorithm d) => NoThunks (VerKeyKES (CompactSumKES h d))

instance ( OptimizedKESAlgorithm d
, NaCl.SodiumHashAlgorithm h
, SodiumHashAlgorithm h
, SizeHash h ~ SeedSizeKES d
, NoThunks (VerKeyKES (CompactSumKES h d))
, KnownNat (SizeVerKeyKES (CompactSumKES h d))
Expand All @@ -421,7 +420,7 @@ instance ( OptimizedKESAlgorithm d
encodedSizeExpr _size = encodedVerKeyKESSizeExpr

instance ( OptimizedKESAlgorithm d
, NaCl.SodiumHashAlgorithm h
, SodiumHashAlgorithm h
, SizeHash h ~ SeedSizeKES d
, NoThunks (VerKeyKES (CompactSumKES h d))
, KnownNat (SizeVerKeyKES (CompactSumKES h d))
Expand Down Expand Up @@ -461,7 +460,7 @@ deriving instance KESAlgorithm d => Eq (SigKES (CompactSumKES h d))
instance KESAlgorithm d => NoThunks (SigKES (CompactSumKES h d))

instance ( OptimizedKESAlgorithm d
, NaCl.SodiumHashAlgorithm h
, SodiumHashAlgorithm h
, SizeHash h ~ SeedSizeKES d
, NoThunks (VerKeyKES (CompactSumKES h d))
, KnownNat (SizeVerKeyKES (CompactSumKES h d))
Expand All @@ -473,7 +472,7 @@ instance ( OptimizedKESAlgorithm d
encodedSizeExpr _size = encodedSigKESSizeExpr

instance ( OptimizedKESAlgorithm d
, NaCl.SodiumHashAlgorithm h
, SodiumHashAlgorithm h
, SizeHash h ~ SeedSizeKES d
, NoThunks (VerKeyKES (CompactSumKES h d))
, KnownNat (SizeVerKeyKES (CompactSumKES h d))
Expand All @@ -487,21 +486,23 @@ instance ( OptimizedKESAlgorithm d
-- Direct ser/deser
--

instance ( DirectSerialise (SignKeyKES d)
, DirectSerialise (VerKeyKES d)
instance ( DirectSerialise m (SignKeyKES d)
, DirectSerialise m (VerKeyKES d)
, MonadSodium m
, KESAlgorithm d
) => DirectSerialise (SignKeyKES (CompactSumKES h d)) where
) => DirectSerialise m (SignKeyKES (CompactSumKES h d)) where
directSerialise push (SignKeyCompactSumKES sk r vk0 vk1) = do
directSerialise push sk
mlockedSeedUseAsCPtr r $ \ptr ->
push (castPtr ptr) (fromIntegral $ seedSizeKES (Proxy :: Proxy d))
directSerialise push vk0
directSerialise push vk1

instance ( DirectDeserialise (SignKeyKES d)
, DirectDeserialise (VerKeyKES d)
instance ( DirectDeserialise m (SignKeyKES d)
, DirectDeserialise m (VerKeyKES d)
, MonadSodium m
, KESAlgorithm d
) => DirectDeserialise (SignKeyKES (CompactSumKES h d)) where
) => DirectDeserialise m (SignKeyKES (CompactSumKES h d)) where
directDeserialise pull = do
sk <- directDeserialise pull

Expand All @@ -515,16 +516,16 @@ instance ( DirectDeserialise (SignKeyKES d)
return $! SignKeyCompactSumKES sk r vk0 vk1


instance DirectSerialise (VerKeyKES (CompactSumKES h d)) where
instance MonadSodium m => DirectSerialise m (VerKeyKES (CompactSumKES h d)) where
directSerialise push (VerKeyCompactSumKES h) = do
BS.useAsCStringLen (hashToBytes h) $ \(ptr, len) ->
useByteStringAsCStringLen (hashToBytes h) $ \(ptr, len) ->
push (castPtr ptr) (fromIntegral len)

instance HashAlgorithm h => DirectDeserialise (VerKeyKES (CompactSumKES h d)) where
instance (MonadSodium m, MonadST m, MonadFail m, HashAlgorithm h) => DirectDeserialise m (VerKeyKES (CompactSumKES h d)) where
directDeserialise pull = do
let len :: Num a => a
len = fromIntegral $ sizeHash (Proxy @h)
allocaBytes len $ \ptr -> do
pull ptr len
bs <- BS.packCStringLen (ptr, len)
bs <- packByteStringCStringLen (ptr, len)
maybe (fail "Invalid hash") return $! VerKeyCompactSumKES <$!> hashFromBytes bs
26 changes: 15 additions & 11 deletions cardano-crypto-class/src/Cardano/Crypto/KES/Mock.hs
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,9 @@ import Data.Proxy (Proxy(..))
import GHC.Generics (Generic)
import GHC.TypeNats (Nat, KnownNat, natVal)
import NoThunks.Class (NoThunks)
import qualified Data.ByteString as BS
import Foreign.Marshal.Alloc (allocaBytes)

import Control.Exception (assert)
import Control.Monad.Class.MonadST (MonadST (..))

import Cardano.Binary (FromCBOR (..), ToCBOR (..))

Expand All @@ -36,7 +35,12 @@ import Cardano.Crypto.Seed
import Cardano.Crypto.KES.Class
import Cardano.Crypto.Util
import Cardano.Crypto.MLockedSeed
import Cardano.Crypto.MonadSodium (mlsbAsByteString)
import Cardano.Crypto.MonadSodium
( MonadSodium (..)
, mlsbAsByteString
, useByteStringAsCStringLen
, packByteStringCStringLen
)
import Cardano.Crypto.DirectSerialise

data MockKES (t :: Nat)
Expand Down Expand Up @@ -182,31 +186,31 @@ rawSerialiseSignKeyMockKES (SignKeyMockKES vk t) =
rawSerialiseVerKeyKES vk
<> writeBinaryWord64 (fromIntegral t)

instance KnownNat t => DirectSerialise (SignKeyKES (MockKES t)) where
instance (MonadSodium m, KnownNat t) => DirectSerialise m (SignKeyKES (MockKES t)) where
directSerialise put sk = do
let bs = rawSerialiseSignKeyMockKES sk
BS.useAsCStringLen bs $ \(cstr, len) -> put cstr (fromIntegral len)
useByteStringAsCStringLen bs $ \(cstr, len) -> put cstr (fromIntegral len)

instance KnownNat t => DirectDeserialise (SignKeyKES (MockKES t)) where
instance (MonadSodium m, MonadST m, KnownNat t) => DirectDeserialise m (SignKeyKES (MockKES t)) where
directDeserialise pull = do
let len = fromIntegral $ sizeSignKeyKES (Proxy @(MockKES t))
bs <- allocaBytes len $ \cstr -> do
pull cstr (fromIntegral len)
BS.packCStringLen (cstr, len)
packByteStringCStringLen (cstr, len)
maybe (error "directDeserialise @(SignKeyKES (MockKES t))") return $
rawDeserialiseSignKeyMockKES bs

instance KnownNat t => DirectSerialise (VerKeyKES (MockKES t)) where
instance (MonadSodium m, KnownNat t) => DirectSerialise m (VerKeyKES (MockKES t)) where
directSerialise put sk = do
let bs = rawSerialiseVerKeyKES sk
BS.useAsCStringLen bs $ \(cstr, len) -> put cstr (fromIntegral len)
useByteStringAsCStringLen bs $ \(cstr, len) -> put cstr (fromIntegral len)

instance KnownNat t => DirectDeserialise (VerKeyKES (MockKES t)) where
instance (MonadSodium m, MonadST m, KnownNat t) => DirectDeserialise m (VerKeyKES (MockKES t)) where
directDeserialise pull = do
let len = fromIntegral $ sizeVerKeyKES (Proxy @(MockKES t))
bs <- allocaBytes len $ \cstr -> do
pull cstr (fromIntegral len)
BS.packCStringLen (cstr, len)
packByteStringCStringLen (cstr, len)
maybe (error "directDeserialise @(VerKeyKES (MockKES t))") return $
rawDeserialiseVerKeyKES bs

Expand Down
8 changes: 4 additions & 4 deletions cardano-crypto-class/src/Cardano/Crypto/KES/Simple.hs
Original file line number Diff line number Diff line change
Expand Up @@ -211,21 +211,21 @@ instance ( UnsoundDSIGNMAlgorithm m d, KnownNat t, KESSignAlgorithm m (SimpleKES
| otherwise
= return Nothing

instance DirectSerialise (VerKeyDSIGNM d) => DirectSerialise (VerKeyKES (SimpleKES d t)) where
instance (Monad m, DirectSerialise m (VerKeyDSIGNM d)) => DirectSerialise m (VerKeyKES (SimpleKES d t)) where
directSerialise push (VerKeySimpleKES vks) =
mapM_ (directSerialise push) vks

instance (DirectDeserialise (VerKeyDSIGNM d), KnownNat t) => DirectDeserialise (VerKeyKES (SimpleKES d t)) where
instance (Monad m, DirectDeserialise m (VerKeyDSIGNM d), KnownNat t) => DirectDeserialise m (VerKeyKES (SimpleKES d t)) where
directDeserialise pull = do
let duration = fromIntegral (natVal (Proxy :: Proxy t))
vks <- Vec.replicateM duration (directDeserialise pull)
return $! VerKeySimpleKES $! vks

instance (DirectSerialise (SignKeyDSIGNM d)) => DirectSerialise (SignKeyKES (SimpleKES d t)) where
instance (Monad m, DirectSerialise m (SignKeyDSIGNM d)) => DirectSerialise m (SignKeyKES (SimpleKES d t)) where
directSerialise push (SignKeySimpleKES sks) =
mapM_ (directSerialise push) sks

instance (DirectDeserialise (SignKeyDSIGNM d), KnownNat t) => DirectDeserialise (SignKeyKES (SimpleKES d t)) where
instance (Monad m, DirectDeserialise m (SignKeyDSIGNM d), KnownNat t) => DirectDeserialise m (SignKeyKES (SimpleKES d t)) where
directDeserialise pull = do
let duration = fromIntegral (natVal (Proxy :: Proxy t))
sks <- Vec.replicateM duration (directDeserialise pull)
Expand Down
Loading

0 comments on commit 6bea65f

Please sign in to comment.