Skip to content

Add streaming decompression for ZSTD_CONTENTSIZE_UNKNOWN case #707

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Jul 10, 2025
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/compression/zstd.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ Zstd
.. autoattribute:: codec_id
.. automethod:: encode
.. automethod:: decode
.. note::
If the compressed data does not contain the decompressed size, streaming
decompression will be used.
.. automethod:: get_config
.. automethod:: from_config

Expand Down
2 changes: 2 additions & 0 deletions docs/release.rst
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ Maintenance

Improvements
~~~~~~~~~~~~
* Add streaming decompression for ZSTD (:issue:`699`)
By :user:`Mark Kittisopikul <mkitti>`.
* Raise a custom `UnknownCodecError` when trying to retrieve an unavailable codec.
By :user:`Cas Wognum <cwognum>`.

Expand Down
58 changes: 58 additions & 0 deletions numcodecs/tests/test_pyzstd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Check Zstd against pyzstd package

import numpy as np
import pytest
import pyzstd

from numcodecs.zstd import Zstd

test_data = [
b"Hello World!",
np.arange(113).tobytes(),
np.arange(10, 15).tobytes(),
np.random.randint(3, 50, size=(53,), dtype=np.uint16).tobytes(),
]


@pytest.mark.parametrize("input", test_data)
def test_pyzstd_simple(input):
z = Zstd()
assert z.decode(pyzstd.compress(input)) == input
assert pyzstd.decompress(z.encode(input)) == input


@pytest.mark.xfail
@pytest.mark.parametrize("input", test_data)
def test_pyzstd_simple_multiple_frames_decode(input):
z = Zstd()
assert z.decode(pyzstd.compress(input) * 2) == input * 2


@pytest.mark.parametrize("input", test_data)
def test_pyzstd_simple_multiple_frames_encode(input):
z = Zstd()
assert pyzstd.decompress(z.encode(input) * 2) == input * 2


@pytest.mark.parametrize("input", test_data)
def test_pyzstd_streaming(input):
pyzstd_c = pyzstd.ZstdCompressor()
pyzstd_d = pyzstd.ZstdDecompressor()
pyzstd_e = pyzstd.EndlessZstdDecompressor()
z = Zstd()

d_bytes = input
pyzstd_c.compress(d_bytes)
c_bytes = pyzstd_c.flush()
assert z.decode(c_bytes) == d_bytes
assert pyzstd_d.decompress(z.encode(d_bytes)) == d_bytes

# Test multiple streaming frames
assert z.decode(c_bytes * 2) == pyzstd_e.decompress(c_bytes * 2)
assert z.decode(c_bytes * 3) == pyzstd_e.decompress(c_bytes * 3)
assert z.decode(c_bytes * 4) == pyzstd_e.decompress(c_bytes * 4)
assert z.decode(c_bytes * 5) == pyzstd_e.decompress(c_bytes * 5)
assert z.decode(c_bytes * 7) == pyzstd_e.decompress(c_bytes * 7)
assert z.decode(c_bytes * 11) == pyzstd_e.decompress(c_bytes * 11)
assert z.decode(c_bytes * 13) == pyzstd_e.decompress(c_bytes * 13)
assert z.decode(c_bytes * 99) == pyzstd_e.decompress(c_bytes * 99)
78 changes: 72 additions & 6 deletions numcodecs/tests/test_zstd.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,9 @@
import itertools
import subprocess

import numpy as np
import pytest

try:
from numcodecs.zstd import Zstd
except ImportError: # pragma: no cover
pytest.skip("numcodecs.zstd not available", allow_module_level=True)


from numcodecs.tests.common import (
check_backwards_compatibility,
check_config,
Expand All @@ -17,6 +12,7 @@
check_err_encode_object_buffer,
check_repr,
)
from numcodecs.zstd import Zstd

