@@ -2,7 +2,7 @@ use std::fmt;
2
2
use std:: ascii:: AsciiExt ;
3
3
use std:: io:: { self , Read , Write , Cursor } ;
4
4
use std:: cell:: RefCell ;
5
- use std:: net:: SocketAddr ;
5
+ use std:: net:: { SocketAddr , Shutdown } ;
6
6
use std:: sync:: { Arc , Mutex } ;
7
7
#[ cfg( feature = "timeouts" ) ]
8
8
use std:: time:: Duration ;
@@ -21,10 +21,13 @@ use net::{NetworkStream, NetworkConnector};
21
21
pub struct MockStream {
22
22
pub read : Cursor < Vec < u8 > > ,
23
23
pub write : Vec < u8 > ,
24
+ pub is_closed : bool ,
25
+ pub error_on_write : bool ,
26
+ pub error_on_read : bool ,
24
27
#[ cfg( feature = "timeouts" ) ]
25
28
pub read_timeout : Cell < Option < Duration > > ,
26
29
#[ cfg( feature = "timeouts" ) ]
27
- pub write_timeout : Cell < Option < Duration > >
30
+ pub write_timeout : Cell < Option < Duration > > ,
28
31
}
29
32
30
33
impl fmt:: Debug for MockStream {
@@ -48,7 +51,10 @@ impl MockStream {
48
51
pub fn with_input ( input : & [ u8 ] ) -> MockStream {
49
52
MockStream {
50
53
read : Cursor :: new ( input. to_vec ( ) ) ,
51
- write : vec ! [ ]
54
+ write : vec ! [ ] ,
55
+ is_closed : false ,
56
+ error_on_write : false ,
57
+ error_on_read : false ,
52
58
}
53
59
}
54
60
@@ -57,6 +63,9 @@ impl MockStream {
57
63
MockStream {
58
64
read : Cursor :: new ( input. to_vec ( ) ) ,
59
65
write : vec ! [ ] ,
66
+ is_closed : false ,
67
+ error_on_write : false ,
68
+ error_on_read : false ,
60
69
read_timeout : Cell :: new ( None ) ,
61
70
write_timeout : Cell :: new ( None ) ,
62
71
}
@@ -65,13 +74,21 @@ impl MockStream {
65
74
66
75
impl Read for MockStream {
67
76
fn read ( & mut self , buf : & mut [ u8 ] ) -> io:: Result < usize > {
68
- self . read . read ( buf)
77
+ if self . error_on_read {
78
+ Err ( io:: Error :: new ( io:: ErrorKind :: Other , "mock error" ) )
79
+ } else {
80
+ self . read . read ( buf)
81
+ }
69
82
}
70
83
}
71
84
72
85
impl Write for MockStream {
73
86
fn write ( & mut self , msg : & [ u8 ] ) -> io:: Result < usize > {
74
- Write :: write ( & mut self . write , msg)
87
+ if self . error_on_write {
88
+ Err ( io:: Error :: new ( io:: ErrorKind :: Other , "mock error" ) )
89
+ } else {
90
+ Write :: write ( & mut self . write , msg)
91
+ }
75
92
}
76
93
77
94
fn flush ( & mut self ) -> io:: Result < ( ) > {
@@ -95,6 +112,11 @@ impl NetworkStream for MockStream {
95
112
self . write_timeout . set ( dur) ;
96
113
Ok ( ( ) )
97
114
}
115
+
116
+ fn close ( & mut self , _how : Shutdown ) -> io:: Result < ( ) > {
117
+ self . is_closed = true ;
118
+ Ok ( ( ) )
119
+ }
98
120
}
99
121
100
122
/// A wrapper around a `MockStream` that allows one to clone it and keep an independent copy to the
@@ -144,6 +166,10 @@ impl NetworkStream for CloneableMockStream {
144
166
fn set_write_timeout ( & self , dur : Option < Duration > ) -> io:: Result < ( ) > {
145
167
self . inner . lock ( ) . unwrap ( ) . set_write_timeout ( dur)
146
168
}
169
+
170
+ fn close ( & mut self , how : Shutdown ) -> io:: Result < ( ) > {
171
+ NetworkStream :: close ( & mut * self . inner . lock ( ) . unwrap ( ) , how)
172
+ }
147
173
}
148
174
149
175
impl CloneableMockStream {
0 commit comments