Skip to content

Allow Zstandard to decompress multiple concatenated frames #757

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/release.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ Improvements
By :user:`John Kirkham <jakirkham>`, :issue:`723`
* All codecs are now pickleable.
By :user:`Tom Nicholas <TomNicholas>`, :issue:`744`
* The Zstandard codec can now decode bytes containing multiple frames
By :user:`Mark Kittisopikul <mkitti>`, :issue:`757`

Fixes
~~~~~
Expand Down
1 change: 0 additions & 1 deletion numcodecs/tests/test_pyzstd.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ def test_pyzstd_simple(input):
assert pyzstd.decompress(z.encode(input)) == input


@pytest.mark.xfail
@pytest.mark.parametrize("input", test_data)
def test_pyzstd_simple_multiple_frames_decode(input):
"""
Expand Down
31 changes: 31 additions & 0 deletions numcodecs/tests/test_zstd.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,3 +156,34 @@ def zstd_cli_available() -> bool:
return not subprocess.run(
["zstd", "-V"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL
).returncode


def test_multi_frame():
codec = Zstd()

hello_world = codec.encode(b"Hello world!")
assert codec.decode(hello_world) == b"Hello world!"
assert codec.decode(hello_world * 2) == b"Hello world!Hello world!"

hola = codec.encode(b"Hola ")
mundo = codec.encode(b"Mundo!")
assert codec.decode(hola) == b"Hola "
assert codec.decode(mundo) == b"Mundo!"
assert codec.decode(hola + mundo) == b"Hola Mundo!"

bytes_val = b'(\xb5/\xfd\x00Xa\x00\x00Hello World!'
dec = codec.decode(bytes_val)
dec_expected = b'Hello World!'
assert dec == dec_expected
cli = zstd_cli_available()
if cli:
assert bytes_val == generate_zstd_streaming_bytes(dec_expected)
assert dec_expected == generate_zstd_streaming_bytes(bytes_val, decompress=True)

# Concatenate frames of known sizes and unknown sizes
# unknown size frame at the end
assert codec.decode(hola + mundo + bytes_val) == b"Hola Mundo!Hello World!"
# unknown size frame at the beginning
assert codec.decode(bytes_val + hola + mundo) == b"Hello World!Hola Mundo!"
# unknown size frame in the middle
assert codec.decode(hola + bytes_val + mundo) == b"Hola Hello World!Mundo!"
52 changes: 49 additions & 3 deletions numcodecs/zstd.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,12 @@ cdef extern from "zstd.h":
size_t ZSTD_freeDStream(ZSTD_DStream* zds) nogil
size_t ZSTD_initDStream(ZSTD_DStream* zds) nogil

cdef long ZSTD_CONTENTSIZE_UNKNOWN
cdef long ZSTD_CONTENTSIZE_ERROR
cdef unsigned long long ZSTD_CONTENTSIZE_UNKNOWN
cdef unsigned long long ZSTD_CONTENTSIZE_ERROR

unsigned long long ZSTD_getFrameContentSize(const void* src,
size_t srcSize) nogil
size_t ZSTD_findFrameCompressedSize(const void* src, size_t srcSize) nogil

int ZSTD_minCLevel() nogil
int ZSTD_maxCLevel() nogil
Expand Down Expand Up @@ -216,7 +218,11 @@ def decompress(source, dest=None):
try:

# determine uncompressed size
dest_size = ZSTD_getFrameContentSize(source_ptr, source_size)
try:
dest_size = findTotalContentSize(source_ptr, source_size)
except RuntimeError:
raise RuntimeError('Zstd decompression error: invalid input data')

if dest_size == 0 or dest_size == ZSTD_CONTENTSIZE_ERROR:
raise RuntimeError('Zstd decompression error: invalid input data')

Expand Down Expand Up @@ -353,6 +359,46 @@ cdef stream_decompress(const Py_buffer* source_pb):

return dest

cdef unsigned long long findTotalContentSize(const char* source_ptr, size_t source_size):
"""Find the total uncompressed content size of all frames in the source buffer

Parameters
----------
source_ptr : Pointer to the beginning of the buffer
source_size : Size of the buffer containing the frame sizes to sum

Returns
-------
total_content_size: Sum of the content size of all frames within the source buffer
If any of the frame sizes is unknown, return ZSTD_CONTENTSIZE_UNKNOWN.
If any of the frames causes ZSTD_getFrameContentSize to error, return ZSTD_CONTENTSIZE_ERROR.
"""
cdef:
unsigned long long frame_content_size = 0
unsigned long long total_content_size = 0
size_t frame_compressed_size = 0
size_t offset = 0

while offset < source_size:
frame_compressed_size = ZSTD_findFrameCompressedSize(source_ptr + offset, source_size - offset);

if ZSTD_isError(frame_compressed_size):
error = ZSTD_getErrorName(frame_compressed_size)
raise RuntimeError('Could not set determine zstd frame size: %s' % error)

frame_content_size = ZSTD_getFrameContentSize(source_ptr + offset, frame_compressed_size);

if frame_content_size == ZSTD_CONTENTSIZE_ERROR:
return ZSTD_CONTENTSIZE_ERROR

if frame_content_size == ZSTD_CONTENTSIZE_UNKNOWN:
return ZSTD_CONTENTSIZE_UNKNOWN

total_content_size += frame_content_size
offset += frame_compressed_size

return total_content_size

class Zstd(Codec):
"""Codec providing compression using Zstandard.

Expand Down
Loading