Skip to content

Commit d32d35b

Browse files
committed
fix(client): close connection when there is an Error
1 parent e305a2e commit d32d35b

File tree

3 files changed

+86
-10
lines changed

3 files changed

+86
-10
lines changed

src/client/request.rs

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -101,11 +101,17 @@ impl Request<Fresh> {
101101
/// Consume a Fresh Request, writing the headers and method,
102102
/// returning a Streaming Request.
103103
pub fn start(mut self) -> ::Result<Request<Streaming>> {
104-
let head = try!(self.message.set_outgoing(RequestHead {
104+
let head = match self.message.set_outgoing(RequestHead {
105105
headers: self.headers,
106106
method: self.method,
107107
url: self.url,
108-
}));
108+
}) {
109+
Ok(head) => head,
110+
Err(e) => {
111+
let _ = self.message.close_connection();
112+
return Err(From::from(e));
113+
}
114+
};
109115

110116
Ok(Request {
111117
method: head.method,
@@ -134,17 +140,30 @@ impl Request<Streaming> {
134140
impl Write for Request<Streaming> {
135141
#[inline]
136142
fn write(&mut self, msg: &[u8]) -> io::Result<usize> {
137-
self.message.write(msg)
143+
match self.message.write(msg) {
144+
Ok(n) => Ok(n),
145+
Err(e) => {
146+
let _ = self.message.close_connection();
147+
Err(e)
148+
}
149+
}
138150
}
139151

140152
#[inline]
141153
fn flush(&mut self) -> io::Result<()> {
142-
self.message.flush()
154+
match self.message.flush() {
155+
Ok(r) => Ok(r),
156+
Err(e) => {
157+
let _ = self.message.close_connection();
158+
Err(e)
159+
}
160+
}
143161
}
144162
}
145163

146164
#[cfg(test)]
147165
mod tests {
166+
use std::io::Write;
148167
use std::str::from_utf8;
149168
use url::Url;
150169
use method::Method::{Get, Head, Post};
@@ -237,4 +256,24 @@ mod tests {
237256
assert!(!s.contains("Content-Length:"));
238257
assert!(s.contains("Transfer-Encoding:"));
239258
}
259+
260+
#[test]
261+
fn test_write_error_closes() {
262+
let url = Url::parse("http://hyper.rs").unwrap();
263+
let req = Request::with_connector(
264+
Get, url, &mut MockConnector
265+
).unwrap();
266+
let mut req = req.start().unwrap();
267+
268+
req.message.downcast_mut::<Http11Message>().unwrap()
269+
.get_mut().downcast_mut::<MockStream>().unwrap()
270+
.error_on_write = true;
271+
272+
req.write(b"foo").unwrap();
273+
assert!(req.flush().is_err());
274+
275+
assert!(req.message.downcast_ref::<Http11Message>().unwrap()
276+
.get_ref().downcast_ref::<MockStream>().unwrap()
277+
.is_closed);
278+
}
240279
}

src/client/response.rs

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,13 @@ impl Response {
3737
/// Creates a new response received from the server on the given `HttpMessage`.
3838
pub fn with_message(url: Url, mut message: Box<HttpMessage>) -> ::Result<Response> {
3939
trace!("Response::with_message");
40-
let ResponseHead { headers, raw_status, version } = try!(message.get_incoming());
40+
let ResponseHead { headers, raw_status, version } = match message.get_incoming() {
41+
Ok(head) => head,
42+
Err(e) => {
43+
let _ = message.close_connection();
44+
return Err(From::from(e));
45+
}
46+
};
4147
let status = status::StatusCode::from_u16(raw_status.0);
4248
debug!("version={:?}, status={:?}", version, status);
4349
debug!("headers={:?}", headers);
@@ -54,6 +60,7 @@ impl Response {
5460
}
5561

5662
/// Get the raw status code and reason.
63+
#[inline]
5764
pub fn status_raw(&self) -> &RawStatus {
5865
&self.status_raw
5966
}
@@ -68,6 +75,10 @@ impl Read for Response {
6875
self.is_drained = true;
6976
Ok(0)
7077
},
78+
Err(e) => {
79+
let _ = self.message.close_connection();
80+
Err(e)
81+
}
7182
r => r
7283
}
7384
}

src/mock.rs

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use std::fmt;
22
use std::ascii::AsciiExt;
33
use std::io::{self, Read, Write, Cursor};
44
use std::cell::RefCell;
5-
use std::net::SocketAddr;
5+
use std::net::{SocketAddr, Shutdown};
66
use std::sync::{Arc, Mutex};
77
#[cfg(feature = "timeouts")]
88
use std::time::Duration;
@@ -21,10 +21,13 @@ use net::{NetworkStream, NetworkConnector};
2121
pub struct MockStream {
2222
pub read: Cursor<Vec<u8>>,
2323
pub write: Vec<u8>,
24+
pub is_closed: bool,
25+
pub error_on_write: bool,
26+
pub error_on_read: bool,
2427
#[cfg(feature = "timeouts")]
2528
pub read_timeout: Cell<Option<Duration>>,
2629
#[cfg(feature = "timeouts")]
27-
pub write_timeout: Cell<Option<Duration>>
30+
pub write_timeout: Cell<Option<Duration>>,
2831
}
2932

3033
impl fmt::Debug for MockStream {
@@ -48,7 +51,10 @@ impl MockStream {
4851
pub fn with_input(input: &[u8]) -> MockStream {
4952
MockStream {
5053
read: Cursor::new(input.to_vec()),
51-
write: vec![]
54+
write: vec![],
55+
is_closed: false,
56+
error_on_write: false,
57+
error_on_read: false,
5258
}
5359
}
5460

@@ -57,6 +63,9 @@ impl MockStream {
5763
MockStream {
5864
read: Cursor::new(input.to_vec()),
5965
write: vec![],
66+
is_closed: false,
67+
error_on_write: false,
68+
error_on_read: false,
6069
read_timeout: Cell::new(None),
6170
write_timeout: Cell::new(None),
6271
}
@@ -65,13 +74,21 @@ impl MockStream {
6574

6675
impl Read for MockStream {
6776
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
68-
self.read.read(buf)
77+
if self.error_on_read {
78+
Err(io::Error::new(io::ErrorKind::Other, "mock error"))
79+
} else {
80+
self.read.read(buf)
81+
}
6982
}
7083
}
7184

7285
impl Write for MockStream {
7386
fn write(&mut self, msg: &[u8]) -> io::Result<usize> {
74-
Write::write(&mut self.write, msg)
87+
if self.error_on_write {
88+
Err(io::Error::new(io::ErrorKind::Other, "mock error"))
89+
} else {
90+
Write::write(&mut self.write, msg)
91+
}
7592
}
7693

7794
fn flush(&mut self) -> io::Result<()> {
@@ -95,6 +112,11 @@ impl NetworkStream for MockStream {
95112
self.write_timeout.set(dur);
96113
Ok(())
97114
}
115+
116+
fn close(&mut self, _how: Shutdown) -> io::Result<()> {
117+
self.is_closed = true;
118+
Ok(())
119+
}
98120
}
99121

100122
/// A wrapper around a `MockStream` that allows one to clone it and keep an independent copy to the
@@ -144,6 +166,10 @@ impl NetworkStream for CloneableMockStream {
144166
fn set_write_timeout(&self, dur: Option<Duration>) -> io::Result<()> {
145167
self.inner.lock().unwrap().set_write_timeout(dur)
146168
}
169+
170+
fn close(&mut self, how: Shutdown) -> io::Result<()> {
171+
NetworkStream::close(&mut *self.inner.lock().unwrap(), how)
172+
}
147173
}
148174

149175
impl CloneableMockStream {

0 commit comments

Comments
 (0)