diff --git a/tower-http/Cargo.toml b/tower-http/Cargo.toml index b88ddeb3..598c376f 100644 --- a/tower-http/Cargo.toml +++ b/tower-http/Cargo.toml @@ -20,6 +20,8 @@ futures-util = { version = "0.3.14", default_features = false, features = [] } http = "0.2.7" http-body = "0.4.5" pin-project-lite = "0.2.7" +# Required for some semantics in some middlewares that pin-project-lite does not support +pin-project = "1.1.3" tower-layer = "0.3" tower-service = "0.3" @@ -59,6 +61,7 @@ full = [ "auth", "catch-panic", "compression-full", + "conditional-response", "cors", "decompression-full", "follow-redirect", @@ -83,6 +86,7 @@ full = [ add-extension = [] auth = ["base64", "validate-request"] catch-panic = ["tracing", "futures-util/std"] +conditional-response = [] cors = [] follow-redirect = ["iri-string", "tower/util"] fs = ["tokio/fs", "tokio-util/io", "tokio/io-util", "dep:http-range-header", "mime_guess", "mime", "percent-encoding", "httpdate", "set-status", "futures-util/alloc", "tracing"] diff --git a/tower-http/src/builder.rs b/tower-http/src/builder.rs index 2cb4f94a..c8985a3d 100644 --- a/tower-http/src/builder.rs +++ b/tower-http/src/builder.rs @@ -66,6 +66,17 @@ pub trait ServiceBuilderExt: crate::sealed::Sealed + Sized { value: T, ) -> ServiceBuilder, L>>; + /// Conditionally bypass the inner service by providing an "early" response. + /// + /// See [`tower_http::conditional_response`] for more details. + /// + /// [`tower_http::conditional_response`]: crate::conditional_response + #[cfg(feature = "conditional-response")] + fn conditional_response( + self, + responder: R, + ) -> ServiceBuilder, L>>; + /// Apply a transformation to the request body. /// /// See [`tower_http::map_request_body`] for more details. @@ -388,6 +399,14 @@ impl ServiceBuilderExt for ServiceBuilder { self.layer(crate::add_extension::AddExtensionLayer::new(value)) } + #[cfg(feature = "conditional-response")] + fn conditional_response( + self, + responder: R, + ) -> ServiceBuilder, L>> { + self.layer(crate::conditional_response::ConditionalResponseLayer::new(responder)) + } + #[cfg(feature = "map-request-body")] fn map_request_body( self, diff --git a/tower-http/src/conditional_response.rs b/tower-http/src/conditional_response.rs new file mode 100644 index 00000000..3c6095a3 --- /dev/null +++ b/tower-http/src/conditional_response.rs @@ -0,0 +1,317 @@ +//! +//! Conditionally provide a response instead of calling the inner service. +//! +//! This middleware provides a way to conditionally skip calling the inner service +//! if a response is already available for the request. +//! +//! Probably the simplest visual for this is providing a cached response, though it +//! is unlikely that this middleware is suitable for a robust response cache interface +//! (or, more accurately, it's not the motivation for developing this so I haven't +//! looked into it adequately enough to provide a robust argument for it being so!). +//! +//! The premise is simple - write a (non-async) function that assesses the current request +//! for the possibility of providing a response before invoking the inner service. Return +//! the "early" response if that is possible, otherwise return the request. +//! +//! The differences between using this and returning an error from a pre-inner layer are. +//! +//! 1. The response will still pass through any _post_-inner layer processing +//! 2. You aren't "hacking" the idea of an error when all you are trying to do is avoid +//! calling the inner service when it isn't necessary. +//! +//! Possible uses: +//! +//! * A pre-inner layer produces a successful response before the inner service is called +//! * Caching (though see above - this could, however, be the layer that skips the inner +//! call while a more robust pre-inner layer implements the actual caching) +//! * Mocking +//! * Debugging +//! * ... + +//! The function signature has to be: +//! +//! ```ignore +//! fn responder(request: request type) -> conditional_response::ConditionalResponse +//! ``` +//! +//! Note in particular that there is no [`Result`] - if you have an error you should just generate an error response +//! (or panic and rely on a panic_trapping layer to sort things out) +//! +//! # Example +//! +//! ```rust +//! use http::{Request, Response}; +//! use std::convert::Infallible; +//! use tower::{Service, ServiceExt, ServiceBuilder}; +//! use tower_http::conditional_response::ConditionalResponse; +//! use tower_http::ServiceBuilderExt; +//! +//! # #[tokio::main] +//! # async fn main() -> Result<(), Box> { +//! +//! // +//! // The responder function here decides whether to return an early response based +//! // upon the presence and value of a specific header.L +//! // +//! fn responder(request: Request) -> ConditionalResponse,Response> { +//! match request.headers().get("x-so-we-skip") { +//! Some(a) if a.to_str().unwrap() == "true" => ConditionalResponse::Response(Response::new("We skipped it".to_string())), +//! _ => ConditionalResponse::Request(request) +//! } +//! } +//! +//! async fn handle(_req: Request) -> Result, Infallible> { +//! // ... +//! Ok(Response::new("We ran it".to_string())) +//! } +//! +//! let mut svc = ServiceBuilder::new() +//! // +//! // Directly wrap the target service with the conditional responder layer +//! // +//! .conditional_response(responder) +//! .service_fn(handle); +//! +//! let request = Request::builder().header("x-so-we-skip","true").body("".to_string()).expect("Expected an empty body"); + +//! // Call the service. +//! let ready = futures::executor::block_on(svc.ready()).expect("Expected the service to be ready"); +//! let response = futures::executor::block_on(ready.call(request)).expect("Expected the service to be successful"); +//! assert_eq!(response.body(), "We skipped it"); +//! # +//! # Ok(()) +//! # } +//! ``` + +use http::{Request, Response}; +use std::future::Future; +use std::{ + pin::Pin, + task::{Context, Poll}, +}; +use tower_layer::Layer; +use tower_service::Service; +use pin_project::pin_project; + +/// Layer that applies [`ConditionalResponseService`] which allows the caller to generate/return a response instead of calling the +/// inner service - useful for stacks where a default response (rather than an error) is determined by a pre-service +/// filter. +/// +/// See the [module docs](crate::conditional_response) for more details. +#[derive(Clone, Debug)] +pub struct ConditionalResponseLayer

