diff --git a/invoke/runners.py b/invoke/runners.py index cb8929b0..df2a9d57 100644 --- a/invoke/runners.py +++ b/invoke/runners.py @@ -1,11 +1,12 @@ +import codecs import errno import locale import os +import signal import struct import sys import threading import time -import signal from subprocess import Popen, PIPE from types import TracebackType from typing import ( @@ -692,8 +693,9 @@ def read_proc_output(self, reader: Callable) -> Generator[str, None, None]: :returns: A generator yielding strings. - Specifically, each resulting string is the result of decoding - `read_chunk_size` bytes read from the subprocess' out/err stream. + Specifically, each resulting string is the result of incrementally + decoding up to `read_chunk_size` bytes from the subprocess' out/err + stream. The decoder ensures that encoding boundaries are respected. .. versionadded:: 1.0 """ @@ -703,11 +705,18 @@ def read_proc_output(self, reader: Callable) -> Generator[str, None, None]: # process is done running" because sometimes that signal will appear # before we've actually read all the data in the stream (i.e.: a race # condition). + decoder_cls = codecs.getincrementaldecoder(self.encoding) + decoder = decoder_cls("replace") while True: data = reader(self.read_chunk_size) if not data: break - yield self.decode(data) + # The incremental decoder will deal with partial characters. + yield decoder.decode(data) + pending_buf, _ = decoder.getstate() + if pending_buf: + # Emit the final chunk of data + yield decoder.decode(b"", True) def write_our_output(self, stream: IO, string: str) -> None: """ @@ -1020,6 +1029,13 @@ def decode(self, data: bytes) -> str: """ Decode some ``data`` bytes, returning Unicode. + .. warning:: + This function should not be used for streaming data. When data is + streamed in chunks, one chunk can end with only parts of a + multi-byte codepoint. This function will return a replacement + character for the incomplete byte sequence. + Use a ``codecs.IncrementalDecoder`` instead. + .. versionadded:: 1.0 """ # NOTE: yes, this is a 1-liner. The point is to make it much harder to diff --git a/tests/runners.py b/tests/runners.py index 49b86cc7..6e27b26a 100644 --- a/tests/runners.py +++ b/tests/runners.py @@ -70,8 +70,8 @@ def _runner(out="", err="", **kwargs): runner = klass(Context(config=Config(overrides=kwargs))) if "exits" in kwargs: runner.returncode = Mock(return_value=kwargs.pop("exits")) - out_file = BytesIO(out.encode()) - err_file = BytesIO(err.encode()) + out_file = BytesIO(out.encode() if isinstance(out, str) else out) + err_file = BytesIO(err.encode() if isinstance(err, str) else err) runner.read_proc_stdout = out_file.read runner.read_proc_stderr = err_file.read return runner @@ -539,6 +539,23 @@ def writes_and_flushes_to_stderr(self): err.write.assert_called_once_with("whatever") err.flush.assert_called_once_with() + def handles_incremental_decoding(self): + # 𠜎 is 4 bytes in utf-8 + expected_out = "𠜎" * 50 + runner = self._runner(out=expected_out) + # Make sure every read returns a partial character. + runner.read_chunk_size = 3 + out = StringIO() + runner.run(_, out_stream=out) + assert out.getvalue() == expected_out + + def handles_trailing_partial_character(self): + out = StringIO() + # Only output the first 3 out of the 4 bytes in 𠜎 + self._runner(out=b"\xf0\xa0\x9c").run(_, out_stream=out) + # Should produce a single unicode replacement character + assert out.getvalue() == "�" + class input_stream_handling: # NOTE: actual autoresponder tests are elsewhere. These just test that # stdin works normally & can be overridden.