Skip to content

Commit 7063268

Browse files
authored
Fix closing pending frames (#194)
* send pending frames after collecting them * test * Apply suggestions from code review * Update yamux/src/connection/closing.rs
1 parent 8bd5d40 commit 7063268

File tree

1 file changed

+110
-11
lines changed

1 file changed

+110
-11
lines changed

yamux/src/connection/closing.rs

Lines changed: 110 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ where
3030
socket: Fuse<frame::Io<T>>,
3131
) -> Self {
3232
Self {
33-
state: State::FlushingPendingFrames,
33+
state: State::ClosingStreamReceiver,
3434
stream_receivers,
3535
pending_frames,
3636
socket,
@@ -49,14 +49,6 @@ where
4949

5050
loop {
5151
match this.state {
52-
State::FlushingPendingFrames => {
53-
ready!(this.socket.poll_ready_unpin(cx))?;
54-
55-
match this.pending_frames.pop_front() {
56-
Some(frame) => this.socket.start_send_unpin(frame)?,
57-
None => this.state = State::ClosingStreamReceiver,
58-
}
59-
}
6052
State::ClosingStreamReceiver => {
6153
for stream in this.stream_receivers.iter_mut() {
6254
stream.inner_mut().close();
@@ -77,11 +69,19 @@ where
7769
Poll::Pending | Poll::Ready(None) => {
7870
// No more frames from streams, append `Term` frame and flush them all.
7971
this.pending_frames.push_back(Frame::term().into());
80-
this.state = State::ClosingSocket;
72+
this.state = State::FlushingPendingFrames;
8173
continue;
8274
}
8375
}
8476
}
77+
State::FlushingPendingFrames => {
78+
ready!(this.socket.poll_ready_unpin(cx))?;
79+
80+
match this.pending_frames.pop_front() {
81+
Some(frame) => this.socket.start_send_unpin(frame)?,
82+
None => this.state = State::ClosingSocket,
83+
}
84+
}
8585
State::ClosingSocket => {
8686
ready!(this.socket.poll_close_unpin(cx))?;
8787

@@ -93,8 +93,107 @@ where
9393
}
9494

9595
enum State {
96-
FlushingPendingFrames,
9796
ClosingStreamReceiver,
9897
DrainingStreamReceiver,
98+
FlushingPendingFrames,
9999
ClosingSocket,
100100
}
101+
102+
#[cfg(test)]
103+
mod tests {
104+
use super::*;
105+
use futures::future::poll_fn;
106+
use futures::FutureExt;
107+
108+
struct Socket {
109+
written: Vec<u8>,
110+
closed: bool,
111+
}
112+
impl AsyncRead for Socket {
113+
fn poll_read(
114+
self: Pin<&mut Self>,
115+
_: &mut Context<'_>,
116+
_: &mut [u8],
117+
) -> Poll<std::io::Result<usize>> {
118+
unimplemented!()
119+
}
120+
}
121+
impl AsyncWrite for Socket {
122+
fn poll_write(
123+
mut self: Pin<&mut Self>,
124+
_: &mut Context<'_>,
125+
buf: &[u8],
126+
) -> Poll<std::io::Result<usize>> {
127+
assert!(!self.closed);
128+
self.written.extend_from_slice(buf);
129+
Poll::Ready(Ok(buf.len()))
130+
}
131+
132+
fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<std::io::Result<()>> {
133+
unimplemented!()
134+
}
135+
136+
fn poll_close(mut self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<std::io::Result<()>> {
137+
assert!(!self.closed);
138+
self.closed = true;
139+
Poll::Ready(Ok(()))
140+
}
141+
}
142+
143+
#[test]
144+
fn pending_frames() {
145+
let frame_pending = Frame::data(StreamId::new(1), vec![2]).unwrap().into();
146+
let frame_data = Frame::data(StreamId::new(3), vec![4]).unwrap().into();
147+
let frame_close = Frame::close_stream(StreamId::new(5), false).into();
148+
let frame_close_ack = Frame::close_stream(StreamId::new(6), true).into();
149+
let frame_term = Frame::term().into();
150+
fn encode(buf: &mut Vec<u8>, frame: &Frame<()>) {
151+
buf.extend_from_slice(&frame::header::encode(frame.header()));
152+
if frame.header().tag() == frame::header::Tag::Data {
153+
buf.extend_from_slice(frame.clone().into_data().body());
154+
}
155+
}
156+
let mut expected_written = vec![];
157+
encode(&mut expected_written, &frame_pending);
158+
encode(&mut expected_written, &frame_data);
159+
encode(&mut expected_written, &frame_close);
160+
encode(&mut expected_written, &frame_close_ack);
161+
encode(&mut expected_written, &frame_term);
162+
163+
let receiver = |frame: &Frame<_>, command: StreamCommand| {
164+
TaggedStream::new(frame.header().stream_id(), {
165+
let (mut tx, rx) = mpsc::channel(1);
166+
tx.try_send(command).unwrap();
167+
rx
168+
})
169+
};
170+
171+
let mut stream_receivers: SelectAll<_> = Default::default();
172+
stream_receivers.push(receiver(
173+
&frame_data,
174+
StreamCommand::SendFrame(frame_data.clone().into_data().left()),
175+
));
176+
stream_receivers.push(receiver(
177+
&frame_close,
178+
StreamCommand::CloseStream { ack: false },
179+
));
180+
stream_receivers.push(receiver(
181+
&frame_close_ack,
182+
StreamCommand::CloseStream { ack: true },
183+
));
184+
let pending_frames = vec![frame_pending.into()];
185+
let mut socket = Socket {
186+
written: vec![],
187+
closed: false,
188+
};
189+
let mut closing = Closing::new(
190+
stream_receivers,
191+
pending_frames.into(),
192+
frame::Io::new(crate::connection::Id(0), &mut socket).fuse(),
193+
);
194+
futures::executor::block_on(async { poll_fn(|cx| closing.poll_unpin(cx)).await.unwrap() });
195+
assert!(closing.pending_frames.is_empty());
196+
assert!(socket.closed);
197+
assert_eq!(socket.written, expected_written);
198+
}
199+
}

0 commit comments

Comments
 (0)