Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
750 changes: 431 additions & 319 deletions Cargo.lock

Large diffs are not rendered by default.

8 changes: 3 additions & 5 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ edition = "2024"
rust-version = "1.85"
license = "MIT"
description = "A flexible and lightweight messaging library for distributed systems"
authors = ["Jonas Bostoen", "Nicolas Racchi"]
authors = ["Chainbound Developers <[email protected]>"]
homepage = "https://github.com/chainbound/msg-rs"
repository = "https://github.com/chainbound/msg-rs"
keywords = [
Expand Down Expand Up @@ -57,10 +57,8 @@ rustc-hash = "1"
rand = "0.8"

# networking
quinn = "0.10"
# (rustls needs to be the same version as the one used by quinn)
rustls = { version = "0.21", features = ["quic", "dangerous_configuration"] }
rcgen = "0.12"
quinn = "0.11.9"
rcgen = "0.14"

# benchmarking & profiling
criterion = { version = "0.5", features = ["async_tokio"] }
Expand Down
2 changes: 1 addition & 1 deletion justfile
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@ fmt:
cargo +nightly fmt --all -- --check

test:
cargo nextest run --workspace --retries 3
cargo nextest run --workspace --all-features --retries 3
1 change: 1 addition & 0 deletions msg-socket/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ parking_lot.workspace = true

[dev-dependencies]
rand.workspace = true
msg-transport = { workspace = true, features = ["quic"] }

msg-sim.workspace = true

Expand Down
10 changes: 7 additions & 3 deletions msg-transport/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,13 @@ tokio.workspace = true
tracing.workspace = true
thiserror.workspace = true

quinn.workspace = true
rustls.workspace = true
rcgen.workspace = true
# QUIC
quinn = { workspace = true, optional = true }
rcgen = { workspace = true, optional = true }

[dev-dependencies]
tracing-subscriber = "0.3"

[features]
default = []
quic = ["dep:quinn", "dep:rcgen"]
50 changes: 40 additions & 10 deletions msg-transport/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use futures::{Future, FutureExt};
use tokio::io::{AsyncRead, AsyncWrite};

pub mod ipc;
#[cfg(feature = "quic")]
pub mod quic;
pub mod tcp;

Expand Down Expand Up @@ -77,14 +78,28 @@ pub trait TransportExt<A: Address>: Transport<A> {
}
}

pub struct Acceptor<'a, T, A> {
/// An `await`-friendly interface for accepting inbound connections.
///
/// This struct is used to accept inbound connections from a transport. It is
/// created using the [`TransportExt::accept`] method.
pub struct Acceptor<'a, T, A>
where
T: Transport<A>,
A: Address,
{
inner: &'a mut T,
/// The pending [`Transport::Accept`] future.
pending: Option<T::Accept>,
_marker: PhantomData<A>,
}

impl<'a, T, A> Acceptor<'a, T, A> {
impl<'a, T, A> Acceptor<'a, T, A>
where
T: Transport<A>,
A: Address,
{
fn new(inner: &'a mut T) -> Self {
Self { inner, _marker: PhantomData }
Self { inner, pending: None, _marker: PhantomData }
}
}

Expand All @@ -96,13 +111,28 @@ where
type Output = Result<T::Io, T::Error>;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match Pin::new(&mut *self.get_mut().inner).poll_accept(cx) {
Poll::Ready(mut accept) => match accept.poll_unpin(cx) {
Poll::Ready(Ok(output)) => Poll::Ready(Ok(output)),
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
Poll::Pending => Poll::Pending,
},
Poll::Pending => Poll::Pending,
let this = self.get_mut();

loop {
// If there's a pending accept future, poll it to completion
if let Some(pending) = this.pending.as_mut() {
match pending.poll_unpin(cx) {
Poll::Ready(res) => {
this.pending = None;
return Poll::Ready(res);
}
Poll::Pending => return Poll::Pending,
}
}

// Otherwise, poll the transport for a new accept future
match Pin::new(&mut *this.inner).poll_accept(cx) {
Poll::Ready(accept) => {
this.pending = Some(accept);
continue;
}
Poll::Pending => return Poll::Pending,
}
}
}
}
Expand Down
17 changes: 6 additions & 11 deletions msg-transport/src/quic/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ use std::{sync::Arc, time::Duration};

use quinn::{IdleTimeout, congestion::ControllerFactory};

use super::tls::{self_signed_certificate, unsafe_client_config};
use crate::quic::tls::tls_server_config;

use super::tls::unsafe_client_config;

use msg_common::constants::MiB;

