Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions changes/3264.fix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
- Expand the range of types accepted by ``parse_data_type`` to include strings and Sequences.
- Move the functionality of ``parse_data_type`` to a new function called ``parse_dtype``. This change
ensures that nomenclature is consistent across the codebase. ``parse_data_type`` remains, so this
change is not breaking.
14 changes: 7 additions & 7 deletions docs/user-guide/data_types.rst
Original file line number Diff line number Diff line change
Expand Up @@ -412,17 +412,17 @@ attempt data type resolution against *every* data type class, and if, for some r
type matches multiple Zarr data types, we treat this as an error and raise an exception.

If you have a NumPy data type and you want to get the corresponding ``ZDType`` instance, you can use
the ``parse_data_type`` function, which will use the dynamic resolution described above. ``parse_data_type``
the ``parse_dtype`` function, which will use the dynamic resolution described above. ``parse_dtype``
handles a range of input types:

- NumPy data types:

.. code-block:: python

>>> import numpy as np
>>> from zarr.dtype import parse_data_type
>>> from zarr.dtype import parse_dtype
>>> my_dtype = np.dtype('>M8[10s]')
>>> parse_data_type(my_dtype, zarr_format=2)
>>> parse_dtype(my_dtype, zarr_format=2)
DateTime64(endianness='big', scale_factor=10, unit='s')


Expand All @@ -431,7 +431,7 @@ handles a range of input types:
.. code-block:: python

>>> dtype_str = '>M8[10s]'
>>> parse_data_type(dtype_str, zarr_format=2)
>>> parse_dtype(dtype_str, zarr_format=2)
DateTime64(endianness='big', scale_factor=10, unit='s')

- ``ZDType`` instances:
Expand All @@ -440,7 +440,7 @@ handles a range of input types:

>>> from zarr.dtype import DateTime64
>>> zdt = DateTime64(endianness='big', scale_factor=10, unit='s')
>>> parse_data_type(zdt, zarr_format=2) # Use a ZDType (this is a no-op)
>>> parse_dtype(zdt, zarr_format=2) # Use a ZDType (this is a no-op)
DateTime64(endianness='big', scale_factor=10, unit='s')

- Python dictionaries (requires ``zarr_format=3``). These dictionaries must be consistent with the
Expand All @@ -449,7 +449,7 @@ handles a range of input types:
.. code-block:: python

