Skip to content

Commit

Permalink
implement icmp without using mio::Waker
Browse files Browse the repository at this point in the history
  • Loading branch information
Arian8j2 committed Sep 18, 2024
1 parent e3dddff commit 4d31839
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 102 deletions.
134 changes: 60 additions & 74 deletions forwarder/src/socket/icmp.rs
Original file line number Diff line number Diff line change
@@ -1,38 +1,20 @@
mod ether_helper;
mod receiver;

use crate::MAX_PACKET_SIZE;

use super::SocketTrait;
use crate::MAX_PACKET_SIZE;
use etherparse::{IcmpEchoHeader, Icmpv4Header, Icmpv4Type, Icmpv6Header, Icmpv6Type};
use mio::{unix::SourceFd, Interest};
use parking_lot::{Mutex, RwLock};
use socket2::{Domain, Protocol, Type};
use std::{
collections::{BTreeMap, VecDeque},
collections::BTreeSet,
io,
mem::MaybeUninit,
net::{SocketAddr, SocketAddrV6},
os::fd::AsRawFd,
sync::Arc,
};

/// Represents single packet
#[derive(Debug)]
struct Packet {
data: Vec<u8>,
from_addr: SocketAddr,
}

/// Thread safe buffer for `Packet`s
type SharedPacketBuffer = Mutex<VecDeque<Packet>>;

/// `Controller` is passed to `IcmpReceiver` so it can communicate to `IcmpSocket`
#[derive(Debug)]
struct Controller {
packets: Arc<SharedPacketBuffer>,
waker: mio::Waker,
}

