Skip to content

Commit 8a36a80

Browse files
committed
feat: close ImapStream after an error
This prevents calls to poll_next() from reaching the underlying stream after an error is returned once. Previously calls to poll_next() to ImapStream built on top of a network connection could have returned infinite stream of Some(Err(_)) values. This is dangerous for the code that looks for the end of stream without processing errors such as `while stream.next().await.is_some() {}` because it may result in infinite loop. It is still better to process errors by writing `while stream.try_next().await?.is_some() {}`, but the change protects in case of incorrect library user code.
1 parent 4d2d23f commit 8a36a80

File tree

1 file changed

+184
-37
lines changed

1 file changed

+184
-37
lines changed

src/imap_stream.rs

Lines changed: 184 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,15 @@ pub struct ImapStream<R: Read + Write> {
2525
decode_needs: usize,
2626
/// The buffer.
2727
buffer: Buffer,
28+
29+
/// True if the stream should not return any more items.
30+
///
31+
/// This is set when reading from a stream
32+
/// returns an error.
33+
/// Afterwards the stream returns only `None`
34+
/// and `poll_next()` does not read
35+
/// from the underlying stream.
36+
read_closed: bool,
2837
}
2938

3039
impl<R: Read + Write + Unpin> ImapStream<R> {
@@ -34,6 +43,7 @@ impl<R: Read + Write + Unpin> ImapStream<R> {
3443
inner,
3544
buffer: Buffer::new(),
3645
decode_needs: 0,
46+
read_closed: false,
3747
}
3848
}
3949

@@ -132,6 +142,52 @@ impl<R: Read + Write + Unpin> ImapStream<R> {
132142
}
133143
}
134144
}
145+
146+
fn do_poll_next(
147+
mut self: Pin<&mut Self>,
148+
cx: &mut Context<'_>,
149+
) -> Poll<Option<io::Result<ResponseData>>> {
150+
let this = &mut *self;
151+
if let Some(response) = this.decode()? {
152+
return Poll::Ready(Some(Ok(response)));
153+
}
154+
loop {
155+
this.buffer.ensure_capacity(this.decode_needs)?;
156+
let buf = this.buffer.free_as_mut_slice();
157+
158+
// The buffer should have at least one byte free
159+
// before we try reading into it
160+
// so we can treat 0 bytes read as EOF.
161+
// This is guaranteed by `ensure_capacity()` above
162+
// even if it is called with 0 as an argument.
163+
debug_assert!(!buf.is_empty());
164+
165+
#[cfg(feature = "runtime-async-std")]
166+
let num_bytes_read = ready!(Pin::new(&mut this.inner).poll_read(cx, buf))?;
167+
168+
#[cfg(feature = "runtime-tokio")]
169+
let num_bytes_read = {
170+
let buf = &mut tokio::io::ReadBuf::new(buf);
171+
let start = buf.filled().len();
172+
ready!(Pin::new(&mut this.inner).poll_read(cx, buf))?;
173+
buf.filled().len() - start
174+
};
175+
176+
if num_bytes_read == 0 {
177+
if this.buffer.used() > 0 {
178+
return Poll::Ready(Some(Err(io::Error::new(
179+
io::ErrorKind::UnexpectedEof,
180+
"bytes remaining in stream",
181+
))));
182+
}
183+
return Poll::Ready(None);
184+
}
185+
this.buffer.extend_used(num_bytes_read);
186+
if let Some(response) = this.decode()? {
187+
return Poll::Ready(Some(Ok(response)));
188+
}
189+
}
190+
}
135191
}
136192

