Skip to content

Commit 044387c

Browse files
committed
initial support for udp proxy via dumbpipe added.
1 parent b90406a commit 044387c

File tree

2 files changed

+454
-5
lines changed

2 files changed

+454
-5
lines changed

src/main.rs

+262-5
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,15 @@ use iroh::{
66
endpoint::{get_remote_node_id, Connecting},
77
Endpoint, NodeAddr, SecretKey,
88
};
9+
use quinn::Connection;
910
use std::{
10-
io,
11-
net::{SocketAddr, SocketAddrV4, SocketAddrV6, ToSocketAddrs},
12-
str::FromStr,
11+
collections::HashMap, io, net::{SocketAddr, SocketAddrV4, SocketAddrV6, ToSocketAddrs}, str::FromStr, sync::Arc
1312
};
1413
use tokio::{
15-
io::{AsyncRead, AsyncWrite, AsyncWriteExt},
16-
select,
14+
io::{AsyncRead, AsyncWrite, AsyncWriteExt}, net::UdpSocket, select
1715
};
1816
use tokio_util::sync::CancellationToken;
17+
mod udpconn;
1918

2019
/// Create a dumb pipe between two machines, using an iroh magicsocket.
2120
///
@@ -54,6 +53,15 @@ pub enum Commands {
5453
/// connecting to a TCP socket for which you have to specify the host and port.
5554
ListenTcp(ListenTcpArgs),
5655

56+
/// Listen on a magicsocket and forward incoming connections to the specified
57+
/// host and port. Every incoming bidi stream is forwarded to a new connection.
58+
///
59+
/// Will print a node ticket on stderr that can be used to connect.
60+
///
61+
/// As far as the magic socket is concerned, this is listening. But it is
62+
/// connecting to a UDP socket for which you have to specify the host and port.
63+
ListenUdp(ListenUdpArgs),
64+
5765
/// Connect to a magicsocket, open a bidi stream, and forward stdin/stdout.
5866
///
5967
/// A node ticket is required to connect.
@@ -67,6 +75,15 @@ pub enum Commands {
6775
/// As far as the magic socket is concerned, this is connecting. But it is
6876
/// listening on a TCP socket for which you have to specify the interface and port.
6977
ConnectTcp(ConnectTcpArgs),
78+
79+
/// Connect to a magicsocket, open a bidi stream, and forward stdin/stdout
80+
/// to it.
81+
///
82+
/// A node ticket is required to connect.
83+
///
84+
/// As far as the magic socket is concerned, this is connecting. But it is
85+
/// listening on a UDP socket for which you have to specify the interface and port.
86+
ConnectUdp(ConnectUdpArgs),
7087
}
7188

7289
#[derive(Parser, Debug)]
@@ -140,6 +157,15 @@ pub struct ListenTcpArgs {
140157
pub common: CommonArgs,
141158
}
142159

160+
#[derive(Parser, Debug)]
161+
pub struct ListenUdpArgs {
162+
#[clap(long)]
163+
pub host: String,
164+
165+
#[clap(flatten)]
166+
pub common: CommonArgs,
167+
}
168+
143169
#[derive(Parser, Debug)]
144170
pub struct ConnectTcpArgs {
145171
/// The addresses to listen on for incoming tcp connections.
@@ -155,6 +181,21 @@ pub struct ConnectTcpArgs {
155181
pub common: CommonArgs,
156182
}
157183

184+
#[derive(Parser, Debug)]
185+
pub struct ConnectUdpArgs {
186+
/// The addresses to listen on for incoming udp connections.
187+
///
188+
/// To listen on all network interfaces, use 0.0.0.0:12345
189+
#[clap(long)]
190+
pub addr: String,
191+
192+
/// The node to connect to
193+
pub ticket: NodeTicket,
194+
195+
#[clap(flatten)]
196+
pub common: CommonArgs,
197+
}
198+
158199
#[derive(Parser, Debug)]
159200
pub struct ConnectArgs {
160201
/// The node to connect to
@@ -440,6 +481,126 @@ async fn connect_tcp(args: ConnectTcpArgs) -> anyhow::Result<()> {
440481
Ok(())
441482
}
442483

484+
pub struct SplitUdpConn {
485+
// TODO: Do we need to store this connection?
486+
// Holding on to this for the future where we need to cleanup the resources.
487+
connection: quinn::Connection,
488+
send: quinn::SendStream,
489+
}
490+
491+
impl SplitUdpConn {
492+
pub fn new(connection: quinn::Connection, send: quinn::SendStream) -> Self {
493+
Self {
494+
connection,
495+
send
496+
}
497+
}
498+
}
499+
500+
// 1- Receives request message from socket
501+
// 2- Forwards it to the quinn stream
502+
// 3- Receives response message back from quinn stream
503+
// 4- Forwards it back to the socket
504+
async fn connect_udp(args: ConnectUdpArgs) -> anyhow::Result<()> {
505+
let addrs = args
506+
.addr
507+
.to_socket_addrs()
508+
.context(format!("invalid host string {}", args.addr))?;
509+
let secret_key = get_or_create_secret()?;
510+
let mut builder = Endpoint::builder().secret_key(secret_key).alpns(vec![]);
511+
if let Some(addr) = args.common.magic_ipv4_addr {
512+
builder = builder.bind_addr_v4(addr);
513+
}
514+
if let Some(addr) = args.common.magic_ipv6_addr {
515+
builder = builder.bind_addr_v6(addr);
516+
}
517+
let endpoint = builder.bind().await.context("unable to bind magicsock")?;
518+
tracing::info!("udp listening on {:?}", addrs);
519+
let socket = Arc::new(UdpSocket::bind(addrs.as_slice()).await?);
520+
521+
let node_addr = args.ticket.node_addr();
522+
let mut buf: Vec<u8> = vec![0u8; 65535];
523+
let mut conns = HashMap::<SocketAddr, SplitUdpConn>::new();
524+
loop {
525+
match socket.recv_from(&mut buf).await {
526+
Ok((size, sock_addr)) => {
527+
// Check if we already have a connection for this socket address
528+
let connection = match conns.get_mut(&sock_addr) {
529+
Some(conn) => conn,
530+
None => {
531+
// We need to finish the connection to be done or we should use something like promise because
532+
// when the connection was getting established, it might receive another message.
533+
let endpoint = endpoint.clone();
534+
let addr = node_addr.clone();
535+
let handshake = !args.common.is_custom_alpn();
536+
let alpn = args.common.alpn()?;
537+
538+
let remote_node_id = addr.node_id;
539+
tracing::info!("forwarding UDP to {}", remote_node_id);
540+
541+
// connect to the node, try only once
542+
let connection = endpoint
543+
.connect(addr.clone(), &alpn)
544+
.await
545+
.context(format!("error connecting to {}", remote_node_id))?;
546+
tracing::info!("connected to {}", remote_node_id);
547+
548+
// open a bidi stream, try only once
549+
let (mut send, recv) = connection
550+
.open_bi()
551+
.await
552+
.context(format!("error opening bidi stream to {}", remote_node_id))?;
553+
tracing::info!("opened bidi stream to {}", remote_node_id);
554+
555+
// send the handshake unless we are using a custom alpn
556+
if handshake {
557+
send.write_all(&dumbpipe::HANDSHAKE).await?;
558+
}
559+
560+
let sock_send = socket.clone();
561+
// Spawn a task for listening the quinn connection, and forwarding the data to the UDP socket
562+
tokio::spawn(async move {
563+
// 3- Receives response message back from quinn stream
564+
// 4- Forwards it back to the socket
565+
if let Err(cause) = udpconn::handle_udp_accept(sock_addr, sock_send, recv )
566+
.await {
567+
// log error at warn level
568+
//
569+
// we should know about it, but it's not fatal
570+
tracing::warn!("error handling connection: {}", cause);
571+
572+
// TODO: cleanup resources
573+
}
574+
});
575+
576+
// Create and store the split connection
577+
let split_conn = SplitUdpConn::new(connection.clone(), send);
578+
conns.insert(sock_addr, split_conn);
579+
conns.get_mut(&sock_addr).expect("connection was just inserted")
580+
}
581+
};
582+
583+
tracing::info!("forward_udp_to_quinn: Received {} bytes from {}", size, sock_addr);
584+
585+
// 1- Receives request message from socket
586+
// 2- Forwards it to the quinn stream
587+
if let Err(e) = connection.send.write_all(&buf[..size]).await {
588+
tracing::error!("Error writing to Quinn stream: {}", e);
589+
// TODO: Cleanup the resources on error.
590+
// Remove the failed connection
591+
// conns.remove(&sock_addr);
592+
return Err(e.into());
593+
}
594+
}
595+
Err(e) => {
596+
tracing::warn!("error receiving from UDP socket: {}", e);
597+
break;
598+
}
599+
}
600+
}
601+
Ok(())
602+
}
603+
443604
/// Listen on a magicsocket and forward incoming connections to a tcp socket.
444605
async fn listen_tcp(args: ListenTcpArgs) -> anyhow::Result<()> {
445606
let addrs = match args.host.to_socket_addrs() {
@@ -533,15 +694,111 @@ async fn listen_tcp(args: ListenTcpArgs) -> anyhow::Result<()> {
533694
Ok(())
534695
}
535696

697+
/// Listen on a magicsocket and forward incoming connections to a udp socket.
698+
async fn listen_udp(args: ListenUdpArgs) -> anyhow::Result<()> {
699+
let addrs = match args.host.to_socket_addrs() {
700+
Ok(addrs) => addrs.collect::<Vec<_>>(),
701+
Err(e) => anyhow::bail!("invalid host string {}: {}", args.host, e),
702+
};
703+
let secret_key = get_or_create_secret()?;
704+
let mut builder = Endpoint::builder()
705+
.alpns(vec![args.common.alpn()?])
706+
.secret_key(secret_key);
707+
if let Some(addr) = args.common.magic_ipv4_addr {
708+
builder = builder.bind_addr_v4(addr);
709+
}
710+
if let Some(addr) = args.common.magic_ipv6_addr {
711+
builder = builder.bind_addr_v6(addr);
712+
}
713+
let endpoint = builder.bind().await?;
714+
// wait for the endpoint to figure out its address before making a ticket
715+
endpoint.home_relay().initialized().await?;
716+
let node_addr = endpoint.node_addr().await?;
717+
let mut short = node_addr.clone();
718+
let ticket = NodeTicket::new(node_addr);
719+
short.direct_addresses.clear();
720+
let short = NodeTicket::new(short);
721+
722+
// print the ticket on stderr so it doesn't interfere with the data itself
723+
//
724+
// note that the tests rely on the ticket being the last thing printed
725+
eprintln!("Forwarding incoming requests to '{}'.", args.host);
726+
eprintln!("To connect, use e.g.:");
727+
eprintln!("dumbpipe connect-udp {ticket}");
728+
if args.common.verbose > 0 {
729+
eprintln!("or:\ndumbpipe connect-udp {}", short);
730+
}
731+
tracing::info!("node id is {}", ticket.node_addr().node_id);
732+
tracing::info!("derp url is {:?}", ticket.node_addr().relay_url);
733+
734+
// handle a new incoming connection on the magic endpoint
735+
async fn handle_magic_accept(
736+
connecting: Connecting,
737+
addrs: Vec<std::net::SocketAddr>,
738+
handshake: bool,
739+
) -> anyhow::Result<()> {
740+
let connection = connecting.await.context("error accepting connection")?;
741+
let remote_node_id = get_remote_node_id(&connection)?;
742+
tracing::info!("got connection from {}", remote_node_id);
743+
let (s, mut r) = connection
744+
.accept_bi()
745+
.await
746+
.context("error accepting stream")?;
747+
tracing::info!("accepted bidi stream from {}", remote_node_id);
748+
if handshake {
749+
// read the handshake and verify it
750+
let mut buf = [0u8; dumbpipe::HANDSHAKE.len()];
751+
r.read_exact(&mut buf).await?;
752+
anyhow::ensure!(buf == dumbpipe::HANDSHAKE, "invalid handshake");
753+
}
754+
755+
// 1- Receives request message from quinn stream
756+
// 2- Forwards it to the (addrs) via UDP socket
757+
// 3- Receives response message back from UDP socket
758+
// 4- Forwards it back to the quinn stream
759+
udpconn::handle_udp_listen(addrs.as_slice(), r, s).await?;
760+
Ok(())
761+
}
762+
763+
loop {
764+
let incoming = select! {
765+
incoming = endpoint.accept() => incoming,
766+
_ = tokio::signal::ctrl_c() => {
767+
eprintln!("got ctrl-c, exiting");
768+
break;
769+
}
770+
};
771+
let Some(incoming) = incoming else {
772+
break;
773+
};
774+
let Ok(connecting) = incoming.accept() else {
775+
break;
776+
};
777+
let addrs = addrs.clone();
778+
let handshake = !args.common.is_custom_alpn();
779+
tokio::spawn(async move {
780+
if let Err(cause) = handle_magic_accept(connecting, addrs, handshake).await {
781+
// log error at warn level
782+
//
783+
// we should know about it, but it's not fatal
784+
tracing::warn!("error handling connection: {}", cause);
785+
}
786+
});
787+
}
788+
Ok(())
789+
}
790+
536791
#[tokio::main]
537792
async fn main() -> anyhow::Result<()> {
538793
tracing_subscriber::fmt::init();
539794
let args = Args::parse();
540795
let res = match args.command {
541796
Commands::Listen(args) => listen_stdio(args).await,
542797
Commands::ListenTcp(args) => listen_tcp(args).await,
798+
Commands::ListenUdp(args) => listen_udp(args).await,
543799
Commands::Connect(args) => connect_stdio(args).await,
544800
Commands::ConnectTcp(args) => connect_tcp(args).await,
801+
Commands::ConnectUdp(args) => connect_udp(args).await,
545802
};
546803
match res {
547804
Ok(()) => std::process::exit(0),

0 commit comments

Comments
 (0)