diff --git a/crates/core/src/stateless/data.rs b/crates/core/src/stateless/data.rs index 598d839e..415084a3 100644 --- a/crates/core/src/stateless/data.rs +++ b/crates/core/src/stateless/data.rs @@ -35,7 +35,7 @@ pub struct StatelessClientData { /// Maps each address with its storage trie and the used storage slots. pub storage_tries: HashMap, /// The code for each account - pub contracts: HashMap, + pub contracts: Vec, /// Immediate parent header pub parent_header: Header, /// List of at most 256 previous block headers diff --git a/crates/core/src/stateless/initialize.rs b/crates/core/src/stateless/initialize.rs index 7bf4901f..f8e9bd63 100644 --- a/crates/core/src/stateless/initialize.rs +++ b/crates/core/src/stateless/initialize.rs @@ -23,7 +23,6 @@ use alloy_primitives::{Address, Bytes, B256, U256}; use anyhow::bail; use core::mem::take; use reth_primitives::revm_primitives::Bytecode; -use reth_primitives::KECCAK_EMPTY; use reth_revm::db::{AccountState, DbAccount}; use reth_revm::primitives::AccountInfo; use std::default::Default; @@ -32,7 +31,7 @@ pub trait InitializationStrategy { fn initialize_database( state_trie: &mut MptNode, storage_tries: &mut HashMap, - contracts: &mut HashMap, + contracts: &mut Vec, parent_header: &mut Driver::Header, ancestor_headers: &mut Vec, ) -> anyhow::Result; @@ -44,7 +43,7 @@ impl InitializationStrategy for MemoryDbSt fn initialize_database( state_trie: &mut MptNode, storage_tries: &mut HashMap, - contracts: &mut HashMap, + contracts: &mut Vec, parent_header: &mut Driver::Header, ancestor_headers: &mut Vec, ) -> anyhow::Result { @@ -58,9 +57,9 @@ impl InitializationStrategy for MemoryDbSt } // hash all the contract code - let mut contracts: HashMap = take(contracts) + let contracts = take(contracts) .into_iter() - .map(|(address, bytes)| (address, (keccak(&bytes).into(), bytes))) + .map(|bytes| (keccak(&bytes).into(), Bytecode::new_raw(bytes))) .collect(); // Load account data into db @@ -84,16 +83,6 @@ impl InitializationStrategy for MemoryDbSt ); } - // load the corresponding code - let code_hash = state_account.code_hash; - let bytecode = if code_hash.0 == KECCAK_EMPTY.0 { - Bytecode::new() - } else { - let (bytecode_hash, bytes) = contracts.remove(address).unwrap(); - assert_eq!(bytecode_hash, code_hash); - Bytecode::new_raw(bytes) - }; - // load storage reads let mut storage = HashMap::with_capacity_and_hasher(slots.len(), Default::default()); for slot in slots { @@ -108,7 +97,7 @@ impl InitializationStrategy for MemoryDbSt balance: state_account.balance, nonce: state_account.nonce, code_hash: state_account.code_hash, - code: Some(bytecode), + code: None, }, account_state: AccountState::None, storage, @@ -150,6 +139,7 @@ impl InitializationStrategy for MemoryDbSt // Initialize database Ok(MemoryDB { accounts, + contracts, block_hashes, ..Default::default() }) diff --git a/crates/preflight/src/client.rs b/crates/preflight/src/client.rs index de0b7d1d..56d300b2 100644 --- a/crates/preflight/src/client.rs +++ b/crates/preflight/src/client.rs @@ -20,6 +20,7 @@ use crate::provider::{new_provider, Provider}; use crate::trie::extend_proof_tries; use alloy::network::Network; use alloy::primitives::map::HashMap; +use alloy::primitives::Bytes; use anyhow::Context; use log::{debug, info, warn}; use std::cell::RefCell; @@ -167,7 +168,7 @@ where let core_parent_header = P::derive_header(data.parent_header.clone()); let mut state_trie = MptNode::from(R::state_root(&core_parent_header)); let mut storage_tries = Default::default(); - let mut contracts = data.contracts.clone(); + let mut contracts: Vec = Default::default(); let mut ancestor_headers: Vec = Default::default(); for num_blocks in 1..=block_count { @@ -218,16 +219,13 @@ where info!("Saving provider cache ..."); preflight_db.save_provider()?; - // collect the code from each account - info!("Collecting contracts ..."); + // collect the code of the used contracts let initial_db = preflight_db.inner.db.db.borrow(); - for (address, account) in initial_db.accounts.iter() { - let code = account.info.code.clone().context("missing code")?; - if !code.is_empty() && !contracts.contains_key(address) { - contracts.insert(*address, code.bytes()); - } + for code in initial_db.contracts.values() { + contracts.push(code.bytes().clone()); } drop(initial_db); + info!("Collected contracts: {}", contracts.len()); // construct the sparse MPTs from the inclusion proofs info!( diff --git a/crates/preflight/src/db.rs b/crates/preflight/src/db.rs index f83f40dc..6712fe34 100644 --- a/crates/preflight/src/db.rs +++ b/crates/preflight/src/db.rs @@ -26,48 +26,33 @@ use reth_primitives::revm_primitives::{Account, AccountInfo, Bytecode}; use reth_revm::db::states::StateChangeset; use reth_revm::db::CacheDB; use reth_revm::{Database, DatabaseCommit, DatabaseRef}; -use std::cell::RefCell; +use std::cell::{Ref, RefCell}; use std::marker::PhantomData; use std::ops::DerefMut; use zeth_core::db::apply_changeset; use zeth_core::driver::CoreDriver; use zeth_core::rescue::{Recoverable, Rescued}; -#[derive(Clone)] -pub struct MutCacheDB { - pub db: RefCell>, +/// Wraps a [`Database`] to provide a [`DatabaseRef`] implementation. +#[derive(Clone, Debug, Default)] +pub struct MutDB { + pub db: RefCell, } -impl MutCacheDB { - pub fn new(db: CacheDB) -> Self { +impl MutDB { + pub fn new(db: T) -> Self { Self { db: RefCell::new(db), } } -} - -impl Database for MutCacheDB { - type Error = as Database>::Error; - - fn basic(&mut self, address: Address) -> Result, Self::Error> { - self.db.borrow_mut().basic(address) - } - - fn code_by_hash(&mut self, code_hash: B256) -> Result { - self.db.borrow_mut().code_by_hash(code_hash) - } - - fn storage(&mut self, address: Address, index: U256) -> Result { - self.db.borrow_mut().storage(address, index) - } - fn block_hash(&mut self, number: u64) -> Result { - self.db.borrow_mut().block_hash(number) + pub fn borrow_db(&self) -> Ref { + self.db.borrow() } } -impl DatabaseRef for MutCacheDB { - type Error = as DatabaseRef>::Error; +impl DatabaseRef for MutDB { + type Error = ::Error; fn basic_ref(&self, address: Address) -> Result, Self::Error> { self.db.borrow_mut().basic(address) @@ -86,7 +71,7 @@ impl DatabaseRef for MutCacheDB { } } -pub type PrePostDB = CacheDB>>; +pub type PrePostDB = CacheDB>>>>; #[derive(Clone)] pub struct PreflightDB> { @@ -109,7 +94,7 @@ impl> From) -> Self { Self { - inner: CacheDB::new(MutCacheDB::new(CacheDB::new(value))), + inner: CacheDB::new(MutDB::new(CacheDB::new(MutDB::new(value)))), driver: PhantomData, } } @@ -136,17 +121,24 @@ impl> From> PreflightDB { pub fn clear(&mut self) -> anyhow::Result<()> { - let cleared = Self::from(self.inner.db.db.borrow().db.clone()); + let cleared = Self::from(self.inner.db.borrow_db().db.borrow_db().clone()); drop(core::mem::replace(self, cleared)); Ok(()) } pub fn save_provider(&mut self) -> anyhow::Result<()> { - self.inner.db.db.borrow_mut().db.save_provider() + self.inner.db.db.borrow_mut().db.db.borrow().save_provider() } pub fn advance_provider_block(&mut self) -> anyhow::Result<()> { - self.inner.db.db.borrow_mut().db.advance_provider_block() + self.inner + .db + .db + .borrow_mut() + .db + .db + .borrow_mut() + .advance_provider_block() } pub fn apply_changeset(&mut self, state_changeset: StateChangeset) -> anyhow::Result<()> { @@ -156,11 +148,11 @@ impl> PreflightDB { pub fn sanity_check(&mut self, state_changeset: StateChangeset) -> anyhow::Result<()> { // storage sanity check let initial_db = &self.inner.db; - let mut provider_db = initial_db.db.borrow().db.clone(); + let mut provider_db = initial_db.db.borrow().db.db.borrow().clone(); provider_db.block_no += 1; for (address, db_account) in &self.inner.accounts { use reth_revm::Database; - let provider_info = provider_db.basic(*address)?.unwrap(); + let provider_info = provider_db.basic(*address)?.unwrap_or_default(); if db_account.info != provider_info { error!("State difference for account {address}:"); if db_account.info.balance != provider_info.balance { @@ -198,13 +190,11 @@ impl> PreflightDB { pub fn get_initial_proofs( &mut self, ) -> anyhow::Result> { - let initial_db = &self.inner.db; - let storage_keys = enumerate_storage_keys(&initial_db.db.borrow()); - - let initial_db = self.inner.db.db.borrow_mut(); - let block_no = initial_db.db.block_no; + let initial_db = self.inner.db.borrow_db(); + let storage_keys = enumerate_storage_keys(&initial_db); + let block_no = initial_db.db.borrow_db().block_no; let res = get_proofs( - initial_db.db.provider.borrow_mut().deref_mut(), + initial_db.db.borrow_db().provider.borrow_mut().deref_mut(), block_no, storage_keys, )?; @@ -215,8 +205,8 @@ impl> PreflightDB { &mut self, ) -> anyhow::Result> { // get initial keys - let initial_db = &self.inner.db; - let mut storage_keys = enumerate_storage_keys(&initial_db.db.borrow()); + let initial_db = self.inner.db.borrow_db(); + let mut storage_keys = enumerate_storage_keys(&initial_db); // merge initial keys with latest db storage keys for (address, mut indices) in enumerate_storage_keys(&self.inner) { match storage_keys.get_mut(&address) { @@ -227,10 +217,9 @@ impl> PreflightDB { } } // return proofs as of next block - let initial_db = self.inner.db.db.borrow_mut(); - let block_no = initial_db.db.block_no + 1; + let block_no = initial_db.db.borrow_db().block_no + 1; let res = get_proofs( - initial_db.db.provider.borrow_mut().deref_mut(), + initial_db.db.borrow_db().provider.borrow_mut().deref_mut(), block_no, storage_keys, )?; @@ -238,8 +227,8 @@ impl> PreflightDB { } pub fn get_ancestor_headers(&mut self) -> anyhow::Result> { - let initial_db = &self.inner.db.db.borrow_mut(); - let db_block_number = initial_db.db.block_no; + let initial_db = self.inner.db.db.borrow_mut(); + let db_block_number = initial_db.db.borrow_db().block_no; let earliest_block = initial_db .block_hashes .keys() @@ -247,7 +236,8 @@ impl> PreflightDB { .copied() .map(|v| v.to()) .unwrap_or(db_block_number); - let mut provider = initial_db.db.provider.borrow_mut(); + let provider_db = initial_db.db.borrow_db(); + let mut provider = provider_db.provider.borrow_mut(); let headers = (earliest_block..db_block_number) .rev() .map(|block_no| { @@ -277,14 +267,12 @@ impl> PreflightDB { state_orphans: &[TrieOrphan], block_count: u64, ) -> Vec { - let initial_db = &self.inner.db.db.borrow_mut(); - let mut provider = initial_db.db.provider.borrow_mut(); + let initial_db = self.inner.db.db.borrow_mut(); + let provider_db = initial_db.db.borrow_db(); + let mut provider = provider_db.provider.borrow_mut(); let mut result = Vec::new(); - let block_no = initial_db.db.block_no + block_count - 1; + let block_no = initial_db.db.borrow_db().block_no + block_count - 1; for (start, digest) in state_orphans { - // if let Ok(val) = provider.get_preimage(&PreimageQuery { digest: *digest }) { - // continue; - // } if let Ok(next_account) = provider.get_next_account(&AccountRangeQuery { block_no, start: *start, @@ -313,10 +301,11 @@ impl> PreflightDB { storage_orphans: &[(Address, TrieOrphan)], block_count: u64, ) -> Vec { - let initial_db = &self.inner.db.db.borrow_mut(); - let mut provider = initial_db.db.provider.borrow_mut(); + let initial_db = self.inner.db.db.borrow_mut(); + let provider_db = initial_db.db.borrow_db(); + let mut provider = provider_db.provider.borrow_mut(); let mut result = Vec::new(); - let block_no = initial_db.db.block_no + block_count - 1; + let block_no = initial_db.db.borrow_db().block_no + block_count - 1; for (address, (start, digest)) in storage_orphans { // if let Ok(val) = provider.get_preimage(&PreimageQuery { digest: *digest }) { // continue; diff --git a/crates/preflight/src/provider/db.rs b/crates/preflight/src/provider/db.rs index f04edbb7..a44ff0e1 100644 --- a/crates/preflight/src/provider/db.rs +++ b/crates/preflight/src/provider/db.rs @@ -19,7 +19,7 @@ use alloy::network::Network; use alloy::primitives::map::HashMap; use alloy::primitives::{Address, B256, U256}; use reth_revm::primitives::{Account, AccountInfo, Bytecode}; -use reth_revm::{Database, DatabaseCommit, DatabaseRef}; +use reth_revm::{Database, DatabaseCommit}; use reth_storage_errors::db::DatabaseError; use std::cell::RefCell; use std::marker::PhantomData; @@ -29,6 +29,9 @@ use zeth_core::driver::CoreDriver; pub struct ProviderDB> { pub provider: Rc>>, pub block_no: u64, + /// Bytecode cache to allow querying bytecode by hash instead of address. + pub contracts: HashMap, + pub driver: PhantomData<(R, P)>, } @@ -37,6 +40,7 @@ impl> Clone for ProviderDB> ProviderDB { ProviderDB { provider, block_no, + contracts: HashMap::default(), driver: PhantomData, } } @@ -62,55 +67,9 @@ impl> ProviderDB { } impl> Database for ProviderDB { - type Error = anyhow::Error; - - fn basic(&mut self, address: Address) -> Result, Self::Error> { - let query = AccountQuery { - block_no: self.block_no, - address: address.into_array().into(), - }; - let nonce = self.provider.borrow_mut().get_transaction_count(&query)?; - let balance = self.provider.borrow_mut().get_balance(&query)?; - let code = self.provider.borrow_mut().get_code(&query)?; - let bytecode = Bytecode::new_raw(code); - Ok(Some(AccountInfo::new( - balance, - nonce.to(), - bytecode.hash_slow(), - bytecode, - ))) - } - - fn code_by_hash(&mut self, code_hash: B256) -> Result { - // not needed because we already load code with basic info - unreachable!("ProviderDB::code_by_hash {code_hash}") - } - - fn storage(&mut self, address: Address, index: U256) -> Result { - let bytes = index.to_be_bytes::<32>(); - let index = U256::from_be_bytes(bytes); - - self.provider.borrow_mut().get_storage(&StorageQuery { - block_no: self.block_no, - address: address.into_array().into(), - index, - }) - } - - fn block_hash(&mut self, block_no: u64) -> Result { - let header = P::derive_header(P::derive_header_response( - self.provider - .borrow_mut() - .get_full_block(&BlockQuery { block_no })?, - )); - Ok(R::header_hash(&header)) - } -} - -impl> DatabaseRef for ProviderDB { type Error = DatabaseError; - fn basic_ref(&self, address: Address) -> Result, Self::Error> { + fn basic(&mut self, address: Address) -> Result, Self::Error> { let query = AccountQuery { block_no: self.block_no, address: address.into_array().into(), @@ -119,44 +78,61 @@ impl> DatabaseRef for Provid .provider .borrow_mut() .get_transaction_count(&query) - .unwrap(); - let balance = self.provider.borrow_mut().get_balance(&query).unwrap(); - let code = self.provider.borrow_mut().get_code(&query).unwrap(); + .map_err(db_error)?; + let balance = self + .provider + .borrow_mut() + .get_balance(&query) + .map_err(db_error)?; + let code = self + .provider + .borrow_mut() + .get_code(&query) + .map_err(db_error)?; let bytecode = Bytecode::new_raw(code); - Ok(Some(AccountInfo::new( + + // index the code by its hash, so that we can later use code_by_hash + let code_hash = bytecode.hash_slow(); + self.contracts.insert(code_hash, bytecode); + + Ok(Some(AccountInfo { + nonce: nonce.to(), balance, - nonce.to(), - bytecode.hash_slow(), - bytecode, - ))) + code_hash, + code: None, // will be queried later using code_by_hash + })) } - fn code_by_hash_ref(&self, code_hash: B256) -> Result { - // not needed because we already load code with basic info - unreachable!("ProviderDB::code_by_hash_ref {code_hash}") + fn code_by_hash(&mut self, code_hash: B256) -> Result { + // this works because `basic` is always called first + let code = self + .contracts + .get(&code_hash) + .expect("`basic` must be called first for the corresponding account"); + + Ok(code.clone()) } - fn storage_ref(&self, address: Address, index: U256) -> Result { + fn storage(&mut self, address: Address, index: U256) -> Result { let bytes = index.to_be_bytes::<32>(); let index = U256::from_be_bytes(bytes); - Ok(self - .provider + self.provider .borrow_mut() .get_storage(&StorageQuery { block_no: self.block_no, address: address.into_array().into(), index, }) - .unwrap()) + .map_err(db_error) } - fn block_hash_ref(&self, block_no: u64) -> Result { + fn block_hash(&mut self, block_no: u64) -> Result { let header = P::derive_header(P::derive_header_response( self.provider .borrow_mut() .get_full_block(&BlockQuery { block_no }) - .unwrap(), + .map_err(db_error)?, )); Ok(R::header_hash(&header)) } @@ -165,3 +141,7 @@ impl> DatabaseRef for Provid impl> DatabaseCommit for ProviderDB { fn commit(&mut self, _changes: HashMap) {} } + +fn db_error(err: anyhow::Error) -> DatabaseError { + DatabaseError::Other(format!("provider error: {err:#}")) +}