Skip to content

Commit 9ff4f7b

Browse files
feat(transport): Support timeouts with "grpc-timeout" header (#606)
* transport: Support timeouts with "grpc-timeout" header * Apply suggestions from code review Co-authored-by: Lucio Franco <[email protected]> * Timeout -> GrpcTimeout and export TimeoutExpired * Clean up imports * Give header name a more proper home * Add fuzz tests for parsing header value into `grpc-timeout` * Map `TimeoutExpired` to `cancelled` status * Recover from timeout errors in the service * Refactor tests * Fix CI * Fix CI, again Co-authored-by: Lucio Franco <[email protected]>
1 parent 4926c60 commit 9ff4f7b

File tree

13 files changed

+494
-14
lines changed

13 files changed

+494
-14
lines changed

.github/workflows/CI.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ jobs:
4949

5050
env:
5151
RUSTFLAGS: "-D warnings"
52+
# run a lot of quickcheck iterations
53+
QUICKCHECK_TESTS: 1000
5254

5355
steps:
5456
- uses: hecrj/setup-rust-action@master

tests/integration_tests/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ bytes = "1.0"
1616

1717
[dev-dependencies]
1818
tokio = { version = "1.0", features = ["macros", "rt-multi-thread", "net"] }
19+
tokio-stream = { version = "0.1.5", features = ["net"] }
1920

2021
[build-dependencies]
2122
tonic-build = { path = "../../tonic-build" }
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
use integration_tests::pb::{test_client, test_server, Input, Output};
2+
use std::{net::SocketAddr, time::Duration};
3+
use tokio::net::TcpListener;
4+
use tonic::{transport::Server, Code, Request, Response, Status};
5+
6+
#[tokio::test]
7+
async fn cancelation_on_timeout() {
8+
let addr = run_service_in_background(Duration::from_secs(1), Duration::from_secs(100)).await;
9+
10+
let mut client = test_client::TestClient::connect(format!("http://{}", addr))
11+
.await
12+
.unwrap();
13+
14+
let mut req = Request::new(Input {});
15+
req.metadata_mut()
16+
// 500 ms
17+
.insert("grpc-timeout", "500m".parse().unwrap());
18+
19+
let res = client.unary_call(req).await;
20+
21+
let err = res.unwrap_err();
22+
assert!(err.message().contains("Timeout expired"));
23+
assert_eq!(err.code(), Code::Cancelled);
24+
}
25+
26+
#[tokio::test]
27+
async fn picks_server_timeout_if_thats_sorter() {
28+
let addr = run_service_in_background(Duration::from_secs(1), Duration::from_millis(100)).await;
29+
30+
let mut client = test_client::TestClient::connect(format!("http://{}", addr))
31+
.await
32+
.unwrap();
33+
34+
let mut req = Request::new(Input {});
35+
req.metadata_mut()
36+
// 10 hours
37+
.insert("grpc-timeout", "10H".parse().unwrap());
38+
39+
let res = client.unary_call(req).await;
40+
let err = res.unwrap_err();
41+
assert!(err.message().contains("Timeout expired"));
42+
assert_eq!(err.code(), Code::Cancelled);
43+
}
44+
45+
#[tokio::test]
46+
async fn picks_client_timeout_if_thats_sorter() {
47+
let addr = run_service_in_background(Duration::from_secs(1), Duration::from_secs(100)).await;
48+
49+
let mut client = test_client::TestClient::connect(format!("http://{}", addr))
50+
.await
51+
.unwrap();
52+
53+
let mut req = Request::new(Input {});
54+
req.metadata_mut()
55+
// 100 ms
56+
.insert("grpc-timeout", "100m".parse().unwrap());
57+
58+
let res = client.unary_call(req).await;
59+
let err = res.unwrap_err();
60+
assert!(err.message().contains("Timeout expired"));
61+
assert_eq!(err.code(), Code::Cancelled);
62+
}
63+
64+
async fn run_service_in_background(latency: Duration, server_timeout: Duration) -> SocketAddr {
65+
struct Svc {
66+
latency: Duration,
67+
}
68+
69+
#[tonic::async_trait]
70+
impl test_server::Test for Svc {
71+
async fn unary_call(&self, _req: Request<Input>) -> Result<Response<Output>, Status> {
72+
tokio::time::sleep(self.latency).await;
73+
Ok(Response::new(Output {}))
74+
}
75+
}
76+
77+
let svc = test_server::TestServer::new(Svc { latency });
78+
79+
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
80+
let addr = listener.local_addr().unwrap();
81+
82+
tokio::spawn(async move {
83+
Server::builder()
84+
.timeout(server_timeout)
85+
.add_service(svc)
86+
.serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(listener))
87+
.await
88+
.unwrap();
89+
});
90+
91+
addr
92+
}