/// `IcmpSocket` that is very similiar to `UdpSocket`
#[derive(Debug)]
pub struct IcmpSocket {
Expand All @@ -41,23 +23,27 @@ pub struct IcmpSocket {
is_blocking: bool,
/// udp socket that is kept alive for avoiding duplicate port
udp_socket: std::net::UdpSocket,
/// address of udp socket
udp_socket_addr: SocketAddr,
/// saves the socket that is connected to
connected_addr: Option<SocketAddr>,
/// each `IcmpSocket` does not actually listen for new packets because
/// icmp protocol is on layer 2 and doesn't have any concept of ports
/// so each packet will wake up all `IcmpSocket`s, to fix that and remove
/// overheads of parsing each packet multiple times we listen to packets
/// only on one socket on another thread and after parsing port and packet
/// we put it in the corresponding controller `packets`
packets: Arc<SharedPacketBuffer>,
}

static IS_RECEIVER_STARTED: Mutex<bool> = Mutex::new(false);
static OPEN_PORTS: RwLock<BTreeMap<u16, Controller>> = RwLock::new(BTreeMap::new());

/// each nonblocking `IcmpSocket` does not actually listen for new packets because
/// icmp protocol is on layer 2 and doesn't have any concept of ports
/// so each packet will wake up all `IcmpSocket`s, to fix that and remove
/// overheads of parsing each packet multiple times we listen to packets
/// only on one socket on another thread and after parsing port and packet
/// we put it in the corresponding controller `packets`, each nonblocking
/// `IcmpSocket` can register it's port via adding it to `OPEN_PORTS`
static OPEN_PORTS: RwLock<BTreeSet<u16>> = RwLock::new(BTreeSet::new());

impl IcmpSocket {
pub fn bind(addr: &SocketAddr) -> io::Result<Self> {
let udp_socket = std::net::UdpSocket::bind(addr)?;
let udp_socket_addr = udp_socket.local_addr()?;
let socket = IcmpSocket::inner_bind(*addr)?;

// run the icmp receiver if it isn't running
Expand All @@ -73,11 +59,10 @@ impl IcmpSocket {
*is_receiver_alive = true;
}

let packets = Mutex::new(VecDeque::with_capacity(10));
Ok(IcmpSocket {
udp_socket,
udp_socket_addr,
socket,
packets: Arc::new(packets),
connected_addr: None,
is_blocking: true,
})
Expand All @@ -99,15 +84,24 @@ impl Drop for IcmpSocket {
fn drop(&mut self) {
// clear port
let mut open_ports = OPEN_PORTS.write();
let port = self.udp_socket.local_addr().unwrap().port();
open_ports.remove(&port);
open_ports.remove(&self.udp_socket_addr.port());
}
}

impl SocketTrait for IcmpSocket {
fn recv(&self, buffer: &mut [u8]) -> io::Result<usize> {
let (size, _) = self.recv_from(buffer)?;
Ok(size)
if self.is_blocking {
unimplemented!("currently IcmpSocket::recv in blocking mode is not being used")
}
// icmp receiver sends packets that it receives to udp socket of `IcmpSocket`
let (size, from_addr) = self.udp_socket.recv_from(buffer)?;
// make sure that the receiver sent the packet
// receiver is local so the packet ip is from loopback
if from_addr.ip().is_loopback() {
Ok(size)
} else {
Err(io::ErrorKind::ConnectionRefused.into())
}
}

fn send(&self, buffer: &[u8]) -> io::Result<usize> {
Expand All @@ -134,60 +128,52 @@ impl SocketTrait for IcmpSocket {
}

fn recv_from(&self, buffer: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
if self.is_blocking {
let mut second_buffer = [0u8; MAX_PACKET_SIZE];
let local_addr = self.local_addr()?;
loop {
let (size, from_addr) = self.socket.recv_from(unsafe {
&mut *(&mut second_buffer as *mut [u8] as *mut [MaybeUninit<u8>])
})?;
let Some(packet) =
receiver::parse_icmp_packet(&second_buffer[..size], local_addr.is_ipv6())
else {
continue;
};
if packet.dst_port != local_addr.port() {
continue;
}
let payload_len = packet.payload.len();
buffer[..payload_len].copy_from_slice(packet.payload);

let mut from_addr = from_addr.as_socket().unwrap();
from_addr.set_port(packet.src_port);
return Ok((payload_len, from_addr));
}
} else {
let mut packets = self.packets.lock();
match packets.pop_front() {
Some(packet) => {
let len = packet.data.len();
buffer[..len].copy_from_slice(&packet.data);
Ok((len, packet.from_addr))
}
None => Err(io::ErrorKind::WouldBlock.into()),
if !self.is_blocking {
unimplemented!("currently IcmpSocket::recv_from in nonblocking mode is not being used")
}
let mut second_buffer = [0u8; MAX_PACKET_SIZE];
let local_addr = self.local_addr()?;
loop {
let (size, from_addr) = self.socket.recv_from(unsafe {
&mut *(&mut second_buffer as *mut [u8] as *mut [MaybeUninit<u8>])
})?;
let Some(packet) =
receiver::parse_icmp_packet(&second_buffer[..size], local_addr.is_ipv6())
else {
continue;
};
if packet.dst_port != local_addr.port() {
continue;
}
let payload_len = packet.payload.len();
buffer[..payload_len].copy_from_slice(packet.payload);

let mut from_addr = from_addr.as_socket().unwrap();
from_addr.set_port(packet.src_port);
return Ok((payload_len, from_addr));
}
}

fn local_addr(&self) -> io::Result<SocketAddr> {
self.udp_socket.local_addr()
Ok(self.udp_socket_addr)
}

fn set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()> {
self.socket.set_nonblocking(nonblocking)?;
self.udp_socket.set_nonblocking(nonblocking)?;
self.is_blocking = !nonblocking;
Ok(())
}

fn register(&mut self, registry: &mio::Registry, token: mio::Token) -> io::Result<()> {
let waker = mio::Waker::new(registry, token)?;
let mut open_ports = OPEN_PORTS.write();
let port = self.local_addr()?.port();
let controller = Controller {
packets: self.packets.clone(),
waker,
};
open_ports.insert(port, controller);
open_ports.insert(self.udp_socket_addr.port());

registry.register(
&mut SourceFd(&self.udp_socket.as_raw_fd()),
token,
Interest::READABLE,
)?;
Ok(())
}
}
Expand Down
42 changes: 14 additions & 28 deletions forwarder/src/socket/icmp/receiver.rs
Original file line number Diff line number Diff line change
@@ -1,47 +1,33 @@
use super::{ether_helper::IcmpSlice, IcmpSocket, Packet, OPEN_PORTS};
use super::{ether_helper::IcmpSlice, IcmpSocket, OPEN_PORTS};
use crate::MAX_PACKET_SIZE;
use etherparse::Ipv4HeaderSlice;
use socket2::SockAddr;
use std::{mem::MaybeUninit, net::SocketAddr};

pub fn run_icmp_receiver(addr: SocketAddr) -> anyhow::Result<()> {
let is_ipv6 = addr.is_ipv6();
let socket: socket2::Socket = IcmpSocket::inner_bind(addr)?;
let udp_socket = std::net::UdpSocket::bind(SocketAddr::new(addr.ip(), 0))?;
udp_socket.set_nonblocking(true)?;

let mut buffer = [0u8; MAX_PACKET_SIZE];
let mut addr_buffer = addr;

loop {
let Ok((size, from_addr)) =
socket.recv_from(unsafe { &mut *(&mut buffer as *mut [u8] as *mut [MaybeUninit<u8>]) })
let Ok(size) =
socket.recv(unsafe { &mut *(&mut buffer as *mut [u8] as *mut [MaybeUninit<u8>]) })
else {
continue;
};

let Some(packet) = parse_icmp_packet(&buffer[..size], is_ipv6) else {
let Some(icmp_packet) = parse_icmp_packet(&buffer[..size], is_ipv6) else {
continue;
};
handle_packet(packet, from_addr);
}
}

fn handle_packet(icmp: IcmpPacket<'_>, from_addr: SockAddr) -> Option<()> {
let open_ports = OPEN_PORTS.write();
let controller = open_ports.get(&icmp.dst_port)?;

let mut source_addr = from_addr.as_socket().unwrap();
source_addr.set_port(icmp.src_port);

let packet = Packet {
data: icmp.payload.to_vec(),
from_addr: source_addr,
};
{
let mut packets = controller.packets.lock();
packets.push_back(packet);
}
if let Err(error) = controller.waker.wake() {
log::warn!("couldn't wake up icmp socket: {error:?}")
let open_ports = OPEN_PORTS.read();
let port = icmp_packet.dst_port;
if open_ports.contains(&port) {
addr_buffer.set_port(port);
udp_socket.send_to(icmp_packet.payload, addr_buffer).ok();
}
}
Some(())
}

pub struct IcmpPacket<'a> {
Expand Down

0 comments on commit 4d31839

Please sign in to comment.