Skip to content

Commit 35d5c99

Browse files
committed
splittable RequestStream
1 parent 689878d commit 35d5c99

File tree

5 files changed

+135
-47
lines changed

5 files changed

+135
-47
lines changed

examples/server.rs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ use structopt::StructOpt;
88
use tokio::{fs::File, io::AsyncReadExt};
99
use tracing::{debug, error, info, trace_span, warn};
1010

11-
use h3::{quic::BidiStream, server::RequestStream};
11+
use h3::{quic::RecvStream, quic::SendStream, server::RequestStream};
1212

1313
#[derive(StructOpt, Debug)]
1414
#[structopt(name = "server")]
@@ -112,13 +112,14 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
112112
Ok(())
113113
}
114114

115-
async fn handle_request<T>(
115+
async fn handle_request<S, R>(
116116
req: Request<()>,
117-
mut stream: RequestStream<T, Bytes>,
117+
mut stream: RequestStream<S, R, Bytes>,
118118
serve_root: Arc<Option<PathBuf>>,
119119
) -> Result<(), Box<dyn std::error::Error>>
120120
where
121-
T: BidiStream<Bytes>,
121+
S: SendStream<Bytes>,
122+
R: RecvStream,
122123
{
123124
let (status, to_serve) = match serve_root.as_deref() {
124125
None => (StatusCode::OK, None),

h3/src/connection.rs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,7 @@ use crate::{
1818
stream::{StreamId, StreamType},
1919
varint::VarInt,
2020
},
21-
qpack,
22-
quic::{self, SendStream as _},
21+
qpack, quic,
2322
stream::{self, AcceptRecvStream, AcceptedRecvStream},
2423
};
2524

@@ -406,9 +405,9 @@ where
406405
}
407406
}
408407

