Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ clap = { version = "4.4.7", features = ["derive"] }
ctrlc = { version = "3.4.2", features = ["termination"] }
delegate = "0.12.0"
educe = { version = "0.6.0", default-features = false, features = ["Debug"] }
io-uring = "0.7.0"
ipnet = { version = "2.8.0", features = ["serde"]}
libc = "0.2.152"
lightway-app-utils = { path = "./lightway-app-utils" }
Expand All @@ -52,3 +53,4 @@ tokio-util = "0.7.10"
tracing = "0.1.37"
tracing-subscriber = "0.3.17"
twelf = { version = "0.15.0", default-features = false, features = ["env", "clap", "yaml"]}
tun = { version = "0.7.1" }
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ Protocol and design documentation can be found in the
Lightway rust implementation currently supports Linux OS. Both x86_64 and arm64 platforms are
supported and built as part of CI.

Support for other platforms will be added soon.
Support for other client platforms will be added soon.

## Development steps

Expand Down
5 changes: 3 additions & 2 deletions lightway-app-utils/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ bytes.workspace = true
clap.workspace = true
fs-mistrust = { version = "0.8.0", default-features = false }
humantime = "2.1.0"
io-uring = { version = "0.7.0", optional = true }
io-uring = { workspace = true, optional = true }
ipnet.workspace = true
libc.workspace = true
lightway-core.workspace = true
Expand All @@ -38,11 +38,12 @@ tokio-stream = { workspace = true, optional = true }
tokio-util.workspace = true
tracing.workspace = true
tracing-subscriber = { workspace = true, features = ["json"] }
tun = { version = "0.7", features = ["async"] }
tun = { workspace = true, features = ["async"] }

[[example]]
name = "udprelay"
path = "examples/udprelay.rs"
required-features = ["io-uring"]

[dev-dependencies]
async-trait.workspace = true
Expand Down
3 changes: 3 additions & 0 deletions lightway-app-utils/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ mod event_stream;
mod iouring;
mod tun;

mod net;
pub use net::{sockaddr_from_socket_addr, socket_addr_from_sockaddr};

#[cfg(feature = "tokio")]
pub use connection_ticker::{
connection_ticker_cb, ConnectionTicker, ConnectionTickerState, ConnectionTickerTask, Tickable,
Expand Down
179 changes: 179 additions & 0 deletions lightway-app-utils/src/net.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
use std::{io, net::SocketAddr};

/// Convert from `libc::sockaddr_storage` to `std::net::SocketAddr`
#[allow(unsafe_code)]
pub fn socket_addr_from_sockaddr(
storage: &libc::sockaddr_storage,
len: libc::socklen_t,
) -> io::Result<SocketAddr> {
match storage.ss_family as libc::c_int {
libc::AF_INET => {
if (len as usize) < std::mem::size_of::<libc::sockaddr_in>() {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"invalid argument (inet len)",
));
}

// SAFETY: Casting from sockaddr_storage to sockaddr_in is safe since we have validated the len.
let addr =
unsafe { &*(storage as *const libc::sockaddr_storage as *const libc::sockaddr_in) };

let ip = u32::from_be(addr.sin_addr.s_addr);
let ip = std::net::Ipv4Addr::from_bits(ip);
let port = u16::from_be(addr.sin_port);

Ok((ip, port).into())
}
libc::AF_INET6 => {
if (len as usize) < std::mem::size_of::<libc::sockaddr_in6>() {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"invalid argument (inet6 len)",
));
}
// SAFETY: Casting from sockaddr_storage to sockaddr_in6 is safe since we have validated the len.
let addr = unsafe {
&*(storage as *const libc::sockaddr_storage as *const libc::sockaddr_in6)
};

let ip = u128::from_be_bytes(addr.sin6_addr.s6_addr);
let ip = std::net::Ipv6Addr::from_bits(ip);
let port = u16::from_be(addr.sin6_port);

Ok((ip, port).into())
}
_ => Err(io::Error::new(
std::io::ErrorKind::InvalidInput,
"invalid argument (ss_family)",
)),
}
}