137193
/// Abstraction around needed buffer management.
@@ -273,54 +329,145 @@ impl<R: Read + Write + Unpin> Stream for ImapStream<R> {
273329
type Item = io::Result<ResponseData>;
274330

275331
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
276-
let this = &mut *self;
277-
if let Some(response) = this.decode()? {
278-
return Poll::Ready(Some(Ok(response)));
332+
if self.read_closed {
333+
return Poll::Ready(None);
279334
}
280-
loop {
281-
this.buffer.ensure_capacity(this.decode_needs)?;
282-
let buf = this.buffer.free_as_mut_slice();
335+
let res = match ready!(self.as_mut().do_poll_next(cx)) {
336+
None => None,
337+
Some(Err(err)) => {
338+
self.read_closed = true;
339+
Some(Err(err))
340+
}
341+
Some(Ok(item)) => Some(Ok(item)),
342+
};
343+
Poll::Ready(res)
344+
}
345+
}
283346

284-
// The buffer should have at least one byte free
285-
// before we try reading into it
286-
// so we can treat 0 bytes read as EOF.
287-
// This is guaranteed by `ensure_capacity()` above
288-
// even if it is called with 0 as an argument.
289-
debug_assert!(!buf.is_empty());
347+
#[cfg(test)]
348+
mod tests {
349+
use super::*;
290350

291-
#[cfg(feature = "runtime-async-std")]
292-
let num_bytes_read = ready!(Pin::new(&mut this.inner).poll_read(cx, buf))?;
351+
use pin_project::pin_project;
352+
use std::io::Write as _;
293353

294-
#[cfg(feature = "runtime-tokio")]
295-
let num_bytes_read = {
296-
let buf = &mut tokio::io::ReadBuf::new(buf);
297-
let start = buf.filled().len();
298-
ready!(Pin::new(&mut this.inner).poll_read(cx, buf))?;
299-
buf.filled().len() - start
300-
};
354+
/// Wrapper for a stream that
355+
/// fails once on a first read.
356+
///
357+
/// Writes are discarded.
358+
#[pin_project]
359+
struct FailingStream {
360+
#[pin]
361+
inner: &'static [u8],
362+
has_failed: bool,
363+
}
301364

302-
if num_bytes_read == 0 {
303-
if this.buffer.used() > 0 {
304-
return Poll::Ready(Some(Err(io::Error::new(
305-
io::ErrorKind::UnexpectedEof,
306-
"bytes remaining in stream",
307-
))));
308-
}
309-
return Poll::Ready(None);
365+
impl FailingStream {
366+
fn new(buf: &'static [u8]) -> Self {
367+
Self {
368+
inner: buf,
369+
has_failed: false,
310370
}
311-
this.buffer.extend_used(num_bytes_read);
312-
if let Some(response) = this.decode()? {
313-
return Poll::Ready(Some(Ok(response)));
371+
}
372+
}
373+
374+
#[cfg(feature = "runtime-tokio")]
375+
impl Read for FailingStream {
376+
fn poll_read(
377+
self: Pin<&mut Self>,
378+
cx: &mut Context<'_>,
379+
buf: &mut tokio::io::ReadBuf<'_>,
380+
) -> Poll<tokio::io::Result<()>> {
381+
let this = self.project();
382+
if !*this.has_failed {
383+
*this.has_failed = true;
384+
385+
Poll::Ready(Err(std::io::Error::other("Failure")))
386+
} else {
387+
this.inner.poll_read(cx, buf)
314388
}
315389
}
316390
}
317-
}
318391

319-
#[cfg(test)]
320-
mod tests {
321-
use super::*;
392+
#[cfg(feature = "runtime-async-std")]
393+
impl Read for FailingStream {
394+
fn poll_read(
395+
self: Pin<&mut Self>,
396+
cx: &mut Context<'_>,
397+
buf: &mut [u8],
398+
) -> Poll<async_std::io::Result<usize>> {
399+
let this = self.project();
400+
if !*this.has_failed {
401+
*this.has_failed = true;
402+
403+
Poll::Ready(Err(std::io::Error::other("Failure")))
404+
} else {
405+
this.inner.poll_read(cx, buf)
406+
}
407+
}
408+
}
409+
410+
#[cfg(feature = "runtime-tokio")]
411+
impl Write for FailingStream {
412+
fn poll_write(
413+
self: Pin<&mut Self>,
414+
_cx: &mut Context<'_>,
415+
buf: &[u8],
416+
) -> Poll<tokio::io::Result<usize>> {
417+
Poll::Ready(Ok(buf.len()))
418+
}
419+
420+
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<tokio::io::Result<()>> {
421+
Poll::Ready(Ok(()))
422+
}
322423

323-
use std::io::Write;
424+
fn poll_shutdown(
425+
self: Pin<&mut Self>,
426+
_cx: &mut Context<'_>,
427+
) -> Poll<tokio::io::Result<()>> {
428+
Poll::Ready(Ok(()))
429+
}
430+
}
431+
432+
#[cfg(feature = "runtime-async-std")]
433+
impl Write for FailingStream {
434+
fn poll_write(
435+
self: Pin<&mut Self>,
436+
_cx: &mut Context<'_>,
437+
buf: &[u8],
438+
) -> Poll<async_std::io::Result<usize>> {
439+
Poll::Ready(Ok(buf.len()))
440+
}
441+
442+
fn poll_flush(
443+
self: Pin<&mut Self>,
444+
_cx: &mut Context<'_>,
445+
) -> Poll<async_std::io::Result<()>> {
446+
Poll::Ready(Ok(()))
447+
}
448+
449+
fn poll_close(
450+
self: Pin<&mut Self>,
451+
_cx: &mut Context<'_>,
452+
) -> Poll<async_std::io::Result<()>> {
453+
Poll::Ready(Ok(()))
454+
}
455+
}
456+
457+
#[cfg_attr(feature = "runtime-tokio", tokio::test)]
458+
#[cfg_attr(feature = "runtime-async-std", async_std::test)]
459+
async fn test_imap_stream_error() {
460+
use futures::StreamExt;
461+
462+
let mock_stream = FailingStream::new(b"* OK\r\n");
463+
let mut imap_stream = ImapStream::new(mock_stream);
464+
465+
// First call is an error because underlying stream fails.
466+
assert!(imap_stream.next().await.unwrap().is_err());
467+
468+
// IMAP stream should end even though underlying stream fails only once.
469+
assert!(imap_stream.next().await.is_none());
470+
}
324471

325472
#[test]
326473
fn test_buffer_empty() {

0 commit comments

Comments
 (0)