91
91
//! # Ok(())
92
92
//! # }
93
93
//! ```
94
+ //!
95
+ //! ## Customizing extensions
96
+ //!
97
+ //! You can use [`FollowRedirectLayer::with_policy_extension()`]
98
+ //! to also set the [`FollowedPolicy`] extension on the response.
99
+ //!
100
+ //! ```
101
+ //! use http::{Request, Response};
102
+ //! use bytes::Bytes;
103
+ //! use http_body_util::Full;
104
+ //! use tower::{Service, ServiceBuilder, ServiceExt};
105
+ //! use tower_http::follow_redirect::{FollowRedirectLayer, FollowedPolicy, policy};
106
+ //!
107
+ //! # #[tokio::main]
108
+ //! # async fn main() -> Result<(), std::convert::Infallible> {
109
+ //! # let http_client =
110
+ //! # tower::service_fn(|_: Request<Full<Bytes>>| async { Ok::<_, std::convert::Infallible>(Response::new(Full::<Bytes>::default())) });
111
+ //! let mut client = ServiceBuilder::new()
112
+ //! .layer(FollowRedirectLayer::with_policy_extension(policy::Limited::new(10)))
113
+ //! .service(http_client);
114
+ //!
115
+ //! let res = client.ready().await?.call(Request::default()).await?;
116
+ //! assert_eq!(
117
+ //! res.extensions()
118
+ //! .get::<FollowedPolicy<policy::Limited>>()
119
+ //! .unwrap()
120
+ //! .0
121
+ //! .remaining,
122
+ //! 10
123
+ //! );
124
+ //! # Ok(())
125
+ //! # }
126
+ //! ```
94
127
95
128
pub mod policy;
96
129
@@ -120,9 +153,9 @@ use tower_service::Service;
120
153
///
121
154
/// See the [module docs](self) for more details.
122
155
#[ derive( Clone , Copy , Debug , Default ) ]
123
- pub struct FollowRedirectLayer < P = Standard , CB = NoOp > {
156
+ pub struct FollowRedirectLayer < P = Standard , CB = UriExtension > {
124
157
policy : P ,
125
- callback : CB ,
158
+ handler : CB ,
126
159
}
127
160
128
161
impl FollowRedirectLayer {
@@ -137,12 +170,12 @@ impl<P> FollowRedirectLayer<P> {
137
170
pub fn with_policy ( policy : P ) -> Self {
138
171
Self {
139
172
policy,
140
- callback : NoOp :: default ( ) ,
173
+ handler : UriExtension :: default ( ) ,
141
174
}
142
175
}
143
176
}
144
177
145
- impl < P > FollowRedirectLayer < P , PolicyExtension >
178
+ impl < P > FollowRedirectLayer < P , UriAndPolicyExtensions >
146
179
where
147
180
P : Send + Sync + ' static ,
148
181
{
@@ -151,7 +184,7 @@ where
151
184
pub fn with_policy_extension ( policy : P ) -> Self {
152
185
Self {
153
186
policy,
154
- callback : PolicyExtension :: default ( ) ,
187
+ handler : UriAndPolicyExtensions :: default ( ) ,
155
188
}
156
189
}
157
190
}
@@ -165,18 +198,18 @@ where
165
198
type Service = FollowRedirect < S , P , CB > ;
166
199
167
200
fn layer ( & self , inner : S ) -> Self :: Service {
168
- FollowRedirect :: with_policy_callback ( inner, self . policy . clone ( ) , self . callback )
201
+ FollowRedirect :: with_policy_handler ( inner, self . policy . clone ( ) , self . handler )
169
202
}
170
203
}
171
204
172
205
/// Middleware that retries requests with a [`Service`] to follow redirection responses.
173
206
///
174
207
/// See the [module docs](self) for more details.
175
208
#[ derive( Clone , Copy , Debug ) ]
176
- pub struct FollowRedirect < S , P = Standard , CB = NoOp > {
209
+ pub struct FollowRedirect < S , P = Standard , CB = UriExtension > {
177
210
inner : S ,
178
211
policy : P ,
179
- callback : CB ,
212
+ handler : CB ,
180
213
}
181
214
182
215
impl < S > FollowRedirect < S > {
@@ -193,18 +226,22 @@ impl<S> FollowRedirect<S> {
193
226
}
194
227
}
195
228
196
- impl < S > FollowRedirect < S , Standard , PolicyExtension > {
229
+ impl < S > FollowRedirect < S , Standard , UriAndPolicyExtensions > {
197
230
/// Create a new [`FollowRedirect`] with a [`Standard`] redirection policy,
198
231
/// that inserts the [`FollowedPolicy`] extension.
199
232
pub fn with_extension ( inner : S ) -> Self {
200
- Self :: with_policy_callback ( inner, Standard :: default ( ) , PolicyExtension :: default ( ) )
233
+ Self :: with_policy_handler (
234
+ inner,
235
+ Standard :: default ( ) ,
236
+ UriAndPolicyExtensions :: default ( ) ,
237
+ )
201
238
}
202
239
203
240
/// Returns a new [`Layer`] that wraps services with a `FollowRedirect` middleware
204
241
/// that inserts the [`FollowedPolicy`] extension.
205
242
///
206
243
/// [`Layer`]: tower_layer::Layer
207
- pub fn layer_with_extension ( ) -> FollowRedirectLayer < Standard , PolicyExtension > {
244
+ pub fn layer_with_extension ( ) -> FollowRedirectLayer < Standard , UriAndPolicyExtensions > {
208
245
FollowRedirectLayer :: with_policy_extension ( Standard :: default ( ) )
209
246
}
210
247
}
@@ -218,7 +255,7 @@ where
218
255
FollowRedirect {
219
256
inner,
220
257
policy,
221
- callback : NoOp :: default ( ) ,
258
+ handler : UriExtension :: default ( ) ,
222
259
}
223
260
}
224
261
@@ -235,53 +272,53 @@ impl<S, P, CB> FollowRedirect<S, P, CB>
235
272
where
236
273
P : Clone ,
237
274
{
238
- /// Create a new [`FollowRedirect`] with the given redirection [`Policy`] and [`ResponseCallback `].
239
- fn with_policy_callback ( inner : S , policy : P , callback : CB ) -> Self {
275
+ /// Create a new [`FollowRedirect`] with the given redirection [`Policy`] and [`ResponseHandler `].
276
+ fn with_policy_handler ( inner : S , policy : P , handler : CB ) -> Self {
240
277
FollowRedirect {
241
278
inner,
242
279
policy,
243
- callback ,
280
+ handler ,
244
281
}
245
282
}
246
283
247
284
define_inner_service_accessors ! ( ) ;
248
285
}
249
286
250
287
/// Called on each new response, can be used for example to add [`http::Extensions`]
251
- trait ResponseCallback < ReqBody , ResBody , S , P > : Sized
288
+ trait ResponseHandler < ReqBody , ResBody , S , P > : Sized
252
289
where
253
290
S : Service < Request < ReqBody > > ,
254
291
{
255
- fn handle ( res : & mut Response < ResBody > , req : & RedirectingRequest < S , ReqBody , P > ) ;
292
+ fn on_response ( res : & mut Response < ResBody > , req : & RedirectingRequest < S , ReqBody , P > ) ;
256
293
}
257
294
258
- /// Default behavior: doesn't do anything
295
+ /// Default behavior: adds a [`RequestUri`] extension to the response.
259
296
#[ derive( Default , Clone , Copy ) ]
260
- pub struct NoOp { }
297
+ pub struct UriExtension { }
261
298
262
- impl < ReqBody , ResBody , S , P > ResponseCallback < ReqBody , ResBody , S , P > for NoOp
299
+ impl < ReqBody , ResBody , S , P > ResponseHandler < ReqBody , ResBody , S , P > for UriExtension
263
300
where
264
301
S : Service < Request < ReqBody > > ,
265
302
{
266
- fn handle ( _res : & mut Response < ResBody > , _req : & RedirectingRequest < S , ReqBody , P > ) { }
303
+ #[ inline]
304
+ fn on_response ( res : & mut Response < ResBody > , req : & RedirectingRequest < S , ReqBody , P > ) {
305
+ res. extensions_mut ( ) . insert ( RequestUri ( req. uri . clone ( ) ) ) ;
306
+ }
267
307
}
268
308
269
- /// Response [`Extensions`][http::Extensions] value that contains the redirect [`Policy`] that
270
- /// was run before the last request of the redirect chain by a [`FollowRedirectExtension`] middleware.
271
- #[ derive( Clone ) ]
272
- pub struct FollowedPolicy < P > ( pub P ) ;
273
-
274
- /// Adds a [`FollowedPolicy`] extension to the response
275
-
309
+ /// Adds a [`FollowedPolicy`] and [`RequestUri`] extension to the response.
276
310
#[ derive( Default , Clone , Copy ) ]
277
- pub struct PolicyExtension { }
311
+ pub struct UriAndPolicyExtensions { }
278
312
279
- impl < ReqBody , ResBody , S , P > ResponseCallback < ReqBody , ResBody , S , P > for PolicyExtension
313
+ impl < ReqBody , ResBody , S , P > ResponseHandler < ReqBody , ResBody , S , P > for UriAndPolicyExtensions
280
314
where
281
315
S : Service < Request < ReqBody > > ,
282
316
P : Clone + Send + Sync + ' static ,
283
317
{
284
- fn handle ( res : & mut Response < ResBody > , req : & RedirectingRequest < S , ReqBody , P > ) {
318
+ #[ inline]
319
+ fn on_response ( res : & mut Response < ResBody > , req : & RedirectingRequest < S , ReqBody , P > ) {
320
+ UriExtension :: on_response ( res, req) ;
321
+
285
322
res. extensions_mut ( )
286
323
. insert ( FollowedPolicy ( req. policy . clone ( ) ) ) ;
287
324
}
@@ -292,7 +329,7 @@ where
292
329
S : Service < Request < ReqBody > , Response = Response < ResBody > > + Clone ,
293
330
ReqBody : Body + Default ,
294
331
P : Policy < ReqBody , S :: Error > + Clone ,
295
- CB : ResponseCallback < ReqBody , ResBody , S , P > + Copy ,
332
+ CB : ResponseHandler < ReqBody , ResBody , S , P > + Copy ,
296
333
{
297
334
type Response = Response < ResBody > ;
298
335
type Error = S :: Error ;
@@ -312,7 +349,7 @@ where
312
349
ResponseFuture {
313
350
future : Either :: Left ( request. service . call ( req) ) ,
314
351
request,
315
- callback : self . callback ,
352
+ handler : self . handler ,
316
353
}
317
354
}
318
355
}
@@ -327,7 +364,7 @@ pin_project! {
327
364
#[ pin]
328
365
future: Either <S :: Future , Oneshot <S , Request <B >>>,
329
366
request: RedirectingRequest <S , B , P >,
330
- callback : CB
367
+ handler : CB
331
368
}
332
369
}
333
370
@@ -336,14 +373,14 @@ where
336
373
S : Service < Request < ReqBody > , Response = Response < ResBody > > + Clone ,
337
374
ReqBody : Body + Default ,
338
375
P : Policy < ReqBody , S :: Error > ,
339
- CB : ResponseCallback < ReqBody , ResBody , S , P > ,
376
+ CB : ResponseHandler < ReqBody , ResBody , S , P > ,
340
377
{
341
378
type Output = Result < Response < ResBody > , S :: Error > ;
342
379
343
380
fn poll ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Self :: Output > {
344
381
let mut this = self . project ( ) ;
345
382
let mut res = ready ! ( this. future. as_mut( ) . poll( cx) ?) ;
346
- CB :: handle ( & mut res, & this. request ) ;
383
+ CB :: on_response ( & mut res, & this. request ) ;
347
384
348
385
match this. request . handle_response ( & mut res) {
349
386
Ok ( Some ( pending) ) => {
@@ -402,8 +439,6 @@ where
402
439
& mut self ,
403
440
res : & mut Response < ResBody > ,
404
441
) -> Result < Option < Oneshot < S , Request < ReqBody > > > , S :: Error > {
405
- res. extensions_mut ( ) . insert ( RequestUri ( self . uri . clone ( ) ) ) ;
406
-
407
442
let drop_payload_headers = |headers : & mut HeaderMap | {
408
443
for header in & [
409
444
CONTENT_TYPE ,
@@ -483,6 +518,11 @@ where
483
518
#[ derive( Clone ) ]
484
519
pub struct RequestUri ( pub Uri ) ;
485
520
521
+ /// Response [`Extensions`][http::Extensions] value that contains the redirect [`Policy`] that
522
+ /// was run before the last request of the redirect chain by a [`FollowRedirectExtension`] middleware.
523
+ #[ derive( Clone ) ]
524
+ pub struct FollowedPolicy < P > ( pub P ) ;
525
+
486
526
#[ derive( Debug ) ]
487
527
enum BodyRepr < B > {
488
528
Some ( B ) ,
@@ -551,7 +591,7 @@ mod tests {
551
591
#[ tokio:: test]
552
592
async fn follows ( ) {
553
593
let svc = ServiceBuilder :: new ( )
554
- . layer ( FollowRedirectLayer :: with_policy ( Action :: Follow ) )
594
+ . layer ( FollowRedirectLayer :: with_policy_extension ( Action :: Follow ) )
555
595
. buffer ( 1 )
556
596
. service_fn ( handle) ;
557
597
let req = Request :: builder ( )
@@ -564,12 +604,18 @@ mod tests {
564
604
res. extensions( ) . get:: <RequestUri >( ) . unwrap( ) . 0 ,
565
605
"http://example.com/0"
566
606
) ;
607
+ assert ! ( res
608
+ . extensions( )
609
+ . get:: <FollowedPolicy <Action >>( )
610
+ . unwrap( )
611
+ . 0
612
+ . is_follow( ) ) ;
567
613
}
568
614
569
615
#[ tokio:: test]
570
616
async fn stops ( ) {
571
617
let svc = ServiceBuilder :: new ( )
572
- . layer ( FollowRedirectLayer :: with_policy ( Action :: Stop ) )
618
+ . layer ( FollowRedirectLayer :: with_policy_extension ( Action :: Stop ) )
573
619
. buffer ( 1 )
574
620
. service_fn ( handle) ;
575
621
let req = Request :: builder ( )
@@ -582,12 +628,18 @@ mod tests {
582
628
res. extensions( ) . get:: <RequestUri >( ) . unwrap( ) . 0 ,
583
629
"http://example.com/42"
584
630
) ;
631
+ assert ! ( res
632
+ . extensions( )
633
+ . get:: <FollowedPolicy <Action >>( )
634
+ . unwrap( )
635
+ . 0
636
+ . is_stop( ) ) ;
585
637
}
586
638
587
639
#[ tokio:: test]
588
640
async fn limited ( ) {
589
641
let svc = ServiceBuilder :: new ( )
590
- . layer ( FollowRedirectLayer :: with_policy ( Limited :: new ( 10 ) ) )
642
+ . layer ( FollowRedirectLayer :: with_policy_extension ( Limited :: new ( 10 ) ) )
591
643
. buffer ( 1 )
592
644
. service_fn ( handle) ;
593
645
let req = Request :: builder ( )
@@ -600,6 +652,14 @@ mod tests {
600
652
res. extensions( ) . get:: <RequestUri >( ) . unwrap( ) . 0 ,
601
653
"http://example.com/32"
602
654
) ;
655
+ assert_eq ! (
656
+ res. extensions( )
657
+ . get:: <FollowedPolicy <Limited >>( )
658
+ . unwrap( )
659
+ . 0
660
+ . remaining,
661
+ 0
662
+ ) ;
603
663
}
604
664
605
665
/// A server with an endpoint `GET /{n}` which redirects to `/{n-1}` unless `n` equals zero,
0 commit comments