From 5f2102a7481d88507d32094ebea36404d5e62b11 Mon Sep 17 00:00:00 2001 From: Mac L Date: Fri, 20 Dec 2024 04:22:32 +0400 Subject: [PATCH] Implement TreeHash for bitfield (#18) * Implement TreeHash for bitfield * Use new release of ethereum_ssz * Delete stray impl for Option * Add some basic tests --------- Co-authored-by: Michael Sproul --- tree_hash/Cargo.toml | 5 ++- tree_hash/src/impls.rs | 93 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 96 insertions(+), 2 deletions(-) diff --git a/tree_hash/Cargo.toml b/tree_hash/Cargo.toml index d7d9165..75a4434 100644 --- a/tree_hash/Cargo.toml +++ b/tree_hash/Cargo.toml @@ -13,13 +13,14 @@ categories = ["cryptography::cryptocurrencies"] [dependencies] alloy-primitives = "0.8.0" ethereum_hashing = "0.7.0" +ethereum_ssz = "0.8.0" smallvec = "1.6.1" +typenum = "1.12.0" [dev-dependencies] rand = "0.8.5" tree_hash_derive = { path = "../tree_hash_derive", version = "0.8.0" } -ethereum_ssz = "0.7" -ethereum_ssz_derive = "0.7" +ethereum_ssz_derive = "0.8.0" [features] arbitrary = ["alloy-primitives/arbitrary"] diff --git a/tree_hash/src/impls.rs b/tree_hash/src/impls.rs index 53d28f4..54ea302 100644 --- a/tree_hash/src/impls.rs +++ b/tree_hash/src/impls.rs @@ -1,6 +1,8 @@ use super::*; use alloy_primitives::{Address, B256, U128, U256}; +use ssz::{Bitfield, Fixed, Variable}; use std::sync::Arc; +use typenum::Unsigned; fn int_to_hash256(int: u64) -> Hash256 { let mut bytes = [0; HASHSIZE]; @@ -197,9 +199,68 @@ impl TreeHash for Arc { } } +/// A helper function providing common functionality for finding the Merkle root of some bytes that +/// represent a bitfield. +pub fn bitfield_bytes_tree_hash_root(bytes: &[u8]) -> Hash256 { + let byte_size = (N::to_usize() + 7) / 8; + let leaf_count = (byte_size + BYTES_PER_CHUNK - 1) / BYTES_PER_CHUNK; + + let mut hasher = MerkleHasher::with_leaves(leaf_count); + + hasher + .write(bytes) + .expect("bitfield should not exceed tree hash leaf limit"); + + hasher + .finish() + .expect("bitfield tree hash buffer should not exceed leaf limit") +} + +impl TreeHash for Bitfield> { + fn tree_hash_type() -> TreeHashType { + TreeHashType::List + } + + fn tree_hash_packed_encoding(&self) -> PackedEncoding { + unreachable!("List should never be packed.") + } + + fn tree_hash_packing_factor() -> usize { + unreachable!("List should never be packed.") + } + + fn tree_hash_root(&self) -> Hash256 { + // Note: we use `as_slice` because it does _not_ have the length-delimiting bit set (or + // present). + let root = bitfield_bytes_tree_hash_root::(self.as_slice()); + mix_in_length(&root, self.len()) + } +} + +impl TreeHash for Bitfield> { + fn tree_hash_type() -> TreeHashType { + TreeHashType::Vector + } + + fn tree_hash_packed_encoding(&self) -> PackedEncoding { + unreachable!("Vector should never be packed.") + } + + fn tree_hash_packing_factor() -> usize { + unreachable!("Vector should never be packed.") + } + + fn tree_hash_root(&self) -> Hash256 { + bitfield_bytes_tree_hash_root::(self.as_slice()) + } +} + #[cfg(test)] mod test { use super::*; + use ssz::{BitList, BitVector}; + use std::str::FromStr; + use typenum::{U32, U8}; #[test] fn bool() { @@ -237,4 +298,36 @@ mod test { ] ); } + + #[test] + fn bitvector() { + let empty_bitvector = BitVector::::new(); + assert_eq!(empty_bitvector.tree_hash_root(), Hash256::ZERO); + + let small_bitvector_bytes = vec![0xff_u8, 0xee, 0xdd, 0xcc]; + let small_bitvector = + BitVector::::from_bytes(small_bitvector_bytes.clone().into()).unwrap(); + assert_eq!( + small_bitvector.tree_hash_root().as_slice()[..4], + small_bitvector_bytes + ); + } + + #[test] + fn bitlist() { + let empty_bitlist = BitList::::with_capacity(8).unwrap(); + assert_eq!( + empty_bitlist.tree_hash_root(), + Hash256::from_str("0x5ac78d953211aa822c3ae6e9b0058e42394dd32e5992f29f9c12da3681985130") + .unwrap() + ); + + let mut small_bitlist = BitList::::with_capacity(4).unwrap(); + small_bitlist.set(1, true).unwrap(); + assert_eq!( + small_bitlist.tree_hash_root(), + Hash256::from_str("0x7eb03d394d83a389980b79897207be3a6512d964cb08978bb7f3cfc0db8cfb8a") + .unwrap() + ); + } }