2
2
import logging
3
3
import shlex
4
4
import subprocess
5
+ from typing import Union
5
6
6
- import lib . config as config
7
+ from _pytest . fixtures import _teardown_yield_fixture
7
8
9
+ import lib .config as config
8
10
from lib .netutil import wrap_ip
9
11
12
+
10
13
class BaseCommandFailed (Exception ):
11
14
__slots__ = 'returncode' , 'stdout' , 'cmd'
12
15
@@ -61,7 +64,7 @@ def _ellide_log_lines(log):
61
64
return "\n {}" .format ("\n " .join (reduced_message ))
62
65
63
66
def _ssh (hostname_or_ip , cmd , check , simple_output , suppress_fingerprint_warnings ,
64
- background , target_os , decode , options ):
67
+ background , target_os , decode , options ) -> Union [ SSHResult , SSHCommandFailed , str , bytes , subprocess . Popen ] :
65
68
opts = list (options )
66
69
opts .append ('-o "BatchMode yes"' )
67
70
if suppress_fingerprint_warnings :
@@ -86,6 +89,7 @@ def _ssh(hostname_or_ip, cmd, check, simple_output, suppress_fingerprint_warning
86
89
87
90
windows_background = background and target_os == "windows"
88
91
# Fetch banner and remove it to avoid stdout/stderr pollution.
92
+ banner_res = None
89
93
if config .ignore_ssh_banner and not windows_background :
90
94
banner_res = subprocess .run (
91
95
"ssh root@%s %s '%s'" % (hostname_or_ip , ' ' .join (opts ), '\n ' ),
@@ -103,9 +107,10 @@ def _ssh(hostname_or_ip, cmd, check, simple_output, suppress_fingerprint_warning
103
107
)
104
108
logging .debug (f"[{ hostname_or_ip } ] { command } " )
105
109
if windows_background :
106
- return True , process
110
+ return process
107
111
108
112
stdout = []
113
+ assert process .stdout is not None
109
114
for line in iter (process .stdout .readline , b'' ):
110
115
readable_line = line .decode (errors = 'replace' ).strip ()
111
116
stdout .append (line )
@@ -118,34 +123,56 @@ def _ssh(hostname_or_ip, cmd, check, simple_output, suppress_fingerprint_warning
118
123
119
124
# Even if check is False, we still raise in case of return code 255, which means a SSH error.
120
125
if res .returncode == 255 :
121
- return False , SSHCommandFailed (255 , "SSH Error: %s" % output_for_errors , command )
126
+ return SSHCommandFailed (255 , "SSH Error: %s" % output_for_errors , command )
122
127
123
- output = res .stdout
124
- if config . ignore_ssh_banner :
128
+ output : Union [ bytes , str ] = res .stdout
129
+ if banner_res :
125
130
if banner_res .returncode == 255 :
126
- return False , SSHCommandFailed (255 , "SSH Error: %s" % banner_res .stdout .decode (errors = 'replace' ), command )
131
+ return SSHCommandFailed (255 , "SSH Error: %s" % banner_res .stdout .decode (errors = 'replace' ), command )
127
132
output = output [len (banner_res .stdout ):]
128
133
129
134
if decode :
135
+ assert isinstance (output , bytes )
130
136
output = output .decode ()
131
137
132
138
if res .returncode and check :
133
- return False , SSHCommandFailed (res .returncode , output_for_errors , command )
139
+ return SSHCommandFailed (res .returncode , output_for_errors , command )
134
140
135
141
if simple_output :
136
- return True , output .strip ()
137
- return True , SSHResult (res .returncode , output )
142
+ return output .strip ()
143
+ return SSHResult (res .returncode , output )
138
144
139
145
# The actual code is in _ssh().
140
146
# This function is kept short for shorter pytest traces upon SSH failures, which are common,
141
147
# as pytest prints the whole function definition that raised the SSHCommandFailed exception
142
- def ssh (hostname_or_ip , cmd , check = True , simple_output = True , suppress_fingerprint_warnings = True ,
143
- background = False , target_os = 'linux' , decode = True , options = []):
144
- success , result_or_exc = _ssh (hostname_or_ip , cmd , check , simple_output , suppress_fingerprint_warnings ,
145
- background , target_os , decode , options )
146
- if not success :
148
+ def ssh (hostname_or_ip , cmd , check = True , simple_output = True , suppress_fingerprint_warnings = True , background = False ,
149
+ target_os = 'linux' , decode = True , options = []) -> Union [SSHResult , str , bytes , subprocess .Popen ]:
150
+ result_or_exc = _ssh (hostname_or_ip , cmd , check , simple_output , suppress_fingerprint_warnings ,
151
+ background , target_os , decode , options )
152
+ if isinstance (result_or_exc , SSHCommandFailed ):
153
+ raise result_or_exc
154
+ else :
155
+ return result_or_exc
156
+
157
+ def ssh_str (hostname_or_ip , cmd , check = True , suppress_fingerprint_warnings = True ,
158
+ background = False , target_os = 'linux' , options = []) -> str :
159
+ result_or_exc = _ssh (hostname_or_ip , cmd , check , True , suppress_fingerprint_warnings ,
160
+ background , target_os , True , options )
161
+ if isinstance (result_or_exc , SSHCommandFailed ):
162
+ raise result_or_exc
163
+ elif isinstance (result_or_exc , str ):
164
+ return result_or_exc
165
+ assert False , "unexpected type"
166
+
167
+ def ssh_with_result (hostname_or_ip , cmd , suppress_fingerprint_warnings = True ,
168
+ background = False , target_os = 'linux' , decode = True , options = []) -> SSHResult :
169
+ result_or_exc = _ssh (hostname_or_ip , cmd , False , False , suppress_fingerprint_warnings ,
170
+ background , target_os , decode , options )
171
+ if isinstance (result_or_exc , SSHCommandFailed ):
147
172
raise result_or_exc
148
- return result_or_exc
173
+ elif isinstance (result_or_exc , SSHResult ):
174
+ return result_or_exc
175
+ assert False , "unexpected type"
149
176
150
177
def scp (hostname_or_ip , src , dest , check = True , suppress_fingerprint_warnings = True , local_dest = False ):
151
178
opts = '-o "BatchMode yes"'
@@ -179,6 +206,7 @@ def scp(hostname_or_ip, src, dest, check=True, suppress_fingerprint_warnings=Tru
179
206
return res
180
207
181
208
def sftp (hostname_or_ip , cmds , check = True , suppress_fingerprint_warnings = True ):
209
+ opts = ''
182
210
if suppress_fingerprint_warnings :
183
211
# Suppress warnings and questions related to host key fingerprints
184
212
# because on a test network IPs get reused, VMs are reinstalled, etc.
0 commit comments