From bd97600c54e94ce53e1a3aa7dd692cac694017a9 Mon Sep 17 00:00:00 2001 From: Bouke van der Bijl Date: Mon, 11 Sep 2023 09:54:06 +0200 Subject: [PATCH] Do compress grpc-web responses --- tower-http/src/compression/mod.rs | 25 +++++++++++++++++++++++++ tower-http/src/compression/predicate.rs | 10 +++++++--- 2 files changed, 32 insertions(+), 3 deletions(-) diff --git a/tower-http/src/compression/mod.rs b/tower-http/src/compression/mod.rs index dbfd2364..b29110cf 100644 --- a/tower-http/src/compression/mod.rs +++ b/tower-http/src/compression/mod.rs @@ -365,6 +365,31 @@ mod tests { assert_eq!(res.headers()[CONTENT_ENCODING], "gzip"); } + #[tokio::test] + async fn does_compress_grpc_web() { + async fn handle(_req: Request) -> Result, Error> { + let mut res = Response::new(Body::from( + "a".repeat((SizeAbove::DEFAULT_MIN_SIZE * 2) as usize), + )); + res.headers_mut() + .insert(CONTENT_TYPE, "application/grpc-web+proto".parse().unwrap()); + Ok(res) + } + + let svc = Compression::new(service_fn(handle)); + + let res = svc + .oneshot( + Request::builder() + .header(ACCEPT_ENCODING, "gzip") + .body(Body::empty()) + .unwrap(), + ) + .await + .unwrap(); + assert_eq!(res.headers()[CONTENT_ENCODING], "gzip"); + } + #[tokio::test] async fn compress_with_quality() { const DATA: &str = "Check compression quality level! Check compression quality level! Check compression quality level!"; diff --git a/tower-http/src/compression/predicate.rs b/tower-http/src/compression/predicate.rs index 2bb37c22..381e5fd4 100644 --- a/tower-http/src/compression/predicate.rs +++ b/tower-http/src/compression/predicate.rs @@ -192,7 +192,10 @@ pub struct NotForContentType { impl NotForContentType { /// Predicate that wont compress gRPC responses. - pub const GRPC: Self = Self::const_new("application/grpc"); + pub const GRPC: Self = Self { + content_type: Str::Static("application/grpc"), + exception: Some(Str::Static("application/grpc-web")), + }; /// Predicate that wont compress images. pub const IMAGES: Self = Self { @@ -222,13 +225,14 @@ impl Predicate for NotForContentType { where B: Body, { + let cty = content_type(response); if let Some(except) = &self.exception { - if content_type(response) == except.as_str() { + if cty.starts_with(except.as_str()) { return true; } } - !content_type(response).starts_with(self.content_type.as_str()) + !cty.starts_with(self.content_type.as_str()) } }