diff --git a/src/net.rs b/src/net.rs index 01bd8ed9b2..168e63be00 100644 --- a/src/net.rs +++ b/src/net.rs @@ -16,12 +16,14 @@ use crate::sql::Sql; use crate::tools::time; pub(crate) mod dns; +pub(crate) mod error_capturing_stream; pub(crate) mod http; pub(crate) mod proxy; pub(crate) mod session; pub(crate) mod tls; use dns::lookup_host_with_cache; +pub(crate) use error_capturing_stream::ErrorCapturingStream; pub use http::{Response as HttpResponse, read_url, read_url_blob}; use tls::wrap_tls; @@ -105,7 +107,7 @@ pub(crate) async fn load_connection_timestamp( /// to the network, which is important to reduce the latency of interactive protocols such as IMAP. pub(crate) async fn connect_tcp_inner( addr: SocketAddr, -) -> Result>>> { +) -> Result>>>> { let tcp_stream = timeout(TIMEOUT, TcpStream::connect(addr)) .await .context("connection timeout")? @@ -118,7 +120,9 @@ pub(crate) async fn connect_tcp_inner( timeout_stream.set_write_timeout(Some(TIMEOUT)); timeout_stream.set_read_timeout(Some(TIMEOUT)); - Ok(Box::pin(timeout_stream)) + let error_capturing_stream = ErrorCapturingStream::new(timeout_stream); + + Ok(Box::pin(error_capturing_stream)) } /// Attempts to establish TLS connection @@ -235,7 +239,7 @@ pub(crate) async fn connect_tcp( host: &str, port: u16, load_cache: bool, -) -> Result>>> { +) -> Result>>>> { let connection_futures = lookup_host_with_cache(context, host, port, "", load_cache) .await? .into_iter() diff --git a/src/net/error_capturing_stream.rs b/src/net/error_capturing_stream.rs new file mode 100644 index 0000000000..4edbb5bfce --- /dev/null +++ b/src/net/error_capturing_stream.rs @@ -0,0 +1,136 @@ +use std::io::IoSlice; +use std::net::SocketAddr; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::time::Duration; +use tokio::io::{self, AsyncRead, AsyncWrite, ReadBuf}; + +use pin_project::pin_project; + +use crate::net::SessionStream; + +/// Stream that remembers the first error +/// and keeps returning it afterwards. +/// +/// It is needed to avoid accidentally using +/// the stream after read timeout. +#[derive(Debug)] +#[pin_project] +pub(crate) struct ErrorCapturingStream { + #[pin] + inner: T, + + /// If true, the stream has already returned an error once. + /// + /// All read and write operations return error in this case. + is_broken: bool, +} + +impl ErrorCapturingStream { + pub fn new(inner: T) -> Self { + Self { + inner, + is_broken: false, + } + } + + /// Gets a reference to the underlying stream. + pub fn get_ref(&self) -> &T { + &self.inner + } + + /// Gets a pinned mutable reference to the underlying stream. + pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut T> { + self.project().inner + } +} + +impl AsyncRead for ErrorCapturingStream { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf, + ) -> Poll> { + let this = self.project(); + if *this.is_broken { + return Poll::Ready(Err(io::Error::other("Broken stream"))); + } + let res = this.inner.poll_read(cx, buf); + if let Poll::Ready(Err(_)) = res { + *this.is_broken = true; + } + res + } +} + +impl AsyncWrite for ErrorCapturingStream { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + let this = self.project(); + if *this.is_broken { + return Poll::Ready(Err(io::Error::other("Broken stream"))); + } + let res = this.inner.poll_write(cx, buf); + if let Poll::Ready(Err(_)) = res { + *this.is_broken = true; + } + res + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.project(); + if *this.is_broken { + return Poll::Ready(Err(io::Error::other("Broken stream"))); + } + let res = this.inner.poll_flush(cx); + if let Poll::Ready(Err(_)) = res { + *this.is_broken = true; + } + res + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.project(); + if *this.is_broken { + return Poll::Ready(Err(io::Error::other("Broken stream"))); + } + let res = this.inner.poll_shutdown(cx); + if let Poll::Ready(Err(_)) = res { + *this.is_broken = true; + } + res + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll> { + let this = self.project(); + if *this.is_broken { + return Poll::Ready(Err(io::Error::other("Broken stream"))); + } + let res = this.inner.poll_write_vectored(cx, bufs); + if let Poll::Ready(Err(_)) = res { + *this.is_broken = true; + } + res + } + + fn is_write_vectored(&self) -> bool { + self.inner.is_write_vectored() + } +} + +impl SessionStream for ErrorCapturingStream { + fn set_read_timeout(&mut self, timeout: Option) { + self.inner.set_read_timeout(timeout) + } + + fn peer_addr(&self) -> anyhow::Result { + self.inner.peer_addr() + } +} diff --git a/src/net/proxy.rs b/src/net/proxy.rs index 0f657b5439..6c4797e9fb 100644 --- a/src/net/proxy.rs +++ b/src/net/proxy.rs @@ -21,9 +21,9 @@ use url::Url; use crate::config::Config; use crate::constants::NON_ALPHANUMERIC_WITHOUT_DOT; use crate::context::Context; -use crate::net::connect_tcp; use crate::net::session::SessionStream; use crate::net::tls::wrap_rustls; +use crate::net::{ErrorCapturingStream, connect_tcp}; use crate::sql::Sql; /// Default SOCKS5 port according to [RFC 1928](https://tools.ietf.org/html/rfc1928). @@ -118,7 +118,7 @@ impl Socks5Config { target_host: &str, target_port: u16, load_dns_cache: bool, - ) -> Result>>>> { + ) -> Result>>>>> { let tcp_stream = connect_tcp(context, &self.host, self.port, load_dns_cache) .await .context("Failed to connect to SOCKS5 proxy")?; diff --git a/src/net/session.rs b/src/net/session.rs index 981e01fd4d..f3d16dc2bc 100644 --- a/src/net/session.rs +++ b/src/net/session.rs @@ -7,6 +7,8 @@ use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, BufStream, BufWriter}; use tokio::net::TcpStream; use tokio_io_timeout::TimeoutStream; +use crate::net::ErrorCapturingStream; + pub(crate) trait SessionStream: AsyncRead + AsyncWrite + Unpin + Send + Sync + std::fmt::Debug { @@ -61,13 +63,13 @@ impl SessionStream for BufWriter { self.get_ref().peer_addr() } } -impl SessionStream for Pin>> { +impl SessionStream for Pin>>> { fn set_read_timeout(&mut self, timeout: Option) { - self.as_mut().set_read_timeout_pinned(timeout); + self.as_mut().get_pin_mut().set_read_timeout_pinned(timeout); } fn peer_addr(&self) -> Result { - Ok(self.get_ref().peer_addr()?) + Ok(self.get_ref().get_ref().peer_addr()?) } } impl SessionStream for Socks5Stream {