/// Convert from `std::net::SocketAddr` to `libc::sockaddr_storage`+`libc::socklen_t`
#[allow(unsafe_code)]
pub fn sockaddr_from_socket_addr(addr: SocketAddr) -> (libc::sockaddr_storage, libc::socklen_t) {
// SAFETY: All zeroes is a valid sockaddr_storage
let mut storage: libc::sockaddr_storage = unsafe { std::mem::zeroed() };

let len = match addr {
SocketAddr::V4(v4) => {
let p = &mut storage as *mut libc::sockaddr_storage as *mut libc::sockaddr_in;
// SAFETY: sockaddr_storage is defined to be big enough for any sockaddr_*.
unsafe {
p.write(libc::sockaddr_in {
sin_family: libc::AF_INET as _,
sin_port: v4.port().to_be(),
sin_addr: libc::in_addr {
s_addr: v4.ip().to_bits().to_be(),
},
sin_zero: Default::default(),
})
};
std::mem::size_of::<libc::sockaddr_in>() as libc::socklen_t
}
SocketAddr::V6(v6) => {
let p = &mut storage as *mut libc::sockaddr_storage as *mut libc::sockaddr_in6;
// SAFETY: sockaddr_storage is defined to be big enough for any sockaddr_*.
unsafe {
p.write(libc::sockaddr_in6 {
sin6_family: libc::AF_INET6 as _,
sin6_port: v6.port().to_be(),
sin6_flowinfo: v6.flowinfo().to_be(),
sin6_addr: libc::in6_addr {
s6_addr: v6.ip().to_bits().to_be_bytes(),
},
sin6_scope_id: v6.scope_id().to_be(),
})
};
std::mem::size_of::<libc::sockaddr_in6>() as libc::socklen_t
}
};

(storage, len)
}

#[cfg(test)]
mod tests {
#![allow(unsafe_code, clippy::undocumented_unsafe_blocks)]

use std::{
net::{IpAddr, Ipv4Addr, Ipv6Addr},
str::FromStr as _,
};

use super::*;

use test_case::test_case;

#[test]
fn socket_addr_from_sockaddr_unknown_af() {
// Test assumes these don't match the zero initialized
// libc::sockaddr_storage::ss_family.
assert_ne!(libc::AF_INET, 0);
assert_ne!(libc::AF_INET6, 0);

let storage = unsafe { std::mem::zeroed() };
let err =
socket_addr_from_sockaddr(&storage, std::mem::size_of::<libc::sockaddr_storage>() as _)
.unwrap_err();

assert!(matches!(err.kind(), std::io::ErrorKind::InvalidInput));
assert!(err.to_string().contains("invalid argument (ss_family)"));
}

#[test]
fn socket_addr_from_sockaddr_unknown_af_inet_short() {
let mut storage: libc::sockaddr_storage = unsafe { std::mem::zeroed() };
storage.ss_family = libc::AF_INET as libc::sa_family_t;

let err = socket_addr_from_sockaddr(
&storage,
(std::mem::size_of::<libc::sockaddr_in>() - 1) as _,
)
.unwrap_err();

assert!(matches!(err.kind(), std::io::ErrorKind::InvalidInput));
assert!(err.to_string().contains("invalid argument (inet len)"));
}

#[test]
fn socket_addr_from_sockaddr_unknown_af_inet6_short() {
let mut storage: libc::sockaddr_storage = unsafe { std::mem::zeroed() };
storage.ss_family = libc::AF_INET6 as libc::sa_family_t;

let err = socket_addr_from_sockaddr(
&storage,
(std::mem::size_of::<libc::sockaddr_in6>() - 1) as _,
)
.unwrap_err();

assert!(matches!(err.kind(), std::io::ErrorKind::InvalidInput));
assert!(err.to_string().contains("invalid argument (inet6 len)"));
}

#[test]
fn sockaddr_from_socket_addr_inet() {
let socket_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080);
let (storage, len) = sockaddr_from_socket_addr(socket_addr);
assert_eq!(storage.ss_family, libc::AF_INET as libc::sa_family_t);
assert_eq!(len as usize, std::mem::size_of::<libc::sockaddr_in>());
}

#[test]
fn sockaddr_from_socket_addr_inet6() {
let socket_addr = SocketAddr::new(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)), 8080);
let (storage, len) = sockaddr_from_socket_addr(socket_addr);
assert_eq!(storage.ss_family, libc::AF_INET6 as libc::sa_family_t);
assert_eq!(len as usize, std::mem::size_of::<libc::sockaddr_in6>());
}

#[test_case("127.0.0.1:443")]
#[test_case("[::1]:8888")]
fn round_trip(addr: &str) {
let orig = SocketAddr::from_str(addr).unwrap();
let (storage, len) = sockaddr_from_socket_addr(orig);
let round_tripped = socket_addr_from_sockaddr(&storage, len).unwrap();
assert_eq!(orig, round_tripped)
}
}
6 changes: 3 additions & 3 deletions lightway-client/src/io/outside/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::{net::SocketAddr, sync::Arc};
use tokio::net::TcpStream;

