diff --git a/docs/release.rst b/docs/release.rst index ac4f851d..5e80861b 100644 --- a/docs/release.rst +++ b/docs/release.rst @@ -27,6 +27,8 @@ Improvements By :user:`John Kirkham `, :issue:`723` * All codecs are now pickleable. By :user:`Tom Nicholas `, :issue:`744` +* The Zstandard codec can now decode bytes containing multiple frames + By :user:`Mark Kittisopikul `, :issue:`757` Fixes ~~~~~ diff --git a/numcodecs/tests/test_pyzstd.py b/numcodecs/tests/test_pyzstd.py index b9dd6db2..7ee6084b 100644 --- a/numcodecs/tests/test_pyzstd.py +++ b/numcodecs/tests/test_pyzstd.py @@ -25,7 +25,6 @@ def test_pyzstd_simple(input): assert pyzstd.decompress(z.encode(input)) == input -@pytest.mark.xfail @pytest.mark.parametrize("input", test_data) def test_pyzstd_simple_multiple_frames_decode(input): """ diff --git a/numcodecs/tests/test_zstd.py b/numcodecs/tests/test_zstd.py index 04b474df..a3a926eb 100644 --- a/numcodecs/tests/test_zstd.py +++ b/numcodecs/tests/test_zstd.py @@ -156,3 +156,34 @@ def zstd_cli_available() -> bool: return not subprocess.run( ["zstd", "-V"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL ).returncode + + +def test_multi_frame(): + codec = Zstd() + + hello_world = codec.encode(b"Hello world!") + assert codec.decode(hello_world) == b"Hello world!" + assert codec.decode(hello_world * 2) == b"Hello world!Hello world!" + + hola = codec.encode(b"Hola ") + mundo = codec.encode(b"Mundo!") + assert codec.decode(hola) == b"Hola " + assert codec.decode(mundo) == b"Mundo!" + assert codec.decode(hola + mundo) == b"Hola Mundo!" + + bytes_val = b'(\xb5/\xfd\x00Xa\x00\x00Hello World!' + dec = codec.decode(bytes_val) + dec_expected = b'Hello World!' + assert dec == dec_expected + cli = zstd_cli_available() + if cli: + assert bytes_val == generate_zstd_streaming_bytes(dec_expected) + assert dec_expected == generate_zstd_streaming_bytes(bytes_val, decompress=True) + + # Concatenate frames of known sizes and unknown sizes + # unknown size frame at the end + assert codec.decode(hola + mundo + bytes_val) == b"Hola Mundo!Hello World!" + # unknown size frame at the beginning + assert codec.decode(bytes_val + hola + mundo) == b"Hello World!Hola Mundo!" + # unknown size frame in the middle + assert codec.decode(hola + bytes_val + mundo) == b"Hola Hello World!Mundo!" diff --git a/numcodecs/zstd.pyx b/numcodecs/zstd.pyx index f93da633..b3cc19f3 100644 --- a/numcodecs/zstd.pyx +++ b/numcodecs/zstd.pyx @@ -68,10 +68,12 @@ cdef extern from "zstd.h": size_t ZSTD_freeDStream(ZSTD_DStream* zds) nogil size_t ZSTD_initDStream(ZSTD_DStream* zds) nogil - cdef long ZSTD_CONTENTSIZE_UNKNOWN - cdef long ZSTD_CONTENTSIZE_ERROR + cdef unsigned long long ZSTD_CONTENTSIZE_UNKNOWN + cdef unsigned long long ZSTD_CONTENTSIZE_ERROR + unsigned long long ZSTD_getFrameContentSize(const void* src, size_t srcSize) nogil + size_t ZSTD_findFrameCompressedSize(const void* src, size_t srcSize) nogil int ZSTD_minCLevel() nogil int ZSTD_maxCLevel() nogil @@ -216,7 +218,11 @@ def decompress(source, dest=None): try: # determine uncompressed size - dest_size = ZSTD_getFrameContentSize(source_ptr, source_size) + try: + dest_size = findTotalContentSize(source_ptr, source_size) + except RuntimeError: + raise RuntimeError('Zstd decompression error: invalid input data') + if dest_size == 0 or dest_size == ZSTD_CONTENTSIZE_ERROR: raise RuntimeError('Zstd decompression error: invalid input data') @@ -353,6 +359,46 @@ cdef stream_decompress(const Py_buffer* source_pb): return dest +cdef unsigned long long findTotalContentSize(const char* source_ptr, size_t source_size): + """Find the total uncompressed content size of all frames in the source buffer + + Parameters + ---------- + source_ptr : Pointer to the beginning of the buffer + source_size : Size of the buffer containing the frame sizes to sum + + Returns + ------- + total_content_size: Sum of the content size of all frames within the source buffer + If any of the frame sizes is unknown, return ZSTD_CONTENTSIZE_UNKNOWN. + If any of the frames causes ZSTD_getFrameContentSize to error, return ZSTD_CONTENTSIZE_ERROR. + """ + cdef: + unsigned long long frame_content_size = 0 + unsigned long long total_content_size = 0 + size_t frame_compressed_size = 0 + size_t offset = 0 + + while offset < source_size: + frame_compressed_size = ZSTD_findFrameCompressedSize(source_ptr + offset, source_size - offset); + + if ZSTD_isError(frame_compressed_size): + error = ZSTD_getErrorName(frame_compressed_size) + raise RuntimeError('Could not set determine zstd frame size: %s' % error) + + frame_content_size = ZSTD_getFrameContentSize(source_ptr + offset, frame_compressed_size); + + if frame_content_size == ZSTD_CONTENTSIZE_ERROR: + return ZSTD_CONTENTSIZE_ERROR + + if frame_content_size == ZSTD_CONTENTSIZE_UNKNOWN: + return ZSTD_CONTENTSIZE_UNKNOWN + + total_content_size += frame_content_size + offset += frame_compressed_size + + return total_content_size + class Zstd(Codec): """Codec providing compression using Zstandard.