tonic/Cargo.toml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ transport = [
3131
"tokio",
3232
"tower",
3333
"tracing-futures",
34-
"tokio/macros"
34+
"tokio/macros",
35+
"tokio/time",
3536
]
3637
tls = ["transport", "tokio-rustls"]
3738
tls-roots = ["tls", "rustls-native-certs"]
@@ -68,7 +69,7 @@ h2 = { version = "0.3", optional = true }
6869
hyper = { version = "0.14.2", features = ["full"], optional = true }
6970
tokio = { version = "1.0.1", features = ["net"], optional = true }
7071
tokio-stream = "0.1"
71-
tower = { version = "0.4.4", features = ["balance", "buffer", "discover", "limit", "load", "make", "timeout", "util"], optional = true }
72+
tower = { version = "0.4.7", features = ["balance", "buffer", "discover", "limit", "load", "make", "timeout", "util"], optional = true }
7273
tracing-futures = { version = "0.2", optional = true }
7374

7475
# rustls
@@ -80,6 +81,8 @@ tokio = { version = "1.0", features = ["rt", "macros"] }
8081
static_assertions = "1.0"
8182
rand = "0.8"
8283
bencher = "0.1.5"
84+
quickcheck = "1.0"
85+
quickcheck_macros = "1.0"
8386

8487
[package.metadata.docs.rs]
8588
all-features = true

tonic/src/metadata/map.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -194,15 +194,17 @@ pub struct OccupiedEntry<'a, VE: ValueEncoding> {
194194
phantom: PhantomData<VE>,
195195
}
196196

197+
#[cfg(feature = "transport")]
198+
pub(crate) const GRPC_TIMEOUT_HEADER: &str = "grpc-timeout";
199+
197200
// ===== impl MetadataMap =====
198201