use super::OutsideIO;
use lightway_core::{IOCallbackResult, OutsideIOSendCallback, OutsideIOSendCallbackArg};
use lightway_core::{CowBytes, IOCallbackResult, OutsideIOSendCallback, OutsideIOSendCallbackArg};

pub struct Tcp(tokio::net::TcpStream, SocketAddr);

Expand Down Expand Up @@ -58,8 +58,8 @@ impl OutsideIO for Tcp {
}

impl OutsideIOSendCallback for Tcp {
fn send(&self, buf: &[u8]) -> IOCallbackResult<usize> {
match self.0.try_write(buf) {
fn send(&self, buf: CowBytes) -> IOCallbackResult<usize> {
match self.0.try_write(buf.as_bytes()) {
Ok(nr) => IOCallbackResult::Ok(nr),
Err(err) if matches!(err.kind(), std::io::ErrorKind::WouldBlock) => {
IOCallbackResult::WouldBlock
Expand Down
6 changes: 3 additions & 3 deletions lightway-client/src/io/outside/udp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use tokio::net::UdpSocket;

use super::OutsideIO;
use lightway_app_utils::sockopt;
use lightway_core::{IOCallbackResult, OutsideIOSendCallback, OutsideIOSendCallbackArg};
use lightway_core::{CowBytes, IOCallbackResult, OutsideIOSendCallback, OutsideIOSendCallbackArg};

pub struct Udp {
sock: tokio::net::UdpSocket,
Expand Down Expand Up @@ -67,8 +67,8 @@ impl OutsideIO for Udp {
}

impl OutsideIOSendCallback for Udp {
fn send(&self, buf: &[u8]) -> IOCallbackResult<usize> {
match self.sock.try_send_to(buf, self.peer_addr) {
fn send(&self, buf: CowBytes) -> IOCallbackResult<usize> {
match self.sock.try_send_to(buf.as_bytes(), self.peer_addr) {
Ok(nr) => IOCallbackResult::Ok(nr),
Err(err) if matches!(err.kind(), std::io::ErrorKind::WouldBlock) => {
IOCallbackResult::WouldBlock
Expand Down
16 changes: 10 additions & 6 deletions lightway-core/src/connection/io_adapter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ use more_asserts::*;
use wolfssl::IOCallbackResult;

use crate::{
plugin::PluginList, wire, ConnectionType, OutsideIOSendCallbackArg, PluginResult, Version,
plugin::PluginList, wire, ConnectionType, CowBytes, OutsideIOSendCallbackArg, PluginResult,
Version,
};

pub(crate) struct SendBuffer {
Expand Down Expand Up @@ -164,26 +165,28 @@ impl WolfSSLIOAdapter {
}
}

let b = b.freeze();

// Send header + buf. If we are in aggressive mode we send it
// a total of three times. On any send error we return
// immediately without the remaining tries, otherwise we
// return the result of the final attempt.

if self.aggressive_send {
match self.io.send(&b[..]) {
match self.io.send(CowBytes::Owned(b.clone())) {
IOCallbackResult::Ok(_) => {}
wb @ IOCallbackResult::WouldBlock => return wb,
err @ IOCallbackResult::Err(_) => return err,
}

match self.io.send(&b[..]) {
match self.io.send(CowBytes::Owned(b.clone())) {
IOCallbackResult::Ok(_) => {}
wb @ IOCallbackResult::WouldBlock => return wb,
err @ IOCallbackResult::Err(_) => return err,
}
}

match self.io.send(&b[..]) {
match self.io.send(CowBytes::Owned(b)) {
IOCallbackResult::Ok(n) => {
// We've sent `n` bytes successfully out of
// `wire::Header::WIRE_SIZE` + `b.len()` that we
Expand Down Expand Up @@ -250,7 +253,7 @@ impl WolfSSLIOAdapter {
debug_assert_le!(send_buffer.original_len(), buf.len());
}

match self.io.send(send_buffer.as_bytes()) {
match self.io.send(CowBytes::Borrowed(send_buffer.as_bytes())) {
IOCallbackResult::Ok(n) if n == send_buffer.actual_len() => {
// We've now sent everything we were originally
// asked to, so signal completion of that original
Expand Down Expand Up @@ -335,7 +338,8 @@ mod tests {
}

impl OutsideIOSendCallback for FakeOutsideIOSend {
fn send(&self, buf: &[u8]) -> IOCallbackResult<usize> {
fn send(&self, buf: CowBytes) -> IOCallbackResult<usize> {
let buf = buf.as_bytes();
let (fakes, sent) = &mut *self.0.lock().unwrap();
match fakes.pop_front() {
Some(IOCallbackResult::Ok(n)) => {
Expand Down
Loading
Loading