409-
impl<S, B> RequestStream<FrameStream<S>, B>
408+
impl<S, B> RequestStream<S, B>
410409
where
411-
S: quic::SendStream<B> + quic::RecvStream,
410+
S: quic::SendStream<B>,
412411
B: Buf,
413412
{
414413
/// Send some data on the response body.

h3/src/quic.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ impl<'a, E: Error + 'a> From<E> for Box<dyn Error + 'a> {
3232
/// Trait representing a QUIC connection.
3333
pub trait Connection<B: Buf> {
3434
/// The type produced by `poll_accept_bidi()`
35-
type BidiStream: SendStream<B> + RecvStream;
35+
type BidiStream: BidiStream<B>;
3636
/// The type of the sending part of `BidiStream`
3737
type SendStream: SendStream<B>;
3838
/// The type produced by `poll_accept_recv()`
@@ -80,7 +80,7 @@ pub trait Connection<B: Buf> {
8080
/// Trait for opening outgoing streams
8181
pub trait OpenStreams<B: Buf> {
8282
/// The type produced by `poll_open_bidi()`
83-
type BidiStream: SendStream<B> + RecvStream;
83+
type BidiStream: BidiStream<B, SendStream = Self::SendStream, RecvStream = Self::RecvStream>;
8484
/// The type produced by `poll_open_send()`
8585
type SendStream: SendStream<B>;
8686
/// The type of the receiving part of `BidiStream`

h3/src/server.rs

Lines changed: 121 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ use crate::{
1616
frame::FrameStream,
1717
proto::{frame::Frame, headers::Header, varint::VarInt},
1818
qpack,
19-
quic::{self, RecvStream as _, SendStream as _},
19+
quic::{self, BidiStream, RecvStream as _, SendStream as _},
2020
stream,
2121
};
2222
use tracing::{error, trace, warn};
@@ -65,24 +65,38 @@ where
6565
{
6666
pub async fn accept(
6767
&mut self,
68-
) -> Result<Option<(Request<()>, RequestStream<C::BidiStream, B>)>, Error> {
69-
let mut stream = match future::poll_fn(|cx| self.poll_accept_request(cx)).await {
70-
Ok(Some(s)) => FrameStream::new(s),
71-
Ok(None) => {
72-
// We always send a last GoAway frame to the client, so it knows which was the last
73-
// non-rejected request.
74-
self.inner.shutdown(0).await?;
75-
return Ok(None);
76-
}
77-
Err(e) => {
78-
if e.is_closed() {
68+
) -> Result<
69+
Option<(
70+
Request<()>,
71+
RequestStream<
72+
<C::BidiStream as BidiStream<B>>::SendStream,
73+
<C::BidiStream as BidiStream<B>>::RecvStream,
74+
B,
75+
>,
76+
)>,
77+
Error,
78+
> {
79+
let (send_stream, mut recv_stream) =
80+
match future::poll_fn(|cx| self.poll_accept_request(cx)).await {
81+
Ok(Some(s)) => {
82+
let (send, recv) = s.split();
83+
(send, FrameStream::new(recv))
84+
}
85+
Ok(None) => {
86+
// We always send a last GoAway frame to the client, so it knows which was the last
87+
// non-rejected request.
88+
self.inner.shutdown(0).await?;
7989
return Ok(None);
8090
}
81-
return Err(e);
82-
}
83-
};
91+
Err(e) => {
92+
if e.is_closed() {
93+
return Ok(None);
94+
}
95+
return Err(e);
96+
}
97+
};
8498

85-
let frame = future::poll_fn(|cx| stream.poll_next(cx)).await;
99+
let frame = future::poll_fn(|cx| recv_stream.poll_next(cx)).await;
86100

87101
let mut encoded = match frame {
88102
Ok(Some(Frame::Headers(h))) => h,
@@ -105,14 +119,24 @@ where
105119
}
106120
};
107121

122+
let stream_id = send_stream.id();
108123
let mut request_stream = RequestStream {
109-
stream_id: stream.id(),
110-
request_end: self.request_end_send.clone(),
111-
inner: connection::RequestStream::new(
112-
stream,
113-
self.max_field_section_size,
114-
self.inner.shared.clone(),
115-
),
124+
send: RequestSendStream {
125+
inner: connection::RequestStream::new(
126+
send_stream,
127+
self.max_field_section_size,
128+
self.inner.shared.clone(),
129+
),
130+
},
131+
recv: RequestRecvStream {
132+
stream_id,
133+
request_end: self.request_end_send.clone(),
134+
inner: connection::RequestStream::new(
135+
recv_stream,
136+
self.max_field_section_size,
137+
self.inner.shared.clone(),
138+
),
139+
},
116140
};
117141

118142
let qpack::Decoded { fields, .. } =
@@ -295,7 +319,11 @@ impl Builder {
295319
}
296320
}
297321

298-
pub struct RequestStream<S, B>
322+
pub struct RequestSendStream<S, B> {
323+
inner: connection::RequestStream<S, B>,
324+
}
325+
326+
pub struct RequestRecvStream<S, B>
299327
where
300328
S: quic::RecvStream,
301329
{
@@ -304,16 +332,30 @@ where
304332
request_end: mpsc::UnboundedSender<StreamId>,
305333
}
306334

307-
impl<S, B> ConnectionState for RequestStream<S, B>
335+
pub struct RequestStream<S, R, B>
308336
where
309-
S: quic::RecvStream,
337+
R: quic::RecvStream,
310338
{
339+
send: RequestSendStream<S, B>,
340+
recv: RequestRecvStream<R, B>,
341+
}
342+
343+
impl<S, B> ConnectionState for RequestSendStream<S, B> {
311344
fn shared_state(&self) -> &SharedStateRef {
312345
&self.inner.conn_state
313346
}
314347
}
315348

316-
impl<S, B> RequestStream<S, B>
349+
impl<S, R, B> ConnectionState for RequestStream<S, R, B>
350+
where
351+
R: quic::RecvStream,
352+
{
353+
fn shared_state(&self) -> &SharedStateRef {
354+
self.send.shared_state()
355+
}
356+
}
357+
358+
impl<S, B> RequestRecvStream<S, B>
317359
where
318360
S: quic::RecvStream,
319361
{
@@ -326,9 +368,9 @@ where
326368
}
327369
}
328370

329-
impl<S, B> RequestStream<S, B>
371+
impl<S, B> RequestSendStream<S, B>
330372
where
331-
S: quic::RecvStream + quic::SendStream<B>,
373+
S: quic::SendStream<B>,
332374
B: Buf,
333375
{
334376
pub async fn send_response(&mut self, resp: Response<()>) -> Result<(), Error> {
@@ -370,13 +412,59 @@ where
370412
}
371413
}
372414

373-
impl<S, B> RequestStream<S, B>
415+
impl<S, R, B> RequestStream<S, R, B>
416+
where
417+
S: quic::SendStream<B>,
418+
R: quic::RecvStream,
419+
B: Buf,
420+
{
421+
pub async fn recv_data(&mut self) -> Result<Option<impl Buf>, Error> {
422+
self.recv.recv_data().await
423+
}
424+
425+
pub fn stop_sending(&mut self, error_code: crate::error::Code) {
426+
self.recv.stop_sending(error_code)
427+
}
428+
429+
pub async fn send_response(&mut self, resp: Response<()>) -> Result<(), Error> {
430+
self.send.send_response(resp).await
431+
}
432+
433+
pub async fn send_data(&mut self, buf: B) -> Result<(), Error> {
434+
self.send.send_data(buf).await
435+
}
436+
437+
pub async fn send_trailers(&mut self, trailers: HeaderMap) -> Result<(), Error> {
438+
self.send.send_trailers(trailers).await
439+
}
440+
441+
pub async fn finish(&mut self) -> Result<(), Error> {
442+
self.send.finish().await
443+
}
444+
445+
pub fn split(self) -> (RequestSendStream<S, B>, RequestRecvStream<R, B>) {
446+
(self.send, self.recv)
447+
}
448+
}
449+
450+
impl<S, B> RequestRecvStream<S, B>
451+
where
452+
S: quic::RecvStream, /*+ quic::SendStream<B>*/
453+
B: Buf,
454+
{
455+
async fn recv_trailers(&mut self) -> Result<Option<HeaderMap>, Error> {
456+
self.inner.recv_trailers().await
457+
}
458+
}
459+
460+
impl<S, R, B> RequestStream<S, R, B>
374461
where
375-
S: quic::RecvStream + quic::SendStream<B>,
462+
S: quic::SendStream<B>,
463+
R: quic::RecvStream,
376464
B: Buf,
377465
{
378466
pub async fn recv_trailers(&mut self) -> Result<Option<HeaderMap>, Error> {
379-
let res = self.inner.recv_trailers().await;
467+
let res = self.recv.recv_trailers().await;
380468
if let Err(ref e) = res {
381469
if e.is_header_too_big() {
382470
self.send_response(
@@ -392,7 +480,7 @@ where
392480
}
393481
}
394482

395-
impl<S, B> Drop for RequestStream<S, B>
483+
impl<S, B> Drop for RequestRecvStream<S, B>
396484
where
397485
S: quic::RecvStream,
398486
{

tests/h3-tests/tests/connection.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,7 @@ use http::{Request, Response, StatusCode};
88
use h3::{
99
client::{self, SendRequest},
1010
error::{Code, Kind},
11-
quic::{self, SendStream},
12-
server,
11+
quic, server,
1312
test_helpers::{
1413
proto::{
1514
coding::Encode as _,
@@ -667,9 +666,10 @@ where
667666
request_stream.recv_response().await
668667
}
669668

670-
async fn response<S, B>(mut stream: server::RequestStream<S, B>)
669+
async fn response<S, R, B>(mut stream: server::RequestStream<S, R, B>)
671670
where
672-
S: quic::RecvStream + SendStream<B>,
671+
S: quic::SendStream<B>,
672+
R: quic::RecvStream,
673673
B: Buf,
674674
{
675675
stream

0 commit comments

Comments
 (0)