>>> dt_dict = {"name": "numpy.datetime64", "configuration": {"unit": "s", "scale_factor": 10}}
>>> parse_data_type(dt_dict, zarr_format=3)
>>> parse_dtype(dt_dict, zarr_format=3)
DateTime64(endianness='little', scale_factor=10, unit='s')
>>> parse_data_type(dt_dict, zarr_format=3).to_json(zarr_format=3)
>>> parse_dtype(dt_dict, zarr_format=3).to_json(zarr_format=3)
{'name': 'numpy.datetime64', 'configuration': {'unit': 's', 'scale_factor': 10}}
6 changes: 3 additions & 3 deletions src/zarr/core/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@
VariableLengthUTF8,
ZDType,
ZDTypeLike,
parse_data_type,
parse_dtype,
)
from zarr.core.dtype.common import HasEndianness, HasItemSize, HasObjectCodec
from zarr.core.indexing import (
Expand Down Expand Up @@ -618,7 +618,7 @@ async def _create(
Deprecated in favor of :func:`zarr.api.asynchronous.create_array`.
"""

dtype_parsed = parse_data_type(dtype, zarr_format=zarr_format)
dtype_parsed = parse_dtype(dtype, zarr_format=zarr_format)
store_path = await make_store_path(store)

shape = parse_shapelike(shape)
Expand Down Expand Up @@ -4239,7 +4239,7 @@ async def init_array(

from zarr.codecs.sharding import ShardingCodec, ShardingCodecIndexLocation

zdtype = parse_data_type(dtype, zarr_format=zarr_format)
zdtype = parse_dtype(dtype, zarr_format=zarr_format)
shape_parsed = parse_shapelike(shape)
chunk_key_encoding_parsed = _parse_chunk_key_encoding(
chunk_key_encoding, zarr_format=zarr_format
Expand Down
68 changes: 50 additions & 18 deletions src/zarr/core/dtype/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from collections.abc import Sequence
from typing import TYPE_CHECKING, Final, TypeAlias

from zarr.core.dtype.common import (
Expand Down Expand Up @@ -94,6 +95,7 @@
"ZDType",
"data_type_registry",
"parse_data_type",
"parse_dtype",
]

data_type_registry = DataTypeRegistry()
Expand Down Expand Up @@ -188,39 +190,69 @@ def parse_data_type(
zarr_format: ZarrFormat,
) -> ZDType[TBaseDType, TBaseScalar]:
"""
Interpret the input as a ZDType instance.
Interpret the input as a ZDType.

This function wraps ``parse_dtype``. The only difference is the function name. This function may
be deprecated in a future version of Zarr Python in favor of ``parse_dtype``.

Parameters
----------
dtype_spec : ZDTypeLike
The input to be interpreted as a ZDType. This could be a ZDType, which will be returned
directly, or a JSON representation of a ZDType, or a native dtype, or a python object that
can be converted into a native dtype.
zarr_format : ZarrFormat
The Zarr format version.

Returns
-------
ZDType[TBaseDType, TBaseScalar]
The ZDType corresponding to the input.

Examples
--------
>>> parse_dtype("int32", zarr_format=2)
Int32(endianness="little")
"""
return parse_dtype(dtype_spec, zarr_format=zarr_format)


def parse_dtype(
dtype_spec: ZDTypeLike,
*,
zarr_format: ZarrFormat,
) -> ZDType[TBaseDType, TBaseScalar]:
"""
Interpret the input as a ZDType.

Parameters
----------
dtype_spec : ZDTypeLike
The input to be interpreted as a ZDType instance. This could be a native data type
(e.g., a NumPy data type), a Python object that can be converted into a native data type,
a ZDType instance (in which case the input is returned unchanged), or a JSON object
representation of a data type.
The input to be interpreted as a ZDType. This could be a ZDType, which will be returned
directly, or a JSON representation of a ZDType, or a native dtype, or a python object that
can be converted into a native dtype.
zarr_format : ZarrFormat
The zarr format version.
The Zarr format version.

Returns
-------
ZDType[TBaseDType, TBaseScalar]
The ZDType instance corresponding to the input.
The ZDType corresponding to the input.

Examples
--------
>>> from zarr.dtype import parse_data_type
>>> import numpy as np
>>> parse_data_type("int32", zarr_format=2)
Int32(endianness='little')
>>> parse_data_type(np.dtype('S10'), zarr_format=2)
NullTerminatedBytes(length=10)
>>> parse_data_type({"name": "numpy.datetime64", "configuration": {"unit": "s", "scale_factor": 10}}, zarr_format=3)
DateTime64(endianness='little', scale_factor=10, unit='s')
>>> parse_dtype("int32", zarr_format=2)
Int32(endianness="little")
"""
if isinstance(dtype_spec, ZDType):
return dtype_spec
# dict and zarr_format 3 means that we have a JSON object representation of the dtype
if zarr_format == 3 and isinstance(dtype_spec, Mapping):
return get_data_type_from_json(dtype_spec, zarr_format=3)
# First attempt to interpret the input as JSON
if isinstance(dtype_spec, Mapping | str | Sequence):
try:
return get_data_type_from_json(dtype_spec, zarr_format=3) # type: ignore[arg-type]
except ValueError:
# no data type matched this JSON-like input
pass
if dtype_spec in VLEN_UTF8_ALIAS:
# If the dtype request is one of the aliases for variable-length UTF-8 strings,
# return that dtype.
Expand Down
2 changes: 2 additions & 0 deletions src/zarr/dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
ZDType,
data_type_registry,
parse_data_type,
parse_dtype,
)

__all__ = [
Expand Down Expand Up @@ -84,4 +85,5 @@
"data_type_registry",
"data_type_registry",
"parse_data_type",
"parse_dtype",
]
4 changes: 2 additions & 2 deletions tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
VariableLengthBytes,
VariableLengthUTF8,
ZDType,
parse_data_type,
parse_dtype,
)
from zarr.core.dtype.common import ENDIANNESS_STR, EndiannessStr
from zarr.core.dtype.npy.common import NUMPY_ENDIANNESS_STR, endianness_from_numpy_str
Expand Down Expand Up @@ -1308,7 +1308,7 @@ async def test_v2_chunk_encoding(
filters=filters,
)
filters_expected, compressor_expected = _parse_chunk_encoding_v2(
filters=filters, compressor=compressors, dtype=parse_data_type(dtype, zarr_format=2)
filters=filters, compressor=compressors, dtype=parse_dtype(dtype, zarr_format=2)
)
assert arr.metadata.zarr_format == 2 # guard for mypy
assert arr.metadata.compressor == compressor_expected
Expand Down
66 changes: 40 additions & 26 deletions tests/test_dtype_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,14 @@
AnyDType,
Bool,
DataTypeRegistry,
DateTime64,
FixedLengthUTF32,
Int8,
Int16,
TBaseDType,
TBaseScalar,
VariableLengthUTF8,
ZDType,
data_type_registry,
get_data_type_from_json,
parse_data_type,
parse_dtype,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -174,28 +171,45 @@ def test_entrypoint_dtype(zarr_format: ZarrFormat) -> None:
data_type_registry.unregister(TestDataType._zarr_v3_name)


@pytest.mark.parametrize(
("dtype_params", "expected", "zarr_format"),
[
("str", VariableLengthUTF8(), 2),
("str", VariableLengthUTF8(), 3),
("int8", Int8(), 3),
(Int8(), Int8(), 3),
(">i2", Int16(endianness="big"), 2),
("datetime64[10s]", DateTime64(unit="s", scale_factor=10), 2),
(
{"name": "numpy.datetime64", "configuration": {"unit": "s", "scale_factor": 10}},
DateTime64(unit="s", scale_factor=10),
3,
),
],
)
def test_parse_data_type(
dtype_params: Any, expected: ZDType[Any, Any], zarr_format: ZarrFormat
) -> None:
@pytest.mark.filterwarnings("ignore::zarr.core.dtype.common.UnstableSpecificationWarning")
@pytest.mark.parametrize("data_type", zdtype_examples, ids=str)
def test_parse_data_type(data_type: ZDType[Any, Any], zarr_format: ZarrFormat) -> None:
"""
Test that parse_data_type accepts alternative representations of ZDType instances, and resolves
Test that parse_dtype accepts alternative representations of ZDType instances, and resolves
those inputs to the expected ZDType instance.
"""
observed = parse_data_type(dtype_params, zarr_format=zarr_format)
assert observed == expected
dtype_spec: Any
if zarr_format == 2:
dtype_spec = data_type.to_json(zarr_format=zarr_format)["name"]
else:
dtype_spec = data_type.to_json(zarr_format=zarr_format)
if dtype_spec == "|O":
msg = "Zarr data type resolution from object failed."
with pytest.raises(ValueError, match=msg):
parse_dtype(dtype_spec, zarr_format=zarr_format)
else:
observed = parse_dtype(dtype_spec, zarr_format=zarr_format)
assert observed == data_type


@pytest.mark.filterwarnings("ignore::zarr.core.dtype.common.UnstableSpecificationWarning")
@pytest.mark.parametrize("data_type", zdtype_examples, ids=str)
def test_parse_data_type_funcs(data_type: ZDType[Any, Any], zarr_format: ZarrFormat) -> None:
"""
Test that parse_data_type generates the same output as parse_dtype.
"""
dtype_spec: Any
if zarr_format == 2:
dtype_spec = data_type.to_json(zarr_format=zarr_format)["name"]
else:
dtype_spec = data_type.to_json(zarr_format=zarr_format)
if dtype_spec == "|O":
msg = "Zarr data type resolution from object failed."
with pytest.raises(ValueError, match=msg):
parse_dtype(dtype_spec, zarr_format=zarr_format)
with pytest.raises(ValueError, match=msg):
parse_data_type(dtype_spec, zarr_format=zarr_format)
else:
assert parse_dtype(dtype_spec, zarr_format=zarr_format) == parse_data_type(
dtype_spec, zarr_format=zarr_format
)
4 changes: 2 additions & 2 deletions tests/test_metadata/test_consolidated.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
open_consolidated,
)
from zarr.core.buffer import cpu, default_buffer_prototype
from zarr.core.dtype import parse_data_type
from zarr.core.dtype import parse_dtype
from zarr.core.group import ConsolidatedMetadata, GroupMetadata
from zarr.core.metadata import ArrayV3Metadata
from zarr.core.metadata.v2 import ArrayV2Metadata
Expand Down Expand Up @@ -504,7 +504,7 @@ async def test_consolidated_metadata_backwards_compatibility(
async def test_consolidated_metadata_v2(self):
store = zarr.storage.MemoryStore()
g = await AsyncGroup.from_store(store, attributes={"key": "root"}, zarr_format=2)
dtype = parse_data_type("uint8", zarr_format=2)
dtype = parse_dtype("uint8", zarr_format=2)
await g.create_array(name="a", shape=(1,), attributes={"key": "a"}, dtype=dtype)
g1 = await g.create_group(name="g1", attributes={"key": "g1"})
await g1.create_group(name="g2", attributes={"key": "g2"})
Expand Down
Loading