{ + responder: P +} + +impl

ConditionalResponseLayer

+{ + /// Create a new [`ConditionalResponseLayer`]. + pub fn new(responder:P) -> Self { + Self { responder } + } +} + +impl Layer for ConditionalResponseLayer

+where + P: Clone +{ + type Service = ConditionalResponseService; + + fn layer(&self, inner: S) -> Self::Service { + ConditionalResponseService:: { + inner, + responder: self.responder.clone(), + } + } +} + +/// Middleware that conditionally provides a response to a request in lieu of calling the inner service. +/// +/// See the [module docs](crate::conditional_response) for more details. +#[derive(Clone,Debug)] +pub struct ConditionalResponseService { + inner: S, + responder: P, +} + +impl ConditionalResponseService +{ + /// Create a new [`ConditionalResponseService`] with the inner service and the "responder" function. + pub fn new(inner: S, responder: P) -> Self { + Self { inner, responder } + } + + define_inner_service_accessors!(); + + /// Returns a new [`Layer`] that wraps services with a `ConditionalResponseService` middleware. + /// + /// [`Layer`]: tower_layer::Layer + pub fn layer(responder: P) -> ConditionalResponseLayer

{ + ConditionalResponseLayer::

::new(responder) + } +} + +impl Service> for ConditionalResponseService +where + S: Service, Response = Response> + Clone + Send + 'static, + P: ConditionalResponder,Response>, + ReqBody: Send + Sync + Clone +{ + type Response = S::Response; + type Error = S::Error; + type Future = ResponseFuture; + + #[inline] + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, req: Request) -> Self::Future { + match self.responder.has_response(req) { + ConditionalResponse::Request(t) => ResponseFuture::::Future(self.inner.call(t)), + ConditionalResponse::Response(r) => ResponseFuture::::Response(Some(r)) + } + } +} + + +/// Response future for [`ConditionalResponseService`]. +/// +/// We use an enum because the inner content may be a future or +/// or may be a direct response. +/// +/// We use an option for the direct response so that ownership can be taken. +/// +#[derive(Debug)] +#[pin_project(project = ResponseFutureProj)] +pub enum ResponseFuture { + /// + /// The future contains a direct response to return on first poll + /// + Response(Option), + /// + /// The future contains a "child" future that should be polled + /// + Future(#[pin] F), +} + +impl Future for ResponseFuture> +where + F: Future, E>>, +{ + type Output = F::Output; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match self.project() { + ResponseFutureProj::Response(r) => Poll::Ready(Ok(r.take().unwrap())), + ResponseFutureProj::Future(ref mut future) => future.as_mut().poll(cx) + } + } +} + +///////////////////////////////////////////////////////////////////////// + +/// +/// The response required from the responder function. +/// +pub enum ConditionalResponse { + /// + /// No response is available, so return the request + /// + Request(T), + /// + /// A response is available, so return the response + /// + Response(R) +} + +/// +/// Fn trait for functions that consume a request and return a +/// ConditionalResponse variant. +/// + +pub trait ConditionalResponder { + /// The type of requests returned by [`has_response`]. + /// + /// This request is forwarded to the inner service if the responder + /// succeeds. + /// + /// [`has_response`]: crate::filter::responder::has_response + /// has_response whether the given request should be forwarded. + /// + /// If the future resolves with [`Ok`], the request is forwarded to the inner service. + fn has_response(&mut self, request: T) -> ConditionalResponse; +} + +impl ConditionalResponder for F +where + F: FnMut(T) -> ConditionalResponse, +{ + fn has_response(&mut self, request: T) -> ConditionalResponse { + self(request) + } +} + +#[cfg(test)] + mod tests { + use super::*; + + use http::{Request, Response}; + use std::convert::Infallible; + use tower::{Service, ServiceExt, ServiceBuilder}; + use crate::builder::ServiceBuilderExt; + use crate::conditional_response::ConditionalResponseLayer; + + fn responder(request: Request) -> ConditionalResponse,Response> { + match request.headers().get("x-so-we-skip") { + Some(a) if a.to_str().unwrap() == "true" => ConditionalResponse::Response(Response::new("We skipped it".to_string())), + _ => ConditionalResponse::Request(request) + } + } + + async fn handle(_req: Request) -> Result, Infallible> { + Ok(Response::new("We ran it".to_string())) + } + + #[test] + fn skip_test() { + let mut svc = ServiceBuilder::new() + .layer(ConditionalResponseLayer::new(responder)) + .service_fn(handle); + + let request = Request::builder().header("x-so-we-skip","true").body("".to_string()).expect("Expected an empty body"); + + // Call the service. + let ready = futures::executor::block_on(svc.ready()).expect("Expected the service to be ready"); + let response = futures::executor::block_on(ready.call(request)).expect("Expected the service to be successful"); + assert_eq!(response.body(), "We skipped it"); + } + + #[test] + fn no_skip_test_header() { + let mut svc = ServiceBuilder::new() + .layer(ConditionalResponseLayer::new(responder)) + .service_fn(handle); + + let request = Request::builder().header("x-so-we-skip","not true").body("".to_string()).expect("Expected an empty body"); + + // Call the service. + let ready = futures::executor::block_on(svc.ready()).expect("Expected the service to be ready"); + let response = futures::executor::block_on(ready.call(request)).expect("Expected the service to be successful"); + assert_eq!(response.body(), "We ran it"); + } + + #[test] + fn no_skip_test_no_header() { + let mut svc = ServiceBuilder::new() + .conditional_response(responder) + .service_fn(handle); + + let request = Request::builder().body("".to_string()).expect("Expected an empty body"); + + // Call the service. + let ready = futures::executor::block_on(svc.ready()).expect("Expected the service to be ready"); + let response = futures::executor::block_on(ready.call(request)).expect("Expected the service to be successful"); + assert_eq!(response.body(), "We ran it"); + } +} diff --git a/tower-http/src/lib.rs b/tower-http/src/lib.rs index 6719ddbd..734c4484 100644 --- a/tower-http/src/lib.rs +++ b/tower-http/src/lib.rs @@ -320,6 +320,9 @@ pub mod request_id; #[cfg(feature = "catch-panic")] pub mod catch_panic; +#[cfg(feature = "conditional-response")] +pub mod conditional_response; + #[cfg(feature = "set-status")] pub mod set_status;