199202
impl MetadataMap {
200203
// Headers reserved by the gRPC protocol.
201-
pub(crate) const GRPC_RESERVED_HEADERS: [&'static str; 8] = [
204+
pub(crate) const GRPC_RESERVED_HEADERS: [&'static str; 7] = [
202205
"te",
203206
"user-agent",
204207
"content-type",
205-
"grpc-timeout",
206208
"grpc-message",
207209
"grpc-encoding",
208210
"grpc-message-type",

tonic/src/metadata/mod.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ pub use self::value::AsciiMetadataValue;
2929
pub use self::value::BinaryMetadataValue;
3030
pub use self::value::MetadataValue;
3131

32+
#[cfg(feature = "transport")]
33+
pub(crate) use self::map::GRPC_TIMEOUT_HEADER;
34+
3235
/// The metadata::errors module contains types for errors that can occur
3336
/// while handling gRPC custom metadata.
3437
pub mod errors {

tonic/src/status.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,7 @@ impl Status {
313313
Status::try_from_error(err).unwrap_or_else(|| Status::new(Code::Unknown, err.to_string()))
314314
}
315315

316-
fn try_from_error(err: &(dyn Error + 'static)) -> Option<Status> {
316+
pub(crate) fn try_from_error(err: &(dyn Error + 'static)) -> Option<Status> {
317317
let mut cause = Some(err);
318318

319319
while let Some(err) = cause {
@@ -331,6 +331,10 @@ impl Status {
331331
if let Some(h2) = err.downcast_ref::<h2::Error>() {
332332
return Some(Status::from_h2_error(h2));
333333
}
334+
335+
if let Some(timeout) = err.downcast_ref::<crate::transport::TimeoutExpired>() {
336+
return Some(Status::cancelled(timeout.to_string()));
337+
}
334338
}
335339

336340
cause = err.source();

tonic/src/transport/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,8 @@ pub use self::channel::{Channel, Endpoint};
9898
pub use self::error::Error;
9999
#[doc(inline)]
100100
pub use self::server::{NamedService, Server};
101+
#[doc(inline)]
102+
pub use self::service::TimeoutExpired;
101103
pub use self::tls::{Certificate, Identity};
102104
pub use hyper::{Body, Uri};
103105

tonic/src/transport/server/mod.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
33
mod conn;
44
mod incoming;
5+
mod recover_error;
56
#[cfg(feature = "tls")]
67
#[cfg_attr(docsrs, doc(cfg(feature = "tls")))]
78
mod tls;
@@ -21,8 +22,9 @@ pub(crate) use tokio_rustls::server::TlsStream;
2122
#[cfg(feature = "tls")]
2223
use crate::transport::Error;
2324

25+
use self::recover_error::RecoverError;
2426
use super::{
25-
service::{Or, Routes, ServerIo},
27+
service::{GrpcTimeout, Or, Routes, ServerIo},
2628
BoxFuture,
2729
};
2830
use crate::{body::BoxBody, request::ConnectionInfo};
@@ -42,10 +44,7 @@ use std::{
4244
time::Duration,
4345
};
4446
use tokio::io::{AsyncRead, AsyncWrite};
45-
use tower::{
46-
limit::concurrency::ConcurrencyLimitLayer, timeout::TimeoutLayer, util::Either, Service,
47-
ServiceBuilder,
48-
};
47+
use tower::{limit::concurrency::ConcurrencyLimitLayer, util::Either, Service, ServiceBuilder};
4948
use tracing_futures::{Instrument, Instrumented};
5049

5150
type BoxService = tower::util::BoxService<Request<Body>, Response<BoxBody>, crate::Error>;
@@ -655,8 +654,9 @@ where
655654

656655
Box::pin(async move {
657656
let svc = ServiceBuilder::new()
657+
.layer_fn(RecoverError::new)
658658
.option_layer(concurrency_limit.map(ConcurrencyLimitLayer::new))
659-
.option_layer(timeout.map(TimeoutLayer::new))
659+
.layer_fn(|s| GrpcTimeout::new(s, timeout))
660660
.service(svc);
661661

662662
let svc = BoxService::new(Svc {
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
use crate::{body::BoxBody, Status};
2+
use futures_util::ready;
3+
use http::Response;
4+
use pin_project::pin_project;
5+
use std::{
6+
future::Future,
7+
pin::Pin,
8+
task::{Context, Poll},
9+
};
10+
use tower::Service;
11+
12+
/// Middleware that attempts to recover from service errors by turning them into a response built
13+
/// from the `Status`.
14+
#[derive(Debug, Clone)]
15+
pub(crate) struct RecoverError<S> {
16+
inner: S,
17+
}
18+
19+
impl<S> RecoverError<S> {
20+
pub(crate) fn new(inner: S) -> Self {
21+
Self { inner }
22+
}
23+
}
24+
25+
impl<S, R> Service<R> for RecoverError<S>
26+
where
27+
S: Service<R, Response = Response<BoxBody>>,
28+
S::Error: Into<crate::Error>,
29+
{
30+
type Response = Response<BoxBody>;
31+
type Error = crate::Error;
32+
type Future = ResponseFuture<S::Future>;
33+
34+
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
35+
self.inner.poll_ready(cx).map_err(Into::into)
36+
}
37+
38+
fn call(&mut self, req: R) -> Self::Future {
39+
ResponseFuture {
40+
inner: self.inner.call(req),
41+
}
42+
}
43+
}
44+
45+
#[pin_project]
46+
pub(crate) struct ResponseFuture<F> {
47+
#[pin]
48+
inner: F,
49+
}
50+
51+
impl<F, E> Future for ResponseFuture<F>
52+
where
53+
F: Future<Output = Result<Response<BoxBody>, E>>,
54+
E: Into<crate::Error>,
55+
{
56+
type Output = Result<Response<BoxBody>, crate::Error>;
57+
58+
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
59+
let result: Result<Response<BoxBody>, crate::Error> =
60+
ready!(self.project().inner.poll(cx)).map_err(Into::into);
61+
62+
match result {
63+
Ok(res) => Poll::Ready(Ok(res)),
64+
Err(err) => {
65+
if let Some(status) = Status::try_from_error(&*err) {
66+
let mut res = Response::new(BoxBody::empty());
67+
status.add_header(res.headers_mut()).unwrap();
68+
Poll::Ready(Ok(res))
69+
} else {
70+
Poll::Ready(Err(err))
71+
}
72+
}
73+
}
74+
}
75+
}

0 commit comments

Comments
 (0)