@@ -5,8 +5,8 @@ use std::process::ExitStatus;
5
5
6
6
use anyhow:: Result ;
7
7
use process_wrap:: tokio:: { TokioChildWrapper , TokioCommandWrap } ;
8
- use tokio:: io:: { AsyncBufReadExt , BufReader } ;
9
- use tokio:: process:: { ChildStderr , ChildStdout } ;
8
+ use tokio:: io:: { AsyncBufReadExt , AsyncWriteExt , BufReader } ;
9
+ use tokio:: process:: { ChildStderr , ChildStdin , ChildStdout } ;
10
10
use tokio:: sync:: { mpsc, oneshot} ;
11
11
12
12
use crate :: agent:: event:: AgentCommandStatus ;
@@ -27,22 +27,25 @@ pub struct ProcessRegistry {
27
27
struct ProcessData {
28
28
command : String ,
29
29
output : String ,
30
- exit_status : Option < ExitStatus > ,
30
+ exit_status : Option < i32 > ,
31
31
receiver : mpsc:: UnboundedReceiver < ProcessOutput > ,
32
32
terminate_sender : Option < oneshot:: Sender < ( ) > > ,
33
+ input_sender : Option < mpsc:: UnboundedSender < Vec < u8 > > > ,
33
34
}
34
35
35
36
enum ProcessOutput {
36
- Exited ( ExitStatus ) ,
37
+ Exited ( Option < ExitStatus > ) ,
37
38
Output ( String ) ,
38
39
Error ( String ) ,
39
40
}
40
41
41
42
struct ProcessRuntime {
42
43
_process : Box < dyn TokioChildWrapper > ,
43
44
stdout : ChildStdout ,
45
+ stdin : ChildStdin ,
44
46
stderr : ChildStderr ,
45
47
sender : mpsc:: UnboundedSender < ProcessOutput > ,
48
+ input_signal : mpsc:: UnboundedReceiver < Vec < u8 > > ,
46
49
terminate_signal : oneshot:: Receiver < ( ) > ,
47
50
}
48
51
@@ -52,29 +55,37 @@ impl ProcessRuntime {
52
55
53
56
let stdout = Self :: handle_stdout ( self . stdout , self . sender . clone ( ) ) ;
54
57
let stderr = Self :: handle_stderr ( self . stderr , self . sender . clone ( ) ) ;
58
+ let stdin = Self :: handle_stdin ( self . stdin , self . input_signal ) ;
59
+
55
60
let status = Box :: into_pin ( self . _process . wait ( ) ) ;
56
61
pin ! ( stdout) ;
57
62
pin ! ( stderr) ;
58
- let mut exit_status = ExitStatus :: default ( ) ;
63
+ pin ! ( stdin) ;
64
+
65
+ let mut exit_status = None ;
59
66
tokio:: select! {
60
67
result = & mut stdout => {
61
68
tracing:: trace!( "Stdout handler completed: {:?}" , result) ;
69
+ exit_status = Some ( ExitStatus :: default ( ) ) ;
62
70
}
63
71
result = & mut stderr => {
64
72
tracing:: trace!( "Stderr handler completed: {:?}" , result) ;
65
73
}
66
74
// capture the status so we don't need to wait for a timeout
67
75
result = status => {
68
76
if let Ok ( result) = result {
69
- exit_status = result;
77
+ exit_status = Some ( result) ;
70
78
}
71
79
tracing:: trace!( "Process exited with status: {:?}" , result) ;
72
80
}
81
+ result = & mut stdin => {
82
+ tracing:: trace!( "Stdin handler completed: {:?}" , result) ;
83
+ }
73
84
_ = self . terminate_signal => {
74
85
tracing:: debug!( "Receive terminal_signal" ) ;
75
86
if self . _process. start_kill( ) . is_ok( ) {
76
87
if let Ok ( status) = Box :: into_pin( self . _process. wait( ) ) . await {
77
- exit_status = status;
88
+ exit_status = Some ( status) ;
78
89
}
79
90
}
80
91
}
@@ -123,17 +134,36 @@ impl ProcessRuntime {
123
134
}
124
135
}
125
136
}
137
+
138
+ async fn handle_stdin ( mut stdin : ChildStdin , mut receiver : mpsc:: UnboundedReceiver < Vec < u8 > > ) {
139
+ while let Some ( data) = receiver. recv ( ) . await {
140
+ tracing:: trace!( "Writing data to stdin: {:?}" , data) ;
141
+ if let Err ( e) = stdin. write_all ( data. as_slice ( ) ) . await {
142
+ tracing:: error!( error = ?e, "Error writing data to child process" ) ;
143
+ break ;
144
+ }
145
+ if let Err ( e) = stdin. flush ( ) . await {
146
+ tracing:: error!( error = ?e, "Error flushing data to child process" ) ;
147
+ break ;
148
+ }
149
+ }
150
+ }
126
151
}
127
152
128
153
impl ProcessRegistry {
129
154
async fn spawn_process (
130
155
& self ,
131
156
command : & str ,
132
157
cwd : & str ,
133
- ) -> Result < ( Box < dyn TokioChildWrapper > , ChildStdout , ChildStderr ) > {
158
+ ) -> Result < (
159
+ Box < dyn TokioChildWrapper > ,
160
+ ChildStdout ,
161
+ ChildStderr ,
162
+ ChildStdin ,
163
+ ) > {
134
164
let mut child = TokioCommandWrap :: with_new ( SHELL , |cmd| {
135
165
cmd. current_dir ( cwd)
136
- . stdin ( std:: process:: Stdio :: null ( ) )
166
+ . stdin ( std:: process:: Stdio :: piped ( ) )
137
167
. stdout ( std:: process:: Stdio :: piped ( ) )
138
168
. stderr ( std:: process:: Stdio :: piped ( ) ) ;
139
169
@@ -164,20 +194,28 @@ impl ProcessRegistry {
164
194
. take ( )
165
195
. ok_or_else ( || anyhow:: anyhow!( "Failed to get stdout" ) ) ?;
166
196
167
- Ok ( ( process, stdout, stderr) )
197
+ let stdin = process
198
+ . stdin ( )
199
+ . take ( )
200
+ . ok_or_else ( || anyhow:: anyhow!( "Failed to get stdin" ) ) ?;
201
+
202
+ Ok ( ( process, stdout, stderr, stdin) )
168
203
}
169
204
170
205
pub async fn execute_command ( & mut self , command : & str , cwd : & str ) -> Result < usize > {
171
206
self . counter = self . counter . saturating_add ( 1 ) ;
172
- let ( process, stdout, stderr) = self . spawn_process ( command, cwd) . await ?;
207
+ let ( process, stdout, stderr, stdin ) = self . spawn_process ( command, cwd) . await ?;
173
208
let ( tx, rx) = mpsc:: unbounded_channel ( ) ;
174
209
let ( t_tx, t_rx) = tokio:: sync:: oneshot:: channel ( ) ;
210
+ let ( in_tx, in_rx) = mpsc:: unbounded_channel ( ) ;
175
211
176
212
let runtime = ProcessRuntime {
177
213
_process : process,
178
214
stdout,
179
215
stderr,
216
+ stdin,
180
217
sender : tx,
218
+ input_signal : in_rx,
181
219
terminate_signal : t_rx,
182
220
} ;
183
221
@@ -191,6 +229,7 @@ impl ProcessRegistry {
191
229
exit_status : None ,
192
230
receiver : rx,
193
231
terminate_sender : Some ( t_tx) ,
232
+ input_sender : Some ( in_tx) ,
194
233
} ,
195
234
) ;
196
235
Ok ( self . counter )
@@ -214,7 +253,11 @@ impl ProcessRegistry {
214
253
while let Ok ( output) = process. receiver . try_recv ( ) {
215
254
match output {
216
255
ProcessOutput :: Exited ( exit_status) => {
217
- process. exit_status = Some ( exit_status)
256
+ process. exit_status = Some (
257
+ exit_status
258
+ . map ( |s| s. code ( ) . unwrap_or_default ( ) )
259
+ . unwrap_or ( 1 ) ,
260
+ )
218
261
}
219
262
ProcessOutput :: Output ( str) => process. output += & str,
220
263
ProcessOutput :: Error ( str) => process. output += & str,
@@ -231,12 +274,12 @@ impl ProcessRegistry {
231
274
modified_terminal_states
232
275
}
233
276
234
- pub fn get_process ( & self , id : usize ) -> Option < ( Option < ExitStatus > , & String ) > {
277
+ pub fn get_process ( & self , id : usize ) -> Option < ( Option < i32 > , & String ) > {
235
278
let process = self . processes . get ( & id) ?;
236
279
Some ( ( process. exit_status , & process. output ) )
237
280
}
238
281
239
- pub fn processes ( & self ) -> impl Iterator < Item = ( usize , Option < ExitStatus > , & String ) > {
282
+ pub fn processes ( & self ) -> impl Iterator < Item = ( usize , Option < i32 > , & String ) > {
240
283
self . processes
241
284
. iter ( )
242
285
. map ( |( key, value) | ( * key, value. exit_status , & value. command ) )
@@ -253,4 +296,12 @@ impl ProcessRegistry {
253
296
}
254
297
Ok ( ( ) )
255
298
}
299
+
300
+ pub fn send_data ( & self , idx : usize , data : Vec < u8 > ) {
301
+ if let Some ( process) = self . processes . get ( & idx) {
302
+ if let Some ( sender) = process. input_sender . as_ref ( ) {
303
+ sender. send ( data) . ok ( ) ;
304
+ }
305
+ }
306
+ }
256
307
}
0 commit comments