codecs = [
Zstd(),
Expand Down Expand Up @@ -90,3 +86,73 @@ def test_native_functions():
assert Zstd.default_level() == 3
assert Zstd.min_level() == -131072
assert Zstd.max_level() == 22


def test_streaming_decompression():
# Test input frames with unknown frame content size
codec = Zstd()

# If the zstd command line interface is available, check the bytes
cli = zstd_cli_available()
if cli:
view_zstd_streaming_bytes()

# Encode bytes directly that were the result of streaming compression
bytes_val = b'(\xb5/\xfd\x00Xa\x00\x00Hello World!'
dec = codec.decode(bytes_val)
dec_expected = b'Hello World!'
assert dec == dec_expected
if cli:
assert bytes_val == generate_zstd_streaming_bytes(dec_expected)
assert dec_expected == generate_zstd_streaming_bytes(bytes_val, decompress=True)

# Two consecutive frames given as input
bytes2 = bytes(bytearray(bytes_val * 2))
dec2 = codec.decode(bytes2)
dec2_expected = b'Hello World!Hello World!'
assert dec2 == dec2_expected
if cli:
assert dec2_expected == generate_zstd_streaming_bytes(bytes2, decompress=True)

# Single long frame that decompresses to a large output
bytes3 = b'(\xb5/\xfd\x00X$\x02\x00\xa4\x03ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz\x01\x00:\xfc\xdfs\x05\x05L\x00\x00\x08s\x01\x00\xfc\xff9\x10\x02L\x00\x00\x08k\x01\x00\xfc\xff9\x10\x02L\x00\x00\x08c\x01\x00\xfc\xff9\x10\x02L\x00\x00\x08[\x01\x00\xfc\xff9\x10\x02L\x00\x00\x08S\x01\x00\xfc\xff9\x10\x02L\x00\x00\x08K\x01\x00\xfc\xff9\x10\x02L\x00\x00\x08C\x01\x00\xfc\xff9\x10\x02L\x00\x00\x08u\x01\x00\xfc\xff9\x10\x02L\x00\x00\x08m\x01\x00\xfc\xff9\x10\x02L\x00\x00\x08e\x01\x00\xfc\xff9\x10\x02L\x00\x00\x08]\x01\x00\xfc\xff9\x10\x02L\x00\x00\x08U\x01\x00\xfc\xff9\x10\x02L\x00\x00\x08M\x01\x00\xfc\xff9\x10\x02M\x00\x00\x08E\x01\x00\xfc\x7f\x1d\x08\x01'
dec3 = codec.decode(bytes3)
dec3_expected = b'ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz' * 1024 * 32
assert dec3 == dec3_expected
if cli:
assert bytes3 == generate_zstd_streaming_bytes(dec3_expected)
assert dec3_expected == generate_zstd_streaming_bytes(bytes3, decompress=True)

# Garbage input results in an error
bytes4 = bytes(bytearray([0, 0, 0, 0, 0, 0, 0, 0]))
with pytest.raises(RuntimeError, match='Zstd decompression error: invalid input data'):
codec.decode(bytes4)


def generate_zstd_streaming_bytes(input: bytes, *, decompress: bool = False) -> bytes:
"""
Use the zstd command line interface to compress or decompress bytes in streaming mode.
"""
if decompress:
args = ["-d"]
else:
args = []

p = subprocess.run(["zstd", "--no-check", *args], input=input, capture_output=True)
return p.stdout


def view_zstd_streaming_bytes():
bytes_val = generate_zstd_streaming_bytes(b"Hello world!")
print(f" bytes_val = {bytes_val}")

bytes3 = generate_zstd_streaming_bytes(
b"ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz" * 1024 * 32
)
print(f" bytes3 = {bytes3}")


def zstd_cli_available() -> bool:
return not subprocess.run(
["zstd", "-V"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL
).returncode
135 changes: 131 additions & 4 deletions numcodecs/zstd.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ from .compat_ext cimport PyBytes_RESIZE, ensure_continguous_memoryview
from .compat import ensure_contiguous_ndarray
from .abc import Codec

from libc.stdlib cimport malloc, realloc, free

cdef extern from "zstd.h":

Expand All @@ -21,6 +22,23 @@ cdef extern from "zstd.h":
struct ZSTD_CCtx_s:
pass
ctypedef ZSTD_CCtx_s ZSTD_CCtx

struct ZSTD_DStream_s:
pass
ctypedef ZSTD_DStream_s ZSTD_DStream

struct ZSTD_inBuffer_s:
const void* src
size_t size
size_t pos
ctypedef ZSTD_inBuffer_s ZSTD_inBuffer

struct ZSTD_outBuffer_s:
void* dst
size_t size
size_t pos
ctypedef ZSTD_outBuffer_s ZSTD_outBuffer

cdef enum ZSTD_cParameter:
ZSTD_c_compressionLevel=100
ZSTD_c_checksumFlag=201
Expand All @@ -36,12 +54,20 @@ cdef extern from "zstd.h":
size_t dstCapacity,
const void* src,
size_t srcSize) nogil

size_t ZSTD_decompress(void* dst,
size_t dstCapacity,
const void* src,
size_t compressedSize) nogil

size_t ZSTD_decompressStream(ZSTD_DStream* zds,
ZSTD_outBuffer* output,
ZSTD_inBuffer* input) nogil

size_t ZSTD_DStreamOutSize() nogil
ZSTD_DStream* ZSTD_createDStream() nogil
size_t ZSTD_freeDStream(ZSTD_DStream* zds) nogil
size_t ZSTD_initDStream(ZSTD_DStream* zds) nogil

cdef long ZSTD_CONTENTSIZE_UNKNOWN
cdef long ZSTD_CONTENTSIZE_ERROR
unsigned long long ZSTD_getFrameContentSize(const void* src,
Expand All @@ -55,7 +81,7 @@ cdef extern from "zstd.h":

unsigned ZSTD_isError(size_t code) nogil

const char* ZSTD_getErrorName(size_t code)
const char* ZSTD_getErrorName(size_t code) nogil


VERSION_NUMBER = ZSTD_versionNumber()
Expand Down Expand Up @@ -157,7 +183,10 @@ def decompress(source, dest=None):
source : bytes-like
Compressed data. Can be any object supporting the buffer protocol.
dest : array-like, optional
Object to decompress into.
Object to decompress into. If the content size is unknown, the
length of dest must match the decompressed size. If the content size
is unknown and dest is not provided, streaming decompression will be
used.

Returns
-------
Expand All @@ -174,6 +203,7 @@ def decompress(source, dest=None):
char* dest_ptr
size_t source_size, dest_size, decompressed_size
size_t nbytes, cbytes, blocksize
size_t dest_nbytes

# obtain source memoryview
source_mv = ensure_continguous_memoryview(source)
Expand All @@ -187,9 +217,12 @@ def decompress(source, dest=None):

# determine uncompressed size
dest_size = ZSTD_getFrameContentSize(source_ptr, source_size)
if dest_size == 0 or dest_size == ZSTD_CONTENTSIZE_UNKNOWN or dest_size == ZSTD_CONTENTSIZE_ERROR:
if dest_size == 0 or dest_size == ZSTD_CONTENTSIZE_ERROR:
raise RuntimeError('Zstd decompression error: invalid input data')

if dest_size == ZSTD_CONTENTSIZE_UNKNOWN and dest is None:
return stream_decompress(source_pb)

# setup destination buffer
if dest is None:
# allocate memory
Expand All @@ -203,6 +236,9 @@ def decompress(source, dest=None):
dest_ptr = <char*>dest_pb.buf
dest_nbytes = dest_pb.len

if dest_size == ZSTD_CONTENTSIZE_UNKNOWN:
dest_size = dest_nbytes

# validate output buffer
if dest_nbytes < dest_size:
raise ValueError('destination buffer too small; expected at least %s, '
Expand All @@ -225,6 +261,97 @@ def decompress(source, dest=None):

return dest

cdef stream_decompress(const Py_buffer* source_pb):
"""Decompress data of unknown size

Parameters
----------
source : Py_buffer
Compressed data buffer

Returns
-------
dest : bytes
Object containing decompressed data.
"""

cdef:
const char *source_ptr
void *dest_ptr
void *new_dst
size_t source_size, dest_size, decompressed_size
size_t DEST_GROWTH_SIZE, status
ZSTD_inBuffer input
ZSTD_outBuffer output
ZSTD_DStream *zds

# Recommended size for output buffer, guaranteed to flush at least
# one completely block in all circumstances
DEST_GROWTH_SIZE = ZSTD_DStreamOutSize();

source_ptr = <const char*>source_pb.buf
source_size = source_pb.len

# unknown content size, guess it is twice the size as the source
dest_size = source_size * 2

if dest_size < DEST_GROWTH_SIZE:
# minimum dest_size is DEST_GROWTH_SIZE
dest_size = DEST_GROWTH_SIZE

dest_ptr = <char *>malloc(dest_size)
zds = ZSTD_createDStream()

try:

with nogil:

status = ZSTD_initDStream(zds)
if ZSTD_isError(status):
error = ZSTD_getErrorName(status)
ZSTD_freeDStream(zds);
raise RuntimeError('Zstd stream decompression error on ZSTD_initDStream: %s' % error)

input = ZSTD_inBuffer(source_ptr, source_size, 0)
output = ZSTD_outBuffer(dest_ptr, dest_size, 0)

# Initialize to 1 to force a loop iteration
status = 1
while(status > 0 or input.pos < input.size):
# Possible returned values of ZSTD_decompressStream:
# 0: frame is completely decoded and fully flushed
# error (<0)
# >0: suggested next input size
status = ZSTD_decompressStream(zds, &output, &input)

if ZSTD_isError(status):
error = ZSTD_getErrorName(status)
raise RuntimeError('Zstd stream decompression error on ZSTD_decompressStream: %s' % error)

# There is more to decompress, grow the buffer
if status > 0 and output.pos == output.size:
new_size = output.size + DEST_GROWTH_SIZE

if new_size < output.size or new_size < DEST_GROWTH_SIZE:
raise RuntimeError('Zstd stream decompression error: output buffer overflow')

new_dst = realloc(output.dst, new_size)

if new_dst == NULL:
# output.dst freed in finally block
raise RuntimeError('Zstd stream decompression error on realloc: could not expand output buffer')

output.dst = new_dst
output.size = new_size

# Copy the output to a bytes object
dest = PyBytes_FromStringAndSize(<char *>output.dst, output.pos)

finally:
ZSTD_freeDStream(zds)
free(output.dst)

return dest

class Zstd(Codec):
"""Codec providing compression using Zstandard.
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ test = [
"coverage",
"pytest",
"pytest-cov",
"pyzstd"
]
test_extras = [
"importlib_metadata",
Expand Down
Loading