diff --git a/hyperactor/Cargo.toml b/hyperactor/Cargo.toml index 1fc5ec007..e9e590bd7 100644 --- a/hyperactor/Cargo.toml +++ b/hyperactor/Cargo.toml @@ -43,7 +43,6 @@ opentelemetry = "0.29" paste = "1.0.14" rand = { version = "0.8", features = ["small_rng"] } regex = "1.11.1" -rustls = "0.21.6" rustls-pemfile = "1.0.0" rustls-webpki = { version = "0.101.4", features = ["alloc", "std"], default-features = false } serde = { version = "1.0.185", features = ["derive", "rc"] } @@ -54,7 +53,7 @@ serde_yaml = "0.9.25" signal-hook-tokio = { version = "0.3", features = ["futures-v0_3"] } thiserror = "2.0.12" tokio = { version = "1.46.1", features = ["full", "test-util", "tracing"] } -tokio-rustls = { git = "https://github.com/shayne-fletcher/tokio-rustls", rev = "62b6a48e4c14a05c193508b9d98a0be6b0cb4baa", features = ["dangerous_configuration"] } +tokio-rustls = "0.26.2" tokio-stream = { version = "0.1.17", features = ["fs", "io-util", "net", "signal", "sync", "time"] } tokio-util = { version = "0.7.15", features = ["full"] } tracing = { version = "0.1.41", features = ["attributes", "valuable"] } diff --git a/hyperactor/src/channel/net.rs b/hyperactor/src/channel/net.rs index 39c1a5659..9a67c1861 100644 --- a/hyperactor/src/channel/net.rs +++ b/hyperactor/src/channel/net.rs @@ -1630,14 +1630,13 @@ pub(crate) mod meta { use anyhow::Context; use anyhow::Result; - use rustls::RootCertStore; + use tokio_rustls::rustls::RootCertStore; + use tokio_rustls::rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName}; use tokio::net::TcpListener; use tokio::net::TcpStream; use tokio_rustls::TlsAcceptor; use tokio_rustls::TlsConnector; use tokio_rustls::client::TlsStream; - use tokio_rustls::rustls::Certificate; - use tokio_rustls::rustls::PrivateKey; use super::*; use crate::RemoteMessage; @@ -1665,24 +1664,15 @@ pub(crate) mod meta { /// Returns the root cert store fn root_cert_store() -> Result { - let mut root_cert_store = rustls::RootCertStore::empty(); + let mut root_cert_store = tokio_rustls::rustls::RootCertStore::empty(); let ca_cert_path = std::env::var_os(THRIFT_TLS_SRV_CA_PATH_ENV).unwrap_or(DEFAULT_SRV_CA_PATH.into()); let ca_certs = rustls_pemfile::certs(&mut BufReader::new( File::open(ca_cert_path).context("Failed to open {ca_cert_path:?}")?, ))?; - let trust_anchors = ca_certs.iter().filter_map(|cert| { - webpki::TrustAnchor::try_from_cert_der(&cert[..]) - .map(|ta| { - rustls::OwnedTrustAnchor::from_subject_spki_name_constraints( - ta.subject, - ta.spki, - ta.name_constraints, - ) - }) - .ok() - }); - root_cert_store.add_trust_anchors(trust_anchors); + for cert in ca_certs { + root_cert_store.add(cert.into()).context("Failed to add certificate to root store")?; + } Ok(root_cert_store) } @@ -1693,7 +1683,7 @@ pub(crate) mod meta { File::open(server_cert_path).context("failed to open {server_cert_path}")?, ))? .into_iter() - .map(Certificate) + .map(CertificateDer::from) .collect(); // certs are good here let server_key_path = DEFAULT_SERVER_PEM_PATH; @@ -1712,22 +1702,22 @@ pub(crate) mod meta { }; }; - let config = rustls::ServerConfig::builder().with_safe_defaults(); + let config = tokio_rustls::rustls::ServerConfig::builder(); let config = if enforce_client_tls { - let client_cert_verifier = Arc::new(rustls::server::AllowAnyAuthenticatedClient::new( - root_cert_store()?, - )); + let client_cert_verifier = tokio_rustls::rustls::server::WebPkiClientVerifier::builder( + Arc::new(root_cert_store()?) + ).build().map_err(|e| anyhow::anyhow!("Failed to build client verifier: {}", e))?; config.with_client_cert_verifier(client_cert_verifier) } else { config.with_no_client_auth() } - .with_single_cert(certs, PrivateKey(key))?; + .with_single_cert(certs, PrivateKeyDer::try_from(key).map_err(|_| anyhow::anyhow!("Invalid private key format"))?)?; Ok(TlsAcceptor::from(Arc::new(config))) } - fn load_client_pem() -> Result, rustls::PrivateKey)>> { + fn load_client_pem() -> Result>, PrivateKeyDer<'static>)>> { let Some(cert_path) = std::env::var_os(THRIFT_TLS_CL_CERT_PATH_ENV) else { return Ok(None); }; @@ -1738,7 +1728,7 @@ pub(crate) mod meta { File::open(cert_path).context("failed to open {cert_path}")?, ))? .into_iter() - .map(rustls::Certificate) + .map(CertificateDer::from) .collect(); let mut key_reader = BufReader::new(File::open(key_path).context("failed to open {key_path}")?); @@ -1752,15 +1742,13 @@ pub(crate) mod meta { }; }; // Certs are verified to be good here. - Ok(Some((certs, rustls::PrivateKey(key)))) + Ok(Some((certs, PrivateKeyDer::try_from(key).map_err(|_| anyhow::anyhow!("Invalid private key format"))?))) } /// Creates a TLS connector by looking for necessary certs and keys in a Meta server environment. fn tls_connector() -> Result { // TODO (T208180540): try to simplify the logic here. - let config = rustls::ClientConfig::builder() - .with_safe_defaults() - .with_root_certificates(root_cert_store()?); + let config = tokio_rustls::rustls::ClientConfig::builder().with_root_certificates(Arc::new(root_cert_store()?)); let result = load_client_pem()?; let config = if let Some((certs, key)) = result { config @@ -1772,9 +1760,9 @@ pub(crate) mod meta { Ok(TlsConnector::from(Arc::new(config))) } - fn tls_connector_config(peer_host_name: &str) -> Result<(TlsConnector, rustls::ServerName)> { + fn tls_connector_config(peer_host_name: &str) -> Result<(TlsConnector, ServerName<'static>)> { let connector = tls_connector()?; - let server_name = rustls::ServerName::try_from(peer_host_name)?; + let server_name = ServerName::try_from(peer_host_name.to_string())?; Ok((connector, server_name)) }