Skip to content

feat: close ImapStream after an error #122

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
228 changes: 191 additions & 37 deletions src/imap_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,15 @@ pub struct ImapStream<R: Read + Write> {
decode_needs: usize,
/// The buffer.
buffer: Buffer,

/// True if the stream should not return any more items.
///
/// This is set when reading from a stream
/// returns an error.
/// Afterwards the stream returns only `None`
/// and `poll_next()` does not read
/// from the underlying stream.
read_closed: bool,
}

impl<R: Read + Write + Unpin> ImapStream<R> {
Expand All @@ -34,6 +43,7 @@ impl<R: Read + Write + Unpin> ImapStream<R> {
inner,
buffer: Buffer::new(),
decode_needs: 0,
read_closed: false,
}
}

Expand Down Expand Up @@ -132,6 +142,52 @@ impl<R: Read + Write + Unpin> ImapStream<R> {
}
}
}

fn do_poll_next(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<io::Result<ResponseData>>> {
let this = &mut *self;
if let Some(response) = this.decode()? {
return Poll::Ready(Some(Ok(response)));
}
loop {
this.buffer.ensure_capacity(this.decode_needs)?;
let buf = this.buffer.free_as_mut_slice();

// The buffer should have at least one byte free
// before we try reading into it
// so we can treat 0 bytes read as EOF.
// This is guaranteed by `ensure_capacity()` above
// even if it is called with 0 as an argument.
debug_assert!(!buf.is_empty());

#[cfg(feature = "runtime-async-std")]
let num_bytes_read = ready!(Pin::new(&mut this.inner).poll_read(cx, buf))?;

#[cfg(feature = "runtime-tokio")]
let num_bytes_read = {
let buf = &mut tokio::io::ReadBuf::new(buf);
let start = buf.filled().len();
ready!(Pin::new(&mut this.inner).poll_read(cx, buf))?;
buf.filled().len() - start
};

if num_bytes_read == 0 {
if this.buffer.used() > 0 {
return Poll::Ready(Some(Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"bytes remaining in stream",
))));
}
return Poll::Ready(None);
}
this.buffer.extend_used(num_bytes_read);
if let Some(response) = this.decode()? {
return Poll::Ready(Some(Ok(response)));
}
}
}
}

/// Abstraction around needed buffer management.
Expand Down Expand Up @@ -273,54 +329,152 @@ impl<R: Read + Write + Unpin> Stream for ImapStream<R> {
type Item = io::Result<ResponseData>;

fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = &mut *self;
if let Some(response) = this.decode()? {
return Poll::Ready(Some(Ok(response)));
if self.read_closed {
return Poll::Ready(None);
}
loop {
this.buffer.ensure_capacity(this.decode_needs)?;
let buf = this.buffer.free_as_mut_slice();
let res = match ready!(self.as_mut().do_poll_next(cx)) {
None => None,
Some(Err(err)) => {
self.read_closed = true;
Some(Err(err))
}
Some(Ok(item)) => Some(Ok(item)),
};
Poll::Ready(res)
}
}

// The buffer should have at least one byte free
// before we try reading into it
// so we can treat 0 bytes read as EOF.
// This is guaranteed by `ensure_capacity()` above
// even if it is called with 0 as an argument.
debug_assert!(!buf.is_empty());
#[cfg(test)]
mod tests {
use super::*;

#[cfg(feature = "runtime-async-std")]
let num_bytes_read = ready!(Pin::new(&mut this.inner).poll_read(cx, buf))?;
use pin_project::pin_project;
use std::io::Write as _;

#[cfg(feature = "runtime-tokio")]
let num_bytes_read = {
let buf = &mut tokio::io::ReadBuf::new(buf);
let start = buf.filled().len();
ready!(Pin::new(&mut this.inner).poll_read(cx, buf))?;
buf.filled().len() - start
};
/// Wrapper for a stream that
/// fails once on a first read.
///
/// Writes are discarded.
#[pin_project]
struct FailingStream {
#[pin]
inner: &'static [u8],
has_failed: bool,
}

if num_bytes_read == 0 {
if this.buffer.used() > 0 {
return Poll::Ready(Some(Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"bytes remaining in stream",
))));
}
return Poll::Ready(None);
impl FailingStream {
fn new(buf: &'static [u8]) -> Self {
Self {
inner: buf,
has_failed: false,
}
this.buffer.extend_used(num_bytes_read);
if let Some(response) = this.decode()? {
return Poll::Ready(Some(Ok(response)));
}
}

#[cfg(feature = "runtime-tokio")]
impl Read for FailingStream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<tokio::io::Result<()>> {
let this = self.project();
if !*this.has_failed {
*this.has_failed = true;

Poll::Ready(Err(std::io::Error::other("Failure")))
} else {
this.inner.poll_read(cx, buf)
}
}
}
}

#[cfg(test)]
mod tests {
use super::*;
#[cfg(feature = "runtime-async-std")]
impl Read for FailingStream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<async_std::io::Result<usize>> {
let this = self.project();
if !*this.has_failed {
*this.has_failed = true;

Poll::Ready(Err(std::io::Error::other("Failure")))
} else {
this.inner.poll_read(cx, buf)
}
}
}

use std::io::Write;
#[cfg(feature = "runtime-tokio")]
impl Write for FailingStream {
fn poll_write(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<tokio::io::Result<usize>> {
Poll::Ready(Ok(buf.len()))
}

fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<tokio::io::Result<()>> {
Poll::Ready(Ok(()))
}

fn poll_shutdown(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<tokio::io::Result<()>> {
Poll::Ready(Ok(()))
}
}

#[cfg(feature = "runtime-async-std")]
impl Write for FailingStream {
fn poll_write(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<async_std::io::Result<usize>> {
Poll::Ready(Ok(buf.len()))
}

fn poll_flush(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<async_std::io::Result<()>> {
Poll::Ready(Ok(()))
}

fn poll_close(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<async_std::io::Result<()>> {
Poll::Ready(Ok(()))
}
}

/// Tests that stream returns `None` after
/// a single error of the underlying stream.
///
/// This is need to prevent accidental
/// reading from a network stream
/// after a temporary error such as a timeout
/// or returning an inifinite stream of errors.
#[cfg_attr(feature = "runtime-tokio", tokio::test)]
#[cfg_attr(feature = "runtime-async-std", async_std::test)]
async fn test_imap_stream_error() {
use futures::StreamExt;

let mock_stream = FailingStream::new(b"* OK\r\n");
let mut imap_stream = ImapStream::new(mock_stream);

// First call is an error because underlying stream fails.
assert!(imap_stream.next().await.unwrap().is_err());

// IMAP stream should end even though underlying stream fails only once.
assert!(imap_stream.next().await.is_none());
}

#[test]
fn test_buffer_empty() {
Expand Down
Loading