Expand Down Expand Up @@ -96,17 +98,13 @@ where
.min_mtu(self.initial_mtu)
.allow_spin(false)
.stream_receive_window((8 * stream_rwnd).into())
.congestion_controller_factory(self.cc)
.congestion_controller_factory(Arc::new(self.cc))
.initial_rtt(Duration::from_millis(self.expected_rtt.into()))
.send_window((8 * stream_rwnd).into());

let transport = Arc::new(transport);
let (cert, key) = self_signed_certificate();

let mut server_config =
quinn::ServerConfig::with_single_cert(cert, key).expect("Valid rustls config");

server_config.use_retry(true);
let mut server_config = quinn::ServerConfig::with_crypto(Arc::new(tls_server_config()));
server_config.transport_config(Arc::clone(&transport));

let mut client_config = quinn::ClientConfig::new(Arc::new(unsafe_client_config()));
Expand Down Expand Up @@ -154,12 +152,9 @@ impl Default for Config {
.send_window((8 * STREAM_RWND).into());

let transport = Arc::new(transport);
let (cert, key) = self_signed_certificate();

let mut server_config =
quinn::ServerConfig::with_single_cert(cert, key).expect("Valid rustls config");
let mut server_config = quinn::ServerConfig::with_crypto(Arc::new(tls_server_config()));

server_config.use_retry(true);
server_config.transport_config(Arc::clone(&transport));

let mut client_config = quinn::ClientConfig::new(Arc::new(unsafe_client_config()));
Expand Down
26 changes: 15 additions & 11 deletions msg-transport/src/quic/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ use stream::QuicStream;

mod tls;

pub(crate) const ALPN_PROTOCOL: &[u8] = b"msg";

/// A QUIC error.
#[derive(Debug, Error)]
pub enum Error {
Expand Down Expand Up @@ -52,7 +54,7 @@ pub struct Quic {
endpoint: Option<quinn::Endpoint>,

/// A receiver for incoming connections waiting to be handled.
incoming: Option<Receiver<Result<quinn::Connecting, Error>>>,
incoming: Option<Receiver<Result<quinn::Incoming, Error>>>,
}

impl Quic {
Expand Down Expand Up @@ -111,25 +113,26 @@ impl Transport<SocketAddr> for Quic {
let endpoint = if let Some(endpoint) = self.endpoint.clone() {
endpoint
} else {
let Ok(endpoint) = self.new_endpoint(None, None) else {
let Ok(mut endpoint) = self.new_endpoint(None, None) else {
return async_error(Error::ClosedEndpoint);
};

endpoint.set_default_client_config(self.config.client_config.clone());

self.endpoint = Some(endpoint.clone());

endpoint
};

let client_config = self.config.client_config.clone();

Box::pin(async move {
debug!(target = %addr, "Initiating connection");

// This `"l"` seems necessary because an empty string is an invalid domain
// name. While we don't use domain names, the underlying rustls library
// is based upon the assumption that we do.
let connection =
endpoint.connect_with(client_config, addr, "l")?.await.map_err(Error::from)?;
let connection = endpoint.connect(addr, "l")?.await?;

debug!("Connected to {}, opening stream", addr);
debug!(target = %addr, "Connected, opening stream...");

// Open a bi-directional stream and return it. We'll think about multiplexing per topic
// later.
Expand All @@ -149,14 +152,15 @@ impl Transport<SocketAddr> for Quic {
if let Some(ref mut incoming) = this.incoming {
// Incoming channel and task are spawned, so we can poll it.
match ready!(incoming.poll_recv(cx)) {
Some(Ok(connecting)) => {
let peer = connecting.remote_address();
Some(Ok(incoming)) => {
let peer = incoming.remote_address();

debug!("New incoming connection from {}", peer);

// Return a future that resolves to the output.
return Poll::Ready(Box::pin(async move {
let connection = connecting.await.map_err(Error::from)?;
debug!(client = %peer, "Accepting connection...");
let connection = incoming.accept()?.await?;
debug!(
"Accepted connection from {}, opening stream",
connection.remote_address()
Expand Down Expand Up @@ -236,7 +240,7 @@ mod tests {
use super::*;

#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_quic_connection() {
async fn test_quic_connection_simple() {
let _ = tracing_subscriber::fmt::try_init();

let config = Config::default();
Expand Down
115 changes: 97 additions & 18 deletions msg-transport/src/quic/tls.rs
Original file line number Diff line number Diff line change
@@ -1,41 +1,120 @@
use std::sync::Arc;

use rustls::client::{ServerCertVerified, ServerCertVerifier};
use quinn::{
crypto::rustls::{QuicClientConfig, QuicServerConfig},
rustls::{
self, SignatureScheme,
client::danger::{ServerCertVerified, ServerCertVerifier},
pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer},
},
};

use crate::quic::ALPN_PROTOCOL;

/// A server certificate verifier that automatically passes all checks.
#[derive(Debug)]
pub(crate) struct SkipServerVerification;
pub(crate) struct SkipServerVerification(Arc<rustls::crypto::CryptoProvider>);

impl SkipServerVerification {
fn new() -> Arc<Self> {
Arc::new(Self(Arc::new(rustls::crypto::ring::default_provider())))
}
}

impl ServerCertVerifier for SkipServerVerification {
fn verify_server_cert(
&self,
_end_entity: &rustls::Certificate,
_intermediates: &[rustls::Certificate],
_server_name: &rustls::ServerName,
_scts: &mut dyn Iterator<Item = &[u8]>,
_end_entity: &rustls::pki_types::CertificateDer<'_>,
_intermediates: &[rustls::pki_types::CertificateDer<'_>],
_server_name: &rustls::pki_types::ServerName<'_>,
_ocsp_response: &[u8],
_now: std::time::SystemTime,
) -> Result<rustls::client::ServerCertVerified, rustls::Error> {
_now: rustls::pki_types::UnixTime,
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
tracing::debug!("skipping server verification");
Ok(ServerCertVerified::assertion())
}

fn verify_tls12_signature(
&self,
message: &[u8],
cert: &rustls::pki_types::CertificateDer<'_>,
dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
tracing::debug!("verifying TLS 1.2 signature");
rustls::crypto::verify_tls12_signature(
message,
cert,
dss,
&self.0.signature_verification_algorithms,
)
}

fn verify_tls13_signature(
&self,
message: &[u8],
cert: &rustls::pki_types::CertificateDer<'_>,
dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
tracing::debug!("verifying TLS 1.3 signature");
rustls::crypto::verify_tls13_signature(
message,
cert,
dss,
&self.0.signature_verification_algorithms,
)
}

fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
self.0.signature_verification_algorithms.supported_schemes()
}
}

/// Returns a TLS configuration that skips all server verification and doesn't do any client
/// authentication.
pub(crate) fn unsafe_client_config() -> rustls::ClientConfig {
rustls::ClientConfig::builder()
.with_safe_defaults()
.with_custom_certificate_verifier(Arc::new(SkipServerVerification))
/// authentication, with the correct ALPN protocol.
pub(crate) fn unsafe_client_config() -> QuicClientConfig {
let provider = Arc::new(rustls::crypto::ring::default_provider());

let mut rustls_config = rustls::ClientConfig::builder_with_provider(provider)
.with_protocol_versions(&[&rustls::version::TLS13])
.expect("ring provider supports TLS 1.3")
.dangerous()
.with_custom_certificate_verifier(SkipServerVerification::new())
.with_no_client_auth();

rustls_config.alpn_protocols = vec![ALPN_PROTOCOL.to_vec()];
rustls_config.enable_early_data = true;

rustls_config.try_into().expect("Valid rustls config")
}

/// Returns a self-signed TLS server configuration that doesn't do any client authentication, with
/// the correct ALPN protocol.
pub(crate) fn tls_server_config() -> QuicServerConfig {
let (cert_chain, key_der) = self_signed_certificate();
let provider = Arc::new(rustls::crypto::ring::default_provider());

let mut rustls_config = rustls::ServerConfig::builder_with_provider(provider)
.with_protocol_versions(&[&rustls::version::TLS13])
.expect("ring provider supports TLS 1.3")
.with_no_client_auth()
.with_single_cert(cert_chain, key_der)
.expect("Valid rustls config");

rustls_config.alpn_protocols = vec![ALPN_PROTOCOL.to_vec()];
rustls_config.max_early_data_size = u32::MAX;

rustls_config.try_into().expect("Valid rustls config")
}

/// Generates a self-signed certificate chain and private key.
pub(crate) fn self_signed_certificate() -> (Vec<rustls::Certificate>, rustls::PrivateKey) {
let cert = rcgen::generate_simple_self_signed(vec![]).expect("Generates valid certificate");
let cert_der = cert.serialize_der().expect("Serializes certificate");
let priv_key = rustls::PrivateKey(cert.serialize_private_key_der());
pub(crate) fn self_signed_certificate() -> (Vec<CertificateDer<'static>>, PrivateKeyDer<'static>) {
let cert = rcgen::generate_simple_self_signed(vec!["localhost".to_string()])
.expect("Generates valid certificate");

let cert_der = CertificateDer::from(cert.cert);
let priv_key = PrivatePkcs8KeyDer::from(cert.signing_key.serialize_der());

(vec![rustls::Certificate(cert_der)], priv_key)
(vec![cert_der], priv_key.into())
}

#[cfg(test)]
Expand Down
Loading
Loading