diff --git a/src/imap_stream.rs b/src/imap_stream.rs index 58f03d0..6420196 100644 --- a/src/imap_stream.rs +++ b/src/imap_stream.rs @@ -25,6 +25,15 @@ pub struct ImapStream { decode_needs: usize, /// The buffer. buffer: Buffer, + + /// True if the stream should not return any more items. + /// + /// This is set when reading from a stream + /// returns an error. + /// Afterwards the stream returns only `None` + /// and `poll_next()` does not read + /// from the underlying stream. + read_closed: bool, } impl ImapStream { @@ -34,6 +43,7 @@ impl ImapStream { inner, buffer: Buffer::new(), decode_needs: 0, + read_closed: false, } } @@ -132,6 +142,52 @@ impl ImapStream { } } } + + fn do_poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + let this = &mut *self; + if let Some(response) = this.decode()? { + return Poll::Ready(Some(Ok(response))); + } + loop { + this.buffer.ensure_capacity(this.decode_needs)?; + let buf = this.buffer.free_as_mut_slice(); + + // The buffer should have at least one byte free + // before we try reading into it + // so we can treat 0 bytes read as EOF. + // This is guaranteed by `ensure_capacity()` above + // even if it is called with 0 as an argument. + debug_assert!(!buf.is_empty()); + + #[cfg(feature = "runtime-async-std")] + let num_bytes_read = ready!(Pin::new(&mut this.inner).poll_read(cx, buf))?; + + #[cfg(feature = "runtime-tokio")] + let num_bytes_read = { + let buf = &mut tokio::io::ReadBuf::new(buf); + let start = buf.filled().len(); + ready!(Pin::new(&mut this.inner).poll_read(cx, buf))?; + buf.filled().len() - start + }; + + if num_bytes_read == 0 { + if this.buffer.used() > 0 { + return Poll::Ready(Some(Err(io::Error::new( + io::ErrorKind::UnexpectedEof, + "bytes remaining in stream", + )))); + } + return Poll::Ready(None); + } + this.buffer.extend_used(num_bytes_read); + if let Some(response) = this.decode()? { + return Poll::Ready(Some(Ok(response))); + } + } + } } /// Abstraction around needed buffer management. @@ -273,54 +329,152 @@ impl Stream for ImapStream { type Item = io::Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = &mut *self; - if let Some(response) = this.decode()? { - return Poll::Ready(Some(Ok(response))); + if self.read_closed { + return Poll::Ready(None); } - loop { - this.buffer.ensure_capacity(this.decode_needs)?; - let buf = this.buffer.free_as_mut_slice(); + let res = match ready!(self.as_mut().do_poll_next(cx)) { + None => None, + Some(Err(err)) => { + self.read_closed = true; + Some(Err(err)) + } + Some(Ok(item)) => Some(Ok(item)), + }; + Poll::Ready(res) + } +} - // The buffer should have at least one byte free - // before we try reading into it - // so we can treat 0 bytes read as EOF. - // This is guaranteed by `ensure_capacity()` above - // even if it is called with 0 as an argument. - debug_assert!(!buf.is_empty()); +#[cfg(test)] +mod tests { + use super::*; - #[cfg(feature = "runtime-async-std")] - let num_bytes_read = ready!(Pin::new(&mut this.inner).poll_read(cx, buf))?; + use pin_project::pin_project; + use std::io::Write as _; - #[cfg(feature = "runtime-tokio")] - let num_bytes_read = { - let buf = &mut tokio::io::ReadBuf::new(buf); - let start = buf.filled().len(); - ready!(Pin::new(&mut this.inner).poll_read(cx, buf))?; - buf.filled().len() - start - }; + /// Wrapper for a stream that + /// fails once on a first read. + /// + /// Writes are discarded. + #[pin_project] + struct FailingStream { + #[pin] + inner: &'static [u8], + has_failed: bool, + } - if num_bytes_read == 0 { - if this.buffer.used() > 0 { - return Poll::Ready(Some(Err(io::Error::new( - io::ErrorKind::UnexpectedEof, - "bytes remaining in stream", - )))); - } - return Poll::Ready(None); + impl FailingStream { + fn new(buf: &'static [u8]) -> Self { + Self { + inner: buf, + has_failed: false, } - this.buffer.extend_used(num_bytes_read); - if let Some(response) = this.decode()? { - return Poll::Ready(Some(Ok(response))); + } + } + + #[cfg(feature = "runtime-tokio")] + impl Read for FailingStream { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll> { + let this = self.project(); + if !*this.has_failed { + *this.has_failed = true; + + Poll::Ready(Err(std::io::Error::other("Failure"))) + } else { + this.inner.poll_read(cx, buf) } } } -} -#[cfg(test)] -mod tests { - use super::*; + #[cfg(feature = "runtime-async-std")] + impl Read for FailingStream { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + let this = self.project(); + if !*this.has_failed { + *this.has_failed = true; + + Poll::Ready(Err(std::io::Error::other("Failure"))) + } else { + this.inner.poll_read(cx, buf) + } + } + } - use std::io::Write; + #[cfg(feature = "runtime-tokio")] + impl Write for FailingStream { + fn poll_write( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Poll::Ready(Ok(buf.len())) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + } + + #[cfg(feature = "runtime-async-std")] + impl Write for FailingStream { + fn poll_write( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Poll::Ready(Ok(buf.len())) + } + + fn poll_flush( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_close( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + } + + /// Tests that stream returns `None` after + /// a single error of the underlying stream. + /// + /// This is need to prevent accidental + /// reading from a network stream + /// after a temporary error such as a timeout + /// or returning an inifinite stream of errors. + #[cfg_attr(feature = "runtime-tokio", tokio::test)] + #[cfg_attr(feature = "runtime-async-std", async_std::test)] + async fn test_imap_stream_error() { + use futures::StreamExt; + + let mock_stream = FailingStream::new(b"* OK\r\n"); + let mut imap_stream = ImapStream::new(mock_stream); + + // First call is an error because underlying stream fails. + assert!(imap_stream.next().await.unwrap().is_err()); + + // IMAP stream should end even though underlying stream fails only once. + assert!(imap_stream.next().await.is_none()); + } #[test] fn test_buffer_empty() {