Skip to content

Commit f348c2b

Browse files
feature: expose redirect policy through extension
1 parent cf3046f commit f348c2b

File tree

5 files changed

+401
-54
lines changed

5 files changed

+401
-54
lines changed

tower-http/src/builder.rs

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,25 @@ pub trait ServiceBuilderExt<L>: sealed::Sealed<L> + Sized {
160160
>,
161161
>;
162162

163+
/// Follow redirect responses using the [`Standard`] policy,
164+
/// storing it as an extension
165+
///
166+
/// See [`tower_http::follow_redirect::extension`] for more details.
167+
///
168+
/// [`tower_http::follow_redirect::extension`]: crate::follow_redirect::extension
169+
/// [`Standard`]: crate::follow_redirect::policy::Standard
170+
#[cfg(feature = "follow-redirect")]
171+
fn follow_redirects_extension(
172+
self,
173+
) -> ServiceBuilder<
174+
Stack<
175+
crate::follow_redirect::extension::FollowRedirectExtensionLayer<
176+
crate::follow_redirect::policy::Standard,
177+
>,
178+
L,
179+
>,
180+
>;
181+
163182
/// Mark headers as [sensitive] on both requests and responses.
164183
///
165184
/// See [`tower_http::sensitive_headers`] for more details.
@@ -459,6 +478,20 @@ impl<L> ServiceBuilderExt<L> for ServiceBuilder<L> {
459478
self.layer(crate::follow_redirect::FollowRedirectLayer::new())
460479
}
461480

481+
#[cfg(feature = "follow-redirect")]
482+
fn follow_redirects_extension(
483+
self,
484+
) -> ServiceBuilder<
485+
Stack<
486+
crate::follow_redirect::extension::FollowRedirectExtensionLayer<
487+
crate::follow_redirect::policy::Standard,
488+
>,
489+
L,
490+
>,
491+
> {
492+
self.layer(crate::follow_redirect::extension::FollowRedirectExtensionLayer::new())
493+
}
494+
462495
#[cfg(feature = "sensitive-headers")]
463496
fn sensitive_headers<I>(
464497
self,
Lines changed: 251 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,251 @@
1+
//! The [`FollowRedirectExtension`] middleware works just like [`super::FollowRedirect`]
2+
//! and also stores a copy of the [`Policy`] in a [`FollowedPolicy`] extension.
3+
//! see [`FollowRedirect`](super) for usage.
4+
5+
use super::policy::{Policy, Standard};
6+
use super::RedirectingRequest;
7+
use futures_util::future::Either;
8+
use http::{Request, Response};
9+
use http_body::Body;
10+
use pin_project_lite::pin_project;
11+
use std::future::Future;
12+
use std::mem;
13+
use std::pin::Pin;
14+
use std::task::{ready, Context, Poll};
15+
use tower::util::Oneshot;
16+
use tower::{Layer, Service};
17+
18+
/// [`Layer`] for retrying requests with a [`Service`] to follow redirection responses.
19+
///
20+
/// See the [module docs](self) for more details.
21+
#[derive(Clone, Copy, Debug, Default)]
22+
pub struct FollowRedirectExtensionLayer<P = Standard> {
23+
policy: P,
24+
}
25+
26+
impl FollowRedirectExtensionLayer {
27+
/// Create a new [`FollowRedirectExtension`] with a [`Standard`] redirection policy.
28+
pub fn new() -> Self {
29+
Self::default()
30+
}
31+
}
32+
33+
impl<P> FollowRedirectExtensionLayer<P> {
34+
/// Create a new [`FollowRedirectExtension`] with the given redirection [`Policy`].
35+
pub fn with_policy(policy: P) -> Self {
36+
Self { policy }
37+
}
38+
}
39+
40+
impl<S, P> Layer<S> for FollowRedirectExtensionLayer<P>
41+
where
42+
S: Clone,
43+
P: Clone + Send + Sync + 'static,
44+
{
45+
type Service = FollowRedirectExtension<S, P>;
46+
47+
fn layer(&self, inner: S) -> Self::Service {
48+
FollowRedirectExtension::with_policy(inner, self.policy.clone())
49+
}
50+
}
51+
52+
/// Middleware that retries requests with a [`Service`] to follow redirection responses.
53+
/// Stores the redirect [`Policy`] that was run before the last request of the redirect chain
54+
/// in the [`FollowedPolicy`] [extension](http::Extensions)
55+
///
56+
/// See the [module docs](super) for more details.
57+
#[derive(Clone, Copy, Debug)]
58+
pub struct FollowRedirectExtension<S, P = Standard> {
59+
inner: S,
60+
policy: P,
61+
}
62+
63+
impl<S> FollowRedirectExtension<S> {
64+
/// Create a new [`FollowRedirectExtension`] with a [`Standard`] redirection policy.
65+
pub fn new(inner: S) -> Self {
66+
Self::with_policy(inner, Standard::default())
67+
}
68+
69+
/// Returns a new [`Layer`] that wraps services with a [`FollowRedirectExtension`] middleware.
70+
///
71+
/// [`Layer`]: tower_layer::Layer
72+
pub fn layer() -> FollowRedirectExtensionLayer {
73+
FollowRedirectExtensionLayer::new()
74+
}
75+
}
76+
77+
impl<S, P> FollowRedirectExtension<S, P>
78+
where
79+
P: Clone + Send + Sync + 'static,
80+
{
81+
/// Create a new [`FollowRedirectExtension`] with the given redirection [`Policy`].
82+
pub fn with_policy(inner: S, policy: P) -> Self {
83+
FollowRedirectExtension { inner, policy }
84+
}
85+
86+
/// Returns a new [`Layer`] that wraps services with a [`FollowRedirectExtension`] middleware
87+
/// with the given redirection [`Policy`].
88+
///
89+
/// [`Layer`]: tower_layer::Layer
90+
pub fn layer_with_policy(policy: P) -> FollowRedirectExtensionLayer<P> {
91+
FollowRedirectExtensionLayer::with_policy(policy)
92+
}
93+
94+
define_inner_service_accessors!();
95+
}
96+
97+
impl<ReqBody, ResBody, S, P> Service<Request<ReqBody>> for FollowRedirectExtension<S, P>
98+
where
99+
S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone,
100+
ReqBody: Body + Default,
101+
P: Policy<ReqBody, S::Error> + Clone + Send + Sync + 'static,
102+
{
103+
type Response = Response<ResBody>;
104+
type Error = S::Error;
105+
type Future = ResponseFuture<S, ReqBody, P>;
106+
107+
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
108+
self.inner.poll_ready(cx)
109+
}
110+
111+
fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
112+
let service = self.inner.clone();
113+
let mut request = RedirectingRequest::new(
114+
mem::replace(&mut self.inner, service),
115+
self.policy.clone(),
116+
&mut req,
117+
);
118+
ResponseFuture {
119+
future: Either::Left(request.service.call(req)),
120+
request,
121+
}
122+
}
123+
}
124+
125+
/// Response [`Extensions`][http::Extensions] value that contains the redirect [`Policy`] that
126+
/// was run before the last request of the redirect chain by a [`FollowRedirectExtension`] middleware.
127+
#[derive(Clone)]
128+
pub struct FollowedPolicy<P>(pub P);
129+
130+
pin_project! {
131+
/// Response future for [`FollowRedirectExtension`].
132+
#[derive(Debug)]
133+
pub struct ResponseFuture<S, B, P>
134+
where
135+
S: Service<Request<B>>,
136+
{
137+
#[pin]
138+
future: Either<S::Future, Oneshot<S, Request<B>>>,
139+
request: RedirectingRequest<S, B, P>
140+
}
141+
}
142+
143+
impl<S, ReqBody, ResBody, P> Future for ResponseFuture<S, ReqBody, P>
144+
where
145+
S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone,
146+
ReqBody: Body + Default,
147+
P: Policy<ReqBody, S::Error> + Clone + Send + Sync + 'static,
148+
{
149+
type Output = Result<Response<ResBody>, S::Error>;
150+
151+
#[inline]
152+
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
153+
let mut this = self.project();
154+
let mut res = ready!(this.future.as_mut().poll(cx)?);
155+
156+
res.extensions_mut()
157+
.insert(FollowedPolicy(this.request.policy.clone()));
158+
159+
match this.request.handle_response(&mut res) {
160+
Ok(Some(pending)) => {
161+
this.future.set(Either::Right(pending));
162+
cx.waker().wake_by_ref();
163+
Poll::Pending
164+
}
165+
Ok(None) => Poll::Ready(Ok(res)),
166+
Err(e) => Poll::Ready(Err(e)),
167+
}
168+
}
169+
}
170+
171+
#[cfg(test)]
172+
mod tests {
173+
use super::super::{policy::*, tests::handle, *};
174+
use super::*;
175+
use crate::test_helpers::Body;
176+
use tower::{ServiceBuilder, ServiceExt};
177+
178+
#[tokio::test]
179+
async fn follows() {
180+
let svc = ServiceBuilder::new()
181+
.layer(FollowRedirectExtensionLayer::with_policy(Action::Follow))
182+
.buffer(1)
183+
.service_fn(handle);
184+
let req = Request::builder()
185+
.uri("http://example.com/42")
186+
.body(Body::empty())
187+
.unwrap();
188+
let res = svc.oneshot(req).await.unwrap();
189+
assert_eq!(*res.body(), 0);
190+
assert_eq!(
191+
res.extensions().get::<RequestUri>().unwrap().0,
192+
"http://example.com/0"
193+
);
194+
assert!(res
195+
.extensions()
196+
.get::<FollowedPolicy<Action>>()
197+
.unwrap()
198+
.0
199+
.is_follow());
200+
}
201+
202+
#[tokio::test]
203+
async fn stops() {
204+
let svc = ServiceBuilder::new()
205+
.layer(FollowRedirectExtensionLayer::with_policy(Action::Stop))
206+
.buffer(1)
207+
.service_fn(handle);
208+
let req = Request::builder()
209+
.uri("http://example.com/42")
210+
.body(Body::empty())
211+
.unwrap();
212+
let res = svc.oneshot(req).await.unwrap();
213+
assert_eq!(*res.body(), 42);
214+
assert_eq!(
215+
res.extensions().get::<RequestUri>().unwrap().0,
216+
"http://example.com/42"
217+
);
218+
assert!(res
219+
.extensions()
220+
.get::<FollowedPolicy<Action>>()
221+
.unwrap()
222+
.0
223+
.is_stop());
224+
}
225+
226+
#[tokio::test]
227+
async fn limited() {
228+
let svc = ServiceBuilder::new()
229+
.layer(FollowRedirectExtensionLayer::with_policy(Limited::new(10)))
230+
.buffer(1)
231+
.service_fn(handle);
232+
let req = Request::builder()
233+
.uri("http://example.com/42")
234+
.body(Body::empty())
235+
.unwrap();
236+
let res = svc.oneshot(req).await.unwrap();
237+
assert_eq!(*res.body(), 42 - 10);
238+
assert_eq!(
239+
res.extensions().get::<RequestUri>().unwrap().0,
240+
"http://example.com/32"
241+
);
242+
assert_eq!(
243+
res.extensions()
244+
.get::<FollowedPolicy<Limited>>()
245+
.unwrap()
246+
.0
247+
.remaining,
248+
0
249+
);
250+
}
251+
}

0 commit comments

Comments
 (0)