diff --git a/Cargo.toml b/Cargo.toml index 351c7d62d..61b4a826d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,7 +2,7 @@ name = "discv5" authors = ["Age Manning "] edition = "2018" -version = "0.3.0" +version = "0.4.1" description = "Implementation of the p2p discv5 discovery protocol" license = "Apache-2.0" repository = "https://github.com/sigp/discv5" @@ -12,43 +12,42 @@ categories = ["network-programming", "asynchronous"] exclude = [".gitignore", ".github/*"] [dependencies] -enr = { version = "0.8.1", features = ["k256", "ed25519"] } -tokio = { version = "1.15.0", features = ["net", "sync", "macros", "rt"] } -libp2p-core = { version = "0.40.0", optional = true } -libp2p-identity = { version = "0.2.1", features = ["ed25519", "secp256k1"], optional = true } -zeroize = { version = "1.4.3", features = ["zeroize_derive"] } -futures = "0.3.19" -uint = { version = "0.9.1", default-features = false } -rlp = "0.5.1" +enr = { version = "0.10", features = ["k256", "ed25519"] } +tokio = { version = "1", features = ["net", "sync", "macros", "rt"] } +libp2p = { version = "0.53", features = ["ed25519", "secp256k1"], optional = true } +zeroize = { version = "1", features = ["zeroize_derive"] } +futures = "0.3" +uint = { version = "0.9", default-features = false } +rlp = "0.5" # This version must be kept up to date do it uses the same dependencies as ENR -hkdf = "0.12.3" -hex = "0.4.3" -fnv = "1.0.7" -arrayvec = "0.7.2" -rand = { version = "0.8.4", package = "rand" } -socket2 = "0.4.4" -smallvec = "1.7.0" -parking_lot = "0.11.2" -lazy_static = "1.4.0" -aes = { version = "0.7.5", features = ["ctr"] } -aes-gcm = "0.9.4" -tracing = { version = "0.1.29", features = ["log"] } -tracing-subscriber = { version = "0.3.3", features = ["env-filter"] } -lru = {version = "0.7.1", default-features = false } -hashlink = "0.7.0" -delay_map = "0.3.0" -more-asserts = "0.2.2" -derive_more = { version = "0.99.17", default-features = false, features = ["from", "display", "deref", "deref_mut"] } +hkdf = "0.12" +hex = "0.4" +fnv = "1" +arrayvec = "0.7" +rand = { version = "0.8", package = "rand" } +socket2 = "0.4" +smallvec = "1" +parking_lot = "0.11" +lazy_static = "1" +aes = { version = "0.7", features = ["ctr"] } +aes-gcm = "0.9" +tracing = { version = "0.1", features = ["log"] } +lru = { version = "0.12", default-features = false } +hashlink = "0.8" +delay_map = "0.3" +more-asserts = "0.3" +derive_more = { version = "0.99", default-features = false, features = ["from", "display", "deref", "deref_mut"] } [dev-dependencies] +clap = { version = "4", features = ["derive"] } +if-addrs = "0.10" +quickcheck = "0.9" rand_07 = { package = "rand", version = "0.7" } -quickcheck = "0.9.2" -tokio = { version = "1.15.0", features = ["full"] } -rand_xorshift = "0.3.0" -rand_core = "0.6.3" -clap = { version = "3.1", features = ["derive"] } -if-addrs = "0.10.1" +rand_core = "0.6" +rand_xorshift = "0.3" +tokio = { version = "1", features = ["full"] } +tracing-subscriber = { version = "0.3", features = ["env-filter"] } [features] -libp2p = ["libp2p-core", "libp2p-identity"] +libp2p = ["dep:libp2p"] serde = ["enr/serde"] diff --git a/README.md b/README.md index c87f42577..3ce88a19d 100644 --- a/README.md +++ b/README.md @@ -19,10 +19,10 @@ Status]][Crates Link] This is a rust implementation of the [Discovery v5](https://github.com/ethereum/devp2p/blob/master/discv5/discv5.md) peer discovery protocol. -Discovery v5 is a protocol designed for encrypted peer discovery (and topic advertisement tba). Each peer/node -on the network is identified via it's `ENR` ([Ethereum Node -Record](https://eips.ethereum.org/EIPS/eip-778)), which is essentially a signed key-value store -containing the node's public key and optionally IP address and port. +Discovery v5 is a protocol designed for encrypted peer discovery. Each peer/node on the network is +identified via it's `ENR` ([Ethereum Node Record](https://eips.ethereum.org/EIPS/eip-778)), which +is essentially a signed key-value store containing the node's public key and optionally IP address +and port. Discv5 employs a kademlia-like routing table to store and manage discovered peers and topics. The protocol allows for external IP discovery in NAT environments through regular PING/PONG's with @@ -37,13 +37,13 @@ For a simple CLI discovery service see [discv5-cli](https://github.com/AgeMannin A simple example of creating this service is as follows: ```rust - use discv5::{enr, enr::{CombinedKey, NodeId}, TokioExecutor, Discv5, Discv5ConfigBuilder}; + use discv5::{enr, enr::{CombinedKey, NodeId}, TokioExecutor, Discv5, ConfigBuilder}; use discv5::socket::ListenConfig; use std::net::SocketAddr; // construct a local ENR let enr_key = CombinedKey::generate_secp256k1(); - let enr = enr::EnrBuilder::new("v4").build(&enr_key).unwrap(); + let enr = enr::Enr::empty(&enr_key).unwrap(); // build the tokio executor let mut runtime = tokio::runtime::Builder::new_multi_thread() @@ -59,7 +59,7 @@ A simple example of creating this service is as follows: }; // default configuration - let config = Discv5ConfigBuilder::new(listen_config).build(); + let config = ConfigBuilder::new(listen_config).build(); // construct the discv5 server let mut discv5: Discv5 = Discv5::new(enr, enr_key, config).unwrap(); diff --git a/examples/custom_executor.rs b/examples/custom_executor.rs index 9ac0df7c4..a184fbf6e 100644 --- a/examples/custom_executor.rs +++ b/examples/custom_executor.rs @@ -9,7 +9,7 @@ //! $ cargo run --example custom_executor //! ``` -use discv5::{enr, enr::CombinedKey, Discv5, Discv5ConfigBuilder, Discv5Event, ListenConfig}; +use discv5::{enr, enr::CombinedKey, ConfigBuilder, Discv5, Event, ListenConfig}; use std::net::Ipv4Addr; fn main() { @@ -29,7 +29,7 @@ fn main() { let enr_key = CombinedKey::generate_secp256k1(); // construct a local ENR - let enr = enr::EnrBuilder::new("v4").build(&enr_key).unwrap(); + let enr = enr::Enr::empty(&enr_key).unwrap(); // build the tokio executor let runtime = tokio::runtime::Builder::new_multi_thread() @@ -39,7 +39,7 @@ fn main() { .unwrap(); // default configuration - uses the current executor - let config = Discv5ConfigBuilder::new(listen_config).build(); + let config = ConfigBuilder::new(listen_config).build(); // construct the discv5 server let mut discv5: Discv5 = Discv5::new(enr, enr_key, config).unwrap(); @@ -72,10 +72,10 @@ fn main() { loop { match event_stream.recv().await { - Some(Discv5Event::SocketUpdated(addr)) => { + Some(Event::SocketUpdated(addr)) => { println!("Nodes ENR socket address has been updated to: {addr:?}"); } - Some(Discv5Event::Discovered(enr)) => { + Some(Event::Discovered(enr)) => { println!("A peer has been discovered: {}", enr.node_id()); } _ => {} diff --git a/examples/find_nodes.rs b/examples/find_nodes.rs index f52d8364f..012256816 100644 --- a/examples/find_nodes.rs +++ b/examples/find_nodes.rs @@ -19,7 +19,7 @@ use clap::Parser; use discv5::{ enr, enr::{k256, CombinedKey}, - Discv5, Discv5ConfigBuilder, Discv5Event, ListenConfig, + ConfigBuilder, Discv5, Event, ListenConfig, }; use std::{ net::{IpAddr, Ipv4Addr, Ipv6Addr}, @@ -90,7 +90,7 @@ async fn main() { }; let enr = { - let mut builder = enr::EnrBuilder::new("v4"); + let mut builder = enr::Enr::builder(); if let Some(ip4) = args.enr_ip4 { // if the given address is the UNSPECIFIED address we want to advertise localhost if ip4.is_unspecified() { @@ -120,10 +120,10 @@ async fn main() { }; // default configuration with packet filtering - // let config = Discv5ConfigBuilder::new(listen_config).enable_packet_filter().build(); + // let config = ConfigBuilder::new(listen_config).enable_packet_filter().build(); // default configuration without packet filtering - let config = Discv5ConfigBuilder::new(listen_config).build(); + let config = ConfigBuilder::new(listen_config).build(); info!("Node Id: {}", enr.node_id()); if args.enr_ip6.is_some() || args.enr_ip4.is_some() { @@ -192,18 +192,19 @@ async fn main() { continue; } match discv5_ev { - Discv5Event::Discovered(enr) => info!("Enr discovered {}", enr), - Discv5Event::EnrAdded { enr, replaced: _ } => info!("Enr added {}", enr), - Discv5Event::NodeInserted { node_id, replaced: _ } => info!("Node inserted {}", node_id), - Discv5Event::SessionEstablished(enr, _) => info!("Session established {}", enr), - Discv5Event::SocketUpdated(addr) => info!("Socket updated {}", addr), - Discv5Event::TalkRequest(_) => info!("Talk request received"), + Event::Discovered(enr) => info!("Enr discovered {}", enr), + Event::EnrAdded { enr, replaced: _ } => info!("Enr added {}", enr), + Event::NodeInserted { node_id, replaced: _ } => info!("Node inserted {}", node_id), + Event::SessionEstablished(enr, _) => info!("Session established {}", enr), + Event::SocketUpdated(addr) => info!("Socket updated {}", addr), + Event::TalkRequest(_) => info!("Talk request received"), }; } } } } +#[derive(Clone)] pub enum SocketKind { Ip4, Ip6, diff --git a/examples/request_enr.rs b/examples/request_enr.rs index 1fa780ba5..d3daac640 100644 --- a/examples/request_enr.rs +++ b/examples/request_enr.rs @@ -13,7 +13,7 @@ //! //! This requires the "libp2p" feature. #[cfg(feature = "libp2p")] -use discv5::Discv5ConfigBuilder; +use discv5::ConfigBuilder; #[cfg(feature = "libp2p")] use discv5::ListenConfig; #[cfg(feature = "libp2p")] @@ -43,10 +43,10 @@ async fn main() { // generate a new enr key let enr_key = CombinedKey::generate_secp256k1(); // construct a local ENR - let enr = enr::EnrBuilder::new("v4").build(&enr_key).unwrap(); + let enr = enr::Enr::empty(&enr_key).unwrap(); // default discv5 configuration - let config = Discv5ConfigBuilder::new(listen_config).build(); + let config = ConfigBuilder::new(listen_config).build(); let multiaddr = std::env::args() .nth(1) diff --git a/examples/simple_server.rs b/examples/simple_server.rs index c4b2509b8..2cc453281 100644 --- a/examples/simple_server.rs +++ b/examples/simple_server.rs @@ -10,7 +10,7 @@ //! $ cargo run --example simple_server -- //! ``` -use discv5::{enr, enr::CombinedKey, Discv5, Discv5ConfigBuilder, Discv5Event, ListenConfig}; +use discv5::{enr, enr::CombinedKey, ConfigBuilder, Discv5, Event, ListenConfig}; use std::net::Ipv4Addr; #[tokio::main] @@ -46,7 +46,7 @@ async fn main() { // construct a local ENR let enr = { - let mut builder = enr::EnrBuilder::new("v4"); + let mut builder = enr::Enr::builder(); // if an IP was specified, use it if let Some(external_address) = address { builder.ip4(external_address); @@ -72,7 +72,7 @@ async fn main() { } // default configuration - let config = Discv5ConfigBuilder::new(listen_config).build(); + let config = ConfigBuilder::new(listen_config).build(); // construct the discv5 server let mut discv5: Discv5 = Discv5::new(enr, enr_key, config).unwrap(); @@ -104,10 +104,10 @@ async fn main() { loop { match event_stream.recv().await { - Some(Discv5Event::SocketUpdated(addr)) => { + Some(Event::SocketUpdated(addr)) => { println!("Nodes ENR socket address has been updated to: {addr:?}"); } - Some(Discv5Event::Discovered(enr)) => { + Some(Event::Discovered(enr)) => { println!("A peer has been discovered: {}", enr.node_id()); } _ => {} diff --git a/src/config.rs b/src/config.rs index 0d3fafa3b..3e3c8e600 100644 --- a/src/config.rs +++ b/src/config.rs @@ -8,11 +8,11 @@ use crate::{ /// boostrap. const MIN_SESSIONS_UNREACHABLE_ENR: usize = 10; -use std::{ops::RangeInclusive, time::Duration}; +use std::{num::NonZeroUsize, ops::RangeInclusive, time::Duration}; /// Configuration parameters that define the performance of the discovery network. #[derive(Clone)] -pub struct Discv5Config { +pub struct Config { /// Whether to enable the incoming packet filter. Default: false. pub enable_packet_filter: bool, @@ -38,7 +38,7 @@ pub struct Discv5Config { pub session_timeout: Duration, /// The maximum number of established sessions to maintain. Default: 1000. - pub session_cache_capacity: usize, + pub session_cache_capacity: NonZeroUsize, /// Updates the local ENR IP and port based on PONG responses from peers. Default: true. pub enr_update: bool, @@ -115,11 +115,11 @@ pub struct Discv5Config { } #[derive(Debug)] -pub struct Discv5ConfigBuilder { - config: Discv5Config, +pub struct ConfigBuilder { + config: Config, } -impl Discv5ConfigBuilder { +impl ConfigBuilder { pub fn new(listen_config: ListenConfig) -> Self { // This is only applicable if enable_packet_filter is set. let filter_rate_limiter = Some( @@ -132,7 +132,7 @@ impl Discv5ConfigBuilder { ); // set default values - let config = Discv5Config { + let config = Config { enable_packet_filter: false, request_timeout: Duration::from_secs(1), vote_duration: Duration::from_secs(30), @@ -140,7 +140,7 @@ impl Discv5ConfigBuilder { query_timeout: Duration::from_secs(60), request_retries: 1, session_timeout: Duration::from_secs(86400), - session_cache_capacity: 1000, + session_cache_capacity: NonZeroUsize::new(1000).expect("infallible"), enr_update: true, max_nodes_response: 16, enr_peer_update_min: 10, @@ -161,7 +161,7 @@ impl Discv5ConfigBuilder { listen_config, }; - Discv5ConfigBuilder { config } + ConfigBuilder { config } } /// Whether to enable the incoming packet filter. @@ -211,7 +211,8 @@ impl Discv5ConfigBuilder { /// The maximum number of established sessions to maintain. pub fn session_cache_capacity(&mut self, capacity: usize) -> &mut Self { - self.config.session_cache_capacity = capacity; + self.config.session_cache_capacity = + NonZeroUsize::new(capacity).expect("session_cache_capacity must be greater than 0"); self } @@ -335,7 +336,7 @@ impl Discv5ConfigBuilder { self } - pub fn build(&mut self) -> Discv5Config { + pub fn build(&mut self) -> Config { // If an executor is not provided, assume a current tokio runtime is running. if self.config.executor.is_none() { self.config.executor = Some(Box::::default()); @@ -350,9 +351,9 @@ impl Discv5ConfigBuilder { } } -impl std::fmt::Debug for Discv5Config { +impl std::fmt::Debug for Config { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("Discv5Config") + f.debug_struct("Config") .field("filter_enabled", &self.enable_packet_filter) .field("request_timeout", &self.request_timeout) .field("vote_duration", &self.vote_duration) @@ -360,7 +361,7 @@ impl std::fmt::Debug for Discv5Config { .field("query_peer_timeout", &self.query_peer_timeout) .field("request_retries", &self.request_retries) .field("session_timeout", &self.session_timeout) - .field("session_cache_capacity", &self.session_cache_capacity) + .field("session_cache_capacity", &self.session_cache_capacity.get()) .field("enr_update", &self.enr_update) .field("query_parallelism", &self.query_parallelism) .field("report_discovered_peers", &self.report_discovered_peers) diff --git a/src/discv5.rs b/src/discv5.rs index f813e8a42..0acf1587a 100644 --- a/src/discv5.rs +++ b/src/discv5.rs @@ -2,9 +2,9 @@ //! //! This provides the main struct for running and interfacing with a discovery v5 server. //! -//! A [`Discv5`] struct needs to be created either with an [`crate::executor::Executor`] specified in the -//! [`Discv5Config`] via the [`crate::Discv5ConfigBuilder`] or in the presence of a tokio runtime that has -//! timing and io enabled. +//! A [`Discv5`] struct needs to be created either with an [`crate::executor::Executor`] specified +//! in the [`Config`] via the [`crate::ConfigBuilder`] or in the presence of a tokio runtime that +//! has timing and io enabled. //! //! Once a [`Discv5`] struct has been created the service is started by running the [`Discv5::start`] //! functions with a UDP socket. This will start a discv5 server in the background listening on the @@ -13,7 +13,7 @@ //! The server can be shutdown using the [`Discv5::shutdown`] function. use crate::{ - error::{Discv5Error, QueryError, RequestError}, + error::{Error, QueryError, RequestError}, kbucket::{ self, ConnectionDirection, ConnectionState, FailureReason, InsertResult, KBucketsTable, NodeStatus, UpdateResult, @@ -21,7 +21,7 @@ use crate::{ node_info::NodeContact, packet::ProtocolIdentity, service::{QueryKind, Service, ServiceRequest, TalkRequest}, - DefaultProtocolId, Discv5Config, Enr, IpMode, + Config, DefaultProtocolId, Enr, IpMode, }; use enr::{CombinedKey, EnrError, EnrKey, NodeId}; use parking_lot::RwLock; @@ -36,7 +36,7 @@ use tokio::sync::{mpsc, oneshot}; use tracing::{debug, warn}; #[cfg(feature = "libp2p")] -use libp2p_core::Multiaddr; +use libp2p::Multiaddr; // Create lazy static variable for the global permit/ban list use crate::{ @@ -49,11 +49,11 @@ lazy_static! { RwLock::new(crate::PermitBanList::default()); } -mod test; +pub(crate) mod test; /// Events that can be produced by the `Discv5` event stream. #[derive(Debug)] -pub enum Discv5Event { +pub enum Event { /// A node has been discovered from a FINDNODES request. /// /// The ENR of the node is returned. Various properties can be derived from the ENR. @@ -81,7 +81,7 @@ pub struct Discv5

