diff --git a/forwarder/src/socket/icmp.rs b/forwarder/src/socket/icmp.rs index 0524c09..700e794 100644 --- a/forwarder/src/socket/icmp.rs +++ b/forwarder/src/socket/icmp.rs @@ -15,6 +15,17 @@ use std::{ os::fd::AsRawFd, }; +/// tracks if the icmp receiver thread is started or not, the first index +/// is for icmpv4 and the second is for icmpv6 +static IS_RECEIVER_STARTED: [Mutex; 2] = [Mutex::new(false), Mutex::new(false)]; + +/// icmp receiver only handles ports that are in OPEN_PORTS so each +/// `IcmpSocket` must register it's port via adding it to `OPEN_PORTS` +/// and removing it when the `IcmpSocket` is dropped, the first index +/// is for icmpv4 open ports and the second index is for icmpv6 open ports +static OPEN_PORTS: [RwLock>; 2] = + [RwLock::new(BTreeSet::new()), RwLock::new(BTreeSet::new())]; + /// `IcmpSocket` that is very similiar to `UdpSocket` #[derive(Debug)] pub struct IcmpSocket { @@ -29,17 +40,6 @@ pub struct IcmpSocket { connected_addr: Option, } -static IS_RECEIVER_STARTED: Mutex = Mutex::new(false); - -/// each nonblocking `IcmpSocket` does not actually listen for new packets because -/// icmp protocol is on layer 2 and doesn't have any concept of ports -/// so each packet will wake up all `IcmpSocket`s, to fix that and remove -/// overheads of parsing each packet multiple times we listen to packets -/// only on one socket on another thread and after parsing port and packet -/// we put it in the corresponding controller `packets`, each nonblocking -/// `IcmpSocket` can register it's port via adding it to `OPEN_PORTS` -static OPEN_PORTS: RwLock> = RwLock::new(BTreeSet::new()); - impl IcmpSocket { pub fn bind(addr: &SocketAddr) -> io::Result { let udp_socket = std::net::UdpSocket::bind(addr)?; @@ -47,7 +47,8 @@ impl IcmpSocket { let socket = IcmpSocket::inner_bind(*addr)?; // run the icmp receiver if it isn't running - let mut is_receiver_alive = IS_RECEIVER_STARTED.lock(); + let receiver_index = addr.is_ipv6() as usize; + let mut is_receiver_alive = IS_RECEIVER_STARTED[receiver_index].lock(); if !*is_receiver_alive { let addr_clone = addr.to_owned(); std::thread::spawn(move || { @@ -83,7 +84,8 @@ impl IcmpSocket { impl Drop for IcmpSocket { fn drop(&mut self) { // clear port - let mut open_ports = OPEN_PORTS.write(); + let open_ports_index = self.udp_socket_addr.is_ipv6() as usize; + let mut open_ports = OPEN_PORTS[open_ports_index].write(); open_ports.remove(&self.udp_socket_addr.port()); } } @@ -178,7 +180,8 @@ impl SocketTrait for IcmpSocket { } fn register(&mut self, registry: &mio::Registry, token: mio::Token) -> io::Result<()> { - let mut open_ports = OPEN_PORTS.write(); + let open_ports_index = self.udp_socket_addr.is_ipv6() as usize; + let mut open_ports = OPEN_PORTS[open_ports_index].write(); open_ports.insert(self.udp_socket_addr.port()); registry.register( diff --git a/forwarder/src/socket/icmp/receiver.rs b/forwarder/src/socket/icmp/receiver.rs index cde2153..bf8ab58 100644 --- a/forwarder/src/socket/icmp/receiver.rs +++ b/forwarder/src/socket/icmp/receiver.rs @@ -3,6 +3,12 @@ use crate::MAX_PACKET_SIZE; use etherparse::Ipv4HeaderSlice; use std::{mem::MaybeUninit, net::SocketAddr}; +// each nonblocking `IcmpSocket` does not actually listen for new packets because +// icmp protocol is on layer 2 and doesn't have any concept of ports +// so each packet will wake up all `IcmpSocket`s, to fix that and remove +// overheads of parsing each packet multiple times we listen to packets +// only on one socket on another thread and after parsing port and packet +// we send it back to `IcmpSocket` via udp protocol pub fn run_icmp_receiver(addr: SocketAddr) -> anyhow::Result<()> { let is_ipv6 = addr.is_ipv6(); let socket: socket2::Socket = IcmpSocket::inner_bind(addr)?; @@ -11,6 +17,7 @@ pub fn run_icmp_receiver(addr: SocketAddr) -> anyhow::Result<()> { let mut buffer = [0u8; MAX_PACKET_SIZE]; let mut addr_buffer = addr; + let open_ports = &OPEN_PORTS[is_ipv6 as usize]; loop { let Ok(size) = @@ -21,7 +28,7 @@ pub fn run_icmp_receiver(addr: SocketAddr) -> anyhow::Result<()> { let Some(icmp_packet) = parse_icmp_packet(&buffer[..size], is_ipv6) else { continue; }; - let open_ports = OPEN_PORTS.read(); + let open_ports = open_ports.read(); let port = icmp_packet.dst_port; if open_ports.contains(&port) { addr_buffer.set_port(port); diff --git a/forwarder/tests/server.rs b/forwarder/tests/server.rs index e1073f4..67acc3b 100644 --- a/forwarder/tests/server.rs +++ b/forwarder/tests/server.rs @@ -100,7 +100,7 @@ fn test_udp_double_forwarder_back_and_forth() { #[test] #[ignore = "icmp sockets requires special access, please run this test with ./test_icmp.sh"] -fn test_icmp_double_forwarder_back_and_forth() { +fn test_icmpv4_double_forwarder_back_and_forth() { let forwarder_uri = SocketUri::from_str("127.0.0.1:38809/udp").unwrap(); let second_forwarder_uri = SocketUri::from_str("127.0.0.1:38810/icmp").unwrap(); let remote_uri = SocketUri::from_str("127.0.0.1:38811/udp").unwrap(); @@ -115,14 +115,14 @@ fn test_udp_ipv6_double_forwarder_back_and_forth() { spawn_double_forwarder_and_test_connection(forwarder_uri, second_forwarder_uri, remote_uri); } -// #[test] -// #[ignore = "icmp sockets requires special access, please run this test with ./test_icmp.sh"] -// fn test_icmp_ipv6_double_forwarder_back_and_forth() { -// let forwarder_uri = SocketUri::from_str("127.0.0.1:38815/udp").unwrap(); -// let second_forwarder_uri = SocketUri::from_str("[::1]:38816/icmp").unwrap(); -// let remote_uri = SocketUri::from_str("127.0.0.1:38817/udp").unwrap(); -// spawn_double_forwarder_and_test_connection(forwarder_uri, second_forwarder_uri, remote_uri); -// } +#[test] +#[ignore = "icmp sockets requires special access, please run this test with ./test_icmp.sh"] +fn test_icmpv6_double_forwarder_back_and_forth() { + let forwarder_uri = SocketUri::from_str("127.0.0.1:38815/udp").unwrap(); + let second_forwarder_uri = SocketUri::from_str("[::1]:38816/icmp").unwrap(); + let remote_uri = SocketUri::from_str("127.0.0.1:38817/udp").unwrap(); + spawn_double_forwarder_and_test_connection(forwarder_uri, second_forwarder_uri, remote_uri); +} fn spawn_double_forwarder_and_test_connection( forwarder_uri: SocketUri,