diff --git a/h3/src/client.rs b/h3/src/client.rs index 990bb493..d7c0e77e 100644 --- a/h3/src/client.rs +++ b/h3/src/client.rs @@ -261,17 +261,11 @@ impl Builder { } } -pub struct RequestStream -where - S: quic::RecvStream, -{ - inner: connection::RequestStream, B>, +pub struct RequestStream { + inner: connection::RequestStream, } -impl ConnectionState for RequestStream -where - S: quic::RecvStream, -{ +impl ConnectionState for RequestStream { fn shared_state(&self) -> &SharedStateRef { &self.inner.conn_state } @@ -339,7 +333,7 @@ where impl RequestStream where - S: quic::RecvStream + quic::SendStream, + S: quic::SendStream, B: Buf, { pub async fn send_data(&mut self, buf: B) -> Result<(), Error> { @@ -354,3 +348,19 @@ where self.inner.finish().await } } + +impl RequestStream +where + S: quic::BidiStream, + B: Buf, +{ + pub fn split( + self, + ) -> ( + RequestStream, + RequestStream, + ) { + let (send, recv) = self.inner.split(); + (RequestStream { inner: send }, RequestStream { inner: recv }) + } +} diff --git a/h3/src/connection.rs b/h3/src/connection.rs index 9ef99ab9..e3ee633c 100644 --- a/h3/src/connection.rs +++ b/h3/src/connection.rs @@ -1,6 +1,5 @@ use std::{ convert::TryFrom, - marker::PhantomData, sync::{Arc, RwLock, RwLockReadGuard, RwLockWriteGuard}, task::{Context, Poll}, }; @@ -79,9 +78,9 @@ where pub(super) shared: SharedStateRef, conn: C, control_send: C::SendStream, - control_recv: Option>, - decoder_recv: Option>, - encoder_recv: Option>, + control_recv: Option>, + decoder_recv: Option>, + encoder_recv: Option>, pending_recv_streams: Vec>, // The id of the last stream received by this connection: // request and push stream for server and clients respectively. @@ -307,21 +306,23 @@ where } pub struct RequestStream { - pub(super) stream: S, + pub(super) stream: FrameStream, pub(super) trailers: Option, pub(super) conn_state: SharedStateRef, pub(super) max_field_section_size: u64, - _phantom_buffer: PhantomData, } impl RequestStream { - pub fn new(stream: S, max_field_section_size: u64, conn_state: SharedStateRef) -> Self { + pub fn new( + stream: FrameStream, + max_field_section_size: u64, + conn_state: SharedStateRef, + ) -> Self { Self { stream, conn_state, max_field_section_size, trailers: None, - _phantom_buffer: PhantomData, } } } @@ -332,7 +333,7 @@ impl ConnectionState for RequestStream { } } -impl RequestStream, B> +impl RequestStream where S: quic::RecvStream, { @@ -406,9 +407,9 @@ where } } -impl RequestStream, B> +impl RequestStream where - S: quic::SendStream + quic::RecvStream, + S: quic::SendStream, B: Buf, { /// Send some data on the response body. @@ -449,3 +450,33 @@ where .map_err(|e| self.maybe_conn_err(e)) } } + +impl RequestStream +where + S: quic::BidiStream, + B: Buf, +{ + pub(crate) fn split( + self, + ) -> ( + RequestStream, + RequestStream, + ) { + let (send, recv) = self.stream.split(); + + ( + RequestStream { + stream: send, + trailers: None, + conn_state: self.conn_state.clone(), + max_field_section_size: 0, + }, + RequestStream { + stream: recv, + trailers: self.trailers, + conn_state: self.conn_state, + max_field_section_size: self.max_field_section_size, + }, + ) + } +} diff --git a/h3/src/frame.rs b/h3/src/frame.rs index 4a19f93d..f414bc31 100644 --- a/h3/src/frame.rs +++ b/h3/src/frame.rs @@ -1,3 +1,4 @@ +use std::marker::PhantomData; use std::task::{Context, Poll}; use bytes::{Buf, Bytes}; @@ -12,26 +13,21 @@ use crate::{ frame::{self, Frame, PayloadLen}, stream::StreamId, }, - quic::{RecvStream, SendStream}, + quic::{BidiStream, RecvStream, SendStream}, stream::WriteBuf, }; -pub struct FrameStream -where - S: RecvStream, -{ +pub struct FrameStream { stream: S, bufs: BufList, decoder: FrameDecoder, remaining_data: usize, /// Set to true when `stream` reaches the end. is_eos: bool, + _phantom_buffer: PhantomData, } -impl FrameStream -where - S: RecvStream, -{ +impl FrameStream { pub fn new(stream: S) -> Self { Self::with_bufs(stream, BufList::new()) } @@ -43,9 +39,15 @@ where decoder: FrameDecoder::default(), remaining_data: 0, is_eos: false, + _phantom_buffer: PhantomData, } } +} +impl FrameStream +where + S: RecvStream, +{ pub fn poll_next( &mut self, cx: &mut Context<'_>, @@ -136,9 +138,9 @@ where } } -impl SendStream for FrameStream +impl SendStream for FrameStream where - T: SendStream + RecvStream, + T: SendStream, B: Buf, { type Error = >::Error; @@ -164,6 +166,34 @@ where } } +impl FrameStream +where + S: BidiStream, + B: Buf, +{ + pub(crate) fn split(self) -> (FrameStream, FrameStream) { + let (send, recv) = self.stream.split(); + ( + FrameStream { + stream: send, + bufs: BufList::new(), + decoder: FrameDecoder::default(), + remaining_data: 0, + is_eos: false, + _phantom_buffer: PhantomData, + }, + FrameStream { + stream: recv, + bufs: self.bufs, + decoder: self.decoder, + remaining_data: self.remaining_data, + is_eos: self.is_eos, + _phantom_buffer: PhantomData, + }, + ) + } +} + #[derive(Default)] pub struct FrameDecoder { expected: Option, @@ -338,7 +368,7 @@ mod tests { Frame::headers(&b"trailer"[..]).encode_with_payload(&mut buf); recv.chunk(buf.freeze()); - let mut stream = FrameStream::new(recv); + let mut stream: FrameStream<_, ()> = FrameStream::new(recv); assert_poll_matches!( |mut cx| stream.poll_next(&mut cx), @@ -366,7 +396,7 @@ mod tests { Frame::headers(&b"header"[..]).encode_with_payload(&mut buf); let mut buf = buf.freeze(); recv.chunk(buf.split_to(buf.len() - 1)); - let mut stream = FrameStream::new(recv); + let mut stream: FrameStream<_, ()> = FrameStream::new(recv); assert_poll_matches!( |mut cx| stream.poll_next(&mut cx), @@ -385,7 +415,7 @@ mod tests { FrameType::DATA.encode(&mut buf); VarInt::from(4u32).encode(&mut buf); recv.chunk(buf.freeze()); - let mut stream = FrameStream::new(recv); + let mut stream: FrameStream<_, ()> = FrameStream::new(recv); assert_poll_matches!( |mut cx| stream.poll_next(&mut cx), @@ -407,7 +437,7 @@ mod tests { let mut buf = buf.freeze(); recv.chunk(buf.split_to(buf.len() - 2)); recv.chunk(buf); - let mut stream = FrameStream::new(recv); + let mut stream: FrameStream<_, ()> = FrameStream::new(recv); // We get the total size of data about to be received assert_poll_matches!( @@ -436,7 +466,7 @@ mod tests { VarInt::from(4u32).encode(&mut buf); buf.put_slice(&b"b"[..]); recv.chunk(buf.freeze()); - let mut stream = FrameStream::new(recv); + let mut stream: FrameStream<_, ()> = FrameStream::new(recv); assert_poll_matches!( |mut cx| stream.poll_next(&mut cx), @@ -468,7 +498,7 @@ mod tests { Frame::Data(Bytes::from("body")).encode_with_payload(&mut buf); recv.chunk(buf.freeze()); - let mut stream = FrameStream::new(recv); + let mut stream: FrameStream<_, ()> = FrameStream::new(recv); assert_poll_matches!( |mut cx| stream.poll_next(&mut cx), @@ -490,7 +520,7 @@ mod tests { buf.put_slice(&b"bo"[..]); recv.chunk(buf.clone().freeze()); - let mut stream = FrameStream::new(recv); + let mut stream: FrameStream<_, ()> = FrameStream::new(recv); assert_poll_matches!( |mut cx| stream.poll_next(&mut cx), diff --git a/h3/src/server.rs b/h3/src/server.rs index 13662f80..6084bb17 100644 --- a/h3/src/server.rs +++ b/h3/src/server.rs @@ -1,6 +1,7 @@ use std::{ collections::HashSet, convert::TryFrom, + sync::Arc, task::{Context, Poll}, }; @@ -107,8 +108,10 @@ where }; let mut request_stream = RequestStream { - stream_id: stream.id(), - request_end: self.request_end_send.clone(), + request_end: Arc::new(RequestEnd { + request_end: self.request_end_send.clone(), + stream_id: stream.id(), + }), inner: connection::RequestStream::new( stream, self.max_field_section_size, @@ -296,19 +299,23 @@ impl Builder { } } -pub struct RequestStream -where - S: quic::RecvStream, -{ - inner: connection::RequestStream, B>, - stream_id: StreamId, +pub struct RequestEnd { request_end: mpsc::UnboundedSender, + stream_id: StreamId, } -impl ConnectionState for RequestStream -where - S: quic::RecvStream, -{ +pub struct RequestStream { + inner: connection::RequestStream, + request_end: Arc, +} + +impl AsMut> for RequestStream { + fn as_mut(&mut self) -> &mut connection::RequestStream { + &mut self.inner + } +} + +impl ConnectionState for RequestStream { fn shared_state(&self) -> &SharedStateRef { &self.inner.conn_state } @@ -329,7 +336,7 @@ where impl RequestStream where - S: quic::RecvStream + quic::SendStream, + S: quic::SendStream, B: Buf, { pub async fn send_response(&mut self, resp: Response<()>) -> Result<(), Error> { @@ -393,10 +400,32 @@ where } } -impl Drop for RequestStream +impl RequestStream where - S: quic::RecvStream, + S: quic::BidiStream, + B: Buf, { + pub fn split( + self, + ) -> ( + RequestStream, + RequestStream, + ) { + let (send, recv) = self.inner.split(); + ( + RequestStream { + inner: send, + request_end: self.request_end.clone(), + }, + RequestStream { + inner: recv, + request_end: self.request_end, + }, + ) + } +} + +impl Drop for RequestEnd { fn drop(&mut self) { if let Err(e) = self.request_end.send(self.stream_id) { error!( diff --git a/h3/src/stream.rs b/h3/src/stream.rs index ae43bdae..ba3b4eed 100644 --- a/h3/src/stream.rs +++ b/h3/src/stream.rs @@ -159,12 +159,12 @@ where } } -pub(super) enum AcceptedRecvStream +pub(super) enum AcceptedRecvStream where S: quic::RecvStream, { - Control(FrameStream), - Push(u64, FrameStream), + Control(FrameStream), + Push(u64, FrameStream), Encoder(S), Decoder(S), Reserved, @@ -195,7 +195,7 @@ where } } - pub fn into_stream(self) -> Result, Error> { + pub fn into_stream(self) -> Result, Error> { Ok(match self.ty.expect("Stream type not resolved yet") { StreamType::CONTROL => { AcceptedRecvStream::Control(FrameStream::with_bufs(self.stream, self.buf))