Skip to content

Commit e2ecc29

Browse files
feature: expose redirect policy through extension
1 parent cf3046f commit e2ecc29

File tree

5 files changed

+397
-52
lines changed

5 files changed

+397
-52
lines changed

tower-http/src/builder.rs

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,25 @@ pub trait ServiceBuilderExt<L>: sealed::Sealed<L> + Sized {
160160
>,
161161
>;
162162

163+
/// Follow redirect responses using the [`Standard`] policy,
164+
/// storing it as an extension
165+
///
166+
/// See [`tower_http::follow_redirect::extension`] for more details.
167+
///
168+
/// [`tower_http::follow_redirect::extension`]: crate::follow_redirect::extension
169+
/// [`Standard`]: crate::follow_redirect::policy::Standard
170+
#[cfg(feature = "follow-redirect")]
171+
fn follow_redirects_extension(
172+
self,
173+
) -> ServiceBuilder<
174+
Stack<
175+
crate::follow_redirect::extension::FollowRedirectExtensionLayer<
176+
crate::follow_redirect::policy::Standard,
177+
>,
178+
L,
179+
>,
180+
>;
181+
163182
/// Mark headers as [sensitive] on both requests and responses.
164183
///
165184
/// See [`tower_http::sensitive_headers`] for more details.
@@ -459,6 +478,20 @@ impl<L> ServiceBuilderExt<L> for ServiceBuilder<L> {
459478
self.layer(crate::follow_redirect::FollowRedirectLayer::new())
460479
}
461480

481+
#[cfg(feature = "follow-redirect")]
482+
fn follow_redirects_extension(
483+
self,
484+
) -> ServiceBuilder<
485+
Stack<
486+
crate::follow_redirect::extension::FollowRedirectExtensionLayer<
487+
crate::follow_redirect::policy::Standard,
488+
>,
489+
L,
490+
>,
491+
> {
492+
self.layer(crate::follow_redirect::extension::FollowRedirectExtensionLayer::new())
493+
}
494+
462495
#[cfg(feature = "sensitive-headers")]
463496
fn sensitive_headers<I>(
464497
self,
Lines changed: 250 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,250 @@
1+
//! The [`FollowRedirectExtension`] middleware works just like [`super::FollowRedirect`]
2+
//! and also stores a copy of the [`Policy`] in a [`FollowedPolicy`] extension.
3+
//! see [`FollowRedirect`](super) for usage.
4+
5+
use super::policy::{Policy, Standard};
6+
use super::RedirectingRequest;
7+
use futures_util::future::Either;
8+
use http::{Request, Response};
9+
use http_body::Body;
10+
use pin_project_lite::pin_project;
11+
use std::future::Future;
12+
use std::mem;
13+
use std::pin::Pin;
14+
use std::task::{ready, Context, Poll};
15+
use tower::util::Oneshot;
16+
use tower::{Layer, Service};
17+
18+
/// [`Layer`] for retrying requests with a [`Service`] to follow redirection responses.
19+
///
20+
/// See the [module docs](self) for more details.
21+
#[derive(Clone, Copy, Debug, Default)]
22+
pub struct FollowRedirectExtensionLayer<P = Standard> {
23+
policy: P,
24+
}
25+
26+
impl FollowRedirectExtensionLayer {
27+
/// Create a new [`FollowRedirectExtension`] with a [`Standard`] redirection policy.
28+
pub fn new() -> Self {
29+
Self::default()
30+
}
31+
}
32+
33+
impl<P> FollowRedirectExtensionLayer<P> {
34+
/// Create a new [`FollowRedirectExtension`] with the given redirection [`Policy`].
35+
pub fn with_policy(policy: P) -> Self {
36+
Self { policy }
37+
}
38+
}
39+
40+
impl<S, P> Layer<S> for FollowRedirectExtensionLayer<P>
41+
where
42+
S: Clone,
43+
P: Clone + Send + Sync + 'static,
44+
{
45+
type Service = FollowRedirectExtension<S, P>;
46+
47+
fn layer(&self, inner: S) -> Self::Service {
48+
FollowRedirectExtension::with_policy(inner, self.policy.clone())
49+
}
50+
}
51+
52+
/// Middleware that retries requests with a [`Service`] to follow redirection responses.
53+
/// Stores the redirect [`Policy`] that was run before the last request of the redirect chain
54+
/// in the [`FollowedPolicy`] [extension](http::Extensions)
55+
///
56+
/// See the [module docs](super) for more details.
57+
#[derive(Clone, Copy, Debug)]
58+
pub struct FollowRedirectExtension<S, P = Standard> {
59+
inner: S,
60+
policy: P,
61+
}
62+
63+
impl<S> FollowRedirectExtension<S> {
64+
/// Create a new [`FollowRedirectExtension`] with a [`Standard`] redirection policy.
65+
pub fn new(inner: S) -> Self {
66+
Self::with_policy(inner, Standard::default())
67+
}
68+
69+
/// Returns a new [`Layer`] that wraps services with a [`FollowRedirectExtension`] middleware.
70+
///
71+
/// [`Layer`]: tower_layer::Layer
72+
pub fn layer() -> FollowRedirectExtensionLayer {
73+
FollowRedirectExtensionLayer::new()
74+
}
75+
}
76+
77+
impl<S, P> FollowRedirectExtension<S, P>
78+
where
79+
P: Clone + Send + Sync + 'static,
80+
{
81+
/// Create a new [`FollowRedirectExtension`] with the given redirection [`Policy`].
82+
pub fn with_policy(inner: S, policy: P) -> Self {
83+
FollowRedirectExtension { inner, policy }
84+
}
85+
86+
/// Returns a new [`Layer`] that wraps services with a [`FollowRedirectExtension`] middleware
87+
/// with the given redirection [`Policy`].
88+
///
89+
/// [`Layer`]: tower_layer::Layer
90+
pub fn layer_with_policy(policy: P) -> FollowRedirectExtensionLayer<P> {
91+
FollowRedirectExtensionLayer::with_policy(policy)
92+
}
93+
94+
define_inner_service_accessors!();
95+
}
96+
97+
impl<ReqBody, ResBody, S, P> Service<Request<ReqBody>> for FollowRedirectExtension<S, P>
98+
where
99+
S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone,
100+
ReqBody: Body + Default,
101+
P: Policy<ReqBody, S::Error> + Clone + Send + Sync + 'static,
102+
{
103+
type Response = Response<ResBody>;
104+
type Error = S::Error;
105+
type Future = ResponseFuture<S, ReqBody, P>;
106+
107+
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
108+
self.inner.poll_ready(cx)
109+
}
110+
111+
fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
112+
let service = self.inner.clone();
113+
let mut request = RedirectingRequest::new(
114+
mem::replace(&mut self.inner, service),
115+
self.policy.clone(),
116+
&mut req,
117+
);
118+
ResponseFuture {
119+
future: Either::Left(request.service.call(req)),
120+
request,
121+
}
122+
}
123+
}
124+
125+
/// Response [`Extensions`][http::Extensions] value that contains the redirect [`Policy`] that
126+
/// was run before the last request of the redirect chain by a [`FollowRedirectExtension`] middleware.
127+
#[derive(Clone)]
128+
pub struct FollowedPolicy<P>(pub P);
129+
130+
pin_project! {
131+
/// Response future for [`FollowRedirectExtension`].
132+
#[derive(Debug)]
133+
pub struct ResponseFuture<S, B, P>
134+
where
135+
S: Service<Request<B>>,
136+
{
137+
#[pin]
138+
future: Either<S::Future, Oneshot<S, Request<B>>>,
139+
request: RedirectingRequest<S, B, P>
140+
}
141+
}
142+
143+
impl<S, ReqBody, ResBody, P> Future for ResponseFuture<S, ReqBody, P>
144+
where
145+
S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone,
146+
ReqBody: Body + Default,
147+
P: Policy<ReqBody, S::Error> + Clone + Send + Sync + 'static,
148+
{
149+
type Output = Result<Response<ResBody>, S::Error>;
150+
151+
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
152+
let mut this = self.project();
153+
let mut res = ready!(this.future.as_mut().poll(cx)?);
154+
155+
res.extensions_mut()
156+
.insert(FollowedPolicy(this.request.policy.clone()));
157+
158+
match this.request.handle_response(&mut res) {
159+
Ok(Some(pending)) => {
160+
this.future.set(Either::Right(pending));
161+
cx.waker().wake_by_ref();
162+
Poll::Pending
163+
}
164+
Ok(None) => Poll::Ready(Ok(res)),
165+
Err(e) => Poll::Ready(Err(e)),
166+
}
167+
}
168+
}
169+
170+
#[cfg(test)]
171+
mod tests {
172+
use super::super::{policy::*, tests::handle, *};
173+
use super::*;
174+
use crate::test_helpers::Body;
175+
use tower::{ServiceBuilder, ServiceExt};
176+
177+
#[tokio::test]
178+
async fn follows() {
179+
let svc = ServiceBuilder::new()
180+
.layer(FollowRedirectExtensionLayer::with_policy(Action::Follow))
181+
.buffer(1)
182+
.service_fn(handle);
183+
let req = Request::builder()
184+
.uri("http://example.com/42")
185+
.body(Body::empty())
186+
.unwrap();
187+
let res = svc.oneshot(req).await.unwrap();
188+
assert_eq!(*res.body(), 0);
189+
assert_eq!(
190+
res.extensions().get::<RequestUri>().unwrap().0,
191+
"http://example.com/0"
192+
);
193+
assert!(res
194+
.extensions()
195+
.get::<FollowedPolicy<Action>>()
196+
.unwrap()
197+
.0
198+
.is_follow());
199+
}
200+
201+
#[tokio::test]
202+
async fn stops() {
203+
let svc = ServiceBuilder::new()
204+
.layer(FollowRedirectExtensionLayer::with_policy(Action::Stop))
205+
.buffer(1)
206+
.service_fn(handle);
207+
let req = Request::builder()
208+
.uri("http://example.com/42")
209+
.body(Body::empty())
210+
.unwrap();
211+
let res = svc.oneshot(req).await.unwrap();
212+
assert_eq!(*res.body(), 42);
213+
assert_eq!(
214+
res.extensions().get::<RequestUri>().unwrap().0,
215+
"http://example.com/42"
216+
);
217+
assert!(res
218+
.extensions()
219+
.get::<FollowedPolicy<Action>>()
220+
.unwrap()
221+
.0
222+
.is_stop());
223+
}
224+
225+
#[tokio::test]
226+
async fn limited() {
227+
let svc = ServiceBuilder::new()
228+
.layer(FollowRedirectExtensionLayer::with_policy(Limited::new(10)))
229+
.buffer(1)
230+
.service_fn(handle);
231+
let req = Request::builder()
232+
.uri("http://example.com/42")
233+
.body(Body::empty())
234+
.unwrap();
235+
let res = svc.oneshot(req).await.unwrap();
236+
assert_eq!(*res.body(), 42 - 10);
237+
assert_eq!(
238+
res.extensions().get::<RequestUri>().unwrap().0,
239+
"http://example.com/32"
240+
);
241+
assert_eq!(
242+
res.extensions()
243+
.get::<FollowedPolicy<Limited>>()
244+
.unwrap()
245+
.0
246+
.remaining,
247+
0
248+
);
249+
}
250+
}

0 commit comments

Comments
 (0)