Skip to content

Handle zarr 3.1.0 #766

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 39 additions & 11 deletions numcodecs/zarr3.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,19 @@
import math
from dataclasses import dataclass, replace
from functools import cached_property
from importlib.metadata import version
from typing import Any, Self
from warnings import warn

import numpy as np
from packaging.version import Version

import numcodecs

try:
import zarr
import zarr # noqa: F401

if zarr.__version__ < "3.0.0": # pragma: no cover
if Version(version('zarr')) < Version("3.0.0"): # pragma: no cover
raise ImportError("zarr 3.0.0 or later is required to use the numcodecs zarr integration.")
except ImportError as e: # pragma: no cover
raise ImportError(
Expand All @@ -56,6 +58,23 @@
CODEC_PREFIX = "numcodecs."


def from_zarr_dtype(dtype: Any) -> np.dtype:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def from_zarr_dtype(dtype: Any) -> np.dtype:
def _from_zarr_dtype(dtype: Any) -> np.dtype:

"""
Get a numpy data type from an array spec, depending on the zarr version.
"""
if Version(version('zarr')) >= Version("3.1.0"):
return dtype.to_native_dtype()
return dtype # pragma: no cover


def to_zarr_dtype(dtype: np.dtype) -> Any:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def to_zarr_dtype(dtype: np.dtype) -> Any:
def _to_zarr_dtype(dtype: np.dtype) -> Any:

if Version(version('zarr')) >= Version("3.1.0"):
from zarr.dtype import parse_data_type

return parse_data_type(dtype, zarr_format=3)
return dtype # pragma: no cover


def _expect_name_prefix(codec_name: str) -> str:
if not codec_name.startswith(CODEC_PREFIX):
raise ValueError(
Expand Down Expand Up @@ -224,15 +243,17 @@ class LZMA(_NumcodecsBytesBytesCodec, codec_name="lzma"):
class Shuffle(_NumcodecsBytesBytesCodec, codec_name="shuffle"):
def evolve_from_array_spec(self, array_spec: ArraySpec) -> Shuffle:
if self.codec_config.get("elementsize") is None:
return Shuffle(**{**self.codec_config, "elementsize": array_spec.dtype.itemsize})
dtype = from_zarr_dtype(array_spec.dtype)
return Shuffle(**{**self.codec_config, "elementsize": dtype.itemsize})
return self # pragma: no cover


# array-to-array codecs ("filters")
class Delta(_NumcodecsArrayArrayCodec, codec_name="delta"):
def resolve_metadata(self, chunk_spec: ArraySpec) -> ArraySpec:
if astype := self.codec_config.get("astype"):
return replace(chunk_spec, dtype=np.dtype(astype)) # type: ignore[call-overload]
dtype = to_zarr_dtype(np.dtype(astype)) # type: ignore[call-overload]
return replace(chunk_spec, dtype=dtype)
return chunk_spec


Expand All @@ -243,12 +264,14 @@ class BitRound(_NumcodecsArrayArrayCodec, codec_name="bitround"):
class FixedScaleOffset(_NumcodecsArrayArrayCodec, codec_name="fixedscaleoffset"):
def resolve_metadata(self, chunk_spec: ArraySpec) -> ArraySpec:
if astype := self.codec_config.get("astype"):
return replace(chunk_spec, dtype=np.dtype(astype)) # type: ignore[call-overload]
dtype = to_zarr_dtype(np.dtype(astype)) # type: ignore[call-overload]
return replace(chunk_spec, dtype=dtype)
return chunk_spec

def evolve_from_array_spec(self, array_spec: ArraySpec) -> FixedScaleOffset:
if self.codec_config.get("dtype") is None:
return FixedScaleOffset(**{**self.codec_config, "dtype": str(array_spec.dtype)})
dtype = from_zarr_dtype(array_spec.dtype)
return FixedScaleOffset(**{**self.codec_config, "dtype": str(dtype)})
return self


Expand All @@ -258,7 +281,8 @@ def __init__(self, **codec_config: JSON) -> None:

def evolve_from_array_spec(self, array_spec: ArraySpec) -> Quantize:
if self.codec_config.get("dtype") is None:
return Quantize(**{**self.codec_config, "dtype": str(array_spec.dtype)})
dtype = from_zarr_dtype(array_spec.dtype)
return Quantize(**{**self.codec_config, "dtype": str(dtype)})
return self


Expand All @@ -267,21 +291,25 @@ def resolve_metadata(self, chunk_spec: ArraySpec) -> ArraySpec:
return replace(
chunk_spec,
shape=(1 + math.ceil(product(chunk_spec.shape) / 8),),
dtype=np.dtype("uint8"),
dtype=to_zarr_dtype(np.dtype("uint8")),
)

def validate(self, *, dtype: np.dtype[Any], **_kwargs) -> None:
if dtype != np.dtype("bool"):
_dtype = from_zarr_dtype(dtype)
if _dtype != np.dtype("bool"):
raise ValueError(f"Packbits filter requires bool dtype. Got {dtype}.")


class AsType(_NumcodecsArrayArrayCodec, codec_name="astype"):
def resolve_metadata(self, chunk_spec: ArraySpec) -> ArraySpec:
return replace(chunk_spec, dtype=np.dtype(self.codec_config["encode_dtype"])) # type: ignore[arg-type]
dtype = to_zarr_dtype(np.dtype(self.codec_config["encode_dtype"])) # type: ignore[arg-type]
return replace(chunk_spec, dtype=dtype)

def evolve_from_array_spec(self, array_spec: ArraySpec) -> AsType:
if self.codec_config.get("decode_dtype") is None:
return AsType(**{**self.codec_config, "decode_dtype": str(array_spec.dtype)})
# TODO: remove these coverage exemptions the correct way, i.e. with tests
dtype = from_zarr_dtype(array_spec.dtype) # pragma: no cover
return AsType(**{**self.codec_config, "decode_dtype": str(dtype)}) # pragma: no cover
return self


Expand Down
Loading