diff --git a/src/ext/mod.rs b/src/ext/mod.rs index b59d809dea..a349121ff9 100644 --- a/src/ext/mod.rs +++ b/src/ext/mod.rs @@ -42,8 +42,12 @@ use bytes::Bytes; feature = "ffi" ))] use http::header::HeaderName; +#[cfg(all(feature = "http1", feature = "ffi"))] +use http::header::IntoHeaderName; +#[cfg(all(any(feature = "client", feature = "server"), feature = "http1"))] +use http::header::InvalidHeaderName; #[cfg(all(any(feature = "client", feature = "server"), feature = "http1"))] -use http::header::{HeaderMap, IntoHeaderName, ValueIter}; +use http::header::{HeaderMap, ValueIter}; #[cfg(feature = "ffi")] use std::collections::HashMap; #[cfg(feature = "http2")] @@ -157,15 +161,15 @@ impl fmt::Debug for Protocol { /// /// [`preserve_header_case`]: /client/struct.Client.html#method.preserve_header_case #[cfg(all(any(feature = "client", feature = "server"), feature = "http1"))] -#[derive(Clone, Debug)] -pub(crate) struct HeaderCaseMap(HeaderMap); +#[derive(Clone, Debug, Default)] +pub struct HeaderCaseMap(HeaderMap); #[cfg(all(any(feature = "client", feature = "server"), feature = "http1"))] impl HeaderCaseMap { /// Returns a view of all spellings associated with that header name, /// in the order they were found. #[cfg(feature = "client")] - pub(crate) fn get_all<'a>( + pub fn get_all<'a>( &'a self, name: &HeaderName, ) -> impl Iterator + 'a> + 'a { @@ -179,22 +183,62 @@ impl HeaderCaseMap { self.0.get_all(name).into_iter() } - #[cfg(any(feature = "client", feature = "server"))] - pub(crate) fn default() -> Self { - Self(Default::default()) + /// Inserts a header spelling, replacing any existing ones associated with that header name. + #[cfg(all(any(feature = "client", feature = "server"), feature = "http1"))] + pub fn insert(&mut self, name: CasedHeaderName) { + self.0.insert(name.0, name.1); } - #[cfg(any(test, feature = "ffi"))] - pub(crate) fn insert(&mut self, name: HeaderName, orig: Bytes) { - self.0.insert(name, orig); + /// Inserts a header spelling in addition to any existing ones associated with that header name. + #[cfg(all(any(feature = "client", feature = "server"), feature = "http1"))] + pub fn append(&mut self, name: CasedHeaderName) { + self.0.append(name.0, name.1); } +} - #[cfg(any(feature = "client", feature = "server"))] - pub(crate) fn append(&mut self, name: N, orig: Bytes) - where - N: IntoHeaderName, - { - self.0.append(name, orig); +/// An error converting a header name spelling to a [`CasedHeaderName`]. +#[cfg(all(any(feature = "client", feature = "server"), feature = "http1"))] +#[derive(Debug)] +pub enum CasedHeaderNameError { + /// Error parsing the header name + Invalid(InvalidHeaderName), + /// The parsed header name doesn't match the spelling's + NoMatch, +} + +/// A header casing representation, guaranteed to be a valid [`http::HeaderName`]. +#[cfg(all(any(feature = "client", feature = "server"), feature = "http1"))] +#[derive(Debug)] +pub struct CasedHeaderName(HeaderName, Bytes); + +#[cfg(all(any(feature = "client", feature = "server"), feature = "http1"))] +impl CasedHeaderName { + /// Constructs a header casing representation. + pub fn new(name: HeaderName, orig: Bytes) -> Result { + let orig_parsed = + HeaderName::from_bytes(&orig).map_err(|err| CasedHeaderNameError::Invalid(err))?; + + if orig_parsed != name { + Err(CasedHeaderNameError::NoMatch) + } else { + Ok(CasedHeaderName(name.into(), orig)) + } + } +} + +#[cfg(all(any(feature = "client", feature = "server"), feature = "http1"))] +impl TryFrom for CasedHeaderName { + type Error = InvalidHeaderName; + + fn try_from(orig: Bytes) -> Result { + HeaderName::from_bytes(&orig).map(|name| CasedHeaderName(name, orig)) + } +} + +#[cfg(all(any(feature = "client", feature = "server"), feature = "http1"))] +impl From for HeaderName { + fn from(value: CasedHeaderName) -> Self { + value.0 } } diff --git a/src/ffi/http_types.rs b/src/ffi/http_types.rs index 3dc4a2549d..264f7a56fa 100644 --- a/src/ffi/http_types.rs +++ b/src/ffi/http_types.rs @@ -7,7 +7,7 @@ use super::error::hyper_code; use super::task::{hyper_task_return_type, AsTaskType}; use super::{UserDataPointer, HYPER_ITER_CONTINUE}; use crate::body::Incoming as IncomingBody; -use crate::ext::{HeaderCaseMap, OriginalHeaderOrder, ReasonPhrase}; +use crate::ext::{CasedHeaderName, HeaderCaseMap, OriginalHeaderOrder, ReasonPhrase}; use crate::ffi::size_t; use crate::header::{HeaderName, HeaderValue}; use crate::{HeaderMap, Method, Request, Response, Uri}; @@ -513,7 +513,7 @@ ffi_fn! { match unsafe { raw_name_value(name, name_len, value, value_len) } { Ok((name, value, orig_name)) => { headers.headers.insert(&name, value); - headers.orig_casing.insert(name.clone(), orig_name.clone()); + headers.orig_casing.insert(CasedHeaderName::new(name.clone(), orig_name).unwrap()).unwrap(); headers.orig_order.insert(name); hyper_code::HYPERE_OK } @@ -533,7 +533,7 @@ ffi_fn! { match unsafe { raw_name_value(name, name_len, value, value_len) } { Ok((name, value, orig_name)) => { headers.headers.append(&name, value); - headers.orig_casing.append(&name, orig_name.clone()); + headers.orig_casing.append(CasedHeaderName::new(name.clone(), orig_name).unwrap()).unwrap(); headers.orig_order.append(name); hyper_code::HYPERE_OK } diff --git a/src/proto/h1/role.rs b/src/proto/h1/role.rs index 1674e26bc6..3f4cc77c25 100644 --- a/src/proto/h1/role.rs +++ b/src/proto/h1/role.rs @@ -315,7 +315,12 @@ impl Http1Transaction for Server { } if let Some(ref mut header_case_map) = header_case_map { - header_case_map.append(&name, slice.slice(header.name.0..header.name.1)); + use crate::ext::CasedHeaderName; + + header_case_map.append( + CasedHeaderName::new(name.clone(), slice.slice(header.name.0..header.name.1)) + .unwrap(), + ); } #[cfg(feature = "ffi")] @@ -1106,7 +1111,15 @@ impl Http1Transaction for Client { } if let Some(ref mut header_case_map) = header_case_map { - header_case_map.append(&name, slice.slice(header.name.0..header.name.1)); + use crate::ext::CasedHeaderName; + + header_case_map.append( + CasedHeaderName::new( + name.clone(), + slice.slice(header.name.0..header.name.1), + ) + .unwrap(), + ); } #[cfg(feature = "ffi")] @@ -1641,6 +1654,8 @@ fn extend(dst: &mut Vec, data: &[u8]) { mod tests { use bytes::BytesMut; + use crate::ext::CasedHeaderName; + use super::*; #[cfg(feature = "server")] @@ -2487,7 +2502,7 @@ mod tests { .insert("content-type", HeaderValue::from_static("application/json")); let mut orig_headers = HeaderCaseMap::default(); - orig_headers.insert(CONTENT_LENGTH, "CONTENT-LENGTH".into()); + orig_headers.insert(CasedHeaderName::new(CONTENT_LENGTH, "CONTENT-LENGTH".into()).unwrap()); head.extensions.insert(orig_headers); let mut vec = Vec::new(); @@ -2524,7 +2539,7 @@ mod tests { .insert("content-type", HeaderValue::from_static("application/json")); let mut orig_headers = HeaderCaseMap::default(); - orig_headers.insert(CONTENT_LENGTH, "CONTENT-LENGTH".into()); + orig_headers.insert(CasedHeaderName::new(CONTENT_LENGTH, "CONTENT-LENGTH".into()).unwrap()); head.extensions.insert(orig_headers); let mut vec = Vec::new(); @@ -2619,7 +2634,7 @@ mod tests { .insert("content-type", HeaderValue::from_static("application/json")); let mut orig_headers = HeaderCaseMap::default(); - orig_headers.insert(CONTENT_LENGTH, "CONTENT-LENGTH".into()); + orig_headers.insert(CasedHeaderName::new(CONTENT_LENGTH, "CONTENT-LENGTH".into()).unwrap()); head.extensions.insert(orig_headers); let mut vec = Vec::new(); @@ -2655,7 +2670,7 @@ mod tests { .insert("content-type", HeaderValue::from_static("application/json")); let mut orig_headers = HeaderCaseMap::default(); - orig_headers.insert(CONTENT_LENGTH, "CONTENT-LENGTH".into()); + orig_headers.insert(CasedHeaderName::new(CONTENT_LENGTH, "CONTENT-LENGTH".into()).unwrap()); head.extensions.insert(orig_headers); let mut vec = Vec::new(); @@ -2692,7 +2707,7 @@ mod tests { .insert("content-type", HeaderValue::from_static("application/json")); let mut orig_headers = HeaderCaseMap::default(); - orig_headers.insert(CONTENT_LENGTH, "CONTENT-LENGTH".into()); + orig_headers.insert(CasedHeaderName::new(CONTENT_LENGTH, "CONTENT-LENGTH".into()).unwrap()); head.extensions.insert(orig_headers); let mut vec = Vec::new(); @@ -2897,7 +2912,7 @@ mod tests { let name = http::header::HeaderName::from_static("x-empty"); headers.insert(&name, "".parse().expect("parse empty")); let mut orig_cases = HeaderCaseMap::default(); - orig_cases.insert(name, Bytes::from_static(b"X-EmptY")); + orig_cases.insert(CasedHeaderName::new(name, Bytes::from_static(b"X-EmptY")).unwrap()); let mut dst = Vec::new(); super::write_headers_original_case(&headers, &orig_cases, &mut dst, false); @@ -2916,8 +2931,9 @@ mod tests { headers.append(&name, "b".parse().unwrap()); let mut orig_cases = HeaderCaseMap::default(); - orig_cases.insert(name.clone(), Bytes::from_static(b"X-Empty")); - orig_cases.append(name, Bytes::from_static(b"X-EMPTY")); + orig_cases + .insert(CasedHeaderName::new(name.clone(), Bytes::from_static(b"X-Empty")).unwrap()); + orig_cases.append(CasedHeaderName::new(name, Bytes::from_static(b"X-EMPTY")).unwrap()); let mut dst = Vec::new(); super::write_headers_original_case(&headers, &orig_cases, &mut dst, false);