From 2363f754c67d309f1029207aa70c2eae5ba2b124 Mon Sep 17 00:00:00 2001 From: maxim Date: Thu, 7 May 2015 19:34:55 +0300 Subject: [PATCH] Added support for writing into streams, thus receiving a live feed about the the tasks being performed. --- pywinrm.iml | 10 ++++ winrm/__init__.py | 14 ++++-- winrm/protocol.py | 29 +++++++----- winrm/tests/test_integration_protocol.py | 58 ++++++++++++++++++++++++ 4 files changed, 94 insertions(+), 17 deletions(-) create mode 100644 pywinrm.iml diff --git a/pywinrm.iml b/pywinrm.iml new file mode 100644 index 00000000..0a6b2fb7 --- /dev/null +++ b/pywinrm.iml @@ -0,0 +1,10 @@ + + + + + + + + + + diff --git a/winrm/__init__.py b/winrm/__init__.py index ee40fa9b..6ac6ed0f 100644 --- a/winrm/__init__.py +++ b/winrm/__init__.py @@ -7,6 +7,7 @@ class Response(object): """Response from a remote command execution""" + def __init__(self, args): self.std_out, self.std_err, self.status_code = args @@ -24,23 +25,25 @@ def __init__(self, target, auth, transport='plaintext'): self.protocol = Protocol(self.url, transport=transport, username=username, password=password) - def run_cmd(self, command, args=()): + def run_cmd(self, command, args=(), out_stream=None, err_stream=None): # TODO optimize perf. Do not call open/close shell every time shell_id = self.protocol.open_shell() command_id = self.protocol.run_command(shell_id, command, args) - rs = Response(self.protocol.get_command_output(shell_id, command_id)) + rs = Response(self.protocol.get_command_output(shell_id, command_id, out_stream, err_stream)) self.protocol.cleanup_command(shell_id, command_id) self.protocol.close_shell(shell_id) return rs - def run_ps(self, script): + def run_ps(self, script, out_stream=None, err_stream=None): """base64 encodes a Powershell script and executes the powershell encoded script command """ # must use utf16 little endian on windows base64_script = base64.b64encode(script.encode("utf_16_le")) - rs = self.run_cmd("powershell -encodedcommand %s" % (base64_script)) + rs = self.run_cmd("powershell -OutputFormat {0} -encodedcommand {1}".format("TEXT", base64_script), + out_stream=out_stream, err_stream=err_stream) + if len(rs.std_err): # if there was an error message, clean it it up and make it human # readable @@ -94,7 +97,8 @@ def strip_namespace(self, xml): @staticmethod def _build_url(target, transport): match = re.match( - '(?i)^((?Phttp[s]?)://)?(?P[0-9a-z-_.]+)(:(?P\d+))?(?P(/)?(wsman)?)?', target) # NOQA + '(?i)^((?Phttp[s]?)://)?(?P[0-9a-z-_.]+)(:(?P\d+))?(?P(/)?(wsman)?)?', + target) # NOQA scheme = match.group('scheme') if not scheme: # TODO do we have anything other than HTTP/HTTPS diff --git a/winrm/protocol.py b/winrm/protocol.py index 0884721f..99841c26 100644 --- a/winrm/protocol.py +++ b/winrm/protocol.py @@ -5,6 +5,7 @@ from isodate.isoduration import duration_isoformat import xmltodict from winrm.transport import HttpPlaintext, HttpKerberos, HttpSSL +import sys class Protocol(object): @@ -105,7 +106,7 @@ def open_shell(self, i_stream='stdin', o_stream='stdout stderr', # TODO: research Lifetime a bit more: # http://msdn.microsoft.com/en-us/library/cc251546(v=PROT.13).aspx # if lifetime: - # shell['rsp:Lifetime'] = iso8601_duration.sec_to_dur(lifetime) + # shell['rsp:Lifetime'] = iso8601_duration.sec_to_dur(lifetime) # TODO: make it so the input is given in milliseconds and converted # to xs:duration if idle_timeout: @@ -114,7 +115,6 @@ def open_shell(self, i_stream='stdin', o_stream='stdout stderr', env = shell.setdefault('rsp:Environment', {}) for key, value in env_vars.items(): env['rsp:Variable'] = {'@Name': key, '#text': value} - rs = self.send_message(xmltodict.unparse(rq)) # rs = xmltodict.parse(rs) # return rs['s:Envelope']['s:Body']['x:ResourceCreated']['a:ReferenceParameters']['w:SelectorSet']['w:Selector']['#text'] # NOQA @@ -247,7 +247,7 @@ def run_command(self, shell_id, command, arguments=(), } ] } - cmd_line = rq['env:Envelope'].setdefault('env:Body', {})\ + cmd_line = rq['env:Envelope'].setdefault('env:Body', {}) \ .setdefault('rsp:CommandLine', {}) cmd_line['rsp:Command'] = {'#text': command} if arguments: @@ -291,13 +291,15 @@ def cleanup_command(self, shell_id, command_id): # TODO change assert into user-friendly exception assert uuid.UUID(relates_to.replace('uuid:', '')) == message_id - def get_command_output(self, shell_id, command_id): + def get_command_output(self, shell_id, command_id, out_stream=None, err_stream=None): """ Get the Output of the given shell and command @param string shell_id: The shell id on the remote machine. See #open_shell @param string command_id: The command id on the remote machine. See #run_command + @param stream out_stream: The stream of which the std_out would be directed to. (optional) + @param stream err_stream: The stream of which the std_err would be directed to. (optional) #@return [Hash] Returns a Hash with a key :exitcode and :data. Data is an Array of Hashes where the cooresponding key # is either :stdout or :stderr. The reason it is in an Array so so @@ -307,20 +309,20 @@ def get_command_output(self, shell_id, command_id): stdout_buffer, stderr_buffer = [], [] command_done = False while not command_done: - stdout, stderr, return_code, command_done = \ - self._raw_get_command_output(shell_id, command_id) + stdout, stderr, return_code, command_done = self._raw_get_command_output(shell_id, command_id, out_stream, + err_stream) stdout_buffer.append(stdout) stderr_buffer.append(stderr) return ''.join(stdout_buffer), ''.join(stderr_buffer), return_code - def _raw_get_command_output(self, shell_id, command_id): + def _raw_get_command_output(self, shell_id, command_id, out_stream=None, err_stream=None): rq = {'env:Envelope': self._get_soap_header( resource_uri='http://schemas.microsoft.com/wbem/wsman/1/windows/shell/cmd', # NOQA action='http://schemas.microsoft.com/wbem/wsman/1/windows/shell/Receive', # NOQA shell_id=shell_id)} stream = rq['env:Envelope'].setdefault( - 'env:Body', {}).setdefault('rsp:Receive', {})\ + 'env:Body', {}).setdefault('rsp:Receive', {}) \ .setdefault('rsp:DesiredStream', {}) stream['@CommandId'] = command_id stream['#text'] = 'stdout stderr' @@ -333,12 +335,15 @@ def _raw_get_command_output(self, shell_id, command_id): return_code = -1 for stream_node in stream_nodes: if stream_node.text: + content = str(base64.b64decode(stream_node.text.encode('ascii'))) if stream_node.attrib['Name'] == 'stdout': - stdout += str(base64.b64decode( - stream_node.text.encode('ascii'))) + if out_stream: + out_stream.write(content) + stdout += content elif stream_node.attrib['Name'] == 'stderr': - stderr += str(base64.b64decode( - stream_node.text.encode('ascii'))) + if err_stream: + err_stream.write(content) + stderr += content # We may need to get additional output if the stream has not finished. # The CommandState will change from Running to Done like so: diff --git a/winrm/tests/test_integration_protocol.py b/winrm/tests/test_integration_protocol.py index 0c4f5645..b6804efa 100644 --- a/winrm/tests/test_integration_protocol.py +++ b/winrm/tests/test_integration_protocol.py @@ -1,5 +1,6 @@ import re import pytest +import sys xfail = pytest.mark.xfail @@ -42,6 +43,63 @@ def test_get_command_output(protocol_real): protocol_real.close_shell(shell_id) +def test_get_legal_command_output_live_and_cleanup_command(protocol_real): + if sys.version[0] == '2': + from StringIO import StringIO + else: + from io import StringIO + import threading + + shell_id = protocol_real.open_shell() + command_id = protocol_real.run_command(shell_id, 'ping', 'localhost'.split()) + + class CmdTask: + def __init__(self): + self.stat, self.o_std, self.e_std = None, None, None + self.o_stream, self.e_stream = StringIO(), StringIO + + def get_response(self): + self.o_std, self.e_std, self.stat = protocol_real.get_command_output(shell_id, command_id, + out_stream=self.o_stream, + err_stream=self.e_stream) + tsk = CmdTask() + threading.Thread(target=tsk.get_response).start() + + # Waiting for the stream to get some input + while not tsk.o_stream: + pass + + tmp = tsk.o_stream.getvalue() + is_different = False + + while tsk.stat is None or tsk.stat != 0: + if tmp == tsk.o_stream.getvalue(): + is_different = True + + # Checking if ever the stream was updated. + # assert is_different + # Checking of the final print to std_out is the same as in the stream + assert tsk.o_stream.getvalue() == tsk.o_std + + +def test_get_illegal_command_output_live_and_cleanup_command(protocol_real): + if sys.version[0] == '2': + from StringIO import StringIO + else: + from io import StringIO + + shell_id = protocol_real.open_shell() + command_id = protocol_real.run_command(shell_id, 'fake_cmd') + o_stream, e_stream = StringIO(), StringIO() + + o_std, e_std, stat = protocol_real.get_command_output(shell_id, command_id, out_stream=o_stream, + err_stream=e_stream) + + # Checking of the final print to std_out is the same as in the stream + assert stat != 0 + assert e_stream.getvalue() == e_std + + def test_run_command_taking_more_than_60_seconds(protocol_real): shell_id = protocol_real.open_shell() command_id = protocol_real.run_command(