where P: ProtocolIdentity, { - config: Discv5Config, + config: Config, /// The channel to make requests from the main service. service_channel: Option>, /// The exit channel to shutdown the underlying service. @@ -102,7 +102,7 @@ impl Discv5

{ pub fn new( local_enr: Enr, enr_key: CombinedKey, - mut config: Discv5Config, + mut config: Config, ) -> Result { // ensure the keypair matches the one that signed the enr. if local_enr.public_key() != enr_key.public() { @@ -154,10 +154,10 @@ impl Discv5

{ } /// Starts the required tasks and begins listening on a given UDP SocketAddr. - pub async fn start(&mut self) -> Result<(), Discv5Error> { + pub async fn start(&mut self) -> Result<(), Error> { if self.service_channel.is_some() { warn!("Service is already started"); - return Err(Discv5Error::ServiceAlreadyStarted); + return Err(Error::ServiceAlreadyStarted); } // create the main service @@ -301,6 +301,13 @@ impl Discv5

{ self.local_enr.read().clone() } + /// Identical to `Discv5::local_enr` except that this exposes the `Arc` itself. + /// + /// This is useful for synchronising views of the local ENR outside of `Discv5`. + pub fn external_enr(&self) -> Arc> { + self.local_enr.clone() + } + /// Returns the routing table of the discv5 service pub fn kbuckets(&self) -> KBucketsTable { self.kbuckets.read().clone() @@ -670,7 +677,7 @@ impl Discv5

{ /// Creates an event stream channel which can be polled to receive Discv5 events. pub fn event_stream( &self, - ) -> impl Future, Discv5Error>> + 'static { + ) -> impl Future, Error>> + 'static { let channel = self.clone_channel(); async move { @@ -682,20 +689,18 @@ impl Discv5

{ channel .send(event) .await - .map_err(|_| Discv5Error::ServiceChannelClosed)?; + .map_err(|_| Error::ServiceChannelClosed)?; - callback_recv - .await - .map_err(|_| Discv5Error::ServiceChannelClosed) + callback_recv.await.map_err(|_| Error::ServiceChannelClosed) } } /// Internal helper function to send events to the Service. - fn clone_channel(&self) -> Result, Discv5Error> { + fn clone_channel(&self) -> Result, Error> { if let Some(channel) = self.service_channel.as_ref() { Ok(channel.clone()) } else { - Err(Discv5Error::ServiceNotStarted) + Err(Error::ServiceNotStarted) } } } diff --git a/src/discv5/test.rs b/src/discv5/test.rs index bb627ff7f..f6cd70225 100644 --- a/src/discv5/test.rs +++ b/src/discv5/test.rs @@ -1,7 +1,7 @@ #![cfg(test)] use crate::{socket::ListenConfig, Discv5, *}; -use enr::{k256, CombinedKey, Enr, EnrBuilder, EnrKey, NodeId}; +use enr::{k256, CombinedKey, Enr, EnrKey, NodeId}; use rand_core::{RngCore, SeedableRng}; use std::{ collections::HashMap, @@ -26,13 +26,9 @@ async fn build_nodes(n: usize, base_port: u16) -> Vec { for port in base_port..base_port + n as u16 { let enr_key = CombinedKey::generate_secp256k1(); let listen_config = ListenConfig::Ipv4 { ip, port }; - let config = Discv5ConfigBuilder::new(listen_config).build(); + let config = ConfigBuilder::new(listen_config).build(); - let enr = EnrBuilder::new("v4") - .ip4(ip) - .udp4(port) - .build(&enr_key) - .unwrap(); + let enr = Enr::builder().ip4(ip).udp4(port).build(&enr_key).unwrap(); // transport for building a swarm let mut discv5 = Discv5::new(enr, enr_key, config).unwrap(); discv5.start().await.unwrap(); @@ -50,13 +46,9 @@ async fn build_nodes_from_keypairs(keys: Vec, base_port: u16) -> Ve let port = base_port + i as u16; let listen_config = ListenConfig::Ipv4 { ip, port }; - let config = Discv5ConfigBuilder::new(listen_config).build(); + let config = ConfigBuilder::new(listen_config).build(); - let enr = EnrBuilder::new("v4") - .ip4(ip) - .udp4(port) - .build(&enr_key) - .unwrap(); + let enr = Enr::builder().ip4(ip).udp4(port).build(&enr_key).unwrap(); let mut discv5 = Discv5::new(enr, enr_key, config).unwrap(); discv5.start().await.unwrap(); @@ -75,9 +67,9 @@ async fn build_nodes_from_keypairs_ipv6(keys: Vec, base_port: u16) ip: Ipv6Addr::LOCALHOST, port, }; - let config = Discv5ConfigBuilder::new(listen_config).build(); + let config = ConfigBuilder::new(listen_config).build(); - let enr = EnrBuilder::new("v4") + let enr = Enr::builder() .ip6(Ipv6Addr::LOCALHOST) .udp6(port) .build(&enr_key) @@ -106,9 +98,9 @@ async fn build_nodes_from_keypairs_dual_stack( ipv6: Ipv6Addr::LOCALHOST, ipv6_port, }; - let config = Discv5ConfigBuilder::new(listen_config).build(); + let config = ConfigBuilder::new(listen_config).build(); - let enr = EnrBuilder::new("v4") + let enr = Enr::builder() .ip4(Ipv4Addr::LOCALHOST) .udp4(ipv4_port) .ip6(Ipv6Addr::LOCALHOST) @@ -124,7 +116,7 @@ async fn build_nodes_from_keypairs_dual_stack( } /// Generate `n` deterministic keypairs from a given seed. -fn generate_deterministic_keypair(n: usize, seed: u64) -> Vec { +pub(crate) fn generate_deterministic_keypair(n: usize, seed: u64) -> Vec { let mut keypairs = Vec::new(); for i in 0..n { let sk = { @@ -744,16 +736,12 @@ async fn test_table_limits() { let mut keypairs = generate_deterministic_keypair(12, 9487); let ip: Ipv4Addr = "127.0.0.1".parse().unwrap(); let enr_key: CombinedKey = keypairs.remove(0); - let enr = EnrBuilder::new("v4") - .ip4(ip) - .udp4(9050) - .build(&enr_key) - .unwrap(); + let enr = Enr::builder().ip4(ip).udp4(9050).build(&enr_key).unwrap(); let listen_config = ListenConfig::Ipv4 { ip: enr.ip4().unwrap(), port: enr.udp4().unwrap(), }; - let config = Discv5ConfigBuilder::new(listen_config).ip_limit().build(); + let config = ConfigBuilder::new(listen_config).ip_limit().build(); // let socket_addr = enr.udp_socket().unwrap(); let discv5: Discv5 = Discv5::new(enr, enr_key, config).unwrap(); @@ -763,7 +751,7 @@ async fn test_table_limits() { .map(|i| { let ip: Ipv4Addr = Ipv4Addr::new(192, 168, 1, i as u8); let enr_key: CombinedKey = keypairs.remove(0); - EnrBuilder::new("v4") + Enr::builder() .ip4(ip) .udp4(9050 + i as u16) .build(&enr_key) @@ -782,11 +770,7 @@ async fn test_table_limits() { async fn test_bucket_limits() { let enr_key = CombinedKey::generate_secp256k1(); let ip: Ipv4Addr = "127.0.0.1".parse().unwrap(); - let enr = EnrBuilder::new("v4") - .ip4(ip) - .udp4(9500) - .build(&enr_key) - .unwrap(); + let enr = Enr::builder().ip4(ip).udp4(9500).build(&enr_key).unwrap(); let bucket_limit: usize = 2; // Generate `bucket_limit + 1` keypairs that go in `enr` node's 256th bucket. let keys = { @@ -794,7 +778,7 @@ async fn test_bucket_limits() { for _ in 0..bucket_limit + 1 { loop { let key = CombinedKey::generate_secp256k1(); - let enr_new = EnrBuilder::new("v4").build(&key).unwrap(); + let enr_new = Enr::empty(&key).unwrap(); let node_key: Key = enr.node_id().into(); let distance = node_key.log2_distance(&enr_new.node_id().into()).unwrap(); if distance == 256 { @@ -810,7 +794,7 @@ async fn test_bucket_limits() { .map(|i| { let kp = &keys[i - 1]; let ip: Ipv4Addr = Ipv4Addr::new(192, 168, 1, i as u8); - EnrBuilder::new("v4") + Enr::builder() .ip4(ip) .udp4(9500 + i as u16) .build(kp) @@ -822,7 +806,7 @@ async fn test_bucket_limits() { ip: enr.ip4().unwrap(), port: enr.udp4().unwrap(), }; - let config = Discv5ConfigBuilder::new(listen_config).ip_limit().build(); + let config = ConfigBuilder::new(listen_config).ip_limit().build(); let discv5: Discv5 = Discv5::new(enr, enr_key, config).unwrap(); for enr in enrs { diff --git a/src/error.rs b/src/error.rs index 7b9198d8c..5f1dd411a 100644 --- a/src/error.rs +++ b/src/error.rs @@ -8,7 +8,7 @@ use std::fmt; #[derive(Debug, From)] /// A general error that is used throughout the Discv5 library. -pub enum Discv5Error { +pub enum Error { /// An invalid message type was received. InvalidMessage, /// An invalid ENR was received. @@ -60,11 +60,11 @@ pub enum Discv5Error { #[derive(Debug)] pub enum NatError { /// Initiator error. - Initiator(Discv5Error), + Initiator(Error), /// Relayer error. - Relay(Discv5Error), + Relay(Error), /// Target error. - Target(Discv5Error), + Target(Error), } macro_rules! impl_from_variant { @@ -76,8 +76,8 @@ macro_rules! impl_from_variant { } }; } -impl_from_variant!(, tokio::sync::mpsc::error::SendError, Discv5Error, Self::ServiceChannelClosed); -impl_from_variant!(, NonContactable, Discv5Error, Self::InvalidEnr); +impl_from_variant!(, tokio::sync::mpsc::error::SendError, Error, Self::ServiceChannelClosed); +impl_from_variant!(, NonContactable, Error, Self::InvalidEnr); #[derive(Debug, Clone, PartialEq, Eq)] /// Types of packet errors. @@ -162,7 +162,7 @@ pub enum QueryError { InvalidMultiaddr(String), } -impl fmt::Display for Discv5Error { +impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "{self:?}") } diff --git a/src/handler/active_requests.rs b/src/handler/active_requests.rs index e46ccee83..8cfa2d602 100644 --- a/src/handler/active_requests.rs +++ b/src/handler/active_requests.rs @@ -1,68 +1,142 @@ use super::*; use delay_map::HashMapDelay; use more_asserts::debug_unreachable; +use std::collections::hash_map::Entry; pub(super) struct ActiveRequests { /// A list of raw messages we are awaiting a response from the remote. - active_requests_mapping: HashMapDelay, + active_requests_mapping: HashMap>, // WHOAREYOU messages do not include the source node id. We therefore maintain another // mapping of active_requests via message_nonce. This allows us to match WHOAREYOU // requests with active requests sent. - /// A mapping of all pending active raw requests message nonces to their NodeAddress. - active_requests_nonce_mapping: HashMap, + /// A mapping of all active raw requests message nonces to their NodeAddress. + active_requests_nonce_mapping: HashMapDelay, } impl ActiveRequests { pub fn new(request_timeout: Duration) -> Self { ActiveRequests { - active_requests_mapping: HashMapDelay::new(request_timeout), - active_requests_nonce_mapping: HashMap::new(), + active_requests_mapping: HashMap::new(), + active_requests_nonce_mapping: HashMapDelay::new(request_timeout), } } + /// Insert a new request into the active requests mapping. pub fn insert(&mut self, node_address: NodeAddress, request_call: RequestCall) { let nonce = *request_call.packet().message_nonce(); self.active_requests_mapping - .insert(node_address.clone(), request_call); + .entry(node_address.clone()) + .or_default() + .push(request_call); self.active_requests_nonce_mapping .insert(nonce, node_address); } - pub fn get(&self, node_address: &NodeAddress) -> Option<&RequestCall> { + /// Update the underlying packet for the request via message nonce. + pub fn update_packet(&mut self, old_nonce: MessageNonce, new_packet: Packet) { + let node_address = + if let Some(node_address) = self.active_requests_nonce_mapping.remove(&old_nonce) { + node_address + } else { + debug_unreachable!("expected to find nonce in active_requests_nonce_mapping"); + error!("expected to find nonce in active_requests_nonce_mapping"); + return; + }; + + self.active_requests_nonce_mapping + .insert(new_packet.header.message_nonce, node_address.clone()); + + match self.active_requests_mapping.entry(node_address) { + Entry::Occupied(mut requests) => { + let maybe_request_call = requests + .get_mut() + .iter_mut() + .find(|req| req.packet().message_nonce() == &old_nonce); + + if let Some(request_call) = maybe_request_call { + request_call.update_packet(new_packet); + } else { + debug_unreachable!("expected to find request call in active_requests_mapping"); + error!("expected to find request call in active_requests_mapping"); + } + } + Entry::Vacant(_) => { + debug_unreachable!("expected to find node address in active_requests_mapping"); + error!("expected to find node address in active_requests_mapping"); + } + } + } + + pub fn get(&self, node_address: &NodeAddress) -> Option<&Vec> { self.active_requests_mapping.get(node_address) } + /// Remove a single request identified by its nonce. pub fn remove_by_nonce(&mut self, nonce: &MessageNonce) -> Option<(NodeAddress, RequestCall)> { - match self.active_requests_nonce_mapping.remove(nonce) { - Some(node_address) => match self.active_requests_mapping.remove(&node_address) { - Some(request_call) => Some((node_address, request_call)), - None => { - debug_unreachable!("A matching request call doesn't exist"); - error!("A matching request call doesn't exist"); - None + let node_address = self.active_requests_nonce_mapping.remove(nonce)?; + match self.active_requests_mapping.entry(node_address.clone()) { + Entry::Vacant(_) => { + debug_unreachable!("expected to find node address in active_requests_mapping"); + error!("expected to find node address in active_requests_mapping"); + None + } + Entry::Occupied(mut requests) => { + let result = requests + .get() + .iter() + .position(|req| req.packet().message_nonce() == nonce) + .map(|index| (node_address, requests.get_mut().remove(index))); + if requests.get().is_empty() { + requests.remove(); } - }, - None => None, + result + } } } - pub fn remove(&mut self, node_address: &NodeAddress) -> Option { - match self.active_requests_mapping.remove(node_address) { - Some(request_call) => { - // Remove the associated nonce mapping. - match self - .active_requests_nonce_mapping - .remove(request_call.packet().message_nonce()) - { - Some(_) => Some(request_call), - None => { - debug_unreachable!("A matching nonce mapping doesn't exist"); - error!("A matching nonce mapping doesn't exist"); - None - } + /// Remove all requests associated with a node. + pub fn remove_requests(&mut self, node_address: &NodeAddress) -> Option> { + let requests = self.active_requests_mapping.remove(node_address)?; + // Account for node addresses in `active_requests_nonce_mapping` with an empty list + if requests.is_empty() { + debug_unreachable!("expected to find requests in active_requests_mapping"); + return None; + } + for req in &requests { + if self + .active_requests_nonce_mapping + .remove(req.packet().message_nonce()) + .is_none() + { + debug_unreachable!("expected to find req with nonce"); + error!("expected to find req with nonce"); + } + } + Some(requests) + } + + /// Remove a single request identified by its id. + pub fn remove_request( + &mut self, + node_address: &NodeAddress, + id: &RequestId, + ) -> Option { + match self.active_requests_mapping.entry(node_address.clone()) { + Entry::Vacant(_) => None, + Entry::Occupied(mut requests) => { + let index = requests.get().iter().position(|req| { + let req_id: RequestId = req.id().into(); + &req_id == id + })?; + let request_call = requests.get_mut().remove(index); + if requests.get().is_empty() { + requests.remove(); } + // Remove the associated nonce mapping. + self.active_requests_nonce_mapping + .remove(request_call.packet().message_nonce()); + Some(request_call) } - None => None, } } @@ -80,10 +154,12 @@ impl ActiveRequests { } } - for (address, request) in self.active_requests_mapping.iter() { - let nonce = request.packet().message_nonce(); - if !self.active_requests_nonce_mapping.contains_key(nonce) { - panic!("Address {} maps to request with nonce {:?}, which does not exist in `active_requests_nonce_mapping`", address, nonce); + for (address, requests) in self.active_requests_mapping.iter() { + for req in requests { + let nonce = req.packet().message_nonce(); + if !self.active_requests_nonce_mapping.contains_key(nonce) { + panic!("Address {} maps to request with nonce {:?}, which does not exist in `active_requests_nonce_mapping`", address, nonce); + } } } } @@ -92,12 +168,27 @@ impl ActiveRequests { impl Stream for ActiveRequests { type Item = Result<(NodeAddress, RequestCall), String>; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match self.active_requests_mapping.poll_next_unpin(cx) { - Poll::Ready(Some(Ok((node_address, request_call)))) => { - // Remove the associated nonce mapping. - self.active_requests_nonce_mapping - .remove(request_call.packet().message_nonce()); - Poll::Ready(Some(Ok((node_address, request_call)))) + match self.active_requests_nonce_mapping.poll_next_unpin(cx) { + Poll::Ready(Some(Ok((nonce, node_address)))) => { + match self.active_requests_mapping.entry(node_address.clone()) { + Entry::Vacant(_) => Poll::Ready(None), + Entry::Occupied(mut requests) => { + match requests + .get() + .iter() + .position(|req| req.packet().message_nonce() == &nonce) + { + Some(index) => { + let result = (node_address, requests.get_mut().remove(index)); + if requests.get().is_empty() { + requests.remove(); + } + Poll::Ready(Some(Ok(result))) + } + None => Poll::Ready(None), + } + } + } } Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))), Poll::Ready(None) => Poll::Ready(None), diff --git a/src/handler/crypto/mod.rs b/src/handler/crypto/mod.rs index 2bda3b0f7..6b609d3a8 100644 --- a/src/handler/crypto/mod.rs +++ b/src/handler/crypto/mod.rs @@ -6,7 +6,7 @@ //! encryption and key-derivation algorithms. Future versions may abstract some of these to allow //! for different algorithms. use crate::{ - error::Discv5Error, + error::Error, node_info::NodeContact, packet::{ChallengeData, MessageNonce}, }; @@ -48,7 +48,7 @@ pub(crate) fn generate_session_keys( local_id: &NodeId, contact: &NodeContact, challenge_data: &ChallengeData, -) -> Result<(Key, Key, Vec), Discv5Error> { +) -> Result<(Key, Key, Vec), Error> { let (secret, ephem_pk) = { match contact.public_key() { CombinedPublicKey::Secp256k1(remote_pk) => { @@ -57,9 +57,7 @@ pub(crate) fn generate_session_keys( let ephem_pk = ephem_sk.verifying_key(); (secret, ephem_pk.to_sec1_bytes().to_vec()) } - CombinedPublicKey::Ed25519(_) => { - return Err(Discv5Error::KeyTypeNotSupported("Ed25519")) - } + CombinedPublicKey::Ed25519(_) => return Err(Error::KeyTypeNotSupported("Ed25519")), } }; @@ -74,7 +72,7 @@ fn derive_key( first_id: &NodeId, second_id: &NodeId, challenge_data: &ChallengeData, -) -> Result<(Key, Key), Discv5Error> { +) -> Result<(Key, Key), Error> { let mut info = [0u8; INFO_LENGTH]; info[0..26].copy_from_slice(KEY_AGREEMENT_STRING.as_bytes()); info[26..26 + NODE_ID_LENGTH].copy_from_slice(&first_id.raw()); @@ -84,7 +82,7 @@ fn derive_key( let mut okm = [0u8; 2 * KEY_LENGTH]; hk.expand(&info, &mut okm) - .map_err(|_| Discv5Error::KeyDerivationFailed)?; + .map_err(|_| Error::KeyDerivationFailed)?; let mut initiator_key: Key = Default::default(); let mut recipient_key: Key = Default::default(); @@ -101,17 +99,17 @@ pub(crate) fn derive_keys_from_pubkey( remote_id: &NodeId, challenge_data: &ChallengeData, ephem_pubkey: &[u8], -) -> Result<(Key, Key), Discv5Error> { +) -> Result<(Key, Key), Error> { let secret = { match local_key { CombinedKey::Secp256k1(key) => { // convert remote pubkey into secp256k1 public key // the key type should match our own node record let remote_pubkey = k256::ecdsa::VerifyingKey::from_sec1_bytes(ephem_pubkey) - .map_err(|_| Discv5Error::InvalidRemotePublicKey)?; + .map_err(|_| Error::InvalidRemotePublicKey)?; ecdh(&remote_pubkey, key) } - CombinedKey::Ed25519(_) => return Err(Discv5Error::KeyTypeNotSupported("Ed25519")), + CombinedKey::Ed25519(_) => return Err(Error::KeyTypeNotSupported("Ed25519")), } }; @@ -127,7 +125,7 @@ pub(crate) fn sign_nonce( challenge_data: &ChallengeData, ephem_pubkey: &[u8], dst_id: &NodeId, -) -> Result, Discv5Error> { +) -> Result, Error> { let signing_message = generate_signing_nonce(challenge_data, ephem_pubkey, dst_id); match signing_key { @@ -135,10 +133,10 @@ pub(crate) fn sign_nonce( let message = Sha256::new().chain_update(signing_message); let signature: Signature = key .try_sign_digest(message) - .map_err(|e| Discv5Error::Error(format!("Failed to sign message: {e}")))?; + .map_err(|e| Error::Error(format!("Failed to sign message: {e}")))?; Ok(signature.to_vec()) } - CombinedKey::Ed25519(_) => Err(Discv5Error::KeyTypeNotSupported("Ed25519")), + CombinedKey::Ed25519(_) => Err(Error::KeyTypeNotSupported("Ed25519")), } } @@ -191,9 +189,9 @@ pub(crate) fn decrypt_message( message_nonce: MessageNonce, msg: &[u8], aad: &[u8], -) -> Result, Discv5Error> { +) -> Result, Error> { if msg.len() < 16 { - return Err(Discv5Error::DecryptionFailed( + return Err(Error::DecryptionFailed( "Message not long enough to contain a MAC".into(), )); } @@ -201,7 +199,7 @@ pub(crate) fn decrypt_message( let aead = Aes128Gcm::new(GenericArray::from_slice(key)); let payload = Payload { msg, aad }; aead.decrypt(GenericArray::from_slice(&message_nonce), payload) - .map_err(|e| Discv5Error::DecryptionFailed(e.to_string())) + .map_err(|e| Error::DecryptionFailed(e.to_string())) } /* Encryption related functions */ @@ -213,11 +211,11 @@ pub(crate) fn encrypt_message( message_nonce: MessageNonce, msg: &[u8], aad: &[u8], -) -> Result, Discv5Error> { +) -> Result, Error> { let aead = Aes128Gcm::new(GenericArray::from_slice(key)); let payload = Payload { msg, aad }; aead.encrypt(GenericArray::from_slice(&message_nonce), payload) - .map_err(|e| Discv5Error::DecryptionFailed(e.to_string())) + .map_err(|e| Error::DecryptionFailed(e.to_string())) } #[cfg(test)] @@ -225,7 +223,7 @@ mod tests { use crate::packet::DefaultProtocolId; use super::*; - use enr::{CombinedKey, EnrBuilder, EnrKey}; + use enr::{CombinedKey, Enr, EnrKey}; use std::convert::TryInto; fn hex_decode(x: &'static str) -> Vec { @@ -343,12 +341,12 @@ mod tests { let node1_key = CombinedKey::generate_secp256k1(); let node2_key = CombinedKey::generate_secp256k1(); - let node1_enr = EnrBuilder::new("v4") + let node1_enr = Enr::builder() .ip("127.0.0.1".parse().unwrap()) .udp4(9000) .build(&node1_key) .unwrap(); - let node2_enr = EnrBuilder::new("v4") + let node2_enr = Enr::builder() .ip("127.0.0.1".parse().unwrap()) .udp4(9000) .build(&node2_key) diff --git a/src/handler/mod.rs b/src/handler/mod.rs index 74a54f348..4cfa262bd 100644 --- a/src/handler/mod.rs +++ b/src/handler/mod.rs @@ -27,9 +27,9 @@ //! Messages from a node on the network come by [`Socket`] and get the form of a [`HandlerOut`] //! and can be forwarded to the application layer via the send channel. use crate::{ - config::Discv5Config, + config::Config, discv5::PERMIT_BAN_LIST, - error::{Discv5Error, NatError, RequestError}, + error::{Error, NatError, RequestError}, packet::{ChallengeData, IdNonce, MessageNonce, Packet, PacketKind, ProtocolIdentity}, rpc::{ Message, Payload, RelayInitNotification, RelayMsgNotification, Request, RequestBody, @@ -42,6 +42,7 @@ use crate::{ use delay_map::HashMapDelay; use enr::{CombinedKey, NodeId}; use futures::prelude::*; +use more_asserts::debug_unreachable; use parking_lot::RwLock; use smallvec::SmallVec; use std::{ @@ -202,6 +203,15 @@ struct PendingRequest { request: RequestBody, } +impl From<&HandlerReqId> for RequestId { + fn from(id: &HandlerReqId) -> Self { + match id { + HandlerReqId::Internal(id) => id.clone(), + HandlerReqId::External(id) => id.clone(), + } + } +} + /// Process to handle handshakes and sessions established from raw RPC communications between nodes. pub struct Handler { /// Configuration for the discv5 service. @@ -213,7 +223,7 @@ pub struct Handler { enr: Arc>, /// The key to sign the ENR and set up encrypted communication with peers. key: Arc>, - /// Pending raw requests. + /// Active requests that are awaiting a response. active_requests: ActiveRequests, /// The expected responses by SocketAddr which allows packets to pass the underlying filter. filter_expected_responses: Arc>>, @@ -250,7 +260,7 @@ impl Handler { pub async fn spawn( enr: Arc>, key: Arc>, - config: Discv5Config, + config: Config, ) -> Result { let (exit_sender, exit) = oneshot::channel(); // create the channels to send/receive messages from the application @@ -265,7 +275,7 @@ impl Handler { // The local node id let node_id = enr.read().node_id(); - let Discv5Config { + let Config { enable_packet_filter, filter_rate_limiter, filter_max_nodes_per_ip, @@ -319,7 +329,7 @@ impl Handler { // Attempt to bind to the socket before spinning up the send/recv tasks. let socket = Socket::new::

(socket_config).await?; - let sessions = LruTimeCache::new(session_timeout, Some(session_cache_capacity)); + let sessions = LruTimeCache::new(session_timeout, Some(session_cache_capacity.get())); let nat = Nat::new( &listen_sockets, @@ -421,13 +431,13 @@ impl Handler { Some(inbound_packet) = self.socket.recv.recv() => { self.process_inbound_packet::

(inbound_packet).await; } - Some(Ok((node_address, pending_request))) = self.active_requests.next() => { - self.handle_request_timeout::

(node_address, pending_request).await; + Some(Ok((node_address, active_request))) = self.active_requests.next() => { + self.handle_request_timeout::

(node_address, active_request).await; } Some(Ok((node_address, _challenge))) = self.active_challenges.next() => { // A challenge has expired. There could be pending requests awaiting this // challenge. We process them here - self.send_next_request::

(node_address).await; + self.send_pending_requests::

(&node_address).await; } Some(Ok(peer_socket)) = self.nat.hole_punch_tracker.next() => { if self.nat.is_behind_nat == Some(false) { @@ -492,7 +502,7 @@ impl Handler { socket_addr: inbound_packet.src_address, node_id: src_id, }; - self.handle_message::

( + self.handle_message( node_address, message_nonce, &inbound_packet.message, @@ -559,9 +569,7 @@ impl Handler { .on_request_time_out::

(relay, local_enr, nonce, target) .await { - Err(NatError::Initiator(Discv5Error::SessionAlreadyEstablished( - node_address, - ))) => { + Err(NatError::Initiator(Error::SessionAlreadyEstablished(node_address))) => { debug!("Session to peer already established, aborting hole punch attempt. Peer: {node_address}"); } Err(e) => { @@ -606,10 +614,10 @@ impl Handler { return Err(RequestError::SelfRequest); } - // If there is already an active request or an active challenge (WHOAREYOU sent) for this - // node, add to pending requests - if self.active_requests.get(&node_address).is_some() - || self.active_challenges.get(&node_address).is_some() + // If there is already an active challenge (WHOAREYOU sent) for this node, or if we are + // awaiting a session with this node to be established, add the request to pending requests. + if self.active_challenges.get(&node_address).is_some() + || self.is_awaiting_session_to_be_established(&node_address) { trace!("Request queued for node: {}", node_address); self.pending_requests @@ -637,14 +645,13 @@ impl Handler { .map_err(|e| RequestError::EncryptionFailed(format!("{e:?}")))?; (packet, false) } else { - // No session exists, start a new handshake + // No session exists, start a new handshake initiating a new session trace!( "Starting session. Sending random packet to: {}", node_address ); let packet = Packet::new_random(&self.node_id).map_err(RequestError::EntropyFailure)?; - // We are initiating a new session (packet, true) } }; @@ -831,6 +838,7 @@ impl Handler { // All sent requests must have an associated node_id. Therefore the following // must not panic. let node_address = request_call.contact().node_address(); + let auth_message_nonce = auth_packet.header.message_nonce; // Keep track if the ENR is reachable. In the case we don't know the ENR, we assume its // fine. @@ -850,7 +858,11 @@ impl Handler { enr_not_reachable = Nat::is_enr_reachable(&enr); // We already know the ENR. Send the handshake response packet - trace!("Sending Authentication response to node: {}", node_address); + trace!( + "Sending Authentication response to node: {} ({:?})", + node_address, + request_call.id() + ); request_call.update_packet(auth_packet.clone()); request_call.set_handshake_sent(); request_call.set_initiating_session(false); @@ -868,7 +880,11 @@ impl Handler { // Send the Auth response let contact = request_call.contact().clone(); - trace!("Sending Authentication response to node: {}", node_address); + trace!( + "Sending Authentication response to node: {} ({:?})", + node_address, + request_call.id() + ); request_call.update_packet(auth_packet.clone()); request_call.set_handshake_sent(); // Reinsert the request_call @@ -886,7 +902,13 @@ impl Handler { } } } - self.new_session(node_address, session, enr_not_reachable); + self.new_session::

( + node_address.clone(), + session, + Some(auth_message_nonce), + enr_not_reachable, + ) + .await; } /// Verifies a Node ENR to it's observed address. If it fails, any associated session is also @@ -957,6 +979,8 @@ impl Handler { most_recent_enr, ) { Ok((mut session, enr)) => { + // Remove the expected response for the challenge. + self.remove_expected_response(node_address.socket_addr); // Receiving an AuthResponse must give us an up-to-date view of the node ENR. // Verify the ENR is valid if self.verify_enr(&enr, &node_address) { @@ -970,20 +994,26 @@ impl Handler { ConnectionDirection::Incoming, ) .await; - self.new_session(node_address.clone(), session, enr_not_reachable); + // When (re-)establishing a session from an outgoing challenge, we do not need + // to filter out this request from active requests, so we do not pass + // the message nonce on to `new_session`. + self.new_session::

( + node_address.clone(), + session, + None, + enr_not_reachable, + ) + .await; self.nat .new_peer_latest_relay_cache .pop(&node_address.node_id); - self.handle_message::

( + self.handle_message( node_address.clone(), message_nonce, message, authenticated_data, ) .await; - // We could have pending messages that were awaiting this session to be - // established. If so process them. - self.send_next_request::

(node_address).await; } else { // IP's or NodeAddress don't match. Drop the session. warn!( @@ -1028,7 +1058,7 @@ impl Handler { } } } - Err(Discv5Error::InvalidChallengeSignature(challenge)) => { + Err(Error::InvalidChallengeSignature(challenge)) => { warn!( "Authentication header contained invalid signature. Ignoring packet from: {}", node_address @@ -1047,47 +1077,43 @@ impl Handler { } } else { warn!( - "Received an authenticated header without a matching WHOAREYOU request. {}", - node_address + node_id = %node_address.node_id, addr = %node_address.socket_addr, + "Received an authenticated header without a matching WHOAREYOU request", ); } } - async fn send_next_request(&mut self, node_address: NodeAddress) { - // ensure we are not over writing any existing requests - if self.active_requests.get(&node_address).is_none() { - if let std::collections::hash_map::Entry::Occupied(mut entry) = - self.pending_requests.entry(node_address) + /// Send all pending requests corresponding to the given node address, that were waiting for a + /// new session to be established or when an active outgoing challenge has expired. + async fn send_pending_requests(&mut self, node_address: &NodeAddress) { + let pending_requests = self + .pending_requests + .remove(node_address) + .unwrap_or_default(); + for req in pending_requests { + trace!( + "Sending pending request {} to {node_address}. {}", + RequestId::from(&req.request_id), + req.request, + ); + if let Err(request_error) = self + .send_request::

(req.contact, req.request_id.clone(), req.request) + .await { - // If it exists, there must be a request here - let PendingRequest { - contact, - request_id, - request, - } = entry.get_mut().remove(0); - if entry.get().is_empty() { - entry.remove(); - } - trace!("Sending next awaiting message. Node: {}", contact); - if let Err(request_error) = self - .send_request::

(contact, request_id.clone(), request) - .await - { - warn!("Failed to send next awaiting request {}", request_error); - // Inform the service that the request failed - match request_id { - HandlerReqId::Internal(_) => { - // An internal request could not be sent. For now we do nothing about - // this. - } - HandlerReqId::External(id) => { - if let Err(e) = self - .service_send - .send(HandlerOut::RequestFailed(id, request_error)) - .await - { - warn!("Failed to inform that request failed {}", e); - } + warn!("Failed to send next pending request {request_error}"); + // Inform the service that the request failed + match req.request_id { + HandlerReqId::Internal(_) => { + // An internal request could not be sent. For now we do nothing about + // this. + } + HandlerReqId::External(id) => { + if let Err(e) = self + .service_send + .send(HandlerOut::RequestFailed(id, request_error)) + .await + { + warn!("Failed to inform that request failed {e}"); } } } @@ -1095,6 +1121,64 @@ impl Handler { } } + /// Replays all active requests for the given node address, in the case that a new session has + /// been established. If an optional message nonce is provided, the corresponding request will + /// be skipped, eg. the request that established the new session. + async fn replay_active_requests( + &mut self, + node_address: &NodeAddress, + // Optional message nonce to filter out the request used to establish the session. + message_nonce: Option, + ) { + trace!( + "Replaying active requests. {}, {:?}", + node_address, + message_nonce + ); + + let packets = if let Some(session) = self.sessions.get_mut(node_address) { + let mut packets = vec![]; + for request_call in self + .active_requests + .get(node_address) + .unwrap_or(&vec![]) + .iter() + .filter(|req| { + // Except the active request that was used to establish the new session, as it has + // already been handled and shouldn't be replayed. + if let Some(nonce) = message_nonce.as_ref() { + req.packet().message_nonce() != nonce + } else { + true + } + }) + { + if let Ok(new_packet) = + session.encrypt_message::

(self.node_id, &request_call.encode()) + { + packets.push((*request_call.packet().message_nonce(), new_packet)); + } else { + error!( + "Failed to re-encrypt packet while replaying active request with id: {:?}", + request_call.id() + ); + } + } + + packets + } else { + debug_unreachable!("Attempted to replay active requests but session doesn't exist."); + error!("Attempted to replay active requests but session doesn't exist."); + return; + }; + + for (old_nonce, new_packet) in packets { + self.active_requests + .update_packet(old_nonce, new_packet.clone()); + self.send(node_address.clone(), new_packet).await; + } + } + /// Handle a session message packet, that is dropped if it can't be decrypted. async fn handle_session_message( &mut self, @@ -1107,7 +1191,7 @@ impl Handler { let Some(session) = self.sessions.get_mut(&node_address) else { warn!( "Dropping message. Error: {}, {}", - Discv5Error::SessionNotEstablished, + Error::SessionNotEstablished, node_address ); return; @@ -1139,7 +1223,7 @@ impl Handler { }; match message { - Message::Response(response) => self.handle_response::

(node_address, response).await, + Message::Response(response) => self.handle_response(node_address, response).await, Message::RelayInitNotification(notification) => { let initiator_node_id = notification.initiator_enr().node_id(); if initiator_node_id != node_address.node_id { @@ -1183,7 +1267,8 @@ impl Handler { } /// Handle a standard message that does not contain an authentication header. - async fn handle_message( + #[allow(clippy::single_match)] + async fn handle_message( &mut self, node_address: NodeAddress, message_nonce: MessageNonce, @@ -1250,7 +1335,7 @@ impl Handler { Message::Response(response) => { // Accept response in Message packet for backwards compatibility warn!("Received a response in a `Message` packet, should be sent in a `SessionMessage`"); - self.handle_response::

(node_address, response).await + self.handle_response(node_address, response).await } Message::RelayInitNotification(_) | Message::RelayMsgNotification(_) => { warn!( @@ -1283,18 +1368,14 @@ impl Handler { /// Handles a response to a request. Re-inserts the request call if the response is a multiple /// Nodes response. - async fn handle_response( - &mut self, - node_address: NodeAddress, - response: Response, - ) { + async fn handle_response(&mut self, node_address: NodeAddress, response: Response) { // Sessions could be awaiting an ENR response. Check if this response matches // this // check if we have an available session let Some(session) = self.sessions.get_mut(&node_address) else { warn!( "Dropping response. Error: {}, {}", - Discv5Error::SessionNotEstablished, + Error::SessionNotEstablished, node_address ); return; @@ -1332,22 +1413,11 @@ impl Handler { // Handle standard responses // Find a matching request, if any - if let Some(mut request_call) = self.active_requests.remove(&node_address) { - let id = match request_call.id() { - HandlerReqId::Internal(id) | HandlerReqId::External(id) => id, - }; - if id != &response.id { - trace!( - "Received an RPC Response to an unknown request. Likely late response. {}", - node_address - ); - // add the request back and reset the timer - self.active_requests.insert(node_address, request_call); - return; - } - + if let Some(mut request_call) = self + .active_requests + .remove_request(&node_address, &response.id) + { // The response matches a request - // Check to see if this is a Nodes response, in which case we may require to wait for // extra responses if let ResponseBody::Nodes { total, ref nodes } = response.body { @@ -1415,7 +1485,6 @@ impl Handler { { warn!("Failed to inform of response {}", e) } - self.send_next_request::

(node_address).await; } else { // This is likely a late response and we have already failed the request. These get // dropped here. @@ -1431,21 +1500,33 @@ impl Handler { self.active_requests.insert(node_address, request_call); } - /// Updates the session cache for a new session. - fn new_session( + /// Establishes a new session with a peer, or re-establishes an existing session if a + /// new challenge was issued during an ongoing session. + async fn new_session( &mut self, node_address: NodeAddress, session: Session, + // Optional message nonce is required to filter out the request that was used in the + // handshake to re-establish a session, if applicable. + message_nonce: Option, enr_not_reachable: bool, ) { if let Some(current_session) = self.sessions.get_mut(&node_address) { current_session.update(session); + // If a session is re-established, due to a new handshake during an ongoing + // session, we need to replay any active requests from the prior session, excluding + // the request that was used to re-establish the session handshake. + self.replay_active_requests::

(&node_address, message_nonce) + .await; } else { self.sessions - .insert_raw(node_address, session, enr_not_reachable); + .insert_raw(node_address.clone(), session, enr_not_reachable); METRICS .active_sessions .store(self.sessions.len(), Ordering::Relaxed); + // We could have pending messages that were awaiting this session to be + // established. If so process them. + self.send_pending_requests::

(&node_address).await; } } @@ -1498,7 +1579,7 @@ impl Handler { .await; } - /// Removes a session and updates associated metrics and fields. + /// Removes a session, fails all of that session's active & pending requests, and updates associated metrics and fields. async fn fail_session( &mut self, node_address: &NodeAddress, @@ -1513,6 +1594,7 @@ impl Handler { // stop keeping hole punched for peer self.nat.untrack(&node_address.socket_addr); } + // fail all pending requests if let Some(to_remove) = self.pending_requests.remove(node_address) { for PendingRequest { request_id, .. } in to_remove { match request_id { @@ -1531,6 +1613,28 @@ impl Handler { } } } + // fail all active requests + for req in self + .active_requests + .remove_requests(node_address) + .unwrap_or_default() + { + match req.id() { + HandlerReqId::Internal(_) => { + // Do not report failures on requests belonging to the handler. + } + HandlerReqId::External(id) => { + if let Err(e) = self + .service_send + .send(HandlerOut::RequestFailed(id.clone(), error.clone())) + .await + { + warn!("Failed to inform request failure {e}") + } + } + } + self.remove_expected_response(node_address.socket_addr); + } } /// Assembles and sends a [`Packet`]. @@ -1563,6 +1667,21 @@ impl Handler { .retain(|_, time| time.is_none() || Some(Instant::now()) < *time); } + /// Returns whether a session with this node does not exist and a request that initiates + /// a session has been sent. + fn is_awaiting_session_to_be_established(&mut self, node_address: &NodeAddress) -> bool { + if self.sessions.get(node_address).is_some() { + // session exists + return false; + } + + if let Some(requests) = self.active_requests.get(node_address) { + requests.iter().any(|req| req.initiating_session()) + } else { + false + } + } + async fn new_connection( &mut self, enr: Enr, @@ -1613,7 +1732,7 @@ impl Handler { ) -> Result<(), NatError> { // Another hole punch process with this target may have just completed. if self.sessions.get(&target_node_address).is_some() { - return Err(NatError::Initiator(Discv5Error::SessionAlreadyEstablished( + return Err(NatError::Initiator(Error::SessionAlreadyEstablished( target_node_address, ))); } @@ -1735,7 +1854,7 @@ impl Handler { // time out of the udp entrypoint for the target peer in the initiator's NAT, set by // the original timed out FINDNODE request from the initiator, as the initiator may // also be behind a NAT. - Err(NatError::Relay(Discv5Error::SessionNotEstablished)) + Err(NatError::Relay(Error::SessionNotEstablished)) } } diff --git a/src/handler/nat.rs b/src/handler/nat.rs index b9e859f65..0b552b713 100644 --- a/src/handler/nat.rs +++ b/src/handler/nat.rs @@ -1,5 +1,6 @@ use std::{ net::{IpAddr, SocketAddr, UdpSocket}, + num::NonZeroUsize, ops::RangeInclusive, time::Duration, }; @@ -49,7 +50,7 @@ impl Nat { ip_mode: IpMode, unused_port_range: Option>, ban_duration: Option, - session_cache_capacity: usize, + session_cache_capacity: NonZeroUsize, unreachable_enr_limit: Option, ) -> Self { let mut nat = Nat { @@ -95,7 +96,7 @@ impl Nat { } /// Called when a new observed address is reported at start up or after a - /// [`crate::Discv5Event::SocketUpdated`]. + /// [`crate::Event::SocketUpdated`]. pub fn set_is_behind_nat( &mut self, listen_sockets: &[SocketAddr], diff --git a/src/handler/session.rs b/src/handler/session.rs index d0f53f305..0a14f371c 100644 --- a/src/handler/session.rs +++ b/src/handler/session.rs @@ -7,7 +7,7 @@ use crate::{ MESSAGE_NONCE_LENGTH, }, rpc::RequestId, - Discv5Error, Enr, + Enr, Error, }; use enr::{CombinedKey, NodeId}; @@ -76,7 +76,7 @@ impl Session { &mut self, src_id: NodeId, message: &[u8], - ) -> Result { + ) -> Result { self.encrypt::

(message, PacketKind::SessionMessage { src_id }) } @@ -85,7 +85,7 @@ impl Session { &mut self, src_id: NodeId, message: &[u8], - ) -> Result { + ) -> Result { self.encrypt::

(message, PacketKind::Message { src_id }) } @@ -95,7 +95,7 @@ impl Session { &mut self, message: &[u8], packet_kind: PacketKind, - ) -> Result { + ) -> Result { self.counter += 1; let random_nonce: [u8; MESSAGE_NONCE_LENGTH - 4] = rand::random(); @@ -136,7 +136,7 @@ impl Session { message_nonce: MessageNonce, message: &[u8], aad: &[u8], - ) -> Result, Discv5Error> { + ) -> Result, Error> { // First try with the canonical keys. let result_canon = crypto::decrypt_message(&self.keys.decryption_key, message_nonce, message, aad); @@ -173,7 +173,7 @@ impl Session { id_nonce_sig: &[u8], ephem_pubkey: &[u8], session_enr: Enr, - ) -> Result<(Session, Enr), Discv5Error> { + ) -> Result<(Session, Enr), Error> { // verify the auth header nonce if !crypto::verify_authentication_nonce( &session_enr.public_key(), @@ -186,7 +186,7 @@ impl Session { data: challenge_data, remote_enr: Some(session_enr), }; - return Err(Discv5Error::InvalidChallengeSignature(challenge)); + return Err(Error::InvalidChallengeSignature(challenge)); } // The keys are derived after the message has been verified to prevent potential extra work @@ -217,7 +217,7 @@ impl Session { local_node_id: &NodeId, challenge_data: &ChallengeData, message: &[u8], - ) -> Result<(Packet, Session), Discv5Error> { + ) -> Result<(Packet, Session), Error> { // generate the session keys let (encryption_key, decryption_key, ephem_pubkey) = crypto::generate_session_keys(local_node_id, remote_contact, challenge_data)?; @@ -234,7 +234,7 @@ impl Session { &ephem_pubkey, &remote_contact.node_id(), ) - .map_err(|_| Discv5Error::Custom("Could not sign WHOAREYOU nonce"))?; + .map_err(|_| Error::Custom("Could not sign WHOAREYOU nonce"))?; // build an authentication packet let message_nonce: MessageNonce = rand::random(); diff --git a/src/handler/tests.rs b/src/handler/tests.rs index c2358c394..b45614244 100644 --- a/src/handler/tests.rs +++ b/src/handler/tests.rs @@ -6,13 +6,18 @@ use crate::{ packet::{DefaultProtocolId, PacketHeader, MAX_PACKET_SIZE, MESSAGE_NONCE_LENGTH}, return_if_ipv6_is_not_supported, rpc::{Request, Response}, - Discv5ConfigBuilder, IpMode, + ConfigBuilder, IpMode, +}; +use std::{ + collections::HashSet, + convert::TryInto, + net::{Ipv4Addr, Ipv6Addr}, + num::NonZeroU16, + ops::Add, }; -use std::net::{Ipv4Addr, Ipv6Addr}; use crate::{handler::HandlerOut::RequestFailed, RequestError::SelfRequest}; use active_requests::ActiveRequests; -use enr::EnrBuilder; use std::time::Duration; use tokio::{net::UdpSocket, time::sleep}; @@ -22,31 +27,18 @@ fn init() { .try_init(); } -struct MockService { - tx: mpsc::UnboundedSender, - rx: mpsc::Receiver, - exit_tx: oneshot::Sender<()>, -} - -async fn build_handler() -> (Handler, MockService) { - build_handler_with_listen_config::

(ListenConfig::default()).await -} - -async fn build_handler_with_listen_config( - listen_config: ListenConfig, -) -> (Handler, MockService) { - let listen_port = listen_config - .ipv4_port() - .expect("listen config should default to ipv4"); - let config = Discv5ConfigBuilder::new(listen_config).build(); - let key = CombinedKey::generate_secp256k1(); - let enr = EnrBuilder::new("v4") - .ip4(Ipv4Addr::LOCALHOST) - .udp4(listen_port) - .build(&key) - .unwrap(); +async fn build_handler( + enr: Enr, + key: CombinedKey, + config: Config, +) -> ( + oneshot::Sender<()>, + mpsc::UnboundedSender, + mpsc::Receiver, + Handler, +) { let mut listen_sockets = SmallVec::default(); - listen_sockets.push((Ipv4Addr::LOCALHOST, listen_port).into()); + listen_sockets.push((Ipv4Addr::LOCALHOST, 9000).into()); let node_id = enr.node_id(); let filter_expected_responses = Arc::new(RwLock::new(HashMap::new())); @@ -71,9 +63,9 @@ async fn build_handler_with_listen_config( Socket::new::

(socket_config).await.unwrap() }; - let (handler_sender, service_recv) = mpsc::unbounded_channel(); + let (handler_send, service_recv) = mpsc::unbounded_channel(); let (service_send, handler_recv) = mpsc::channel(50); - let (exit_tx, exit) = oneshot::channel(); + let (exit_sender, exit) = oneshot::channel(); let nat = Nat::new( &listen_sockets, @@ -85,37 +77,31 @@ async fn build_handler_with_listen_config( None, ); - ( - Handler { - request_retries: config.request_retries, - node_id, - enr: Arc::new(RwLock::new(enr)), - key: Arc::new(RwLock::new(key)), - active_requests: ActiveRequests::new(config.request_timeout), - pending_requests: HashMap::new(), - filter_expected_responses, - sessions: LruTimeCache::new( - config.session_timeout, - Some(config.session_cache_capacity), - ), - one_time_sessions: LruTimeCache::new( - Duration::from_secs(ONE_TIME_SESSION_TIMEOUT), - Some(ONE_TIME_SESSION_CACHE_CAPACITY), - ), - active_challenges: HashMapDelay::new(config.request_timeout), - service_recv, - service_send, - listen_sockets, - socket, - nat, - exit, - }, - MockService { - tx: handler_sender, - rx: handler_recv, - exit_tx, - }, - ) + let handler = Handler { + request_retries: config.request_retries, + node_id, + enr: Arc::new(RwLock::new(enr)), + key: Arc::new(RwLock::new(key)), + active_requests: ActiveRequests::new(config.request_timeout), + pending_requests: HashMap::new(), + filter_expected_responses, + sessions: LruTimeCache::new( + config.session_timeout, + Some(config.session_cache_capacity.get()), + ), + one_time_sessions: LruTimeCache::new( + Duration::from_secs(ONE_TIME_SESSION_TIMEOUT), + Some(ONE_TIME_SESSION_CACHE_CAPACITY), + ), + active_challenges: HashMapDelay::new(config.request_timeout), + service_recv, + service_send, + listen_sockets, + socket, + nat, + exit, + }; + (exit_sender, handler_send, handler_recv, handler) } macro_rules! arc_rw { @@ -136,12 +122,12 @@ async fn simple_session_message() { let key1 = CombinedKey::generate_secp256k1(); let key2 = CombinedKey::generate_secp256k1(); - let sender_enr = EnrBuilder::new("v4") + let sender_enr = Enr::builder() .ip4(ip) .udp4(sender_port) .build(&key1) .unwrap(); - let receiver_enr = EnrBuilder::new("v4") + let receiver_enr = Enr::builder() .ip4(ip) .udp4(receiver_port) .build(&key2) @@ -151,7 +137,7 @@ async fn simple_session_message() { ip: sender_enr.ip4().unwrap(), port: sender_enr.udp4().unwrap(), }; - let sender_config = Discv5ConfigBuilder::new(sender_listen_config) + let sender_config = ConfigBuilder::new(sender_listen_config) .enable_packet_filter() .build(); let (_exit_send, sender_send, _sender_recv) = Handler::spawn::( @@ -166,7 +152,7 @@ async fn simple_session_message() { ip: receiver_enr.ip4().unwrap(), port: receiver_enr.udp4().unwrap(), }; - let receiver_config = Discv5ConfigBuilder::new(receiver_listen_config) + let receiver_config = ConfigBuilder::new(receiver_listen_config) .enable_packet_filter() .build(); let (_exit_recv, recv_send, mut receiver_recv) = Handler::spawn::( @@ -225,44 +211,55 @@ async fn multiple_messages() { let key1 = CombinedKey::generate_secp256k1(); let key2 = CombinedKey::generate_secp256k1(); - let sender_enr = EnrBuilder::new("v4") + let sender_enr = Enr::builder() .ip4(ip) .udp4(sender_port) .build(&key1) .unwrap(); - let sender_listen_config = ListenConfig::Ipv4 { - ip: sender_enr.ip4().unwrap(), - port: sender_enr.udp4().unwrap(), - }; - let sender_config = Discv5ConfigBuilder::new(sender_listen_config).build(); - let receiver_enr = EnrBuilder::new("v4") + let receiver_enr = Enr::builder() .ip4(ip) .udp4(receiver_port) .build(&key2) .unwrap(); - let receiver_listen_config = ListenConfig::Ipv4 { - ip: receiver_enr.ip4().unwrap(), - port: receiver_enr.udp4().unwrap(), - }; - let receiver_config = Discv5ConfigBuilder::new(receiver_listen_config).build(); - let (_exit_send, sender_handler, mut sender_handler_recv) = - Handler::spawn::( - arc_rw!(sender_enr.clone()), - arc_rw!(key1), - sender_config, - ) - .await - .unwrap(); + // Build sender handler + let (sender_exit, sender_send, mut sender_recv, mut handler) = { + let sender_listen_config = ListenConfig::Ipv4 { + ip: sender_enr.ip4().unwrap(), + port: sender_enr.udp4().unwrap(), + }; + let sender_config = ConfigBuilder::new(sender_listen_config).build(); + build_handler::(sender_enr.clone(), key1, sender_config).await + }; + let sender = async move { + // Start sender handler. + handler.start::().await; + // After the handler has been terminated test the handler's states. + assert!(handler.pending_requests.is_empty()); + assert_eq!(0, handler.active_requests.count().await); + assert!(handler.active_challenges.is_empty()); + assert!(handler.filter_expected_responses.read().is_empty()); + }; - let (_exit_recv, recv_send, mut receiver_handler) = Handler::spawn::( - arc_rw!(receiver_enr.clone()), - arc_rw!(key2), - receiver_config, - ) - .await - .unwrap(); + // Build receiver handler + let (receiver_exit, receiver_send, mut receiver_recv, mut handler) = { + let receiver_listen_config = ListenConfig::Ipv4 { + ip: receiver_enr.ip4().unwrap(), + port: receiver_enr.udp4().unwrap(), + }; + let receiver_config = ConfigBuilder::new(receiver_listen_config).build(); + build_handler::(receiver_enr.clone(), key2, receiver_config).await + }; + let receiver = async move { + // Start receiver handler. + handler.start::().await; + // After the handler has been terminated test the handler's states. + assert!(handler.pending_requests.is_empty()); + assert_eq!(0, handler.active_requests.count().await); + assert!(handler.active_challenges.is_empty()); + assert!(handler.filter_expected_responses.read().is_empty()); + }; let send_message = Box::new(Request { id: RequestId(vec![1]), @@ -270,7 +267,7 @@ async fn multiple_messages() { }); // sender to send the first message then await for the session to be established - let _ = sender_handler.send(HandlerIn::Request( + let _ = sender_send.send(HandlerIn::Request( receiver_enr.clone().into(), send_message.clone(), )); @@ -280,7 +277,7 @@ async fn multiple_messages() { body: ResponseBody::Pong { enr_seq: 1, ip: ip.into(), - port: sender_port, + port: sender_port.try_into().unwrap(), }, }; @@ -289,28 +286,38 @@ async fn multiple_messages() { let mut message_count = 0usize; let recv_send_message = send_message.clone(); - let sender = async move { + let sender_ops = async move { + let mut response_count = 0usize; loop { - match sender_handler_recv.recv().await { + match sender_recv.recv().await { Some(HandlerOut::Established(_, _, _)) => { // now the session is established, send the rest of the messages for _ in 0..messages_to_send - 1 { - let _ = sender_handler.send(HandlerIn::Request( + let _ = sender_send.send(HandlerIn::Request( receiver_enr.clone().into(), send_message.clone(), )); } } + Some(HandlerOut::Response(_, _)) => { + response_count += 1; + if response_count == messages_to_send { + // Notify the handlers that the message exchange has been completed. + sender_exit.send(()).unwrap(); + receiver_exit.send(()).unwrap(); + return; + } + } _ => continue, }; } }; - let receiver = async move { + let receiver_ops = async move { loop { - match receiver_handler.recv().await { + match receiver_recv.recv().await { Some(HandlerOut::RequestEnr(EnrRequestData::WhoAreYou(wru_ref))) => { - let _ = recv_send.send(HandlerIn::EnrResponse( + let _ = receiver_send.send(HandlerIn::EnrResponse( Some(sender_enr.clone()), EnrRequestData::WhoAreYou(wru_ref), )); @@ -319,8 +326,8 @@ async fn multiple_messages() { assert_eq!(request, recv_send_message); message_count += 1; // required to send a pong response to establish the session - let _ = - recv_send.send(HandlerIn::Response(addr, Box::new(pong_response.clone()))); + let _ = receiver_send + .send(HandlerIn::Response(addr, Box::new(pong_response.clone()))); if message_count == messages_to_send { return; } @@ -333,49 +340,182 @@ async fn multiple_messages() { }; let sleep_future = sleep(Duration::from_millis(100)); + let message_exchange = async move { + let _ = tokio::join!(sender, sender_ops, receiver, receiver_ops); + }; tokio::select! { - _ = sender => {} - _ = receiver => {} + _ = message_exchange => {} _ = sleep_future => { panic!("Test timed out"); } } } +fn create_node() -> Enr { + let key = CombinedKey::generate_secp256k1(); + let ip = "127.0.0.1".parse().unwrap(); + let port = 8080 + rand::random::() % 1000; + Enr::builder().ip4(ip).udp4(port).build(&key).unwrap() +} + +fn create_req_call(node: &Enr) -> (RequestCall, NodeAddress) { + let node_contact: NodeContact = node.clone().into(); + let packet = Packet::new_random(&node.node_id()).unwrap(); + let id = HandlerReqId::Internal(RequestId::random()); + let request = RequestBody::Ping { enr_seq: 1 }; + let initiating_session = true; + let node_addr = node_contact.node_address(); + let req = RequestCall::new(node_contact, packet, id, request, initiating_session); + (req, node_addr) +} + #[tokio::test] async fn test_active_requests_insert() { const EXPIRY: Duration = Duration::from_secs(5); let mut active_requests = ActiveRequests::new(EXPIRY); - // Create the test values needed - let port = 5000; - let ip = "127.0.0.1".parse().unwrap(); + let node_1 = create_node(); + let node_2 = create_node(); + let (req_1, req_1_addr) = create_req_call(&node_1); + let (req_2, req_2_addr) = create_req_call(&node_2); + let (req_3, req_3_addr) = create_req_call(&node_2); - let key = CombinedKey::generate_secp256k1(); + // insert the pair and verify the mapping remains in sync + active_requests.insert(req_1_addr, req_1); + active_requests.check_invariant(); + active_requests.insert(req_2_addr, req_2); + active_requests.check_invariant(); + active_requests.insert(req_3_addr, req_3); + active_requests.check_invariant(); +} - let enr = EnrBuilder::new("v4") - .ip4(ip) - .udp4(port) - .build(&key) - .unwrap(); - let node_id = enr.node_id(); +#[tokio::test] +async fn test_active_requests_remove_requests() { + const EXPIRY: Duration = Duration::from_secs(5); + let mut active_requests = ActiveRequests::new(EXPIRY); - let contact: NodeContact = enr.into(); - let node_address = contact.node_address(); + let node_1 = create_node(); + let node_2 = create_node(); + let (req_1, req_1_addr) = create_req_call(&node_1); + let (req_2, req_2_addr) = create_req_call(&node_2); + let (req_3, req_3_addr) = create_req_call(&node_2); + active_requests.insert(req_1_addr.clone(), req_1); + active_requests.insert(req_2_addr.clone(), req_2); + active_requests.insert(req_3_addr.clone(), req_3); + active_requests.check_invariant(); + let reqs = active_requests.remove_requests(&req_1_addr).unwrap(); + assert_eq!(reqs.len(), 1); + active_requests.check_invariant(); + let reqs = active_requests.remove_requests(&req_2_addr).unwrap(); + assert_eq!(reqs.len(), 2); + active_requests.check_invariant(); + assert!(active_requests.remove_requests(&req_3_addr).is_none()); +} - let packet = Packet::new_random(&node_id).unwrap(); - let id = HandlerReqId::Internal(RequestId::random()); - let request = RequestBody::Ping { enr_seq: 1 }; - let initiating_session = true; - let request_call = RequestCall::new(contact, packet, id, request, initiating_session); +#[tokio::test] +async fn test_active_requests_remove_request() { + const EXPIRY: Duration = Duration::from_secs(5); + let mut active_requests = ActiveRequests::new(EXPIRY); - // insert the pair and verify the mapping remains in sync - let nonce = *request_call.packet().message_nonce(); - active_requests.insert(node_address, request_call); + let node_1 = create_node(); + let node_2 = create_node(); + let (req_1, req_1_addr) = create_req_call(&node_1); + let (req_2, req_2_addr) = create_req_call(&node_2); + let (req_3, req_3_addr) = create_req_call(&node_2); + let req_1_id = req_1.id().into(); + let req_2_id = req_2.id().into(); + let req_3_id = req_3.id().into(); + + active_requests.insert(req_1_addr.clone(), req_1); + active_requests.insert(req_2_addr.clone(), req_2); + active_requests.insert(req_3_addr.clone(), req_3); + active_requests.check_invariant(); + let req_id: RequestId = active_requests + .remove_request(&req_1_addr, &req_1_id) + .unwrap() + .id() + .into(); + assert_eq!(req_id, req_1_id); + active_requests.check_invariant(); + let req_id: RequestId = active_requests + .remove_request(&req_2_addr, &req_2_id) + .unwrap() + .id() + .into(); + assert_eq!(req_id, req_2_id); + active_requests.check_invariant(); + let req_id: RequestId = active_requests + .remove_request(&req_3_addr, &req_3_id) + .unwrap() + .id() + .into(); + assert_eq!(req_id, req_3_id); + active_requests.check_invariant(); + assert!(active_requests + .remove_request(&req_3_addr, &req_3_id) + .is_none()); +} + +#[tokio::test] +async fn test_active_requests_remove_by_nonce() { + const EXPIRY: Duration = Duration::from_secs(5); + let mut active_requests = ActiveRequests::new(EXPIRY); + + let node_1 = create_node(); + let node_2 = create_node(); + let (req_1, req_1_addr) = create_req_call(&node_1); + let (req_2, req_2_addr) = create_req_call(&node_2); + let (req_3, req_3_addr) = create_req_call(&node_2); + let req_1_nonce = *req_1.packet().message_nonce(); + let req_2_nonce = *req_2.packet().message_nonce(); + let req_3_nonce = *req_3.packet().message_nonce(); + + active_requests.insert(req_1_addr.clone(), req_1); + active_requests.insert(req_2_addr.clone(), req_2); + active_requests.insert(req_3_addr.clone(), req_3); + active_requests.check_invariant(); + + let req = active_requests.remove_by_nonce(&req_1_nonce).unwrap(); + assert_eq!(req.0, req_1_addr); + active_requests.check_invariant(); + let req = active_requests.remove_by_nonce(&req_2_nonce).unwrap(); + assert_eq!(req.0, req_2_addr); active_requests.check_invariant(); - active_requests.remove_by_nonce(&nonce); + let req = active_requests.remove_by_nonce(&req_3_nonce).unwrap(); + assert_eq!(req.0, req_3_addr); active_requests.check_invariant(); + let random_nonce = rand::random(); + assert!(active_requests.remove_by_nonce(&random_nonce).is_none()); +} + +#[tokio::test] +async fn test_active_requests_update_packet() { + const EXPIRY: Duration = Duration::from_secs(5); + let mut active_requests = ActiveRequests::new(EXPIRY); + + let node_1 = create_node(); + let node_2 = create_node(); + let (req_1, req_1_addr) = create_req_call(&node_1); + let (req_2, req_2_addr) = create_req_call(&node_2); + let (req_3, req_3_addr) = create_req_call(&node_2); + + let old_nonce = *req_2.packet().message_nonce(); + active_requests.insert(req_1_addr, req_1); + active_requests.insert(req_2_addr.clone(), req_2); + active_requests.insert(req_3_addr, req_3); + active_requests.check_invariant(); + + let new_packet = Packet::new_random(&node_2.node_id()).unwrap(); + let new_nonce = new_packet.message_nonce(); + active_requests.update_packet(old_nonce, new_packet.clone()); + active_requests.check_invariant(); + + assert_eq!(2, active_requests.get(&req_2_addr).unwrap().len()); + assert!(active_requests.remove_by_nonce(&old_nonce).is_none()); + let (addr, req) = active_requests.remove_by_nonce(new_nonce).unwrap(); + assert_eq!(addr, req_2_addr); + assert_eq!(req.packet(), &new_packet); } #[tokio::test] @@ -383,7 +523,7 @@ async fn test_self_request_ipv4() { init(); let key = CombinedKey::generate_secp256k1(); - let enr = EnrBuilder::new("v4") + let enr = Enr::builder() .ip4(Ipv4Addr::LOCALHOST) .udp4(5004) .build(&key) @@ -392,7 +532,7 @@ async fn test_self_request_ipv4() { ip: enr.ip4().unwrap(), port: enr.udp4().unwrap(), }; - let config = Discv5ConfigBuilder::new(listen_config) + let config = ConfigBuilder::new(listen_config) .enable_packet_filter() .build(); @@ -423,7 +563,7 @@ async fn test_self_request_ipv6() { init(); let key = CombinedKey::generate_secp256k1(); - let enr = EnrBuilder::new("v4") + let enr = Enr::builder() .ip6(Ipv6Addr::LOCALHOST) .udp6(5005) .build(&key) @@ -432,7 +572,7 @@ async fn test_self_request_ipv6() { ip: enr.ip6().unwrap(), port: enr.udp6().unwrap(), }; - let config = Discv5ConfigBuilder::new(listen_config) + let config = ConfigBuilder::new(listen_config) .enable_packet_filter() .build(); @@ -458,11 +598,18 @@ async fn test_self_request_ipv6() { #[tokio::test] async fn remove_one_time_session() { - let (mut handler, _) = build_handler::().await; + let config = ConfigBuilder::new(ListenConfig::default()).build(); + let key = CombinedKey::generate_secp256k1(); + let enr = Enr::builder() + .ip4(Ipv4Addr::LOCALHOST) + .udp4(9000) + .build(&key) + .unwrap(); + let (_, _, _, mut handler) = build_handler::(enr, key, config).await; let enr = { let key = CombinedKey::generate_secp256k1(); - EnrBuilder::new("v4") + Enr::builder() .ip4(Ipv4Addr::LOCALHOST) .udp4(9000) .build(&key) @@ -493,21 +640,450 @@ async fn remove_one_time_session() { assert_eq!(0, handler.one_time_sessions.len()); } +// Tests replaying active requests. +// +// In this test, Receiver's session expires and Receiver returns WHOAREYOU. +// Sender then creates a new session and resend active requests. +// +// ```mermaid +// sequenceDiagram +// participant Sender +// participant Receiver +// Note over Sender: Start discv5 server +// Note over Receiver: Start discv5 server +// +// Note over Sender,Receiver: Session established +// +// rect rgb(100, 100, 0) +// Note over Receiver: ** Session expired ** +// end +// +// rect rgb(10, 10, 10) +// Note left of Sender: Sender sends requests
**in parallel**. +// par +// Sender ->> Receiver: PING(id:2) +// and +// Sender -->> Receiver: PING(id:3) +// and +// Sender -->> Receiver: PING(id:4) +// and +// Sender -->> Receiver: PING(id:5) +// end +// end +// +// Note over Receiver: Send WHOAREYOU
since the session has been expired +// Receiver ->> Sender: WHOAREYOU +// +// rect rgb(100, 100, 0) +// Note over Receiver: Drop PING(id:2,3,4,5) request
since WHOAREYOU already sent. +// end +// +// Note over Sender: New session established with Receiver +// +// Sender ->> Receiver: Handshake message (id:2) +// +// Note over Receiver: New session established with Sender +// +// rect rgb(10, 10, 10) +// Note left of Sender: Handler::replay_active_requests() +// Sender ->> Receiver: PING (id:3) +// Sender ->> Receiver: PING (id:4) +// Sender ->> Receiver: PING (id:5) +// end +// +// Receiver ->> Sender: PONG (id:2) +// Receiver ->> Sender: PONG (id:3) +// Receiver ->> Sender: PONG (id:4) +// Receiver ->> Sender: PONG (id:5) +// ``` +#[tokio::test] +async fn test_replay_active_requests() { + init(); + let sender_port = 5006; + let receiver_port = 5007; + let ip = "127.0.0.1".parse().unwrap(); + let key1 = CombinedKey::generate_secp256k1(); + let key2 = CombinedKey::generate_secp256k1(); + + let sender_enr = Enr::builder() + .ip4(ip) + .udp4(sender_port) + .build(&key1) + .unwrap(); + + let receiver_enr = Enr::builder() + .ip4(ip) + .udp4(receiver_port) + .build(&key2) + .unwrap(); + + // Build sender handler + let (sender_exit, sender_send, mut sender_recv, mut handler) = { + let sender_listen_config = ListenConfig::Ipv4 { + ip: sender_enr.ip4().unwrap(), + port: sender_enr.udp4().unwrap(), + }; + let sender_config = ConfigBuilder::new(sender_listen_config).build(); + build_handler::(sender_enr.clone(), key1, sender_config).await + }; + let sender = async move { + // Start sender handler. + handler.start::().await; + // After the handler has been terminated test the handler's states. + assert!(handler.pending_requests.is_empty()); + assert_eq!(0, handler.active_requests.count().await); + assert!(handler.active_challenges.is_empty()); + assert!(handler.filter_expected_responses.read().is_empty()); + }; + + // Build receiver handler + // Shorten receiver's timeout to reproduce session expired. + let receiver_session_timeout = Duration::from_secs(1); + let (receiver_exit, receiver_send, mut receiver_recv, mut handler) = { + let receiver_listen_config = ListenConfig::Ipv4 { + ip: receiver_enr.ip4().unwrap(), + port: receiver_enr.udp4().unwrap(), + }; + let receiver_config = ConfigBuilder::new(receiver_listen_config) + .session_timeout(receiver_session_timeout) + .build(); + build_handler::(receiver_enr.clone(), key2, receiver_config).await + }; + let receiver = async move { + // Start receiver handler. + handler.start::().await; + // After the handler has been terminated test the handler's states. + assert!(handler.pending_requests.is_empty()); + assert_eq!(0, handler.active_requests.count().await); + assert!(handler.active_challenges.is_empty()); + assert!(handler.filter_expected_responses.read().is_empty()); + }; + + let messages_to_send = 5usize; + + let sender_ops = async move { + let mut response_count = 0usize; + let mut expected_request_ids = HashSet::new(); + expected_request_ids.insert(RequestId(vec![1])); + + // sender to send the first message then await for the session to be established + let _ = sender_send.send(HandlerIn::Request( + receiver_enr.clone().into(), + Box::new(Request { + id: RequestId(vec![1]), + body: RequestBody::Ping { enr_seq: 1 }, + }), + )); + + match sender_recv.recv().await { + Some(HandlerOut::Established(_, _, _)) => { + // Sleep until receiver's session expired. + tokio::time::sleep(receiver_session_timeout.add(Duration::from_millis(500))).await; + // send the rest of the messages + for req_id in 2..=messages_to_send { + let request_id = RequestId(vec![req_id as u8]); + expected_request_ids.insert(request_id.clone()); + let _ = sender_send.send(HandlerIn::Request( + receiver_enr.clone().into(), + Box::new(Request { + id: request_id, + body: RequestBody::Ping { enr_seq: 1 }, + }), + )); + } + } + handler_out => panic!("Unexpected message: {:?}", handler_out), + } + + loop { + match sender_recv.recv().await { + Some(HandlerOut::Response(_, response)) => { + assert!(expected_request_ids.remove(&response.id)); + response_count += 1; + if response_count == messages_to_send { + // Notify the handlers that the message exchange has been completed. + assert!(expected_request_ids.is_empty()); + sender_exit.send(()).unwrap(); + receiver_exit.send(()).unwrap(); + return; + } + } + _ => continue, + }; + } + }; + + let receiver_ops = async move { + let mut message_count = 0usize; + loop { + match receiver_recv.recv().await { + Some(HandlerOut::RequestEnr(enr_request_data)) => { + receiver_send + .send(HandlerIn::EnrResponse( + Some(sender_enr.clone()), + enr_request_data, + )) + .unwrap(); + } + Some(HandlerOut::Request(addr, request)) => { + assert!(matches!(request.body, RequestBody::Ping { .. })); + let pong_response = Response { + id: request.id, + body: ResponseBody::Pong { + enr_seq: 1, + ip: ip.into(), + port: NonZeroU16::new(sender_port).unwrap(), + }, + }; + receiver_send + .send(HandlerIn::Response(addr, Box::new(pong_response))) + .unwrap(); + message_count += 1; + if message_count == messages_to_send { + return; + } + } + _ => { + continue; + } + } + } + }; + + let sleep_future = sleep(Duration::from_secs(5)); + let message_exchange = async move { + let _ = tokio::join!(sender, sender_ops, receiver, receiver_ops); + }; + + tokio::select! { + _ = message_exchange => {} + _ = sleep_future => { + panic!("Test timed out"); + } + } +} + +// Tests sending pending requests. +// +// Sender attempts to send multiple requests in parallel, but due to the absence of a session, only +// one of the requests from Sender is sent and others are inserted into `pending_requests`. +// The pending requests are sent once a session is established. +// +// ```mermaid +// sequenceDiagram +// participant Sender +// participant Receiver +// +// Note over Sender: No session with Receiver +// +// rect rgb(10, 10, 10) +// Note left of Sender: Sender attempts to send multiple requests in parallel
but no session with Receiver.
So Sender sends a random packet for the first request,
and the rest of the requests are inserted into pending_requests. +// par +// Sender ->> Receiver: Random packet (id:1) +// Note over Sender: Insert the request into `active_requests` +// and +// Note over Sender: Insert Request(id:2) into *pending_requests* +// and +// Note over Sender: Insert Request(id:3) into *pending_requests* +// end +// end +// +// Receiver ->> Sender: WHOAREYOU (id:1) +// +// Note over Sender: New session established with Receiver +// +// rect rgb(0, 100, 0) +// Note over Sender: Send pending requests since a session has been established. +// Sender ->> Receiver: Request (id:2) +// Sender ->> Receiver: Request (id:3) +// end +// +// Sender ->> Receiver: Handshake message (id:1) +// +// Note over Receiver: New session established with Sender +// +// Receiver ->> Sender: Response (id:2) +// Receiver ->> Sender: Response (id:3) +// Receiver ->> Sender: Response (id:1) +// +// Note over Sender: The request (id:2) completed. +// Note over Sender: The request (id:3) completed. +// Note over Sender: The request (id:1) completed. +// ``` +#[tokio::test] +async fn test_send_pending_request() { + init(); + let sender_port = 5008; + let receiver_port = 5009; + let ip = "127.0.0.1".parse().unwrap(); + let key1 = CombinedKey::generate_secp256k1(); + let key2 = CombinedKey::generate_secp256k1(); + + let sender_enr = Enr::builder() + .ip4(ip) + .udp4(sender_port) + .build(&key1) + .unwrap(); + + let receiver_enr = Enr::builder() + .ip4(ip) + .udp4(receiver_port) + .build(&key2) + .unwrap(); + + // Build sender handler + let (sender_exit, sender_send, mut sender_recv, mut handler) = { + let sender_listen_config = ListenConfig::Ipv4 { + ip: sender_enr.ip4().unwrap(), + port: sender_enr.udp4().unwrap(), + }; + let sender_config = ConfigBuilder::new(sender_listen_config).build(); + build_handler::(sender_enr.clone(), key1, sender_config).await + }; + let sender = async move { + // Start sender handler. + handler.start::().await; + // After the handler has been terminated test the handler's states. + assert!(handler.pending_requests.is_empty()); + assert_eq!(0, handler.active_requests.count().await); + assert!(handler.active_challenges.is_empty()); + assert!(handler.filter_expected_responses.read().is_empty()); + }; + + // Build receiver handler + // Shorten receiver's timeout to reproduce session expired. + let receiver_session_timeout = Duration::from_secs(1); + let (receiver_exit, receiver_send, mut receiver_recv, mut handler) = { + let receiver_listen_config = ListenConfig::Ipv4 { + ip: receiver_enr.ip4().unwrap(), + port: receiver_enr.udp4().unwrap(), + }; + let receiver_config = ConfigBuilder::new(receiver_listen_config) + .session_timeout(receiver_session_timeout) + .build(); + build_handler::(receiver_enr.clone(), key2, receiver_config).await + }; + let receiver = async move { + // Start receiver handler. + handler.start::().await; + // After the handler has been terminated test the handler's states. + assert!(handler.pending_requests.is_empty()); + assert_eq!(0, handler.active_requests.count().await); + assert!(handler.active_challenges.is_empty()); + assert!(handler.filter_expected_responses.read().is_empty()); + }; + + let messages_to_send = 3usize; + + let sender_ops = async move { + let mut response_count = 0usize; + let mut expected_request_ids = HashSet::new(); + + // send requests + for req_id in 1..=messages_to_send { + let request_id = RequestId(vec![req_id as u8]); + expected_request_ids.insert(request_id.clone()); + let _ = sender_send.send(HandlerIn::Request( + receiver_enr.clone().into(), + Box::new(Request { + id: request_id, + body: RequestBody::Ping { enr_seq: 1 }, + }), + )); + } + + loop { + match sender_recv.recv().await { + Some(HandlerOut::Response(_, response)) => { + assert!(expected_request_ids.remove(&response.id)); + response_count += 1; + if response_count == messages_to_send { + // Notify the handlers that the message exchange has been completed. + assert!(expected_request_ids.is_empty()); + sender_exit.send(()).unwrap(); + receiver_exit.send(()).unwrap(); + return; + } + } + _ => continue, + }; + } + }; + + let receiver_ops = async move { + let mut message_count = 0usize; + loop { + match receiver_recv.recv().await { + Some(HandlerOut::RequestEnr(enr_request_data)) => { + receiver_send + .send(HandlerIn::EnrResponse( + Some(sender_enr.clone()), + enr_request_data, + )) + .unwrap(); + } + Some(HandlerOut::Request(addr, request)) => { + assert!(matches!(request.body, RequestBody::Ping { .. })); + let pong_response = Response { + id: request.id, + body: ResponseBody::Pong { + enr_seq: 1, + ip: ip.into(), + port: NonZeroU16::new(sender_port).unwrap(), + }, + }; + receiver_send + .send(HandlerIn::Response(addr, Box::new(pong_response))) + .unwrap(); + message_count += 1; + if message_count == messages_to_send { + return; + } + } + _ => { + continue; + } + } + } + }; + + let sleep_future = sleep(Duration::from_secs(5)); + let message_exchange = async move { + let _ = tokio::join!(sender, sender_ops, receiver, receiver_ops); + }; + + tokio::select! { + _ = message_exchange => {} + _ = sleep_future => { + panic!("Test timed out"); + } + } +} + #[tokio::test(flavor = "multi_thread")] async fn nat_hole_punch_relay() { init(); // Relay - let listen_config = ListenConfig::default().with_ipv4(Ipv4Addr::LOCALHOST, 9901); - let (mut handler, mock_service) = - build_handler_with_listen_config::(listen_config).await; + let (relay_exit, relay_send, mut relay_recv, mut handler) = { + let key = CombinedKey::generate_secp256k1(); + let enr = Enr::builder() + .ip4(Ipv4Addr::LOCALHOST) + .udp4(9901) + .build(&key) + .unwrap(); + let listen_config = + ListenConfig::default().with_ipv4(enr.ip4().unwrap(), enr.udp4().unwrap()); + let config = ConfigBuilder::new(listen_config).build(); + build_handler::(enr, key, config).await + }; let relay_addr = handler.enr.read().udp4_socket().unwrap().into(); let relay_node_id = handler.enr.read().node_id(); // Initiator let inr_enr = { let key = CombinedKey::generate_secp256k1(); - EnrBuilder::new("v4") + Enr::builder() .ip4(Ipv4Addr::LOCALHOST) .udp4(9011) .build(&key) @@ -528,7 +1104,7 @@ async fn nat_hole_punch_relay() { // Target let tgt_enr = { let key = CombinedKey::generate_secp256k1(); - EnrBuilder::new("v4") + Enr::builder() .ip4(Ipv4Addr::LOCALHOST) .udp4(9012) .build(&key) @@ -551,12 +1127,13 @@ async fn nat_hole_punch_relay() { // Relay mock service let tgt_enr_clone = tgt_enr.clone(); - let tx = mock_service.tx; - let mut rx = mock_service.rx; let mock_service_handle = tokio::spawn(async move { - let service_msg = rx.recv().await.expect("should receive service message"); + let service_msg = relay_recv + .recv() + .await + .expect("should receive service message"); match service_msg { - HandlerOut::RequestEnr(EnrRequestData::Nat(relay_init)) => tx + HandlerOut::RequestEnr(EnrRequestData::Nat(relay_init)) => relay_send .send(HandlerIn::EnrResponse( Some(tgt_enr_clone), EnrRequestData::Nat(relay_init), @@ -584,7 +1161,6 @@ async fn nat_hole_punch_relay() { }); // Target handle - let relay_exit = mock_service.exit_tx; let tgt_handle = tokio::spawn(async move { let mut buffer = [0; MAX_PACKET_SIZE]; let res = tgt_socket @@ -644,9 +1220,18 @@ async fn nat_hole_punch_target() { init(); // Target - let listen_config = ListenConfig::default().with_ipv4(Ipv4Addr::LOCALHOST, 9902); - let (mut handler, mock_service) = - build_handler_with_listen_config::(listen_config).await; + let (target_exit, _, _, mut handler) = { + let key = CombinedKey::generate_secp256k1(); + let enr = Enr::builder() + .ip4(Ipv4Addr::LOCALHOST) + .udp4(9902) + .build(&key) + .unwrap(); + let listen_config = + ListenConfig::default().with_ipv4(enr.ip4().unwrap(), enr.udp4().unwrap()); + let config = ConfigBuilder::new(listen_config).build(); + build_handler::(enr, key, config).await + }; let tgt_addr = handler.enr.read().udp4_socket().unwrap().into(); let tgt_node_id = handler.enr.read().node_id(); handler.nat.is_behind_nat = Some(true); @@ -654,7 +1239,7 @@ async fn nat_hole_punch_target() { // Relay let relay_enr = { let key = CombinedKey::generate_secp256k1(); - EnrBuilder::new("v4") + Enr::builder() .ip4(Ipv4Addr::LOCALHOST) .udp4(9022) .build(&key) @@ -675,7 +1260,7 @@ async fn nat_hole_punch_target() { // Initiator let inr_enr = { let key = CombinedKey::generate_secp256k1(); - EnrBuilder::new("v4") + Enr::builder() .ip4(Ipv4Addr::LOCALHOST) .udp4(9021) .build(&key) @@ -709,7 +1294,6 @@ async fn nat_hole_punch_target() { }); // Initiator handle - let target_exit = mock_service.exit_tx; let inr_handle = tokio::spawn(async move { let mut buffer = [0; MAX_PACKET_SIZE]; let res = inr_socket diff --git a/src/ipmode.rs b/src/ipmode.rs index f2dbe48da..bc292d7f9 100644 --- a/src/ipmode.rs +++ b/src/ipmode.rs @@ -136,7 +136,7 @@ mod tests { fn test(&self) { let test_enr = { - let builder = &mut enr::EnrBuilder::new("v4"); + let builder = &mut enr::Enr::builder(); if let Some(ip4) = self.enr_ip4 { builder.ip4(ip4).udp4(IP4_TEST_PORT); } diff --git a/src/lib.rs b/src/lib.rs index 3dde575e8..99b2eb256 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,70 +1,65 @@ -#![warn(rust_2018_idioms)] #![deny(rustdoc::broken_intra_doc_links)] -#![cfg_attr(docsrs, feature(doc_cfg))] -#![allow(clippy::needless_doctest_main)] //! An implementation of [Discovery V5](https://github.com/ethereum/devp2p/blob/master/discv5/discv5.md). //! //! # Overview //! -//! Discovery v5 is a protocol designed for encrypted peer discovery and topic advertisement. Each peer/node -//! on the network is identified via it's ENR ([Ethereum Name +//! Discovery v5 is a protocol designed for encrypted peer discovery and topic advertisement. Each +//! peer/node on the network is identified via it's ENR ([Ethereum Name //! Record](https://eips.ethereum.org/EIPS/eip-778)), which is essentially a signed key-value store //! containing the node's public key and optionally IP address and port. //! -//! Discv5 employs a kademlia-like routing table to store and manage discovered peers (and topics tba). The -//! protocol allows for external IP discovery in NAT environments through regular PING/PONG's with +//! Discv5 employs a kademlia-like routing table to store and manage discovered peers. The protocol +//! allows for external IP discovery in NAT environments through regular PING/PONG's with //! discovered nodes. Nodes return the external IP address that they have received and a simple //! majority is chosen as our external IP address. If an external IP address is updated, this is -//! produced as an event to notify the swarm (if one is used for this behaviour). +//! produced as an event. //! -//! For a simple CLI discovery service see [discv5-cli](https://github.com/AgeManning/discv5-cli) +//! For a simple CLI discovery service see [discv5-cli](https://github.com/AgeManning/discv5-cli) //! -//! This protocol is split into four main sections/layers: +//! This protocol is split into four main layers: //! -//! * Socket - The [`socket`] module is responsible for opening the underlying UDP socket. It -//! creates individual tasks for sending/encoding and receiving/decoding packets from the UDP -//! socket. -//! * Handler - The protocol's communication is encrypted with `AES_GCM`. All node communication -//! undergoes a handshake, which results in a [`Session`]. [`Session`]'s are established when -//! needed and get dropped after a timeout. This section manages the creation and maintenance of -//! sessions between nodes and the encryption/decryption of packets from the socket. It is -//! realised by the [`handler::Handler`] struct and it runs in its own task. -//! * Service - This section contains the protocol-level logic. In particular it manages the -//! routing table of known ENR's, (and topic registration/advertisement tba) and performs -//! parallel queries for peer discovery. This section is realised by the [`Service`] struct. This -//! also runs in it's own thread. -//! * Application - This section is the user-facing API which can start/stop the underlying -//! tasks, initiate queries and obtain metrics about the underlying server. +//! - [`socket`]: Responsible for opening the underlying UDP socket. It creates individual tasks +//! for sending/encoding and receiving/decoding packets from the UDP socket. +//! - [`handler`]: The protocol's communication is encrypted with `AES_GCM`. All node communication +//! undergoes a handshake, which results in a `Session`. These are established when needed and get +//! dropped after a timeout. The creation and maintenance of sessions between nodes and the +//! encryption/decryption of packets from the socket is realised by the [`handler::Handler`] struct +//! runnning in its own task. +//! - [`service`]: Contains the protocol-level logic. The [`service::Service`] manages the routing +//! table of known ENR's, and performs parallel queries for peer discovery. It also runs in it's +//! own task. +//! - [`Discv5`]: The application level. Manages the user-facing API. It starts/stops the underlying +//! tasks, allows initiating queries and obtain metrics about the underlying server. //! -//! ## Event Stream +//! ## Event Stream //! -//! The [`Discv5`] struct provides access to an event-stream which allows the user to listen to -//! [`Discv5Event`] that get generated from the underlying server. The stream can be obtained -//! from the [`Discv5::event_stream()`] function. +//! The [`Discv5`] struct provides access to an event-stream which allows the user to listen to +//! [`Event`] that get generated from the underlying server. The stream can be obtained from the +//! [`Discv5::event_stream`] function. //! -//! ## Runtimes +//! ## Runtimes //! -//! Discv5 requires a tokio runtime with timing and io enabled. An explicit runtime can be given -//! via the configuration. See the [`Discv5ConfigBuilder`] for further details. Such a runtime -//! must implement the [`Executor`] trait. +//! Discv5 requires a tokio runtime with timing and io enabled. An explicit runtime can be given +//! via the configuration. See the [`ConfigBuilder`] for further details. Such a runtime must +//! implement the [`Executor`] trait. //! -//! If an explicit runtime is not provided via the configuration parameters, it is assumed that -//! a tokio runtime is present when creating the [`Discv5`] struct. The struct will use the -//! existing runtime for spawning the underlying server tasks. If a runtime is not present, the -//! creation of the [`Discv5`] struct will panic. +//! If an explicit runtime is not provided via the configuration parameters, it is assumed that a +//! tokio runtime is present when creating the [`Discv5`] struct. The struct will use the existing +//! runtime for spawning the underlying server tasks. If a runtime is not present, the creation of +//! the [`Discv5`] struct will panic. //! //! # Usage //! //! A simple example of creating this service is as follows: //! //! ```rust -//! use discv5::{enr, enr::{CombinedKey, NodeId}, TokioExecutor, Discv5, Discv5ConfigBuilder}; +//! use discv5::{enr, enr::{CombinedKey, NodeId}, TokioExecutor, Discv5, ConfigBuilder}; //! use discv5::socket::ListenConfig; //! use std::net::{Ipv4Addr, SocketAddr}; //! //! // construct a local ENR //! let enr_key = CombinedKey::generate_secp256k1(); -//! let enr = enr::EnrBuilder::new("v4").build(&enr_key).unwrap(); +//! let enr = enr::Enr::empty(&enr_key).unwrap(); //! //! // build the tokio executor //! let mut runtime = tokio::runtime::Builder::new_multi_thread() @@ -80,7 +75,7 @@ //! }; //! //! // default configuration -//! let config = Discv5ConfigBuilder::new(listen_config).build(); +//! let config = ConfigBuilder::new(listen_config).build(); //! //! // construct the discv5 server //! let mut discv5: Discv5 = Discv5::new(enr, enr_key, config).unwrap(); @@ -98,14 +93,6 @@ //! println!("Found nodes: {:?}", found_nodes); //! }); //! ``` -//! -//! [`Discv5`]: struct.Discv5.html -//! [`Discv5Event`]: enum.Discv5Event.html -//! [`Discv5Config`]: config/struct.Discv5Config.html -//! [`Discv5ConfigBuilder`]: config/struct.Discv5ConfigBuilder.html -//! [Packet]: packet/enum.Packet.html -//! [`Service`]: service/struct.Service.html -//! [`Session`]: session/struct.Session.html mod config; mod discv5; @@ -129,9 +116,9 @@ extern crate lazy_static; pub type Enr = enr::Enr; -pub use crate::discv5::{Discv5, Discv5Event}; -pub use config::{Discv5Config, Discv5ConfigBuilder}; -pub use error::{Discv5Error, QueryError, RequestError, ResponseError}; +pub use crate::discv5::{Discv5, Event}; +pub use config::{Config, ConfigBuilder}; +pub use error::{Error, QueryError, RequestError, ResponseError}; pub use executor::{Executor, TokioExecutor}; pub use ipmode::IpMode; pub use kbucket::{ConnectionDirection, ConnectionState, Key}; diff --git a/src/node_info.rs b/src/node_info.rs index 980e6b0ec..ba04b7790 100644 --- a/src/node_info.rs +++ b/src/node_info.rs @@ -5,9 +5,11 @@ use enr::{CombinedPublicKey, NodeId}; use std::net::SocketAddr; #[cfg(feature = "libp2p")] -use libp2p_core::{multiaddr::Protocol, Multiaddr}; -#[cfg(feature = "libp2p")] -use libp2p_identity::{KeyType, PublicKey}; +use libp2p::{ + identity::{KeyType, PublicKey}, + multiaddr::Protocol, + Multiaddr, +}; /// This type relaxes the requirement of having an ENR to connect to a node, to allow for unsigned /// connection types, such as multiaddrs. diff --git a/src/query_pool/peers/closest.rs b/src/query_pool/peers/closest.rs index 936cf2c9d..334d8a39f 100644 --- a/src/query_pool/peers/closest.rs +++ b/src/query_pool/peers/closest.rs @@ -23,7 +23,7 @@ // use super::*; use crate::{ - config::Discv5Config, + config::Config, kbucket::{Distance, Key, MAX_NODES_PER_BUCKET}, }; use std::{ @@ -76,7 +76,7 @@ pub struct FindNodeQueryConfig { } impl FindNodeQueryConfig { - pub fn new_from_config(config: &Discv5Config) -> Self { + pub fn new_from_config(config: &Config) -> Self { Self { parallelism: config.query_parallelism, num_results: MAX_NODES_PER_BUCKET, diff --git a/src/query_pool/peers/predicate.rs b/src/query_pool/peers/predicate.rs index 4768a1c35..3b4442019 100644 --- a/src/query_pool/peers/predicate.rs +++ b/src/query_pool/peers/predicate.rs @@ -1,6 +1,6 @@ use super::*; use crate::{ - config::Discv5Config, + config::Config, kbucket::{Distance, Key, PredicateKey, MAX_NODES_PER_BUCKET}, }; use std::{ @@ -55,7 +55,7 @@ pub(crate) struct PredicateQueryConfig { } impl PredicateQueryConfig { - pub(crate) fn new_from_config(config: &Discv5Config) -> Self { + pub(crate) fn new_from_config(config: &Config) -> Self { Self { parallelism: config.query_parallelism, num_results: MAX_NODES_PER_BUCKET, diff --git a/src/rpc.rs b/src/rpc.rs index 974cb8d4b..82af87ea0 100644 --- a/src/rpc.rs +++ b/src/rpc.rs @@ -10,8 +10,6 @@ pub use notification::{RelayInitNotification, RelayMsgNotification}; pub use request::{Request, RequestBody, RequestId}; pub use response::{Response, ResponseBody}; -/// Message type IDs. -#[derive(Debug)] #[repr(u8)] pub enum MessageType { Ping = 1, @@ -91,8 +89,18 @@ impl Message { return Err(DecoderError::RlpIsTooShort); } let msg_type = data[0]; + let data = &data[1..]; - let rlp = rlp::Rlp::new(&data[1..]); + let rlp = rlp::Rlp::new(data); + + if rlp.item_count()? < 2 { + return Err(DecoderError::RlpIncorrectListLen); + } + + let payload_info = rlp.payload_info()?; + if data.len() != payload_info.header_len + payload_info.value_len { + return Err(DecoderError::RlpInconsistentLengthAndData); + } match msg_type.try_into()? { MessageType::Ping | MessageType::FindNode | MessageType::TalkReq => { @@ -120,7 +128,7 @@ impl Message { mod tests { use super::*; use crate::packet::MESSAGE_NONCE_LENGTH; - use enr::{CombinedKey, Enr, EnrBuilder}; + use enr::{CombinedKey, Enr}; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; #[test] @@ -166,7 +174,11 @@ mod tests { let port = 5000; let message = Message::Response(Response { id, - body: ResponseBody::Pong { enr_seq, ip, port }, + body: ResponseBody::Pong { + enr_seq, + ip, + port: port.try_into().unwrap(), + }, }); // expected hex output @@ -283,7 +295,7 @@ mod tests { body: ResponseBody::Pong { enr_seq: 15, ip: "127.0.0.1".parse().unwrap(), - port: 80, + port: 80.try_into().unwrap(), }, }); @@ -301,7 +313,7 @@ mod tests { body: ResponseBody::Pong { enr_seq: 15, ip: IpAddr::V6(Ipv4Addr::new(192, 0, 2, 1).to_ipv6_mapped()), - port: 80, + port: 80.try_into().unwrap(), }, }); @@ -312,7 +324,7 @@ mod tests { body: ResponseBody::Pong { enr_seq: 15, ip: IpAddr::V4(Ipv4Addr::new(192, 0, 2, 1)), - port: 80, + port: 80.try_into().unwrap(), }, }); @@ -327,7 +339,7 @@ mod tests { body: ResponseBody::Pong { enr_seq: 15, ip: IpAddr::V6(Ipv6Addr::LOCALHOST), - port: 80, + port: 80.try_into().unwrap(), }, }); @@ -356,17 +368,17 @@ mod tests { #[test] fn encode_decode_nodes_response() { let key = CombinedKey::generate_secp256k1(); - let enr1 = EnrBuilder::new("v4") + let enr1 = Enr::builder() .ip4("127.0.0.1".parse().unwrap()) .udp4(500) .build(&key) .unwrap(); - let enr2 = EnrBuilder::new("v4") + let enr2 = Enr::builder() .ip4("10.0.0.1".parse().unwrap()) .tcp4(8080) .build(&key) .unwrap(); - let enr3 = EnrBuilder::new("v4") + let enr3 = Enr::builder() .ip("10.4.5.6".parse().unwrap()) .build(&key) .unwrap(); @@ -387,6 +399,25 @@ mod tests { assert_eq!(request, decoded); } + #[test] + fn reject_extra_data() { + let data = [6, 194, 0, 75]; + let msg = Message::decode(&data).unwrap(); + assert_eq!( + msg, + Message::Response(Response { + id: RequestId(vec![0]), + body: ResponseBody::TalkResp { response: vec![75] } + }) + ); + + let data2 = [6, 193, 0, 75, 252]; + Message::decode(&data2).expect_err("should reject extra data"); + + let data3 = [6, 194, 0, 75, 252]; + Message::decode(&data3).expect_err("should reject extra data"); + } + #[test] fn encode_decode_talk_request() { let id = RequestId(vec![1]); @@ -409,12 +440,12 @@ mod tests { // generate a new enr key for the initiator let enr_key = CombinedKey::generate_secp256k1(); // construct the initiator's ENR - let inr_enr = EnrBuilder::new("v4").build(&enr_key).unwrap(); + let inr_enr = Enr::builder().build(&enr_key).unwrap(); // generate a new enr key for the target let enr_key_tgt = CombinedKey::generate_secp256k1(); // construct the target's ENR - let tgt_enr = EnrBuilder::new("v4").build(&enr_key_tgt).unwrap(); + let tgt_enr = Enr::builder().build(&enr_key_tgt).unwrap(); let tgt_node_id = tgt_enr.node_id(); let nonce_bytes = hex::decode("47644922f5d6e951051051ac").unwrap(); @@ -435,7 +466,7 @@ mod tests { // generate a new enr key for the initiator let enr_key = CombinedKey::generate_secp256k1(); // construct the initiator's ENR - let inr_enr = EnrBuilder::new("v4").build(&enr_key).unwrap(); + let inr_enr = Enr::builder().build(&enr_key).unwrap(); let nonce_bytes = hex::decode("9951051051aceb").unwrap(); let mut nonce = [0u8; MESSAGE_NONCE_LENGTH]; diff --git a/src/rpc/response.rs b/src/rpc/response.rs index e0978e6fa..79882976e 100644 --- a/src/rpc/response.rs +++ b/src/rpc/response.rs @@ -5,6 +5,7 @@ use rlp::{DecoderError, Rlp, RlpStream}; use std::{ convert::TryInto, net::{IpAddr, Ipv6Addr}, + num::NonZeroU16, }; use tracing::debug; @@ -44,7 +45,7 @@ impl Payload for Response { IpAddr::V4(addr) => s.append(&(&addr.octets() as &[u8])), IpAddr::V6(addr) => s.append(&(&addr.octets() as &[u8])), }; - s.append(&port); + s.append(&port.get()); buf.extend_from_slice(&s.out()); buf } @@ -117,14 +118,19 @@ impl Payload for Response { return Err(DecoderError::RlpIncorrectListLen); } }; - let port = rlp.val_at::(3)?; - Self { - id, - body: ResponseBody::Pong { - enr_seq: rlp.val_at::(1)?, - ip, - port, - }, + let raw_port = rlp.val_at::(3)?; + if let Ok(port) = raw_port.try_into() { + Self { + id, + body: ResponseBody::Pong { + enr_seq: rlp.val_at::(1)?, + ip, + port, + }, + } + } else { + debug!("The port number should be non zero: {raw_port}"); + return Err(DecoderError::Custom("PONG response port number invalid")); } } MessageType::Nodes => { @@ -197,7 +203,7 @@ pub enum ResponseBody { /// Our external IP address as observed by the responder. ip: IpAddr, /// Our external UDP port as observed by the responder. - port: u16, + port: NonZeroU16, }, /// A NODES response. Nodes { diff --git a/src/service.rs b/src/service.rs index f55089010..9983a1cd2 100644 --- a/src/service.rs +++ b/src/service.rs @@ -29,7 +29,7 @@ use crate::{ query_pool::{ FindNodeQueryConfig, PredicateQueryConfig, QueryId, QueryPool, QueryPoolState, TargetKey, }, - rpc, Discv5Config, Discv5Event, Enr, IpMode, + rpc, Config, Enr, Event, IpMode, }; use delay_map::HashSetDelay; use enr::{CombinedKey, NodeId}; @@ -40,6 +40,7 @@ use parking_lot::RwLock; use rpc::*; use std::{ collections::HashMap, + convert::TryInto, net::{IpAddr, SocketAddr}, sync::Arc, task::Poll, @@ -161,14 +162,14 @@ pub enum ServiceRequest { Ping(Enr, Option>>), /// Sets up an event stream where the discv5 server will return various events such as /// discovered nodes as it traverses the DHT. - RequestEventStream(oneshot::Sender>), + RequestEventStream(oneshot::Sender>), } use crate::discv5::PERMIT_BAN_LIST; pub struct Service { /// Configuration parameters. - config: Discv5Config, + config: Config, /// The local ENR of the server. local_enr: Arc>, @@ -187,7 +188,7 @@ pub struct Service { active_requests: FnvHashMap, /// Keeps track of the number of responses received from a NODES response. - active_nodes_responses: HashMap, + active_nodes_responses: HashMap, /// A map of votes nodes have made about our external IP address. We accept the majority. ip_votes: Option, @@ -211,7 +212,7 @@ pub struct Service { peers_to_ping: HashSetDelay, /// A channel that the service emits events on. - event_stream: Option>, + event_stream: Option>, // Type of socket we are using ip_mode: IpMode, @@ -277,7 +278,7 @@ impl Service { local_enr: Arc>, enr_key: Arc>, kbuckets: Arc>>, - config: Discv5Config, + config: Config, ) -> Result<(oneshot::Sender<()>, mpsc::Sender), std::io::Error> { // process behaviour-level configuration parameters let ip_votes = if config.enr_update { @@ -323,7 +324,7 @@ impl Service { ip_mode, }; - info!("Discv5 Service started"); + info!(mode = ?service.ip_mode, "Discv5 Service started"); service.start().await; })); @@ -332,7 +333,6 @@ impl Service { /// The main execution loop of the discv5 serviced. async fn start(&mut self) { - info!("{:?}", self.ip_mode); loop { tokio::select! { _ = &mut self.exit => { @@ -378,7 +378,7 @@ impl Service { Some(event) = self.handler_recv.recv() => { match event { HandlerOut::Established(enr, socket_addr, direction) => { - self.send_event(Discv5Event::SessionEstablished(enr.clone(), socket_addr)); + self.send_event(Event::SessionEstablished(enr.clone(), socket_addr)); self.inject_session_established(enr, direction); } HandlerOut::Request(node_address, request) => { @@ -646,20 +646,24 @@ impl Service { // build the PONG response let src = node_address.socket_addr; - let response = Response { - id, - body: ResponseBody::Pong { - enr_seq: self.local_enr.read().seq(), - ip: src.ip(), - port: src.port(), - }, - }; - debug!("Sending PONG response to {}", node_address); - if let Err(e) = self - .handler_send - .send(HandlerIn::Response(node_address, Box::new(response))) - { - warn!("Failed to send response {}", e) + if let Ok(port) = src.port().try_into() { + let response = Response { + id, + body: ResponseBody::Pong { + enr_seq: self.local_enr.read().seq(), + ip: src.ip(), + port, + }, + }; + debug!("Sending PONG response to {}", node_address); + if let Err(e) = self + .handler_send + .send(HandlerIn::Response(node_address, Box::new(response))) + { + warn!("Failed to send response {}", e); + } + } else { + warn!("The src port number should be non zero. {src}"); } } RequestBody::TalkReq { protocol, request } => { @@ -671,7 +675,7 @@ impl Service { sender: Some(self.handler_send.clone()), }; - self.send_event(Discv5Event::TalkRequest(req)); + self.send_event(Event::TalkRequest(req)); } } } @@ -757,10 +761,9 @@ impl Service { if nodes.len() < before_len { // Peer sent invalid ENRs. Blacklist the Node - warn!( - "Peer sent invalid ENR. Blacklisting {}", - active_request.contact - ); + let node_id = active_request.contact.node_id(); + let addr = active_request.contact.socket_addr(); + warn!(%node_id, %addr, "ENRs received of unsolicited distances. Blacklisting"); let ban_timeout = self.config.ban_duration.map(|v| Instant::now() + v); PERMIT_BAN_LIST.write().ban(node_address, ban_timeout); } @@ -768,10 +771,8 @@ impl Service { // handle the case that there is more than one response if total > 1 { - let mut current_response = self - .active_nodes_responses - .remove(&node_id) - .unwrap_or_default(); + let mut current_response = + self.active_nodes_responses.remove(&id).unwrap_or_default(); debug!( "Nodes Response: {} of {} received", @@ -789,7 +790,7 @@ impl Service { current_response.received_nodes.append(&mut nodes); self.active_nodes_responses - .insert(node_id, current_response); + .insert(id.clone(), current_response); self.active_requests.insert(id, active_request); return; } @@ -811,19 +812,23 @@ impl Service { // in a later response sends a response with a total of 1, all previous nodes // will be ignored. // ensure any mapping is removed in this rare case - self.active_nodes_responses.remove(&node_id); + self.active_nodes_responses.remove(&id); self.discovered(&node_id, nodes, active_request.query_id); } ResponseBody::Pong { enr_seq, ip, port } => { // Send the response to the user, if they are who asked if let Some(CallbackResponse::Pong(callback)) = active_request.callback { - let response = Pong { enr_seq, ip, port }; + let response = Pong { + enr_seq, + ip, + port: port.get(), + }; if let Err(e) = callback.send(Ok(response)) { warn!("Failed to send callback response {:?}", e) }; } else { - let socket = SocketAddr::new(ip, port); + let socket = SocketAddr::new(ip, port.get()); // perform ENR majority-based update if required. // Only count votes that from peers we have contacted. @@ -874,11 +879,9 @@ impl Service { updated = true; info!( "Local UDP ip6 socket updated to: {}", - new_ip6 - ); - self.send_event(Discv5Event::SocketUpdated( new_ip6, - )); + ); + self.send_event(Event::SocketUpdated(new_ip6)); // Notify Handler of socket update if let Err(e) = self.handler_send.send(HandlerIn::SocketUpdate( @@ -904,9 +907,7 @@ impl Service { Ok(_) => { updated = true; info!("Local UDP socket updated to: {}", new_ip4); - self.send_event(Discv5Event::SocketUpdated( - new_ip4, - )); + self.send_event(Event::SocketUpdated(new_ip4)); // Notify Handler of socket update if let Err(e) = self.handler_send.send(HandlerIn::SocketUpdate( @@ -1231,7 +1232,7 @@ impl Service { } } - fn send_event(&mut self, event: Discv5Event) { + fn send_event(&mut self, event: Event) { if let Some(stream) = self.event_stream.as_mut() { if let Err(mpsc::error::TrySendError::Closed(_)) = stream.try_send(event) { // If the stream has been dropped prevent future attempts to send events @@ -1251,7 +1252,7 @@ impl Service { // If any of the discovered nodes are in the routing table, and there contains an older ENR, update it. // If there is an event stream send the Discovered event if self.config.report_discovered_peers { - self.send_event(Discv5Event::Discovered(enr.clone())); + self.send_event(Event::Discovered(enr.clone())); } // ignore peers that don't pass the table filter @@ -1347,7 +1348,7 @@ impl Service { self.send_ping(enr, None); } - let event = Discv5Event::NodeInserted { + let event = Event::NodeInserted { node_id, replaced: None, }; @@ -1507,13 +1508,14 @@ impl Service { match active_request.request_body { // if a failed FindNodes request, ensure we haven't partially received packets. If // so, process the partially found nodes - RequestBody::FindNode { .. } => { - if let Some(nodes_response) = self.active_nodes_responses.remove(&node_id) { + RequestBody::FindNode { ref distances } => { + if let Some(nodes_response) = self.active_nodes_responses.remove(&id) { if !nodes_response.received_nodes.is_empty() { - warn!( - "NODES Response failed, but was partially processed from: {}", - active_request.contact - ); + let node_id = active_request.contact.node_id(); + let addr = active_request.contact.socket_addr(); + let received = nodes_response.received_nodes.len(); + let expected = distances.len(); + warn!(%node_id, %addr, %error, %received, %expected, "FINDNODE request failed with partial results"); // if it's a query mark it as success, to process the partial // collection of peers self.discovered( @@ -1561,14 +1563,12 @@ impl Service { } /// A future that maintains the routing table and inserts nodes when required. This returns the - /// `Discv5Event::NodeInserted` variant if a new node has been inserted into the routing table. - async fn bucket_maintenance_poll( - kbuckets: &Arc>>, - ) -> Discv5Event { + /// [`Event::NodeInserted`] variant if a new node has been inserted into the routing table. + async fn bucket_maintenance_poll(kbuckets: &Arc>>) -> Event { future::poll_fn(move |_cx| { // Drain applied pending entries from the routing table. if let Some(entry) = kbuckets.write().take_applied_pending() { - let event = Discv5Event::NodeInserted { + let event = Event::NodeInserted { node_id: entry.inserted.into_preimage(), replaced: entry.evicted.map(|n| n.key.into_preimage()), }; @@ -1589,10 +1589,8 @@ impl Service { let request_body = query.target().rpc_request(return_peer); Poll::Ready(QueryEvent::Waiting(query.id(), node_id, request_body)) } - QueryPoolState::Timeout(query) => { - warn!("Query id: {:?} timed out", query.id()); - Poll::Ready(QueryEvent::TimedOut(Box::new(query))) - } + + QueryPoolState::Timeout(query) => Poll::Ready(QueryEvent::TimedOut(Box::new(query))), QueryPoolState::Waiting(None) | QueryPoolState::Idle => Poll::Pending, }) .await diff --git a/src/service/test.rs b/src/service/test.rs index d02f8394c..c85f6a97a 100644 --- a/src/service/test.rs +++ b/src/service/test.rs @@ -3,6 +3,7 @@ use super::*; use crate::{ + discv5::test::generate_deterministic_keypair, handler::Handler, kbucket, kbucket::{BucketInsertResult, KBucketsTable, NodeStatus}, @@ -12,13 +13,16 @@ use crate::{ rpc::RequestId, service::{ActiveRequest, Service}, socket::ListenConfig, - Discv5ConfigBuilder, Enr, + ConfigBuilder, Enr, }; -use enr::{CombinedKey, EnrBuilder}; +use enr::CombinedKey; use parking_lot::RwLock; -use std::{collections::HashMap, sync::Arc, time::Duration}; +use std::{collections::HashMap, net::Ipv4Addr, sync::Arc, time::Duration}; use tokio::sync::{mpsc, oneshot}; +/// Default UDP port number to use for tests requiring UDP exposure +pub const DEFAULT_UDP_PORT: u16 = 0; + fn _connected_state() -> NodeStatus { NodeStatus { state: ConnectionState::Connected, @@ -48,7 +52,7 @@ async fn build_service( ip: local_enr.read().ip4().unwrap(), port: local_enr.read().udp4().unwrap(), }; - let config = Discv5ConfigBuilder::new(listen_config) + let config = ConfigBuilder::new(listen_config) .executor(Box::::default()) .build(); // build the session service @@ -103,16 +107,16 @@ async fn test_updating_connection_on_ping() { init(); let enr_key1 = CombinedKey::generate_secp256k1(); let ip = "127.0.0.1".parse().unwrap(); - let enr = EnrBuilder::new("v4") + let enr = Enr::builder() .ip4(ip) - .udp4(10001) + .udp4(DEFAULT_UDP_PORT) .build(&enr_key1) .unwrap(); let ip2 = "127.0.0.1".parse().unwrap(); let enr_key2 = CombinedKey::generate_secp256k1(); - let enr2 = EnrBuilder::new("v4") + let enr2 = Enr::builder() .ip4(ip2) - .udp4(10002) + .udp4(DEFAULT_UDP_PORT) .build(&enr_key2) .unwrap(); @@ -141,7 +145,7 @@ async fn test_updating_connection_on_ping() { body: ResponseBody::Pong { enr_seq: 2, ip: ip2.into(), - port: 10002, + port: 9000.try_into().unwrap(), }, }; @@ -171,17 +175,17 @@ async fn test_connection_direction_on_inject_session_established() { let enr_key1 = CombinedKey::generate_secp256k1(); let ip = std::net::Ipv4Addr::LOCALHOST; - let enr = EnrBuilder::new("v4") + let enr = Enr::builder() .ip4(ip) - .udp4(10003) + .udp4(DEFAULT_UDP_PORT) .build(&enr_key1) .unwrap(); let enr_key2 = CombinedKey::generate_secp256k1(); let ip2 = std::net::Ipv4Addr::LOCALHOST; - let enr2 = EnrBuilder::new("v4") + let enr2 = Enr::builder() .ip4(ip2) - .udp4(10004) + .udp4(DEFAULT_UDP_PORT) .build(&enr_key2) .unwrap(); @@ -218,3 +222,122 @@ async fn test_connection_direction_on_inject_session_established() { assert!(status.is_connected()); assert_eq!(ConnectionDirection::Outgoing, status.direction); } + +#[tokio::test] +async fn test_handling_concurrent_responses() { + init(); + + // Seed is chosen such that all nodes are in the 256th distance of the first node. + let seed = 1652; + let mut keypairs = generate_deterministic_keypair(5, seed); + + let mut service = { + let enr_key = keypairs.pop().unwrap(); + let enr = Enr::builder() + .ip4(Ipv4Addr::LOCALHOST) + .udp4(10005) + .build(&enr_key) + .unwrap(); + build_service::( + Arc::new(RwLock::new(enr)), + Arc::new(RwLock::new(enr_key)), + false, + ) + .await + }; + + let node_contact: NodeContact = Enr::builder() + .ip4(Ipv4Addr::LOCALHOST) + .udp4(10006) + .build(&keypairs.remove(0)) + .unwrap() + .into(); + let node_address = node_contact.node_address(); + + // Add fake requests + // Request1 + service.active_requests.insert( + RequestId(vec![1]), + ActiveRequest { + contact: node_contact.clone(), + request_body: RequestBody::FindNode { + distances: vec![254, 255, 256], + }, + query_id: Some(QueryId(1)), + callback: None, + }, + ); + // Request2 + service.active_requests.insert( + RequestId(vec![2]), + ActiveRequest { + contact: node_contact, + request_body: RequestBody::FindNode { + distances: vec![254, 255, 256], + }, + query_id: Some(QueryId(2)), + callback: None, + }, + ); + + assert_eq!(3, keypairs.len()); + let mut enrs_for_response = keypairs + .iter() + .enumerate() + .map(|(i, key)| { + Enr::builder() + .ip4(Ipv4Addr::LOCALHOST) + .udp4(10007 + i as u16) + .build(key) + .unwrap() + }) + .collect::>(); + + // Response to `Request1` is sent as two separate messages in total. Handle the first one of the + // messages here. + service.handle_rpc_response( + node_address.clone(), + Response { + id: RequestId(vec![1]), + body: ResponseBody::Nodes { + total: 2, + nodes: vec![enrs_for_response.pop().unwrap()], + }, + }, + ); + // Service has still two active requests since we are waiting for the second NODE response to + // `Request1`. + assert_eq!(2, service.active_requests.len()); + // Service stores the first response to `Request1` into `active_nodes_responses`. + assert!(!service.active_nodes_responses.is_empty()); + + // Second, handle a response to *`Request2`* before the second response to `Request1`. + service.handle_rpc_response( + node_address.clone(), + Response { + id: RequestId(vec![2]), + body: ResponseBody::Nodes { + total: 1, + nodes: vec![enrs_for_response.pop().unwrap()], + }, + }, + ); + // `Request2` is completed so now the number of active requests should be one. + assert_eq!(1, service.active_requests.len()); + // Service still keeps the first response in `active_nodes_responses`. + assert!(!service.active_nodes_responses.is_empty()); + + // Finally, handle the second response to `Request1`. + service.handle_rpc_response( + node_address, + Response { + id: RequestId(vec![1]), + body: ResponseBody::Nodes { + total: 2, + nodes: vec![enrs_for_response.pop().unwrap()], + }, + }, + ); + assert!(service.active_requests.is_empty()); + assert!(service.active_nodes_responses.is_empty()); +} diff --git a/src/socket/filter/mod.rs b/src/socket/filter/mod.rs index d7a7da2e4..7bb54ad77 100644 --- a/src/socket/filter/mod.rs +++ b/src/socket/filter/mod.rs @@ -7,6 +7,7 @@ use lru::LruCache; use std::{ collections::HashSet, net::{IpAddr, SocketAddr}, + num::NonZeroUsize, sync::atomic::Ordering, time::{Duration, Instant}, }; @@ -19,9 +20,16 @@ pub use config::FilterConfig; use rate_limiter::{LimitKind, RateLimiter}; /// The maximum number of IPs to retain when calculating the number of nodes per IP. -const KNOWN_ADDRS_SIZE: usize = 500; +const KNOWN_ADDRS_SIZE: NonZeroUsize = match NonZeroUsize::new(500) { + Some(non_zero) => non_zero, + None => unreachable!(), +}; /// The number of IPs to retain at any given time that have banned nodes. -const BANNED_NODES_SIZE: usize = 50; +const BANNED_NODES_SIZE: NonZeroUsize = match NonZeroUsize::new(50) { + Some(non_zero) => non_zero, + None => unreachable!(), +}; + /// The maximum number of packets to keep record of for metrics if the rate limiter is not /// specified. const DEFAULT_PACKETS_PER_SECOND: usize = 20;