Skip to content

Commit 4ff7d4c

Browse files
better names
1 parent 63a2eec commit 4ff7d4c

File tree

4 files changed

+106
-45
lines changed

4 files changed

+106
-45
lines changed

tower-http/src/builder.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ pub trait ServiceBuilderExt<L>: sealed::Sealed<L> + Sized {
174174
Stack<
175175
crate::follow_redirect::FollowRedirectLayer<
176176
crate::follow_redirect::policy::Standard,
177-
crate::follow_redirect::PolicyExtension,
177+
crate::follow_redirect::UriAndPolicyExtensions,
178178
>,
179179
L,
180180
>,
@@ -486,7 +486,7 @@ impl<L> ServiceBuilderExt<L> for ServiceBuilder<L> {
486486
Stack<
487487
crate::follow_redirect::FollowRedirectLayer<
488488
crate::follow_redirect::policy::Standard,
489-
crate::follow_redirect::PolicyExtension,
489+
crate::follow_redirect::UriAndPolicyExtensions,
490490
>,
491491
L,
492492
>,

tower-http/src/follow_redirect/mod.rs

Lines changed: 101 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,39 @@
9191
//! # Ok(())
9292
//! # }
9393
//! ```
94+
//!
95+
//! ## Customizing extensions
96+
//!
97+
//! You can use [`FollowRedirectLayer::with_policy_extension()`]
98+
//! to also set the [`FollowedPolicy`] extension on the response.
99+
//!
100+
//! ```
101+
//! use http::{Request, Response};
102+
//! use bytes::Bytes;
103+
//! use http_body_util::Full;
104+
//! use tower::{Service, ServiceBuilder, ServiceExt};
105+
//! use tower_http::follow_redirect::{FollowRedirectLayer, FollowedPolicy, policy};
106+
//!
107+
//! # #[tokio::main]
108+
//! # async fn main() -> Result<(), std::convert::Infallible> {
109+
//! # let http_client =
110+
//! # tower::service_fn(|_: Request<Full<Bytes>>| async { Ok::<_, std::convert::Infallible>(Response::new(Full::<Bytes>::default())) });
111+
//! let mut client = ServiceBuilder::new()
112+
//! .layer(FollowRedirectLayer::with_policy_extension(policy::Limited::new(10)))
113+
//! .service(http_client);
114+
//!
115+
//! let res = client.ready().await?.call(Request::default()).await?;
116+
//! assert_eq!(
117+
//! res.extensions()
118+
//! .get::<FollowedPolicy<policy::Limited>>()
119+
//! .unwrap()
120+
//! .0
121+
//! .remaining,
122+
//! 10
123+
//! );
124+
//! # Ok(())
125+
//! # }
126+
//! ```
94127
95128
pub mod policy;
96129

@@ -120,9 +153,9 @@ use tower_service::Service;
120153
///
121154
/// See the [module docs](self) for more details.
122155
#[derive(Clone, Copy, Debug, Default)]
123-
pub struct FollowRedirectLayer<P = Standard, CB = NoOp> {
156+
pub struct FollowRedirectLayer<P = Standard, CB = UriExtension> {
124157
policy: P,
125-
callback: CB,
158+
handler: CB,
126159
}
127160

128161
impl FollowRedirectLayer {
@@ -137,12 +170,12 @@ impl<P> FollowRedirectLayer<P> {
137170
pub fn with_policy(policy: P) -> Self {
138171
Self {
139172
policy,
140-
callback: NoOp::default(),
173+
handler: UriExtension::default(),
141174
}
142175
}
143176
}
144177

145-
impl<P> FollowRedirectLayer<P, PolicyExtension>
178+
impl<P> FollowRedirectLayer<P, UriAndPolicyExtensions>
146179
where
147180
P: Send + Sync + 'static,
148181
{
@@ -151,7 +184,7 @@ where
151184
pub fn with_policy_extension(policy: P) -> Self {
152185
Self {
153186
policy,
154-
callback: PolicyExtension::default(),
187+
handler: UriAndPolicyExtensions::default(),
155188
}
156189
}
157190
}
@@ -165,18 +198,18 @@ where
165198
type Service = FollowRedirect<S, P, CB>;
166199

167200
fn layer(&self, inner: S) -> Self::Service {
168-
FollowRedirect::with_policy_callback(inner, self.policy.clone(), self.callback)
201+
FollowRedirect::with_policy_handler(inner, self.policy.clone(), self.handler)
169202
}
170203
}
171204

172205
/// Middleware that retries requests with a [`Service`] to follow redirection responses.
173206
///
174207
/// See the [module docs](self) for more details.
175208
#[derive(Clone, Copy, Debug)]
176-
pub struct FollowRedirect<S, P = Standard, CB = NoOp> {
209+
pub struct FollowRedirect<S, P = Standard, CB = UriExtension> {
177210
inner: S,
178211
policy: P,
179-
callback: CB,
212+
handler: CB,
180213
}
181214

182215
impl<S> FollowRedirect<S> {
@@ -193,18 +226,22 @@ impl<S> FollowRedirect<S> {
193226
}
194227
}
195228

196-
impl<S> FollowRedirect<S, Standard, PolicyExtension> {
229+
impl<S> FollowRedirect<S, Standard, UriAndPolicyExtensions> {
197230
/// Create a new [`FollowRedirect`] with a [`Standard`] redirection policy,
198231
/// that inserts the [`FollowedPolicy`] extension.
199232
pub fn with_extension(inner: S) -> Self {
200-
Self::with_policy_callback(inner, Standard::default(), PolicyExtension::default())
233+
Self::with_policy_handler(
234+
inner,
235+
Standard::default(),
236+
UriAndPolicyExtensions::default(),
237+
)
201238
}
202239

203240
/// Returns a new [`Layer`] that wraps services with a `FollowRedirect` middleware
204241
/// that inserts the [`FollowedPolicy`] extension.
205242
///
206243
/// [`Layer`]: tower_layer::Layer
207-
pub fn layer_with_extension() -> FollowRedirectLayer<Standard, PolicyExtension> {
244+
pub fn layer_with_extension() -> FollowRedirectLayer<Standard, UriAndPolicyExtensions> {
208245
FollowRedirectLayer::with_policy_extension(Standard::default())
209246
}
210247
}
@@ -218,7 +255,7 @@ where
218255
FollowRedirect {
219256
inner,
220257
policy,
221-
callback: NoOp::default(),
258+
handler: UriExtension::default(),
222259
}
223260
}
224261

@@ -235,53 +272,53 @@ impl<S, P, CB> FollowRedirect<S, P, CB>
235272
where
236273
P: Clone,
237274
{
238-
/// Create a new [`FollowRedirect`] with the given redirection [`Policy`] and [`ResponseCallback`].
239-
fn with_policy_callback(inner: S, policy: P, callback: CB) -> Self {
275+
/// Create a new [`FollowRedirect`] with the given redirection [`Policy`] and [`ResponseHandler`].
276+
fn with_policy_handler(inner: S, policy: P, handler: CB) -> Self {
240277
FollowRedirect {
241278
inner,
242279
policy,
243-
callback,
280+
handler,
244281
}
245282
}
246283

247284
define_inner_service_accessors!();
248285
}
249286

250287
/// Called on each new response, can be used for example to add [`http::Extensions`]
251-
trait ResponseCallback<ReqBody, ResBody, S, P>: Sized
288+
trait ResponseHandler<ReqBody, ResBody, S, P>: Sized
252289
where
253290
S: Service<Request<ReqBody>>,
254291
{
255-
fn handle(res: &mut Response<ResBody>, req: &RedirectingRequest<S, ReqBody, P>);
292+
fn on_response(res: &mut Response<ResBody>, req: &RedirectingRequest<S, ReqBody, P>);
256293
}
257294

258-
/// Default behavior: doesn't do anything
295+
/// Default behavior: adds a [`RequestUri`] extension to the response.
259296
#[derive(Default, Clone, Copy)]
260-
pub struct NoOp {}
297+
pub struct UriExtension {}
261298

262-
impl<ReqBody, ResBody, S, P> ResponseCallback<ReqBody, ResBody, S, P> for NoOp
299+
impl<ReqBody, ResBody, S, P> ResponseHandler<ReqBody, ResBody, S, P> for UriExtension
263300
where
264301
S: Service<Request<ReqBody>>,
265302
{
266-
fn handle(_res: &mut Response<ResBody>, _req: &RedirectingRequest<S, ReqBody, P>) {}
303+
#[inline]
304+
fn on_response(res: &mut Response<ResBody>, req: &RedirectingRequest<S, ReqBody, P>) {
305+
res.extensions_mut().insert(RequestUri(req.uri.clone()));
306+
}
267307
}
268308

269-
/// Response [`Extensions`][http::Extensions] value that contains the redirect [`Policy`] that
270-
/// was run before the last request of the redirect chain by a [`FollowRedirectExtension`] middleware.
271-
#[derive(Clone)]
272-
pub struct FollowedPolicy<P>(pub P);
273-
274-
/// Adds a [`FollowedPolicy`] extension to the response
275-
309+
/// Adds a [`FollowedPolicy`] and [`RequestUri`] extension to the response.
276310
#[derive(Default, Clone, Copy)]
277-
pub struct PolicyExtension {}
311+
pub struct UriAndPolicyExtensions {}
278312

279-
impl<ReqBody, ResBody, S, P> ResponseCallback<ReqBody, ResBody, S, P> for PolicyExtension
313+
impl<ReqBody, ResBody, S, P> ResponseHandler<ReqBody, ResBody, S, P> for UriAndPolicyExtensions
280314
where
281315
S: Service<Request<ReqBody>>,
282316
P: Clone + Send + Sync + 'static,
283317
{
284-
fn handle(res: &mut Response<ResBody>, req: &RedirectingRequest<S, ReqBody, P>) {
318+
#[inline]
319+
fn on_response(res: &mut Response<ResBody>, req: &RedirectingRequest<S, ReqBody, P>) {
320+
UriExtension::on_response(res, req);
321+
285322
res.extensions_mut()
286323
.insert(FollowedPolicy(req.policy.clone()));
287324
}
@@ -292,7 +329,7 @@ where
292329
S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone,
293330
ReqBody: Body + Default,
294331
P: Policy<ReqBody, S::Error> + Clone,
295-
CB: ResponseCallback<ReqBody, ResBody, S, P> + Copy,
332+
CB: ResponseHandler<ReqBody, ResBody, S, P> + Copy,
296333
{
297334
type Response = Response<ResBody>;
298335
type Error = S::Error;
@@ -312,7 +349,7 @@ where
312349
ResponseFuture {
313350
future: Either::Left(request.service.call(req)),
314351
request,
315-
callback: self.callback,
352+
handler: self.handler,
316353
}
317354
}
318355
}
@@ -327,7 +364,7 @@ pin_project! {
327364
#[pin]
328365
future: Either<S::Future, Oneshot<S, Request<B>>>,
329366
request: RedirectingRequest<S, B, P>,
330-
callback: CB
367+
handler: CB
331368
}
332369
}
333370

@@ -336,14 +373,14 @@ where
336373
S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone,
337374
ReqBody: Body + Default,
338375
P: Policy<ReqBody, S::Error>,
339-
CB: ResponseCallback<ReqBody, ResBody, S, P>,
376+
CB: ResponseHandler<ReqBody, ResBody, S, P>,
340377
{
341378
type Output = Result<Response<ResBody>, S::Error>;
342379

343380
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
344381
let mut this = self.project();
345382
let mut res = ready!(this.future.as_mut().poll(cx)?);
346-
CB::handle(&mut res, &this.request);
383+
CB::on_response(&mut res, &this.request);
347384

348385
match this.request.handle_response(&mut res) {
349386
Ok(Some(pending)) => {
@@ -402,8 +439,6 @@ where
402439
&mut self,
403440
res: &mut Response<ResBody>,
404441
) -> Result<Option<Oneshot<S, Request<ReqBody>>>, S::Error> {
405-
res.extensions_mut().insert(RequestUri(self.uri.clone()));
406-
407442
let drop_payload_headers = |headers: &mut HeaderMap| {
408443
for header in &[
409444
CONTENT_TYPE,
@@ -483,6 +518,11 @@ where
483518
#[derive(Clone)]
484519
pub struct RequestUri(pub Uri);
485520

521+
/// Response [`Extensions`][http::Extensions] value that contains the redirect [`Policy`] that
522+
/// was run before the last request of the redirect chain by a [`FollowRedirectExtension`] middleware.
523+
#[derive(Clone)]
524+
pub struct FollowedPolicy<P>(pub P);
525+
486526
#[derive(Debug)]
487527
enum BodyRepr<B> {
488528
Some(B),
@@ -551,7 +591,7 @@ mod tests {
551591
#[tokio::test]
552592
async fn follows() {
553593
let svc = ServiceBuilder::new()
554-
.layer(FollowRedirectLayer::with_policy(Action::Follow))
594+
.layer(FollowRedirectLayer::with_policy_extension(Action::Follow))
555595
.buffer(1)
556596
.service_fn(handle);
557597
let req = Request::builder()
@@ -564,12 +604,18 @@ mod tests {
564604
res.extensions().get::<RequestUri>().unwrap().0,
565605
"http://example.com/0"
566606
);
607+
assert!(res
608+
.extensions()
609+
.get::<FollowedPolicy<Action>>()
610+
.unwrap()
611+
.0
612+
.is_follow());
567613
}
568614

569615
#[tokio::test]
570616
async fn stops() {
571617
let svc = ServiceBuilder::new()
572-
.layer(FollowRedirectLayer::with_policy(Action::Stop))
618+
.layer(FollowRedirectLayer::with_policy_extension(Action::Stop))
573619
.buffer(1)
574620
.service_fn(handle);
575621
let req = Request::builder()
@@ -582,12 +628,18 @@ mod tests {
582628
res.extensions().get::<RequestUri>().unwrap().0,
583629
"http://example.com/42"
584630
);
631+
assert!(res
632+
.extensions()
633+
.get::<FollowedPolicy<Action>>()
634+
.unwrap()
635+
.0
636+
.is_stop());
585637
}
586638

587639
#[tokio::test]
588640
async fn limited() {
589641
let svc = ServiceBuilder::new()
590-
.layer(FollowRedirectLayer::with_policy(Limited::new(10)))
642+
.layer(FollowRedirectLayer::with_policy_extension(Limited::new(10)))
591643
.buffer(1)
592644
.service_fn(handle);
593645
let req = Request::builder()
@@ -600,6 +652,14 @@ mod tests {
600652
res.extensions().get::<RequestUri>().unwrap().0,
601653
"http://example.com/32"
602654
);
655+
assert_eq!(
656+
res.extensions()
657+
.get::<FollowedPolicy<Limited>>()
658+
.unwrap()
659+
.0
660+
.remaining,
661+
0
662+
);
603663
}
604664

605665
/// A server with an endpoint `GET /{n}` which redirects to `/{n-1}` unless `n` equals zero,

tower-http/src/follow_redirect/policy/limited.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@ use super::{Action, Attempt, Policy};
33
/// A redirection [`Policy`] that limits the number of successive redirections.
44
#[derive(Clone, Copy, Debug)]
55
pub struct Limited {
6-
pub(crate) remaining: usize,
6+
/// The number or possible redirections remaining.
7+
pub remaining: usize,
78
}
89

910
impl Limited {

tower-http/src/service_ext.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ pub trait ServiceExt {
160160
) -> crate::follow_redirect::FollowRedirect<
161161
Self,
162162
crate::follow_redirect::policy::Standard,
163-
crate::follow_redirect::PolicyExtension,
163+
crate::follow_redirect::UriAndPolicyExtensions,
164164
>
165165
where
166166
Self: Sized,

0 commit comments

Comments
 (0)