diff --git a/bitreq/examples/custom_cert.rs b/bitreq/examples/custom_cert.rs new file mode 100644 index 000000000..6d88aa8c9 --- /dev/null +++ b/bitreq/examples/custom_cert.rs @@ -0,0 +1,37 @@ +//! This example demonstrates the client builder with custom DER certificate. +//! to run: cargo run --example custom_cert --features async-https-rustls + +#[cfg(not(feature = "async-https-rustls"))] +fn main() { + println!("This example requires the 'async-https-rustls' feature."); +} + +#[cfg(feature = "async-https-rustls")] +fn main() -> Result<(), bitreq::Error> { + let runtime = tokio::runtime::Builder::new_current_thread() + .enable_io() + .build() + .expect("failed to build Tokio runtime"); + + runtime.block_on(request_with_client()) +} + +#[cfg(feature = "async-https-rustls")] +async fn request_with_client() -> Result<(), bitreq::Error> { + let url = "http://example.com"; + let cert_der = include_bytes!("../tests/test_cert.der"); + let client = bitreq::Client::builder().with_root_certificate(cert_der.as_slice()).build(); + // OR + // let cert_der: &[u8] = include_bytes!("../tests/test_cert.der"); + // let client = bitreq::Client::builder().with_root_certificate(cert_der).build(); + // OR + // let cert_vec: Vec = include_bytes!("../tests/test_cert.der").to_vec(); + // let client = bitreq::Client::builder().with_root_certificate(cert_vec.as_slice()).build(); + + let response = client.send_async(bitreq::get(url)).await.unwrap(); + + println!("Status: {}", response.status_code); + println!("Body: {}", response.as_str()?); + + Ok(()) +} diff --git a/bitreq/src/client.rs b/bitreq/src/client.rs index b5de6f2fb..28daf8536 100644 --- a/bitreq/src/client.rs +++ b/bitreq/src/client.rs @@ -9,6 +9,7 @@ use std::collections::{hash_map, HashMap, VecDeque}; use std::sync::{Arc, Mutex}; +use crate::connection::certificates::Certificates; use crate::connection::AsyncConnection; use crate::request::{OwnedConnectionParams as ConnectionKey, ParsedRequest}; use crate::{Error, Request, Response}; @@ -39,10 +40,141 @@ struct ClientImpl { connections: HashMap>, lru_order: VecDeque, capacity: usize, + client_config: Option, +} + +pub struct ClientBuilder { + capacity: usize, + client_config: Option, +} + +#[derive(Clone)] +pub(crate) struct ClientConfig { + pub(crate) tls: Option, +} + +#[derive(Clone)] +pub(crate) struct TlsConfig { + pub(crate) certificates: Certificates, +} + +impl TlsConfig { + fn new(cert_der: Vec) -> Self { + let certificates = Certificates::new(Some(cert_der)).expect("failed to append certificate"); + + Self { certificates } + } +} + +/// Builder for configuring a `Client` with custom settings. +/// +/// The builder allows you to set the connection pool capacity and add +/// custom root certificates for TLS verification before constructing the client. +/// +/// # Example +/// +/// ```no_run +/// # async fn example() -> Result<(), bitreq::Error> { +/// use bitreq::{Client, RequestExt}; +/// +/// let cert_der = include_bytes!("../tests/test_cert.der"); +/// let client = Client::builder() +/// .with_root_certificate(cert_der.as_slice()) +/// .with_capacity(20) +/// .build(); +/// +/// let response = bitreq::get("https://example.com") +/// .send_async_with_client(&client) +/// .await?; +/// # Ok(()) +/// # } +/// ``` +impl ClientBuilder { + /// Creates a new `ClientBuilder` with default settings. + /// + /// Default configuration: + /// * `capacity` - 1 (single connection) + /// * `root_certificates` - None (uses system certificates) + pub fn new() -> Self { Self { capacity: 1, client_config: None } } + + /// Adds a custom root certificate for TLS verification. + /// + /// The certificate must be provided in DER format. This method accepts any type + /// that can be converted into a `Vec`, such as `Vec`, `&[u8]`, or arrays. + /// This is useful when connecting to servers using self-signed certificates + /// or custom Certificate Authorities. + /// + /// # Arguments + /// + /// * `cert_der` - A DER-encoded X.509 certificate. Accepts any type that implements + /// `Into>` (e.g., `&[u8]`, `Vec`, or `[u8; N]`). + /// + /// # Example + /// + /// ```no_run + /// # use bitreq::Client; + /// // Using a byte slice + /// let cert_der: &[u8] = include_bytes!("../tests/test_cert.der"); + /// let client = Client::builder() + /// .with_root_certificate(cert_der) + /// .build(); + /// + /// // Using a Vec + /// let cert_vec: Vec = cert_der.to_vec(); + /// let client = Client::builder() + /// .with_root_certificate(cert_vec) + /// .build(); + /// ``` + pub fn with_root_certificate>>(mut self, cert_der: T) -> Self { + let tls_config = TlsConfig::new(cert_der.into()); + self.client_config = Some(ClientConfig { tls: Some(tls_config) }); + self + } + + /// Sets the maximum number of connections to keep in the pool. + /// + /// When the pool reaches this capacity, the least recently used connection + /// is evicted to make room for new connections. + /// + /// # Arguments + /// + /// * `capacity` - Maximum number of cached connections + /// + /// # Example + /// + /// ```no_run + /// # use bitreq::Client; + /// let client = Client::builder() + /// .with_capacity(10) + /// .build(); + /// ``` + pub fn with_capacity(mut self, capacity: usize) -> Self { + self.capacity = capacity; + self + } + + /// Builds the `Client` with the configured settings. + /// + /// Consumes the builder and returns a configured `Client` instance + /// ready to send requests with connection pooling. + pub fn build(self) -> Client { + Client { + r#async: Arc::new(Mutex::new(ClientImpl { + connections: HashMap::new(), + lru_order: VecDeque::new(), + capacity: self.capacity, + client_config: self.client_config, + })), + } + } +} + +impl Default for ClientBuilder { + fn default() -> Self { Self::new() } } impl Client { - /// Creates a new `Client` with the specified connection cache capacity. + /// Creates a new `Client` with the specified connection pool capacity. /// /// # Arguments /// @@ -54,10 +186,14 @@ impl Client { connections: HashMap::new(), lru_order: VecDeque::new(), capacity, + client_config: None, })), } } + /// Create a builder for a client + pub fn builder() -> ClientBuilder { ClientBuilder::new() } + /// Sends a request asynchronously using a cached connection if available. pub async fn send_async(&self, request: Request) -> Result { let parsed_request = ParsedRequest::new(request)?; @@ -77,7 +213,13 @@ impl Client { let conn = if let Some(conn) = conn_opt { conn } else { - let connection = AsyncConnection::new(key, parsed_request.timeout_at).await?; + let client_config = { + let state = self.r#async.lock().unwrap(); + state.client_config.clone() + }; + + let connection = + AsyncConnection::new(key, parsed_request.timeout_at, client_config).await?; let connection = Arc::new(connection); let mut state = self.r#async.lock().unwrap(); diff --git a/bitreq/src/connection.rs b/bitreq/src/connection.rs index 37cbd12c6..ec76f8aff 100644 --- a/bitreq/src/connection.rs +++ b/bitreq/src/connection.rs @@ -22,6 +22,8 @@ use tokio::net::TcpStream as AsyncTcpStream; #[cfg(feature = "async")] use tokio::sync::Mutex as AsyncMutex; +#[cfg(feature = "async")] +use crate::client::ClientConfig; use crate::request::{ConnectionParams, OwnedConnectionParams, ParsedRequest}; #[cfg(feature = "async")] use crate::Response; @@ -29,6 +31,8 @@ use crate::{Error, Method, ResponseLazy}; type UnsecuredStream = TcpStream; +#[cfg(feature = "rustls")] +pub(crate) mod certificates; #[cfg(feature = "rustls")] mod rustls_stream; #[cfg(feature = "rustls")] @@ -266,15 +270,13 @@ impl AsyncConnection { pub(crate) async fn new( params: ConnectionParams<'_>, timeout_at: Option, + client_config: Option, ) -> Result { let future = async move { let socket = Self::connect(params).await?; if params.https { - #[cfg(not(feature = "tokio-rustls"))] - return Err(Error::HttpsFeatureNotEnabled); - #[cfg(feature = "tokio-rustls")] - rustls_stream::wrap_async_stream(socket, params.host).await + Self::wrap_async_stream(socket, params.host, client_config).await } else { Ok(AsyncHttpStream::Unsecured(socket)) } @@ -298,6 +300,30 @@ impl AsyncConnection { })))) } + /// Call the correct wrapper function depending on whether client_configs are present + #[cfg(all(feature = "rustls", feature = "tokio-rustls"))] + async fn wrap_async_stream( + socket: AsyncTcpStream, + host: &str, + client_config: Option, + ) -> Result { + if let Some(client_config) = client_config { + rustls_stream::wrap_async_stream_with_configs(socket, host, client_config).await + } else { + rustls_stream::wrap_async_stream(socket, host).await + } + } + + /// Error treatment function, should not be called under normal circustances + #[cfg(not(all(feature = "rustls", feature = "tokio-rustls")))] + async fn wrap_async_stream( + _socket: AsyncTcpStream, + _host: &str, + _client_config: Option, + ) -> Result { + Err(Error::HttpsFeatureNotEnabled) + } + async fn tcp_connect(host: &str, port: u16) -> Result { #[cfg(feature = "log")] log::trace!("Looking up host {host}"); @@ -447,7 +473,7 @@ impl AsyncConnection { }; (_internal) => { let new_connection = - AsyncConnection::new(request.connection_params(), request.timeout_at) + AsyncConnection::new(request.connection_params(), request.timeout_at, None) .await?; *self.0.lock().unwrap() = Arc::clone(&*new_connection.0.lock().unwrap()); core::mem::drop(read); @@ -806,7 +832,8 @@ async fn async_handle_redirects( let new_connection; if needs_new_connection { new_connection = - AsyncConnection::new(request.connection_params(), request.timeout_at).await?; + AsyncConnection::new(request.connection_params(), request.timeout_at, None) + .await?; connection = &new_connection; } connection.send(request).await diff --git a/bitreq/src/connection/certificates.rs b/bitreq/src/connection/certificates.rs new file mode 100644 index 000000000..4bfb26a73 --- /dev/null +++ b/bitreq/src/connection/certificates.rs @@ -0,0 +1,63 @@ +#[cfg(feature = "rustls")] +use rustls::RootCertStore; +#[cfg(feature = "rustls-webpki")] +use webpki_roots::TLS_SERVER_ROOTS; + +use crate::Error; + +#[derive(Clone)] +pub(crate) struct Certificates { + pub(crate) inner: RootCertStore, +} + +impl Certificates { + pub(crate) fn new(cert_der: Option>) -> Result { + let certificates = Self { inner: RootCertStore::empty() }; + + if let Some(cert_der) = cert_der { + certificates.append_certificate(cert_der) + } else { + Ok(certificates) + } + } + + #[cfg(feature = "rustls")] + pub(crate) fn append_certificate(mut self, cert_der: Vec) -> Result { + let mut certificates = self.inner; + certificates + .add(&rustls::Certificate(cert_der)) + .map_err(Error::RustlsAppendCert)?; + self.inner = certificates; + Ok(self) + } + + #[cfg(feature = "rustls")] + pub(crate) fn with_root_certificates(mut self) -> Self { + let mut root_certificates = self.inner; + + // Try to load native certs + #[cfg(feature = "https-rustls-probe")] + if let Ok(os_roots) = rustls_native_certs::load_native_certs() { + for root_cert in os_roots { + // Ignore erroneous OS certificates, there's nothing + // to do differently in that situation anyways. + let _ = root_certificates.add(&rustls::Certificate(root_cert.0)); + } + } + + #[cfg(feature = "rustls-webpki")] + { + #[allow(deprecated)] + // Need to use add_server_trust_anchors to compile with rustls 0.21.1 + root_certificates.add_server_trust_anchors(TLS_SERVER_ROOTS.iter().map(|ta| { + rustls::OwnedTrustAnchor::from_subject_spki_name_constraints( + ta.subject, + ta.spki, + ta.name_constraints, + ) + })); + } + self.inner = root_certificates; + self + } +} diff --git a/bitreq/src/connection/rustls_stream.rs b/bitreq/src/connection/rustls_stream.rs index 01a3c417f..530ba20c5 100644 --- a/bitreq/src/connection/rustls_stream.rs +++ b/bitreq/src/connection/rustls_stream.rs @@ -24,6 +24,8 @@ use webpki_roots::TLS_SERVER_ROOTS; use super::{AsyncHttpStream, AsyncTcpStream}; #[cfg(all(feature = "native-tls", not(feature = "rustls"), feature = "tokio-native-tls"))] use super::{AsyncHttpStream, AsyncTcpStream}; +#[cfg(feature = "async")] +use crate::client::ClientConfig as CustomClientConfig; use crate::Error; #[cfg(feature = "rustls")] @@ -63,6 +65,15 @@ fn build_client_config() -> Arc { Arc::new(config) } +#[cfg(feature = "rustls")] +fn build_rustls_client_config(certificates: RootCertStore) -> Arc { + let config = ClientConfig::builder() + .with_safe_defaults() + .with_root_certificates(certificates) + .with_no_client_auth(); + Arc::new(config) +} + #[cfg(feature = "rustls")] pub(super) fn wrap_stream(tcp: TcpStream, host: &str) -> Result { #[cfg(feature = "log")] @@ -106,6 +117,33 @@ pub(super) async fn wrap_async_stream( Ok(AsyncHttpStream::Secured(Box::new(tls))) } +#[cfg(all(feature = "rustls", feature = "tokio-rustls"))] +pub(super) async fn wrap_async_stream_with_configs( + tcp: AsyncTcpStream, + host: &str, + custom_client_config: CustomClientConfig, +) -> Result { + #[cfg(feature = "log")] + log::trace!("Setting up TLS parameters for {host}."); + let dns_name = match ServerName::try_from(host) { + Ok(result) => result, + Err(err) => return Err(Error::IoError(io::Error::new(io::ErrorKind::Other, err))), + }; + + let mut certificates = custom_client_config.tls.unwrap().certificates; + certificates = certificates.with_root_certificates(); + + let client_config = build_rustls_client_config(certificates.inner); + let connector = TlsConnector::from(CONFIG.get_or_init(|| client_config).clone()); + + #[cfg(feature = "log")] + log::trace!("Establishing TLS session to {host}."); + + let tls = connector.connect(dns_name, tcp).await.map_err(Error::IoError)?; + + Ok(AsyncHttpStream::Secured(Box::new(tls))) +} + #[cfg(all(feature = "native-tls", not(feature = "rustls")))] pub type SecuredStream = TlsStream; diff --git a/bitreq/src/error.rs b/bitreq/src/error.rs index 9eb4346d1..7014d73a2 100644 --- a/bitreq/src/error.rs +++ b/bitreq/src/error.rs @@ -19,6 +19,9 @@ pub enum Error { #[cfg(feature = "rustls")] /// Ran into a rustls error while creating the connection. RustlsCreateConnection(rustls::Error), + #[cfg(feature = "rustls")] + /// Ran into a rustls error while appending a certificate. + RustlsAppendCert(rustls::Error), #[cfg(feature = "native-tls")] /// Ran into a native-tls error while creating the connection. NativeTlsCreateConnection(native_tls::Error), @@ -101,6 +104,8 @@ impl fmt::Display for Error { #[cfg(feature = "rustls")] RustlsCreateConnection(err) => write!(f, "error creating rustls connection: {}", err), + #[cfg(feature = "rustls")] + RustlsAppendCert(err) => write!(f, "error appending certificate: {}", err), #[cfg(feature = "native-tls")] NativeTlsCreateConnection(err) => write!(f, "error creating native-tls connection: {err}"), MalformedChunkLength => write!(f, "non-usize chunk length with transfer-encoding: chunked"), @@ -143,6 +148,8 @@ impl error::Error for Error { InvalidUtf8InBody(err) => Some(err), #[cfg(feature = "rustls")] RustlsCreateConnection(err) => Some(err), + #[cfg(feature = "rustls")] + RustlsAppendCert(err) => Some(err), _ => None, } } diff --git a/bitreq/src/request.rs b/bitreq/src/request.rs index 2f1755782..9271a5ce9 100644 --- a/bitreq/src/request.rs +++ b/bitreq/src/request.rs @@ -338,7 +338,7 @@ impl Request { #[cfg(feature = "async")] pub async fn send_async(self) -> Result { let parsed_request = ParsedRequest::new(self)?; - AsyncConnection::new(parsed_request.connection_params(), parsed_request.timeout_at) + AsyncConnection::new(parsed_request.connection_params(), parsed_request.timeout_at, None) .await? .send(parsed_request) .await diff --git a/bitreq/tests/main.rs b/bitreq/tests/main.rs index 8d357f354..e59c81e34 100644 --- a/bitreq/tests/main.rs +++ b/bitreq/tests/main.rs @@ -16,6 +16,34 @@ async fn test_https() { assert_eq!(get_status_code(bitreq::get("https://example.com")).await, 200); } +#[tokio::test] +#[cfg(feature = "rustls")] +async fn test_https_with_client() { + setup(); + let client = bitreq::Client::new(1); + let response = client.send_async(bitreq::get("https://example.com")).await.unwrap(); + assert_eq!(response.status_code, 200); +} + +#[tokio::test] +#[cfg(feature = "rustls")] +async fn test_https_with_client_builder() { + setup(); + let client = bitreq::Client::builder().build(); + let response = client.send_async(bitreq::get("https://example.com")).await.unwrap(); + assert_eq!(response.status_code, 200); +} + +#[tokio::test] +#[cfg(feature = "rustls")] +async fn test_https_with_client_builder_and_cert() { + setup(); + let cert_der = include_bytes!("test_cert.der"); + let client = bitreq::Client::builder().with_root_certificate(cert_der.as_slice()).build(); + let response = client.send_async(bitreq::get("https://example.com")).await.unwrap(); + assert_eq!(response.status_code, 200); +} + #[tokio::test] #[cfg(feature = "json-using-serde")] async fn test_json_using_serde() { diff --git a/bitreq/tests/test_cert.der b/bitreq/tests/test_cert.der new file mode 100644 index 000000000..f8d4129e3 Binary files /dev/null and b/bitreq/tests/test_cert.der differ