diff --git a/Cargo.lock b/Cargo.lock index f3006eb..5b2a0cc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -711,6 +711,7 @@ dependencies = [ name = "dumbpipe" version = "0.28.0" dependencies = [ + "bytes", "clap", "data-encoding", "duct", diff --git a/Cargo.toml b/Cargo.toml index 640c11b..e454eb9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,6 +27,7 @@ tracing = "0.1.40" tracing-subscriber = { version = "0.3.18", features = ["env-filter"] } data-encoding = "2.9.0" n0-snafu = "0.2.1" +bytes = "1.5.0" [dev-dependencies] duct = "0.13.6" diff --git a/src/main.rs b/src/main.rs index 1de688e..a1ad5d8 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,4 +1,5 @@ //! Command line arguments. +use bytes::Bytes; use clap::{Parser, Subcommand}; use dumbpipe::NodeTicket; use iroh::{endpoint::Connecting, Endpoint, NodeAddr, SecretKey, Watcher}; @@ -51,6 +52,12 @@ pub enum Commands { /// connecting to a TCP socket for which you have to specify the host and port. ListenTcp(ListenTcpArgs), + /// Listen on a magicsocket and forward incoming connections to the specified + /// UDP socket. Every incoming connection is forwarded to a new UDP socket. + /// + /// Will print a node ticket on stderr that can be used to connect. + ListenUdp(ListenUdpArgs), + /// Connect to a magicsocket, open a bidi stream, and forward stdin/stdout. /// /// A node ticket is required to connect. @@ -64,6 +71,14 @@ pub enum Commands { /// As far as the magic socket is concerned, this is connecting. But it is /// listening on a TCP socket for which you have to specify the interface and port. ConnectTcp(ConnectTcpArgs), + + /// Connect to a magicsocket and forward UDP packets bidirectionally. + /// + /// A node ticket is required to connect. + /// + /// As far as the magic socket is concerned, this is connecting. But it is + /// listening on a UDP socket for which you have to specify the interface and port. + ConnectUdp(ConnectUdpArgs), } #[derive(Parser, Debug)] @@ -137,6 +152,14 @@ pub struct ListenTcpArgs { pub common: CommonArgs, } +#[derive(Parser, Debug)] +pub struct ListenUdpArgs { + #[clap(long)] + pub host: String, + #[clap(flatten)] + pub common: CommonArgs, +} + #[derive(Parser, Debug)] pub struct ConnectTcpArgs { /// The addresses to listen on for incoming tcp connections. @@ -152,6 +175,18 @@ pub struct ConnectTcpArgs { pub common: CommonArgs, } +#[derive(Parser, Debug)] +pub struct ConnectUdpArgs { + /// The addresses to listen on for incoming udp datagrams. + /// + /// To listen on all network interfaces, use 0.0.0.0:12345 + #[clap(long)] + pub addr: String, + pub ticket: NodeTicket, + #[clap(flatten)] + pub common: CommonArgs, +} + #[derive(Parser, Debug)] pub struct ConnectArgs { /// The node to connect to @@ -161,6 +196,114 @@ pub struct ConnectArgs { pub common: CommonArgs, } +/// Forward UDP packets between a QUIC connection (unreliable datagrams) and a +/// local UDP socket. +/// +/// Spawns two tasks: +/// - QUIC → UDP +/// - UDP → QUIC +/// +/// If the `udp` socket is not `connect`ed to a peer, this will learn the peer +/// address from the first incoming UDP packet and send all QUIC datagrams to that +/// peer. This is the mode used by `connect-udp`. +/// +/// If the `udp` socket is `connect`ed, it will use `send` to send to the +/// connected peer. This is the mode used by `listen-udp`. +/// +/// Both directions are cancelled when either task finishes or on ctrl-c. +async fn forward_udp_bidi( + conn: iroh::endpoint::Connection, + udp: tokio::net::UdpSocket, +) -> Result<()> { + let token = CancellationToken::new(); + let udp = std::sync::Arc::new(udp); + // The remote peer for an unconnected UDP socket. + // This is the address of the local application that sends us packets. + // It is None until the first packet is received. + let remote_udp_peer = std::sync::Arc::new(tokio::sync::Mutex::new(None::)); + let is_connected_udp = udp.peer_addr().is_ok(); + + // QUIC -> UDP + let t1 = tokio::spawn({ + let conn = conn.clone(); + let udp = udp.clone(); + let token = token.clone(); + let remote_udp_peer = remote_udp_peer.clone(); + async move { + loop { + tokio::select! { + res = conn.read_datagram() => { + let pkt = res.context("read_datagram")?; + if is_connected_udp { + // The UDP socket is 'connected', we can use send. + // This is the `listen-udp` case. + udp.send(&pkt).await.context("send udp")?; + } else { + // The UDP socket is not 'connected', we must use send_to. + // This is the `connect-udp` case. + // We need to have a destination address, which we learn + // from the first incoming packet in the other task. + if let Some(peer) = *remote_udp_peer.lock().await { + udp.send_to(&pkt, peer).await.context("send_to udp")?; + } else { + // We have received a packet from QUIC, but we don't know + // where to send it on the local UDP network yet. + // So we just drop it. + tracing::trace!("dropping datagram from quic, no udp peer yet"); + } + } + } + _ = token.cancelled() => break, + } + } + Result::<_, n0_snafu::Error>::Ok(()) + } + }); + + // UDP -> QUIC + let t2 = tokio::spawn({ + let udp = udp.clone(); + let token = token.clone(); + // remote_udp_peer is moved into this closure + async move { + let mut buf = vec![0u8; 65536]; + loop { + tokio::select! { + res = udp.recv_from(&mut buf) => { + let (len, src) = res.context("recv udp")?; + if !is_connected_udp { + // This is the `connect-udp` case. We are acting as a server + // for a local application. The first packet we receive + // tells us the address of the client. We'll send all + // subsequent packets from QUIC back to this address. + let mut peer = remote_udp_peer.lock().await; + if peer.is_none() { + tracing::info!("established udp session with {}", src); + *peer = Some(src); + } + } + // Forward the packet to the QUIC connection. + conn.send_datagram(Bytes::copy_from_slice(&buf[..len])) + .context("send_datagram")?; + } + _ = token.cancelled() => break, + } + } + Result::<_, n0_snafu::Error>::Ok(()) + } + }); + + // Wait for first task to finish or ctrl-c + tokio::select! { + _ = tokio::signal::ctrl_c() => { + token.cancel(); + } + res = t1 => res.context("quic->udp task")?.e()?, + res = t2 => res.context("udp->quic task")?.e()?, + } + Ok(()) +} + /// Copy from a reader to a quinn stream. /// /// Will send a reset to the other side if the operation is cancelled, and fail @@ -288,9 +431,11 @@ async fn listen_stdio(args: ListenArgs) -> Result<()> { // print the ticket on stderr so it doesn't interfere with the data itself // // note that the tests rely on the ticket being the last thing printed - eprintln!("Listening. To connect, use:\ndumbpipe connect {ticket}"); + eprintln!("Listening. To connect, use: +dumbpipe connect {ticket}"); if args.common.verbose > 0 { - eprintln!("or:\ndumbpipe connect {short}"); + eprintln!("or: +dumbpipe connect {short}"); } loop { @@ -472,7 +617,8 @@ async fn listen_tcp(args: ListenTcpArgs) -> Result<()> { eprintln!("To connect, use e.g.:"); eprintln!("dumbpipe connect-tcp {ticket}"); if args.common.verbose > 0 { - eprintln!("or:\ndumbpipe connect-tcp {short}"); + eprintln!("or: +dumbpipe connect-tcp {short}"); } tracing::info!("node id is {}", ticket.node_addr().node_id); tracing::info!("derp url is {:?}", ticket.node_addr().relay_url); @@ -540,8 +686,10 @@ async fn main() -> Result<()> { let res = match args.command { Commands::Listen(args) => listen_stdio(args).await, Commands::ListenTcp(args) => listen_tcp(args).await, + Commands::ListenUdp(args) => listen_udp(args).await, Commands::Connect(args) => connect_stdio(args).await, Commands::ConnectTcp(args) => connect_tcp(args).await, + Commands::ConnectUdp(args) => connect_udp(args).await, }; match res { Ok(()) => std::process::exit(0), @@ -551,3 +699,138 @@ async fn main() -> Result<()> { } } } + +/// Listen on a magicsocket and forward incoming connections to a UDP socket. +async fn listen_udp(args: ListenUdpArgs) -> Result<()> { + let addrs = match args.host.to_socket_addrs() { + Ok(addrs) => addrs.collect::>(), + Err(e) => snafu::whatever!("invalid host string {}: {}", args.host, e), + }; + let secret_key = get_or_create_secret()?; + let mut builder = Endpoint::builder() + .alpns(vec![args.common.alpn()?]) + .secret_key(secret_key); + if let Some(addr) = args.common.magic_ipv4_addr { + builder = builder.bind_addr_v4(addr); + } + if let Some(addr) = args.common.magic_ipv6_addr { + builder = builder.bind_addr_v6(addr); + } + let endpoint = builder.bind().await?; + endpoint.home_relay().initialized().await?; + let node_addr = endpoint.node_addr().initialized().await?; + let mut short = node_addr.clone(); + let ticket = NodeTicket::new(node_addr); + short.direct_addresses.clear(); + let short = NodeTicket::new(short); + + eprintln!("Forwarding incoming magic connections to UDP '{}'.", args.host); + eprintln!("To connect, use e.g.:"); + eprintln!("dumbpipe connect-udp --addr 0.0.0.0:0 {ticket}"); + if args.common.verbose > 0 { + eprintln!("or: +dumbpipe connect-udp --addr 0.0.0.0:0 {short}"); + } + + async fn handle_magic_udp( + connecting: Connecting, + addrs: Vec, + handshake: bool, + ) -> Result<()> { + let conn = connecting.await.context("accept connection")?; + let remote_node_id = &conn.remote_node_id()?; + tracing::info!("got connection from {}", remote_node_id); + + if handshake { + // read the handshake and verify it + let mut buf = [0u8; dumbpipe::HANDSHAKE.len()]; + let (_s, mut r) = conn.accept_bi().await.context("accept_bi")?; + r.read_exact(&mut buf).await.e()?; + snafu::ensure_whatever!(buf == dumbpipe::HANDSHAKE, "invalid handshake"); + // we don't need the stream anymore; drop it and let the unreliable datagram API do the work + } + + let udp = tokio::net::UdpSocket::bind("0.0.0.0:0") + .await + .context("bind udp socket")?; + udp.connect(&*addrs).await.context("udp connect")?; + tracing::info!("opened UDP {} <-> {}", remote_node_id, addrs[0]); + + forward_udp_bidi(conn, udp).await + } + + loop { + let incoming = select! { + incoming = endpoint.accept() => incoming, + _ = tokio::signal::ctrl_c() => { + eprintln!("got ctrl-c, exiting"); + break; + } + }; + let Some(incoming) = incoming else { break }; + let Ok(connecting) = incoming.accept() else { continue }; + let addrs = addrs.clone(); + let handshake = !args.common.is_custom_alpn(); + tokio::spawn(async move { + if let Err(cause) = handle_magic_udp(connecting, addrs, handshake).await { + tracing::warn!("error handling connection: {}", cause); + } + }); + } + Ok(()) +} + +/// Connect to a magicsocket and forward UDP packets bidirectionally. +async fn connect_udp(args: ConnectUdpArgs) -> Result<()> { + let addrs = args + .addr + .to_socket_addrs() + .context(format!("invalid host string {}", args.addr))?; + let secret_key = get_or_create_secret()?; + let mut builder = Endpoint::builder().secret_key(secret_key).alpns(vec![]); + if let Some(addr) = args.common.magic_ipv4_addr { + builder = builder.bind_addr_v4(addr); + } + if let Some(addr) = args.common.magic_ipv6_addr { + builder = builder.bind_addr_v6(addr); + } + let endpoint = builder.bind().await.context("unable to bind magicsock")?; + + let udp = tokio::net::UdpSocket::bind(addrs.as_slice()) + .await + .context("bind udp socket")?; + // This is the fix: get the actual local address and print it for the user. + // This is important if the user specifies port 0 to get a random free port. + let local_addr = udp.local_addr().context("failed to get local udp address")?; + eprintln!("UDP listening on {}", local_addr); + tracing::info!("UDP listening on {}", local_addr); + + let addr = args.ticket.node_addr(); + let remote_node_id = addr.node_id; + let connection = endpoint + .connect(addr.clone(), &args.common.alpn()?) + .await + .context(format!("connect to {remote_node_id}"))?; + + if !args.common.is_custom_alpn() { + // send the handshake using a short-lived bidi stream + let (mut s, r) = connection.open_bi().await.context("open_bi")?; + s.write_all(&dumbpipe::HANDSHAKE).await.e()?; + // we don't need the stream anymore + drop((s, r)); + } + + tracing::info!("starting UDP <-> QUIC forwarding to {}", remote_node_id); + forward_udp_bidi(connection, udp).await +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_alpn() { + assert_eq!(parse_alpn("utf8:foo").unwrap(), b"foo"); + assert_eq!(parse_alpn("666f6f").unwrap(), b"foo"); + } +}