From 07edf73bdd380aa361c74fc207304c42d577a1f0 Mon Sep 17 00:00:00 2001 From: Florian Pinault Date: Wed, 26 Feb 2025 16:46:30 +0000 Subject: [PATCH 01/16] feat: zarr3 --- src/anemoi/datasets/__init__.py | 2 + src/anemoi/datasets/add_zarr_support.py | 153 ++++++++++++++++++++++ src/anemoi/datasets/commands/copy.py | 48 ++----- src/anemoi/datasets/create/__init__.py | 7 +- src/anemoi/datasets/create/patch.py | 4 +- src/anemoi/datasets/create/synchronise.py | 82 ++++++++++++ src/anemoi/datasets/create/zarr.py | 106 +++++++-------- src/anemoi/datasets/data/misc.py | 5 +- src/anemoi/datasets/data/stores.py | 71 +++++----- tests/create/test_create.py | 1 + tests/test_data.py | 83 ++++++++---- 11 files changed, 402 insertions(+), 160 deletions(-) create mode 100644 src/anemoi/datasets/add_zarr_support.py create mode 100644 src/anemoi/datasets/create/synchronise.py diff --git a/src/anemoi/datasets/__init__.py b/src/anemoi/datasets/__init__.py index b7689466d..4c875f0f7 100644 --- a/src/anemoi/datasets/__init__.py +++ b/src/anemoi/datasets/__init__.py @@ -9,6 +9,7 @@ from typing import List +from .add_zarr_support import Zarr2AndZarr3 from .data import MissingDateError from .data import add_dataset_path from .data import add_named_dataset @@ -30,4 +31,5 @@ "MissingDateError", "open_dataset", "__version__", + "Zarr2AndZarr3", ] diff --git a/src/anemoi/datasets/add_zarr_support.py b/src/anemoi/datasets/add_zarr_support.py new file mode 100644 index 000000000..28a469c06 --- /dev/null +++ b/src/anemoi/datasets/add_zarr_support.py @@ -0,0 +1,153 @@ +# (C) Copyright 2025 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +import logging + +import zarr + +LOG = logging.getLogger(__name__) + + +class Zarr2: + @classmethod + def base_store(cls): + return zarr.storage.BaseStore + + @classmethod + def is_zarr_group(cls, obj): + return isinstance(obj, zarr.hierarchy.Group) + + @classmethod + def create_array(cls, zarr_root, *args, **kwargs): + return zarr_root.create_dataset(*args, **kwargs) + + @classmethod + def change_dtype_datetime64(cls, dtype): + return dtype + + @classmethod + def cast_dtype_datetime64(cls, array, dtype): + return array, dtype + + @classmethod + def get_not_found_exception(cls): + return zarr.errors.PathNotFoundError + + @classmethod + def zarr_open_mode_append(cls): + return "w+" + + @classmethod + def zarr_open_to_patch_in_tests(cls): + return "zarr.convenience.open" + + @classmethod + def zarr_open(cls, *args, **kwargs): + return zarr.convenience.open(*args, **kwargs) + + @classmethod + def get_read_only_store_class(cls): + class ReadOnlyStore(zarr.storage.BaseStore): + """A base class for read-only stores.""" + + def __delitem__(self, key: str) -> None: + """Prevent deletion of items.""" + raise NotImplementedError() + + def __setitem__(self, key: str, value: bytes) -> None: + """Prevent setting of items.""" + raise NotImplementedError() + + def __len__(self) -> int: + """Return the number of items in the store.""" + raise NotImplementedError() + + def __iter__(self) -> iter: + """Return an iterator over the store.""" + raise NotImplementedError() + + return ReadOnlyStore + + @classmethod + def raise_if_not_supported(cls, msg): + pass + + +class Zarr3: + @classmethod + def base_store(cls): + return zarr.abc.store.Store + + @classmethod + def is_zarr_group(cls, obj): + return isinstance(obj, zarr.Group) + + @classmethod + def create_array(cls, zarr_root, *args, **kwargs): + if "compressor" in kwargs and kwargs["compressor"] is None: + # compressor is deprecated, use compressors instead + kwargs.pop("compressor") + kwargs["compressors"] = () + return zarr_root.create_array(*args, **kwargs) + + @classmethod + def get_not_found_exception(cls): + return FileNotFoundError + + @classmethod + def zarr_open_mode_append(cls): + return "a" + + @classmethod + def change_dtype_datetime64(cls, dtype): + # remove this flag (and the relevant code) when Zarr 3 supports datetime64 + # https://github.com/zarr-developers/zarr-python/issues/2616 + import numpy as np + + if dtype == "datetime64[s]": + dtype = np.dtype("int64") + return dtype + + @classmethod + def cast_dtype_datetime64(cls, array, dtype): + # remove this flag (and the relevant code) when Zarr 3 supports datetime64 + # https://github.com/zarr-developers/zarr-python/issues/2616 + import numpy as np + + if dtype == np.dtype("datetime64[s]"): + dtype = "int64" + array = array.astype(dtype) + + return array, dtype + + @classmethod + def zarr_open_to_patch_in_tests(cls): + return "zarr.open" + + @classmethod + def zarr_open(cls, *args, **kwargs): + return zarr.open(*args, **kwargs) + + @classmethod + def get_read_only_store_class(cls): + class ReadOnlyStore(zarr.abc.store.Store): + def __init__(self, *args, **kwargs): + raise NotImplementedError("Zarr 3 is not for this kind of store : {}".format(args)) + + return ReadOnlyStore + + @classmethod + def raise_if_not_supported(cls, msg="Zarr 3 is not supported in this context"): + raise NotImplementedError(msg) + + +if zarr.__version__.startswith("3"): + Zarr2AndZarr3 = Zarr3 +else: + LOG.warning("Using Zarr 2 : only zarr datasets build with zarr 2 are supported") + Zarr2AndZarr3 = Zarr2 diff --git a/src/anemoi/datasets/commands/copy.py b/src/anemoi/datasets/commands/copy.py index 43bf51da7..694714a87 100644 --- a/src/anemoi/datasets/commands/copy.py +++ b/src/anemoi/datasets/commands/copy.py @@ -20,6 +20,8 @@ from anemoi.utils.remote import Transfer from anemoi.utils.remote import TransferMethodNotImplementedError +from anemoi.datasets import Zarr2AndZarr3 + from . import Command LOG = logging.getLogger(__name__) @@ -49,8 +51,6 @@ class ZarrCopier: Flag to resume copying an existing dataset. verbosity : int Verbosity level of logging. - nested : bool - Flag to use ZARR's nested directory backend. rechunk : str Rechunk size for the target data array. """ @@ -64,7 +64,6 @@ def __init__( overwrite: bool, resume: bool, verbosity: int, - nested: bool, rechunk: str, **kwargs: Any, ) -> None: @@ -86,8 +85,6 @@ def __init__( Flag to resume copying an existing dataset. verbosity : int Verbosity level of logging. - nested : bool - Flag to use ZARR's nested directory backend. rechunk : str Rechunk size for the target data array. **kwargs : Any @@ -100,7 +97,6 @@ def __init__( self.overwrite = overwrite self.resume = resume self.verbosity = verbosity - self.nested = nested self.rechunk = rechunk self.rechunking = rechunk.split(",") if rechunk else [] @@ -113,27 +109,6 @@ def __init__( raise NotImplementedError("Rechunking with SSH not implemented.") assert NotImplementedError("SSH not implemented.") - def _store(self, path: str, nested: bool = False) -> Any: - """Get the storage path. - - Parameters - ---------- - path : str - Path to the storage. - nested : bool, optional - Flag to use nested directory storage. - - Returns - ------- - Any - Storage path. - """ - if nested: - import zarr - - return zarr.storage.NestedDirectoryStore(path) - return path - def copy_chunk(self, n: int, m: int, source: Any, target: Any, _copy: Any, verbosity: int) -> Optional[slice]: """Copy a chunk of data from source to target. @@ -237,7 +212,8 @@ def copy_data(self, source: Any, target: Any, _copy: Any, verbosity: int) -> Non target_data = ( target["data"] if "data" in target - else target.create_dataset( + else Zarr2AndZarr3.create_array( + target, "data", shape=source_data.shape, chunks=self.data_chunks, @@ -317,13 +293,12 @@ def copy_group(self, source: Any, target: Any, _copy: Any, verbosity: int) -> No verbosity : int Verbosity level of logging. """ - import zarr for k, v in source.attrs.items(): target.attrs[k] = v for name in sorted(source.keys()): - if isinstance(source[name], zarr.hierarchy.Group): + if Zarr2AndZarr3.is_zarr_group(source[name]): group = target[name] if name in target else target.create_group(name) self.copy_group( source[name], @@ -376,13 +351,13 @@ def run(self) -> None: def target_exists() -> bool: try: - zarr.open(self._store(self.target), mode="r") + zarr.open(self.target, mode="r") return True except ValueError: return False def target_finished() -> bool: - target = zarr.open(self._store(self.target), mode="r") + target = zarr.open(self.target, mode="r") if "_copy" in target: done = sum(1 if x else 0 for x in target["_copy"]) todo = len(target["_copy"]) @@ -400,11 +375,11 @@ def target_finished() -> bool: def open_target() -> Any: if not target_exists(): - return zarr.open(self._store(self.target, self.nested), mode="w") + return zarr.open(self.target, mode="w") if self.overwrite: LOG.error("Target already exists, overwriting.") - return zarr.open(self._store(self.target, self.nested), mode="w") + return zarr.open(self.target, mode="w") if self.resume: if target_finished(): @@ -412,7 +387,7 @@ def open_target() -> Any: sys.exit(0) LOG.error("Target already exists, resuming copy.") - return zarr.open(self._store(self.target, self.nested), mode="w+") + return zarr.open(self.target, mode=Zarr2AndZarr3.zarr_open_mode_append()) LOG.error("Target already exists, use either --overwrite or --resume.") sys.exit(1) @@ -421,7 +396,7 @@ def open_target() -> Any: assert target is not None, target - source = zarr.open(self._store(self.source), mode="r") + source = zarr.open(self.source, mode="r") self.copy(source, target, self.verbosity) @@ -455,7 +430,6 @@ def add_arguments(self, command_parser: Any) -> None: help="Verbosity level. 0 is silent, 1 is normal, 2 is verbose.", default=1, ) - command_parser.add_argument("--nested", action="store_true", help="Use ZARR's nested directpry backend.") command_parser.add_argument( "--rechunk", help="Rechunk the target data array. Rechunk size should be a diviser of the block size." ) diff --git a/src/anemoi/datasets/create/__init__.py b/src/anemoi/datasets/create/__init__.py index a51a43412..f9c7bfbaf 100644 --- a/src/anemoi/datasets/create/__init__.py +++ b/src/anemoi/datasets/create/__init__.py @@ -32,6 +32,7 @@ from earthkit.data.core.order import build_remapping from anemoi.datasets import MissingDateError +from anemoi.datasets import Zarr2AndZarr3 from anemoi.datasets import open_dataset from anemoi.datasets.create.input.trace import enable_trace from anemoi.datasets.create.persistent import build_storage @@ -154,7 +155,7 @@ def _path_readable(path: str) -> bool: try: zarr.open(path, "r") return True - except zarr.errors.PathNotFoundError: + except Zarr2AndZarr3.get_not_found_exception(): return False @@ -208,7 +209,7 @@ def update_metadata(self, **kwargs: Any) -> None: import zarr LOG.debug(f"Updating metadata {kwargs}") - z = zarr.open(self.path, mode="w+") + z = zarr.open(self.path, mode=Zarr2AndZarr3.zarr_open_mode_append()) for k, v in kwargs.items(): if isinstance(v, np.datetime64): v = v.astype(datetime.datetime) @@ -1520,7 +1521,7 @@ def run(self) -> None: LOG.info(stats) - if not all(self.registry.get_flags(sync=False)): + if not all(self.registry.get_flags()): raise Exception(f"❗Zarr {self.path} is not fully built, not writing statistics into dataset.") for k in ["mean", "stdev", "minimum", "maximum", "sums", "squares", "count", "has_nans"]: diff --git a/src/anemoi/datasets/create/patch.py b/src/anemoi/datasets/create/patch.py index e8de85851..7cd9e0ff4 100755 --- a/src/anemoi/datasets/create/patch.py +++ b/src/anemoi/datasets/create/patch.py @@ -14,6 +14,8 @@ import zarr +from anemoi.datasets import Zarr2AndZarr3 + LOG = logging.getLogger(__name__) @@ -134,7 +136,7 @@ def apply_patch(path: str, verbose: bool = True, dry_run: bool = False) -> None: try: attrs = zarr.open(path, mode="r").attrs.asdict() - except zarr.errors.PathNotFoundError as e: + except Zarr2AndZarr3.get_not_found_exception() as e: LOG.error(f"Failed to open {path}") LOG.error(e) exit(0) diff --git a/src/anemoi/datasets/create/synchronise.py b/src/anemoi/datasets/create/synchronise.py new file mode 100644 index 000000000..b0990f4af --- /dev/null +++ b/src/anemoi/datasets/create/synchronise.py @@ -0,0 +1,82 @@ +# (C) Copyright 2025 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import logging +import shutil +import time + +from filelock import FileLock +from filelock import Timeout + +LOG = logging.getLogger(__name__) + + +class NoSynchroniser: + def __init__(self): + pass + + def __enter__(self): + pass + + def __exit__(self, exc_type, exc_val, exc_tb): + pass + + def clean(self): + pass + + +class Synchroniser: + def __init__(self, lock_file_path, timeout=10): + """Initialize the Synchroniser with the path to the lock file and an optional timeout. + Parameters + ---------- + lock_file_path + Path to the lock file on a shared filesystem. + timeout + Timeout for acquiring the lock in seconds. + """ + self.lock_file_path = lock_file_path + self.timeout = timeout + self.lock = FileLock(lock_file_path) + + def __enter__(self): + """Acquire the lock when entering the context.""" + try: + self.lock.acquire(timeout=self.timeout) + print("Lock acquired.") + except Timeout: + print("Could not acquire lock, another process might be holding it.") + raise + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Release the lock when exiting the context.""" + self.lock.release() + print("Lock released.") + + def clean(self): + try: + shutil.rmtree(self.lock_file_path) + except FileNotFoundError: + pass + + +# Example usage +if __name__ == "__main__": + + def example_operation(): + print("Performing operation...") + time.sleep(2) # Simulate some work + print("Operation complete.") + + lock_path = "/path/to/shared/lockfile.lock" + + # Use the Synchroniser as a context manager + with Synchroniser(lock_path) as sync: + example_operation() diff --git a/src/anemoi/datasets/create/zarr.py b/src/anemoi/datasets/create/zarr.py index 574f603d0..f6c9c966d 100644 --- a/src/anemoi/datasets/create/zarr.py +++ b/src/anemoi/datasets/create/zarr.py @@ -8,8 +8,6 @@ # nor does it submit to any jurisdiction. import datetime -import logging -import shutil from typing import Any from typing import Optional @@ -17,7 +15,10 @@ import zarr from numpy.typing import NDArray -LOG = logging.getLogger(__name__) +from anemoi.datasets import Zarr2AndZarr3 + +from .synchronise import NoSynchroniser +from .synchronise import Synchroniser def add_zarr_dataset( @@ -72,8 +73,11 @@ def add_zarr_dataset( shape = array.shape if array is not None: + array, dtype = Zarr2AndZarr3.cast_dtype_datetime64(array, dtype) + assert array.shape == shape, (array.shape, shape) - a = zarr_root.create_dataset( + a = Zarr2AndZarr3.create_array( + zarr_root, name, shape=shape, dtype=dtype, @@ -100,7 +104,9 @@ def add_zarr_dataset( else: raise ValueError(f"No fill_value for dtype={dtype}") - a = zarr_root.create_dataset( + dtype = Zarr2AndZarr3.change_dtype_datetime64(dtype) + a = Zarr2AndZarr3.create_array( + zarr_root, name, shape=shape, dtype=dtype, @@ -132,33 +138,19 @@ def __init__(self, path: str, synchronizer_path: Optional[str] = None, use_threa use_threads : bool Whether to use thread-based synchronization. """ - import zarr assert isinstance(path, str), path self.zarr_path = path - if use_threads: - self.synchronizer = zarr.ThreadSynchronizer() - self.synchronizer_path = None - else: - if synchronizer_path is None: - synchronizer_path = self.zarr_path + ".sync" - self.synchronizer_path = synchronizer_path - self.synchronizer = zarr.ProcessSynchronizer(self.synchronizer_path) + self.synchronizer = Synchroniser(synchronizer_path) if synchronizer_path else NoSynchroniser() def clean(self) -> None: """Clean up the synchronizer path.""" - if self.synchronizer_path is not None: - try: - shutil.rmtree(self.synchronizer_path) - except FileNotFoundError: - pass + self.synchronizer.clean() def _open_write(self) -> zarr.Group: """Open the Zarr store in write mode.""" - import zarr - - return zarr.open(self.zarr_path, mode="r+", synchronizer=self.synchronizer) + return zarr.open(self.zarr_path, mode="r+") def _open_read(self, sync: bool = True) -> zarr.Group: """Open the Zarr store in read mode. @@ -173,12 +165,7 @@ def _open_read(self, sync: bool = True) -> zarr.Group: zarr.Group The opened Zarr group. """ - import zarr - - if sync: - return zarr.open(self.zarr_path, mode="r", synchronizer=self.synchronizer) - else: - return zarr.open(self.zarr_path, mode="r") + return zarr.open(self.zarr_path, mode="r") def new_dataset(self, *args, **kwargs) -> None: """Create a new dataset in the Zarr store. @@ -190,9 +177,11 @@ def new_dataset(self, *args, **kwargs) -> None: **kwargs Keyword arguments for dataset creation. """ - z = self._open_write() - zarr_root = z["_build"] - add_zarr_dataset(*args, zarr_root=zarr_root, overwrite=True, dimensions=("tmp",), **kwargs) + with self.synchronizer: + z = self._open_write() + zarr_root = z["_build"] + add_zarr_dataset(*args, zarr_root=zarr_root, overwrite=True, dimensions=("tmp",), **kwargs) + del z def add_to_history(self, action: str, **kwargs) -> None: """Add an action to the history attribute of the Zarr store. @@ -210,10 +199,12 @@ def add_to_history(self, action: str, **kwargs) -> None: ) new.update(kwargs) - z = self._open_write() - history = z.attrs.get("history", []) - history.append(new) - z.attrs["history"] = history + with self.synchronizer: + z = self._open_write() + history = z.attrs.get("history", []) + history.append(new) + z.attrs["history"] = history + del z def get_lengths(self) -> list[int]: """Get the lengths dataset. @@ -223,8 +214,11 @@ def get_lengths(self) -> list[int]: list[int] The lengths dataset. """ - z = self._open_read() - return list(z["_build"][self.name_lengths][:]) + with self.synchronizer: + z = self._open_read() + lengths = list(z["_build"][self.name_lengths][:]) + del z + return lengths def get_flags(self, **kwargs) -> list[bool]: """Get the flags dataset. @@ -239,8 +233,11 @@ def get_flags(self, **kwargs) -> list[bool]: list[bool] The flags dataset. """ - z = self._open_read(**kwargs) - return list(z["_build"][self.name_flags][:]) + with self.synchronizer: + z = self._open_read(**kwargs) + flags = list(z["_build"][self.name_flags][:]) + del z + return flags def get_flag(self, i: int) -> bool: """Get a specific flag. @@ -255,8 +252,11 @@ def get_flag(self, i: int) -> bool: bool The flag value. """ - z = self._open_read() - return z["_build"][self.name_flags][i] + with self.synchronizer: + z = self._open_read() + flag = z["_build"][self.name_flags][i] + del z + return flag def set_flag(self, i: int, value: bool = True) -> None: """Set a specific flag. @@ -268,11 +268,13 @@ def set_flag(self, i: int, value: bool = True) -> None: value : bool Value to set the flag to. """ - z = self._open_write() - z.attrs["latest_write_timestamp"] = ( - datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None).isoformat() - ) - z["_build"][self.name_flags][i] = value + with self.synchronizer: + z = self._open_write() + z.attrs["latest_write_timestamp"] = ( + datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None).isoformat() + ) + z["_build"][self.name_flags][i] = value + del z def ready(self) -> bool: """Check if all flags are set. @@ -316,11 +318,13 @@ def add_provenance(self, name: str) -> None: name : str Name of the provenance attribute. """ - z = self._open_write() + with self.synchronizer: + z = self._open_write() - if name in z.attrs: - return + if name in z.attrs: + return - from anemoi.utils.provenance import gather_provenance_info + from anemoi.utils.provenance import gather_provenance_info - z.attrs[name] = gather_provenance_info() + z.attrs[name] = gather_provenance_info() + del z diff --git a/src/anemoi/datasets/data/misc.py b/src/anemoi/datasets/data/misc.py index 84dc7da8d..40e98aeeb 100644 --- a/src/anemoi/datasets/data/misc.py +++ b/src/anemoi/datasets/data/misc.py @@ -21,10 +21,11 @@ from typing import Union import numpy as np -import zarr from anemoi.utils.config import load_config as load_settings from numpy.typing import NDArray +from anemoi.datasets import Zarr2AndZarr3 + if TYPE_CHECKING: from .dataset import Dataset @@ -320,7 +321,7 @@ def _open(a: Union[str, PurePath, Dict[str, Any], List[Any], Tuple[Any, ...]]) - if isinstance(a, Dataset): return a.mutate() - if isinstance(a, zarr.hierarchy.Group): + if Zarr2AndZarr3.is_zarr_group(a): return Zarr(a).mutate() if isinstance(a, str): diff --git a/src/anemoi/datasets/data/stores.py b/src/anemoi/datasets/data/stores.py index 4e2bc9a9a..b3a08d61a 100644 --- a/src/anemoi/datasets/data/stores.py +++ b/src/anemoi/datasets/data/stores.py @@ -26,6 +26,8 @@ from anemoi.utils.dates import frequency_to_timedelta from numpy.typing import NDArray +from anemoi.datasets import Zarr2AndZarr3 + from . import MissingDateError from .dataset import Dataset from .dataset import FullIndex @@ -41,24 +43,7 @@ LOG = logging.getLogger(__name__) -class ReadOnlyStore(zarr.storage.BaseStore): - """A base class for read-only stores.""" - - def __delitem__(self, key: str) -> None: - """Prevent deletion of items.""" - raise NotImplementedError() - - def __setitem__(self, key: str, value: bytes) -> None: - """Prevent setting of items.""" - raise NotImplementedError() - - def __len__(self) -> int: - """Return the number of items in the store.""" - raise NotImplementedError() - - def __iter__(self) -> iter: - """Return an iterator over the store.""" - raise NotImplementedError() +ReadOnlyStore = Zarr2AndZarr3.get_read_only_store_class() class HTTPStore(ReadOnlyStore): @@ -154,7 +139,7 @@ def __getitem__(self, key: str) -> bytes: class DebugStore(ReadOnlyStore): """A store to debug the zarr loading.""" - def __init__(self, store: ReadOnlyStore) -> None: + def __init__(self, store: Any) -> None: """Initialize the DebugStore with another store.""" assert not isinstance(store, DebugStore) self.store = store @@ -180,7 +165,7 @@ def __contains__(self, key: str) -> bool: return key in self.store -def name_to_zarr_store(path_or_url: str) -> ReadOnlyStore: +def name_to_zarr_store(path_or_url: str) -> Any: """Convert a path or URL to a zarr store.""" store = path_or_url @@ -202,7 +187,7 @@ def name_to_zarr_store(path_or_url: str) -> ReadOnlyStore: return store -def open_zarr(path: str, dont_fail: bool = False, cache: int = None) -> zarr.hierarchy.Group: +def open_zarr(path: str, dont_fail: bool = False, cache: int = None) -> Any: """Open a zarr store from a path.""" try: store = name_to_zarr_store(path) @@ -222,18 +207,18 @@ def open_zarr(path: str, dont_fail: bool = False, cache: int = None) -> zarr.hie if cache is not None: store = zarr.LRUStoreCache(store, max_size=cache) - return zarr.convenience.open(store, "r") - except zarr.errors.PathNotFoundError: + return Zarr2AndZarr3.zarr_open(store, "r") + except Zarr2AndZarr3.get_not_found_exception(): if not dont_fail: - raise zarr.errors.PathNotFoundError(path) + raise FileNotFoundError(f"Zarr store not found: {path}") class Zarr(Dataset): """A zarr dataset.""" - def __init__(self, path: Union[str, zarr.hierarchy.Group]) -> None: + def __init__(self, path: Union[str, Any]) -> None: """Initialize the Zarr dataset with a path or zarr group.""" - if isinstance(path, zarr.hierarchy.Group): + if Zarr2AndZarr3.is_zarr_group(path): self.was_zarr = True self.path = str(id(path)) self.z = path @@ -243,7 +228,7 @@ def __init__(self, path: Union[str, zarr.hierarchy.Group]) -> None: self.z = open_zarr(self.path) # This seems to speed up the reading of the data a lot - self.data = self.z.data + self.data = self.z["data"] self._missing = set() @property @@ -294,7 +279,7 @@ def _unwind(self, index: Union[int, slice, list, tuple], rest: list, shape: tupl @cached_property def chunks(self) -> TupleIndex: """Return the chunks of the dataset.""" - return self.z.data.chunks + return self.data.chunks @cached_property def shape(self) -> Shape: @@ -304,39 +289,45 @@ def shape(self) -> Shape: @cached_property def dtype(self) -> np.dtype: """Return the data type of the dataset.""" - return self.z.data.dtype + return self.data.dtype @cached_property def dates(self) -> NDArray[np.datetime64]: """Return the dates of the dataset.""" - return self.z.dates[:] # Convert to numpy + dates = self.z["dates"][:] + if not dates.dtype == np.dtype("datetime64[s]"): + # The datasets created with zarr3 will have the dates as int64 as long + # as zarr3 does not support datetime64 + LOG.warning("Converting dates to 'datetime64[s]'") + dates = dates.astype("datetime64[s]") + return dates @property def latitudes(self) -> NDArray[Any]: """Return the latitudes of the dataset.""" try: - return self.z.latitudes[:] + return self.z["latitudes"][:] except AttributeError: LOG.warning("No 'latitudes' in %r, trying 'latitude'", self) - return self.z.latitude[:] + return self.z["latitude"][:] @property def longitudes(self) -> NDArray[Any]: """Return the longitudes of the dataset.""" try: - return self.z.longitudes[:] + return self.z["longitudes"][:] except AttributeError: LOG.warning("No 'longitudes' in %r, trying 'longitude'", self) - return self.z.longitude[:] + return self.z["longitude"][:] @property def statistics(self) -> Dict[str, NDArray[Any]]: """Return the statistics of the dataset.""" return dict( - mean=self.z.mean[:], - stdev=self.z.stdev[:], - maximum=self.z.maximum[:], - minimum=self.z.minimum[:], + mean=self.z["mean"][:], + stdev=self.z["stdev"][:], + maximum=self.z["maximum"][:], + minimum=self.z["minimum"][:], ) def statistics_tendencies(self, delta: Optional[datetime.timedelta] = None) -> Dict[str, NDArray[Any]]: @@ -465,7 +456,7 @@ def collect_input_sources(self, collected: set) -> None: class ZarrWithMissingDates(Zarr): """A zarr dataset with missing dates.""" - def __init__(self, path: Union[str, zarr.hierarchy.Group]) -> None: + def __init__(self, path: Union[str, Any]) -> None: """Initialize the ZarrWithMissingDates dataset with a path or zarr group.""" super().__init__(path) @@ -564,7 +555,7 @@ def zarr_lookup(name: str, fail: bool = True) -> Optional[str]: LOG.info("Opening `%s` as `%s`", name, full) QUIET.add(name) return full - except zarr.errors.PathNotFoundError: + except Zarr2AndZarr3.get_not_found_exception(): pass if fail: diff --git a/tests/create/test_create.py b/tests/create/test_create.py index 47762abf5..426b33db7 100755 --- a/tests/create/test_create.py +++ b/tests/create/test_create.py @@ -204,6 +204,7 @@ def compare_dot_zattrs(a: dict, b: dict, path: str, errors: list) -> None: "description", "config_path", "total_size", + "total_number_of_files", # expected to differ when comparing datasets generated with zarr 2 vs zarr 3 ]: if type(a[k]) is not type(b[k]): errors.append(f"❌ {path}.{k} : type differs {type(a[k])} != {type(b[k])}") diff --git a/tests/test_data.py b/tests/test_data.py index c2fba512d..dfa56d9dd 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -23,6 +23,7 @@ from anemoi.utils.dates import frequency_to_string from anemoi.utils.dates import frequency_to_timedelta +from anemoi.datasets import Zarr2AndZarr3 from anemoi.datasets import open_dataset from anemoi.datasets.data.concat import Concat from anemoi.datasets.data.ensemble import Ensemble @@ -55,7 +56,7 @@ def mockup_open_zarr(func: Callable) -> Callable: @wraps(func) def wrapper(*args, **kwargs): - with patch("zarr.convenience.open", zarr_from_str): + with patch(Zarr2AndZarr3.zarr_open_to_patch_in_tests(), zarr_from_str): with patch("anemoi.datasets.data.stores.zarr_lookup", lambda name: name): return func(*args, **kwargs) @@ -140,7 +141,7 @@ def create_zarr( dates.append(date) date += frequency - dates = np.array(dates, dtype="datetime64") + dates = np.array(dates, dtype="datetime64[s]") ensembles = ensemble if ensemble is not None else 1 values = grids if grids is not None else VALUES @@ -152,28 +153,42 @@ def create_zarr( for e in range(ensembles): data[i, j, e] = _(date.astype(object), var, k, e, values) - root.create_dataset( + Zarr2AndZarr3.create_array( + root, "data", - data=data, dtype=data.dtype, chunks=data.shape, compressor=None, - ) - root.create_dataset( + shape=data.shape, + )[...] = data + + dates, dtype_ = Zarr2AndZarr3.cast_dtype_datetime64(dates, dates.dtype) + del dtype_ + Zarr2AndZarr3.create_array( + root, "dates", - data=dates, compressor=None, - ) - root.create_dataset( + dtype=dates.dtype, + shape=dates.shape, + )[...] = dates + + latitudes = np.array([x + values for x in range(values)]) + Zarr2AndZarr3.create_array( + root, "latitudes", - data=np.array([x + values for x in range(values)]), compressor=None, - ) - root.create_dataset( + dtype=latitudes.dtype, + shape=latitudes.shape, + )[...] = latitudes + + longitudes = np.array([x + values for x in range(values)]) + Zarr2AndZarr3.create_array( + root, "longitudes", - data=np.array([x + values for x in range(values)]), compressor=None, - ) + dtype=longitudes.dtype, + shape=longitudes.shape, + )[...] = longitudes root.attrs["frequency"] = frequency_to_string(frequency) root.attrs["resolution"] = resolution @@ -194,26 +209,42 @@ def create_zarr( root.attrs["missing_dates"] = [d.isoformat() for d in missing_dates] - root.create_dataset( + Zarr2AndZarr3.create_array( + root, "mean", - data=np.mean(data, axis=0), compressor=None, - ) - root.create_dataset( + shape=data.shape[1:], + dtype=data.dtype, + )[ + ... + ] = np.mean(data, axis=0) + Zarr2AndZarr3.create_array( + root, "stdev", - data=np.std(data, axis=0), compressor=None, - ) - root.create_dataset( + shape=data.shape[1:], + dtype=data.dtype, + )[ + ... + ] = np.std(data, axis=0) + Zarr2AndZarr3.create_array( + root, "maximum", - data=np.max(data, axis=0), compressor=None, - ) - root.create_dataset( + shape=data.shape[1:], + dtype=data.dtype, + )[ + ... + ] = np.max(data, axis=0) + Zarr2AndZarr3.create_array( + root, "minimum", - data=np.min(data, axis=0), compressor=None, - ) + shape=data.shape[1:], + dtype=data.dtype, + )[ + ... + ] = np.min(data, axis=0) return root From 1459e83c987c2c1b477b6f52581dcc849a5ceb01 Mon Sep 17 00:00:00 2001 From: Florian Pinault Date: Thu, 27 Feb 2025 10:18:55 +0000 Subject: [PATCH 02/16] zarr3 --- pyproject.toml | 6 +++++- src/anemoi/datasets/create/__init__.py | 7 +++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 9457c8405..dfa151ea0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,7 +56,11 @@ dependencies = [ "pyyaml", "semantic-version", "tqdm", - "zarr<=2.17", + # anemoi-datasets supports zarr 2 and zarr 3, but we still use only zarr 2: + # - we don't want to create zarr 3 datasets yet, as they will no be readable by zarr 2 + # - anemoi-inference needs zarr 2 for patching + # - anemoi-registry needs zarr 2 + "zarr<3", ] optional-dependencies.all = [ diff --git a/src/anemoi/datasets/create/__init__.py b/src/anemoi/datasets/create/__init__.py index f9c7bfbaf..97e15f49b 100644 --- a/src/anemoi/datasets/create/__init__.py +++ b/src/anemoi/datasets/create/__init__.py @@ -172,6 +172,13 @@ def __init__(self, path: str): """ self.path = path + if Zarr2AndZarr3.version == "3" and not os.environ.get("ANEMOI_DATASETS_ALLOW_BUILDING_ZARR3_DATASETS"): + raise ValueError( + "zarr 3 is installed. anemoi-datasets supports zarr 3, but the datasets build with zarr 3 will " + "not be readable by zarr 2. It is likely that you do not want to create a dataset with zarr 3. " + "Please uninstall zarr 3 and install zarr 2." + ) + _, ext = os.path.splitext(self.path) if ext != ".zarr": raise ValueError(f"Unsupported extension={ext} for path={self.path}") From e43f0f661d0bd5abaaba7ca035a5e020ce7bd0e2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 24 Jun 2025 11:43:26 +0000 Subject: [PATCH 03/16] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/anemoi/datasets/commands/copy.py | 2 -- src/anemoi/datasets/data/misc.py | 2 -- 2 files changed, 4 deletions(-) diff --git a/src/anemoi/datasets/commands/copy.py b/src/anemoi/datasets/commands/copy.py index a43588312..4ee111306 100644 --- a/src/anemoi/datasets/commands/copy.py +++ b/src/anemoi/datasets/commands/copy.py @@ -305,7 +305,6 @@ def copy_group(self, source: Any, target: Any, _copy: Any, verbosity: int) -> No LOG.info(f"Copying attribute {k} = {textwrap.shorten(str(v), 40)}") target.attrs[k] = v - source_keys = list(source.keys()) if not source_keys: @@ -426,7 +425,6 @@ def open_target() -> Any: assert target is not None, target - if self.verbosity > 0: LOG.info(f"Open source: {self.source}") diff --git a/src/anemoi/datasets/data/misc.py b/src/anemoi/datasets/data/misc.py index f0dcb9ed7..a1dd1c02c 100644 --- a/src/anemoi/datasets/data/misc.py +++ b/src/anemoi/datasets/data/misc.py @@ -22,10 +22,8 @@ from typing import Union import numpy as np - import zarr from anemoi.utils.config import load_any_dict_format - from anemoi.utils.config import load_config as load_settings from numpy.typing import NDArray From 59d3a9e02922fb54fedea7233a893daf00d0b16c Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Wed, 25 Jun 2025 10:47:58 +0000 Subject: [PATCH 04/16] update --- src/anemoi/datasets/__init__.py | 2 - src/anemoi/datasets/add_zarr_support.py | 153 ----------- src/anemoi/datasets/commands/copy.py | 8 +- src/anemoi/datasets/create/__init__.py | 12 +- src/anemoi/datasets/data/misc.py | 4 +- src/anemoi/datasets/data/stores.py | 149 +---------- src/anemoi/datasets/zarr_versions/__init__.py | 23 ++ src/anemoi/datasets/zarr_versions/zarr2.py | 242 ++++++++++++++++++ src/anemoi/datasets/zarr_versions/zarr3.py | 60 +++++ 9 files changed, 349 insertions(+), 304 deletions(-) delete mode 100644 src/anemoi/datasets/add_zarr_support.py create mode 100644 src/anemoi/datasets/zarr_versions/__init__.py create mode 100644 src/anemoi/datasets/zarr_versions/zarr2.py create mode 100644 src/anemoi/datasets/zarr_versions/zarr3.py diff --git a/src/anemoi/datasets/__init__.py b/src/anemoi/datasets/__init__.py index 4c875f0f7..b7689466d 100644 --- a/src/anemoi/datasets/__init__.py +++ b/src/anemoi/datasets/__init__.py @@ -9,7 +9,6 @@ from typing import List -from .add_zarr_support import Zarr2AndZarr3 from .data import MissingDateError from .data import add_dataset_path from .data import add_named_dataset @@ -31,5 +30,4 @@ "MissingDateError", "open_dataset", "__version__", - "Zarr2AndZarr3", ] diff --git a/src/anemoi/datasets/add_zarr_support.py b/src/anemoi/datasets/add_zarr_support.py deleted file mode 100644 index 28a469c06..000000000 --- a/src/anemoi/datasets/add_zarr_support.py +++ /dev/null @@ -1,153 +0,0 @@ -# (C) Copyright 2025 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. -import logging - -import zarr - -LOG = logging.getLogger(__name__) - - -class Zarr2: - @classmethod - def base_store(cls): - return zarr.storage.BaseStore - - @classmethod - def is_zarr_group(cls, obj): - return isinstance(obj, zarr.hierarchy.Group) - - @classmethod - def create_array(cls, zarr_root, *args, **kwargs): - return zarr_root.create_dataset(*args, **kwargs) - - @classmethod - def change_dtype_datetime64(cls, dtype): - return dtype - - @classmethod - def cast_dtype_datetime64(cls, array, dtype): - return array, dtype - - @classmethod - def get_not_found_exception(cls): - return zarr.errors.PathNotFoundError - - @classmethod - def zarr_open_mode_append(cls): - return "w+" - - @classmethod - def zarr_open_to_patch_in_tests(cls): - return "zarr.convenience.open" - - @classmethod - def zarr_open(cls, *args, **kwargs): - return zarr.convenience.open(*args, **kwargs) - - @classmethod - def get_read_only_store_class(cls): - class ReadOnlyStore(zarr.storage.BaseStore): - """A base class for read-only stores.""" - - def __delitem__(self, key: str) -> None: - """Prevent deletion of items.""" - raise NotImplementedError() - - def __setitem__(self, key: str, value: bytes) -> None: - """Prevent setting of items.""" - raise NotImplementedError() - - def __len__(self) -> int: - """Return the number of items in the store.""" - raise NotImplementedError() - - def __iter__(self) -> iter: - """Return an iterator over the store.""" - raise NotImplementedError() - - return ReadOnlyStore - - @classmethod - def raise_if_not_supported(cls, msg): - pass - - -class Zarr3: - @classmethod - def base_store(cls): - return zarr.abc.store.Store - - @classmethod - def is_zarr_group(cls, obj): - return isinstance(obj, zarr.Group) - - @classmethod - def create_array(cls, zarr_root, *args, **kwargs): - if "compressor" in kwargs and kwargs["compressor"] is None: - # compressor is deprecated, use compressors instead - kwargs.pop("compressor") - kwargs["compressors"] = () - return zarr_root.create_array(*args, **kwargs) - - @classmethod - def get_not_found_exception(cls): - return FileNotFoundError - - @classmethod - def zarr_open_mode_append(cls): - return "a" - - @classmethod - def change_dtype_datetime64(cls, dtype): - # remove this flag (and the relevant code) when Zarr 3 supports datetime64 - # https://github.com/zarr-developers/zarr-python/issues/2616 - import numpy as np - - if dtype == "datetime64[s]": - dtype = np.dtype("int64") - return dtype - - @classmethod - def cast_dtype_datetime64(cls, array, dtype): - # remove this flag (and the relevant code) when Zarr 3 supports datetime64 - # https://github.com/zarr-developers/zarr-python/issues/2616 - import numpy as np - - if dtype == np.dtype("datetime64[s]"): - dtype = "int64" - array = array.astype(dtype) - - return array, dtype - - @classmethod - def zarr_open_to_patch_in_tests(cls): - return "zarr.open" - - @classmethod - def zarr_open(cls, *args, **kwargs): - return zarr.open(*args, **kwargs) - - @classmethod - def get_read_only_store_class(cls): - class ReadOnlyStore(zarr.abc.store.Store): - def __init__(self, *args, **kwargs): - raise NotImplementedError("Zarr 3 is not for this kind of store : {}".format(args)) - - return ReadOnlyStore - - @classmethod - def raise_if_not_supported(cls, msg="Zarr 3 is not supported in this context"): - raise NotImplementedError(msg) - - -if zarr.__version__.startswith("3"): - Zarr2AndZarr3 = Zarr3 -else: - LOG.warning("Using Zarr 2 : only zarr datasets build with zarr 2 are supported") - Zarr2AndZarr3 = Zarr2 diff --git a/src/anemoi/datasets/commands/copy.py b/src/anemoi/datasets/commands/copy.py index 4ee111306..da2076ffa 100644 --- a/src/anemoi/datasets/commands/copy.py +++ b/src/anemoi/datasets/commands/copy.py @@ -20,8 +20,8 @@ from anemoi.utils.remote import Transfer from anemoi.utils.remote import TransferMethodNotImplementedError -from anemoi.datasets import Zarr2AndZarr3 from anemoi.datasets.check import check_zarr +from anemoi.datasets.zarr_versions import zarr_2_or_3 from . import Command @@ -213,7 +213,7 @@ def copy_data(self, source: Any, target: Any, _copy: Any, verbosity: int) -> Non target_data = ( target["data"] if "data" in target - else Zarr2AndZarr3.create_array( + else zarr_2_or_3.create_array( target, "data", shape=source_data.shape, @@ -319,7 +319,7 @@ def copy_group(self, source: Any, target: Any, _copy: Any, verbosity: int) -> No LOG.info(f"Skipping {name}") continue - if Zarr2AndZarr3.is_zarr_group(source[name]): + if zarr_2_or_3.is_zarr_group(source[name]): group = target[name] if name in target else target.create_group(name) self.copy_group( source[name], @@ -413,7 +413,7 @@ def open_target() -> Any: sys.exit(0) LOG.error("Target already exists, resuming copy.") - return zarr.open(self.target, mode=Zarr2AndZarr3.zarr_open_mode_append()) + return zarr.open(self.target, mode=zarr_2_or_3.open_mode_append) LOG.error("Target already exists, use either --overwrite or --resume.") sys.exit(1) diff --git a/src/anemoi/datasets/create/__init__.py b/src/anemoi/datasets/create/__init__.py index 1a7f36775..d4e7113ce 100644 --- a/src/anemoi/datasets/create/__init__.py +++ b/src/anemoi/datasets/create/__init__.py @@ -32,13 +32,13 @@ from earthkit.data.core.order import build_remapping from anemoi.datasets import MissingDateError -from anemoi.datasets import Zarr2AndZarr3 from anemoi.datasets import open_dataset from anemoi.datasets.create.input.trace import enable_trace from anemoi.datasets.create.persistent import build_storage from anemoi.datasets.data.misc import as_first_date from anemoi.datasets.data.misc import as_last_date from anemoi.datasets.dates.groups import Groups +from anemoi.datasets.zarr_versions import zarr_2_or_3 from .check import DatasetName from .check import check_data_values @@ -157,7 +157,7 @@ def _path_readable(path: str) -> bool: try: zarr.open(path, "r") return True - except Zarr2AndZarr3.get_not_found_exception(): + except zarr_2_or_3.FileNotFoundException: return False @@ -174,11 +174,9 @@ def __init__(self, path: str): """ self.path = path - if Zarr2AndZarr3.version == "3" and not os.environ.get("ANEMOI_DATASETS_ALLOW_BUILDING_ZARR3_DATASETS"): + if zarr_2_or_3.version != 2: raise ValueError( - "zarr 3 is installed. anemoi-datasets supports zarr 3, but the datasets build with zarr 3 will " - "not be readable by zarr 2. It is likely that you do not want to create a dataset with zarr 3. " - "Please uninstall zarr 3 and install zarr 2." + f"Only zarr version 2 is supported when creating datasets, found version: {zarr.__version__}" ) _, ext = os.path.splitext(self.path) @@ -218,7 +216,7 @@ def update_metadata(self, **kwargs: Any) -> None: import zarr LOG.debug(f"Updating metadata {kwargs}") - z = zarr.open(self.path, mode=Zarr2AndZarr3.zarr_open_mode_append()) + z = zarr.open(self.path, mode=zarr_2_or_3.open_mode_append) for k, v in kwargs.items(): if isinstance(v, np.datetime64): v = v.astype(datetime.datetime) diff --git a/src/anemoi/datasets/data/misc.py b/src/anemoi/datasets/data/misc.py index a1dd1c02c..54b6ba382 100644 --- a/src/anemoi/datasets/data/misc.py +++ b/src/anemoi/datasets/data/misc.py @@ -27,7 +27,7 @@ from anemoi.utils.config import load_config as load_settings from numpy.typing import NDArray -from anemoi.datasets import Zarr2AndZarr3 +from anemoi.datasets.zarr_versions import zarr_2_or_3 if TYPE_CHECKING: from .dataset import Dataset @@ -373,7 +373,7 @@ def _open(a: Union[str, PurePath, Dict[str, Any], List[Any], Tuple[Any, ...]]) - if isinstance(a, Dataset): return a.mutate() - if Zarr2AndZarr3.is_zarr_group(a): + if isinstance(a, zarr_2_or_3.Group): return Zarr(a).mutate() if isinstance(a, str): diff --git a/src/anemoi/datasets/data/stores.py b/src/anemoi/datasets/data/stores.py index c2f1e9f0e..f49d199f8 100644 --- a/src/anemoi/datasets/data/stores.py +++ b/src/anemoi/datasets/data/stores.py @@ -12,7 +12,6 @@ import logging import os import tempfile -import warnings from functools import cached_property from typing import Any from typing import Dict @@ -27,8 +26,7 @@ from anemoi.utils.dates import frequency_to_timedelta from numpy.typing import NDArray -from anemoi.datasets import Zarr2AndZarr3 - +from ..zarr_versions import zarr_2_or_3 from . import MissingDateError from .dataset import Dataset from .dataset import FullIndex @@ -44,134 +42,12 @@ LOG = logging.getLogger(__name__) -ReadOnlyStore = Zarr2AndZarr3.get_read_only_store_class() - - -class HTTPStore(ReadOnlyStore): - """A read-only store for HTTP(S) resources.""" - - def __init__(self, url: str) -> None: - """Initialize the HTTPStore with a URL.""" - self.url = url - - def __getitem__(self, key: str) -> bytes: - """Retrieve an item from the store.""" - import requests - - r = requests.get(self.url + "/" + key) - - if r.status_code == 404: - raise KeyError(key) - - r.raise_for_status() - return r.content - - -class S3Store(ReadOnlyStore): - """A read-only store for S3 resources.""" - - """We write our own S3Store because the one used by zarr (s3fs) - does not play well with fork(). We also get to control the s3 client - options using the anemoi configs. - """ - - def __init__(self, url: str, region: Optional[str] = None) -> None: - """Initialize the S3Store with a URL and optional region.""" - from anemoi.utils.remote.s3 import s3_client - - _, _, self.bucket, self.key = url.split("/", 3) - self.s3 = s3_client(self.bucket, region=region) - - def __getitem__(self, key: str) -> bytes: - """Retrieve an item from the store.""" - try: - response = self.s3.get_object(Bucket=self.bucket, Key=self.key + "/" + key) - except self.s3.exceptions.NoSuchKey: - raise KeyError(key) - - return response["Body"].read() - - -class PlanetaryComputerStore(ReadOnlyStore): - """We write our own Store to access catalogs on Planetary Computer, - as it requires some extra arguments to use xr.open_zarr. - """ - - def __init__(self, data_catalog_id: str) -> None: - """Initialize the PlanetaryComputerStore with a data catalog ID. - - Parameters - ---------- - data_catalog_id : str - The data catalog ID. - """ - self.data_catalog_id = data_catalog_id - - import planetary_computer - import pystac_client - - catalog = pystac_client.Client.open( - "https://planetarycomputer.microsoft.com/api/stac/v1/", - modifier=planetary_computer.sign_inplace, - ) - collection = catalog.get_collection(self.data_catalog_id) - - asset = collection.assets["zarr-abfs"] - - if "xarray:storage_options" in asset.extra_fields: - store = { - "store": asset.href, - "storage_options": asset.extra_fields["xarray:storage_options"], - **asset.extra_fields["xarray:open_kwargs"], - } - else: - store = { - "filename_or_obj": asset.href, - **asset.extra_fields["xarray:open_kwargs"], - } - - self.store = store - - def __getitem__(self, key: str) -> bytes: - """Retrieve an item from the store.""" - raise NotImplementedError() - - -class DebugStore(ReadOnlyStore): - """A store to debug the zarr loading.""" - - def __init__(self, store: Any) -> None: - """Initialize the DebugStore with another store.""" - assert not isinstance(store, DebugStore) - self.store = store - - def __getitem__(self, key: str) -> bytes: - """Retrieve an item from the store and print debug information.""" - # print() - print("GET", key, self) - # traceback.print_stack(file=sys.stdout) - return self.store[key] - - def __len__(self) -> int: - """Return the number of items in the store.""" - return len(self.store) - - def __iter__(self) -> iter: - """Return an iterator over the store.""" - warnings.warn("DebugStore: iterating over the store") - return iter(self.store) - - def __contains__(self, key: str) -> bool: - """Check if the store contains a key.""" - return key in self.store - - def name_to_zarr_store(path_or_url: str) -> Any: """Convert a path or URL to a zarr store.""" store = path_or_url if store.startswith("s3://"): - return S3Store(store) + return zarr_2_or_3.S3Store(store) if store.startswith("http://") or store.startswith("https://"): @@ -198,12 +74,12 @@ def name_to_zarr_store(path_or_url: str) -> Any: bits = parsed.netloc.split(".") if len(bits) == 5 and (bits[1], bits[3], bits[4]) == ("s3", "amazonaws", "com"): s3_url = f"s3://{bits[0]}{parsed.path}" - store = S3Store(s3_url, region=bits[2]) + store = zarr_2_or_3.S3Store(s3_url, region=bits[2]) elif store.startswith("https://planetarycomputer.microsoft.com/"): data_catalog_id = store.rsplit("/", 1)[-1] - store = PlanetaryComputerStore(data_catalog_id).store + store = zarr_2_or_3.PlanetaryComputerStore(data_catalog_id).store else: - store = HTTPStore(store) + store = zarr_2_or_3.HTTPStore(store) return store @@ -222,14 +98,15 @@ def open_zarr(path: str, dont_fail: bool = False, cache: int = None) -> Any: "DEBUG_ZARR_LOADING is only implemented for DirectoryStore. " "Please disable it for other backends." ) - store = zarr.storage.DirectoryStore(store) - store = DebugStore(store) + store = zarr_2_or_3.DirectoryStore(store) + store = zarr_2_or_3.DebugStore(store) if cache is not None: - store = zarr.LRUStoreCache(store, max_size=cache) + store = zarr_2_or_3.LRUStoreCache(store, max_size=cache) + + return zarr.open(store, mode="r") - return Zarr2AndZarr3.zarr_open(store, "r") - except Zarr2AndZarr3.get_not_found_exception(): + except zarr_2_or_3.FileNotFoundException: if not dont_fail: raise FileNotFoundError(f"Zarr store not found: {path}") @@ -239,7 +116,7 @@ class Zarr(Dataset): def __init__(self, path: Union[str, Any]) -> None: """Initialize the Zarr dataset with a path or zarr group.""" - if Zarr2AndZarr3.is_zarr_group(path): + if isinstance(path, zarr_2_or_3.Group): self.was_zarr = True self.path = str(id(path)) self.z = path @@ -593,7 +470,7 @@ def zarr_lookup(name: str, fail: bool = True) -> Optional[str]: LOG.info("Opening `%s` as `%s`", name, full) QUIET.add(name) return full - except Zarr2AndZarr3.get_not_found_exception(): + except zarr_2_or_3.FileNotFoundException: pass if fail: diff --git a/src/anemoi/datasets/zarr_versions/__init__.py b/src/anemoi/datasets/zarr_versions/__init__.py new file mode 100644 index 000000000..39fd71d91 --- /dev/null +++ b/src/anemoi/datasets/zarr_versions/__init__.py @@ -0,0 +1,23 @@ +# (C) Copyright 2025 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +import zarr + +version = zarr.__version__.split(".")[0] + +if version == "2": + from . import zarr2 as zarr_2_or_3 + +elif version == "3": + from . import zarr3 as zarr_2_or_3 +else: + raise ImportError(f"Unsupported Zarr version: {zarr.__version__}. Supported versions are 2 and 3.") + +__all__ = ["zarr_2_or_3"] diff --git a/src/anemoi/datasets/zarr_versions/zarr2.py b/src/anemoi/datasets/zarr_versions/zarr2.py new file mode 100644 index 000000000..87d2a8a83 --- /dev/null +++ b/src/anemoi/datasets/zarr_versions/zarr2.py @@ -0,0 +1,242 @@ +# (C) Copyright 2025 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +import logging +import warnings +from typing import Any +from typing import Optional + +import zarr + +from ..zarr_versions import zarr_2_or_3 + +LOG = logging.getLogger(__name__) + + +version = 2 + +FileNotFoundException = zarr.errors.PathNotFoundError +Group = zarr.hierarchy.Group +open_mode_append = "w+" + + +class Zarr2: + @classmethod + def base_store(cls): + return zarr.storage.BaseStore + + @classmethod + def is_zarr_group(cls, obj): + return isinstance(obj, zarr.hierarchy.Group) + + @classmethod + def create_array(cls, zarr_root, *args, **kwargs): + return zarr_root.create_dataset(*args, **kwargs) + + @classmethod + def change_dtype_datetime64(cls, dtype): + return dtype + + @classmethod + def cast_dtype_datetime64(cls, array, dtype): + return array, dtype + + @classmethod + def get_not_found_exception(cls): + return zarr.errors.PathNotFoundError + + @classmethod + def zarr_open_mode_append(cls): + return "w+" + + @classmethod + def zarr_open_to_patch_in_tests(cls): + return "zarr.convenience.open" + + @classmethod + def zarr_open(cls, *args, **kwargs): + return zarr.convenience.open(*args, **kwargs) + + @classmethod + def get_read_only_store_class(cls): + + return ReadOnlyStore + + @classmethod + def raise_if_not_supported(cls, msg): + pass + + +class ReadOnlyStore(zarr.storage.BaseStore): + """A base class for read-only stores.""" + + def __delitem__(self, key: str) -> None: + """Prevent deletion of items.""" + raise NotImplementedError() + + def __setitem__(self, key: str, value: bytes) -> None: + """Prevent setting of items.""" + raise NotImplementedError() + + def __len__(self) -> int: + """Return the number of items in the store.""" + raise NotImplementedError() + + def __iter__(self) -> iter: + """Return an iterator over the store.""" + raise NotImplementedError() + + +class S3Store(ReadOnlyStore): + """A read-only store for S3 resources.""" + + """We write our own S3Store because the one used by zarr (s3fs) + does not play well with fork(). We also get to control the s3 client + options using the anemoi configs. + """ + + def __init__(self, url: str, region: Optional[str] = None) -> None: + """Initialize the S3Store with a URL and optional region.""" + from anemoi.utils.remote.s3 import s3_client + + super().__init__() + + _, _, self.bucket, self.key = url.split("/", 3) + self.s3 = s3_client(self.bucket, region=region) + + # Version 2 + def __getitem__(self, key: str) -> bytes: + """Retrieve an item from the store.""" + try: + response = self.s3.get_object(Bucket=self.bucket, Key=self.key + "/" + key) + except self.s3.exceptions.NoSuchKey: + raise KeyError(key) + + return response["Body"].read() + + +class HTTPStore(zarr_2_or_3.ReadOnlyStore): + """A read-only store for HTTP(S) resources.""" + + def __init__(self, url: str) -> None: + """Initialize the HTTPStore with a URL.""" + super().__init__() + self.url = url + + def __getitem__(self, key: str) -> bytes: + """Retrieve an item from the store.""" + import requests + + r = requests.get(self.url + "/" + key) + + if r.status_code == 404: + raise KeyError(key) + + r.raise_for_status() + return r.content + + # Version 3 + async def get(self, key: str, prototype, byte_range=None): + """Retrieve an item from the store.""" + assert byte_range is None, "S3Store does not support byte ranges" + try: + response = self.s3.get_object(Bucket=self.bucket, Key=self.key + "/" + key) + except self.s3.exceptions.NoSuchKey: + return None + + return prototype.buffer.from_bytes(response["Body"].read()) + + @property + def supports_listing(self): + return True + + async def list_dir(self, prefix: str): + from anemoi.utils.remote.s3 import list_folder + + path = "s3://" + self.bucket + "/" + self.key + "/" + prefix + print(path) + for n in list_folder(path): + print("------------", n) + yield n + # return [x[len(self.key) + 1:] for x in result if x.startswith(self.key + "/")] + + +class PlanetaryComputerStore(zarr_2_or_3.ReadOnlyStore): + """We write our own Store to access catalogs on Planetary Computer, + as it requires some extra arguments to use xr.open_zarr. + """ + + def __init__(self, data_catalog_id: str) -> None: + """Initialize the PlanetaryComputerStore with a data catalog ID. + + Parameters + ---------- + data_catalog_id : str + The data catalog ID. + """ + super().__init__() + self.data_catalog_id = data_catalog_id + + import planetary_computer + import pystac_client + + catalog = pystac_client.Client.open( + "https://planetarycomputer.microsoft.com/api/stac/v1/", + modifier=planetary_computer.sign_inplace, + ) + collection = catalog.get_collection(self.data_catalog_id) + + asset = collection.assets["zarr-abfs"] + + if "xarray:storage_options" in asset.extra_fields: + store = { + "store": asset.href, + "storage_options": asset.extra_fields["xarray:storage_options"], + **asset.extra_fields["xarray:open_kwargs"], + } + else: + store = { + "filename_or_obj": asset.href, + **asset.extra_fields["xarray:open_kwargs"], + } + + self.store = store + + def __getitem__(self, key: str) -> bytes: + """Retrieve an item from the store.""" + raise NotImplementedError() + + +class DebugStore(zarr_2_or_3.ReadOnlyStore): + """A store to debug the zarr loading.""" + + def __init__(self, store: Any) -> None: + super().__init__() + """Initialize the DebugStore with another store.""" + assert not isinstance(store, DebugStore) + self.store = store + + def __getitem__(self, key: str) -> bytes: + """Retrieve an item from the store and print debug information.""" + # print() + print("GET", key, self) + # traceback.print_stack(file=sys.stdout) + return self.store[key] + + def __len__(self) -> int: + """Return the number of items in the store.""" + return len(self.store) + + def __iter__(self) -> iter: + """Return an iterator over the store.""" + warnings.warn("DebugStore: iterating over the store") + return iter(self.store) + + def __contains__(self, key: str) -> bool: + """Check if the store contains a key.""" + return key in self.store diff --git a/src/anemoi/datasets/zarr_versions/zarr3.py b/src/anemoi/datasets/zarr_versions/zarr3.py new file mode 100644 index 000000000..2d327044e --- /dev/null +++ b/src/anemoi/datasets/zarr_versions/zarr3.py @@ -0,0 +1,60 @@ +# (C) Copyright 2025 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +import logging + +import zarr + +LOG = logging.getLogger(__name__) + +version = 3 +FileNotFoundException = FileNotFoundError +Group = zarr.Group +open_mode_append = "a" + + +class S3Store(zarr.storage.ObjectStore): + """We use our class to manage per bucket credentials""" + + def __init__(self, url): + + import boto3 + from anemoi.utils.remote.s3 import s3_options + from obstore.auth.boto3 import Boto3CredentialProvider + from obstore.store import from_url + + options = s3_options(url) + + credential_provider = Boto3CredentialProvider( + session=boto3.session.Session( + aws_access_key_id=options["aws_access_key_id"], + aws_secret_access_key=options["aws_secret_access_key"], + ), + ) + + objectstore = from_url( + url, + credential_provider=credential_provider, + endpoint=options["endpoint_url"], + ) + + super().__init__(objectstore, read_only=True) + + +class HTTPStore(zarr.storage.ObjectStore): + + def __init__(self, url): + + from obstore.store import from_url + + objectstore = from_url(url) + + super().__init__(objectstore, read_only=True) + + +DebugStore = zarr.storage.LoggingStore From 5c870dcc52163a196ac54c813070dcbcfabac65d Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Wed, 25 Jun 2025 10:57:12 +0000 Subject: [PATCH 05/16] update --- src/anemoi/datasets/zarr_versions/zarr2.py | 33 ++-------------------- 1 file changed, 3 insertions(+), 30 deletions(-) diff --git a/src/anemoi/datasets/zarr_versions/zarr2.py b/src/anemoi/datasets/zarr_versions/zarr2.py index 87d2a8a83..0dd2028a1 100644 --- a/src/anemoi/datasets/zarr_versions/zarr2.py +++ b/src/anemoi/datasets/zarr_versions/zarr2.py @@ -13,8 +13,6 @@ import zarr -from ..zarr_versions import zarr_2_or_3 - LOG = logging.getLogger(__name__) @@ -120,7 +118,7 @@ def __getitem__(self, key: str) -> bytes: return response["Body"].read() -class HTTPStore(zarr_2_or_3.ReadOnlyStore): +class HTTPStore(ReadOnlyStore): """A read-only store for HTTP(S) resources.""" def __init__(self, url: str) -> None: @@ -140,33 +138,8 @@ def __getitem__(self, key: str) -> bytes: r.raise_for_status() return r.content - # Version 3 - async def get(self, key: str, prototype, byte_range=None): - """Retrieve an item from the store.""" - assert byte_range is None, "S3Store does not support byte ranges" - try: - response = self.s3.get_object(Bucket=self.bucket, Key=self.key + "/" + key) - except self.s3.exceptions.NoSuchKey: - return None - - return prototype.buffer.from_bytes(response["Body"].read()) - - @property - def supports_listing(self): - return True - - async def list_dir(self, prefix: str): - from anemoi.utils.remote.s3 import list_folder - - path = "s3://" + self.bucket + "/" + self.key + "/" + prefix - print(path) - for n in list_folder(path): - print("------------", n) - yield n - # return [x[len(self.key) + 1:] for x in result if x.startswith(self.key + "/")] - -class PlanetaryComputerStore(zarr_2_or_3.ReadOnlyStore): +class PlanetaryComputerStore(ReadOnlyStore): """We write our own Store to access catalogs on Planetary Computer, as it requires some extra arguments to use xr.open_zarr. """ @@ -212,7 +185,7 @@ def __getitem__(self, key: str) -> bytes: raise NotImplementedError() -class DebugStore(zarr_2_or_3.ReadOnlyStore): +class DebugStore(ReadOnlyStore): """A store to debug the zarr loading.""" def __init__(self, store: Any) -> None: From f6c2683dff5b455f83b64ccbdea468b363146dd2 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Wed, 25 Jun 2025 11:30:40 +0000 Subject: [PATCH 06/16] update --- src/anemoi/datasets/create/patch.py | 4 +- src/anemoi/datasets/create/zarr.py | 10 ++--- src/anemoi/datasets/zarr_versions/zarr2.py | 51 ++-------------------- src/anemoi/datasets/zarr_versions/zarr3.py | 4 ++ tests/test_data.py | 24 +++++----- tests/test_data_gridded.py | 2 +- 6 files changed, 28 insertions(+), 67 deletions(-) diff --git a/src/anemoi/datasets/create/patch.py b/src/anemoi/datasets/create/patch.py index 7cd9e0ff4..3a96d5324 100755 --- a/src/anemoi/datasets/create/patch.py +++ b/src/anemoi/datasets/create/patch.py @@ -14,7 +14,7 @@ import zarr -from anemoi.datasets import Zarr2AndZarr3 +from anemoi.datasets.zarr_versions import zarr_2_or_3 LOG = logging.getLogger(__name__) @@ -136,7 +136,7 @@ def apply_patch(path: str, verbose: bool = True, dry_run: bool = False) -> None: try: attrs = zarr.open(path, mode="r").attrs.asdict() - except Zarr2AndZarr3.get_not_found_exception() as e: + except zarr_2_or_3.get_not_found_exception() as e: LOG.error(f"Failed to open {path}") LOG.error(e) exit(0) diff --git a/src/anemoi/datasets/create/zarr.py b/src/anemoi/datasets/create/zarr.py index f6c9c966d..c48b67c9f 100644 --- a/src/anemoi/datasets/create/zarr.py +++ b/src/anemoi/datasets/create/zarr.py @@ -15,7 +15,7 @@ import zarr from numpy.typing import NDArray -from anemoi.datasets import Zarr2AndZarr3 +from anemoi.datasets.zarr_versions import zarr_2_or_3 from .synchronise import NoSynchroniser from .synchronise import Synchroniser @@ -73,10 +73,10 @@ def add_zarr_dataset( shape = array.shape if array is not None: - array, dtype = Zarr2AndZarr3.cast_dtype_datetime64(array, dtype) + array, dtype = zarr_2_or_3.cast_dtype_datetime64(array, dtype) assert array.shape == shape, (array.shape, shape) - a = Zarr2AndZarr3.create_array( + a = zarr_2_or_3.create_array( zarr_root, name, shape=shape, @@ -104,8 +104,8 @@ def add_zarr_dataset( else: raise ValueError(f"No fill_value for dtype={dtype}") - dtype = Zarr2AndZarr3.change_dtype_datetime64(dtype) - a = Zarr2AndZarr3.create_array( + dtype = zarr_2_or_3.change_dtype_datetime64(dtype) + a = zarr_2_or_3.create_array( zarr_root, name, shape=shape, diff --git a/src/anemoi/datasets/zarr_versions/zarr2.py b/src/anemoi/datasets/zarr_versions/zarr2.py index 0dd2028a1..c22d0045f 100644 --- a/src/anemoi/datasets/zarr_versions/zarr2.py +++ b/src/anemoi/datasets/zarr_versions/zarr2.py @@ -23,53 +23,6 @@ open_mode_append = "w+" -class Zarr2: - @classmethod - def base_store(cls): - return zarr.storage.BaseStore - - @classmethod - def is_zarr_group(cls, obj): - return isinstance(obj, zarr.hierarchy.Group) - - @classmethod - def create_array(cls, zarr_root, *args, **kwargs): - return zarr_root.create_dataset(*args, **kwargs) - - @classmethod - def change_dtype_datetime64(cls, dtype): - return dtype - - @classmethod - def cast_dtype_datetime64(cls, array, dtype): - return array, dtype - - @classmethod - def get_not_found_exception(cls): - return zarr.errors.PathNotFoundError - - @classmethod - def zarr_open_mode_append(cls): - return "w+" - - @classmethod - def zarr_open_to_patch_in_tests(cls): - return "zarr.convenience.open" - - @classmethod - def zarr_open(cls, *args, **kwargs): - return zarr.convenience.open(*args, **kwargs) - - @classmethod - def get_read_only_store_class(cls): - - return ReadOnlyStore - - @classmethod - def raise_if_not_supported(cls, msg): - pass - - class ReadOnlyStore(zarr.storage.BaseStore): """A base class for read-only stores.""" @@ -213,3 +166,7 @@ def __iter__(self) -> iter: def __contains__(self, key: str) -> bool: """Check if the store contains a key.""" return key in self.store + + +def change_dtype_datetime64(a): + return a diff --git a/src/anemoi/datasets/zarr_versions/zarr3.py b/src/anemoi/datasets/zarr_versions/zarr3.py index 2d327044e..cd2c7f76f 100644 --- a/src/anemoi/datasets/zarr_versions/zarr3.py +++ b/src/anemoi/datasets/zarr_versions/zarr3.py @@ -58,3 +58,7 @@ def __init__(self, url): DebugStore = zarr.storage.LoggingStore + + +def change_dtype_datetime64(a): + return a diff --git a/tests/test_data.py b/tests/test_data.py index 99a2b4685..72b57dd60 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -24,7 +24,6 @@ from anemoi.utils.dates import frequency_to_string from anemoi.utils.dates import frequency_to_timedelta -from anemoi.datasets import Zarr2AndZarr3 from anemoi.datasets import open_dataset from anemoi.datasets.data.concat import Concat from anemoi.datasets.data.ensemble import Ensemble @@ -38,6 +37,7 @@ from anemoi.datasets.data.statistics import Statistics from anemoi.datasets.data.stores import Zarr from anemoi.datasets.data.subset import Subset +from anemoi.datasets.zarr_versions import zarr_2_or_3 VALUES = 10 @@ -58,7 +58,7 @@ def mockup_open_zarr(func: Callable) -> Callable: @wraps(func) def wrapper(*args, **kwargs): - with patch(Zarr2AndZarr3.zarr_open_to_patch_in_tests(), zarr_from_str): + with patch("zarr.open", zarr_from_str): with patch("anemoi.datasets.data.stores.zarr_lookup", lambda name: name): return func(*args, **kwargs) @@ -105,7 +105,7 @@ def create_zarr( ensemble: Optional[int] = None, grids: Optional[int] = None, missing: bool = False, -) -> zarr.Group: +) -> zarr_2_or_3.Group: """Create a Zarr dataset. Parameters @@ -155,7 +155,7 @@ def create_zarr( for e in range(ensembles): data[i, j, e] = _(date.astype(object), var, k, e, values) - Zarr2AndZarr3.create_array( + zarr_2_or_3.create_array( root, "data", dtype=data.dtype, @@ -164,9 +164,9 @@ def create_zarr( shape=data.shape, )[...] = data - dates, dtype_ = Zarr2AndZarr3.cast_dtype_datetime64(dates, dates.dtype) + dates, dtype_ = zarr_2_or_3.cast_dtype_datetime64(dates, dates.dtype) del dtype_ - Zarr2AndZarr3.create_array( + zarr_2_or_3.create_array( root, "dates", compressor=None, @@ -175,7 +175,7 @@ def create_zarr( )[...] = dates latitudes = np.array([x + values for x in range(values)]) - Zarr2AndZarr3.create_array( + zarr_2_or_3.create_array( root, "latitudes", compressor=None, @@ -184,7 +184,7 @@ def create_zarr( )[...] = latitudes longitudes = np.array([x + values for x in range(values)]) - Zarr2AndZarr3.create_array( + zarr_2_or_3.create_array( root, "longitudes", compressor=None, @@ -211,7 +211,7 @@ def create_zarr( root.attrs["missing_dates"] = [d.isoformat() for d in missing_dates] - Zarr2AndZarr3.create_array( + zarr_2_or_3.create_array( root, "mean", compressor=None, @@ -220,7 +220,7 @@ def create_zarr( )[ ... ] = np.mean(data, axis=0) - Zarr2AndZarr3.create_array( + zarr_2_or_3.create_array( root, "stdev", compressor=None, @@ -229,7 +229,7 @@ def create_zarr( )[ ... ] = np.std(data, axis=0) - Zarr2AndZarr3.create_array( + zarr_2_or_3.create_array( root, "maximum", compressor=None, @@ -238,7 +238,7 @@ def create_zarr( )[ ... ] = np.max(data, axis=0) - Zarr2AndZarr3.create_array( + zarr_2_or_3.create_array( root, "minimum", compressor=None, diff --git a/tests/test_data_gridded.py b/tests/test_data_gridded.py index 5e4738980..e56c3f117 100644 --- a/tests/test_data_gridded.py +++ b/tests/test_data_gridded.py @@ -44,7 +44,7 @@ def mockup_open_zarr(func: Callable) -> Callable: @wraps(func) def wrapper(*args, **kwargs): - with patch("zarr.convenience.open", zarr_from_str): + with patch("zarr.open", zarr_from_str): with patch("anemoi.datasets.data.stores.zarr_lookup", lambda name: name): return func(*args, **kwargs) From a7a52e28353fb302e05b20d8eda0eff1d22ea946 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Wed, 25 Jun 2025 11:35:57 +0000 Subject: [PATCH 07/16] update --- src/anemoi/datasets/zarr_versions/zarr2.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/anemoi/datasets/zarr_versions/zarr2.py b/src/anemoi/datasets/zarr_versions/zarr2.py index c22d0045f..7ae1e8c11 100644 --- a/src/anemoi/datasets/zarr_versions/zarr2.py +++ b/src/anemoi/datasets/zarr_versions/zarr2.py @@ -168,5 +168,13 @@ def __contains__(self, key: str) -> bool: return key in self.store -def change_dtype_datetime64(a): - return a +def create_array(zarr_root, *args, **kwargs): + return zarr_root.create_dataset(*args, **kwargs) + + +def change_dtype_datetime64(dtype): + return dtype + + +def cast_dtype_datetime64(array, dtype): + return array, dtype From 5ed4442be9e47ca876c15d58631ddc27b2dc7f59 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Wed, 25 Jun 2025 11:38:50 +0000 Subject: [PATCH 08/16] update --- src/anemoi/datasets/zarr_versions/zarr3.py | 30 ++++++++++++++++++++-- 1 file changed, 28 insertions(+), 2 deletions(-) diff --git a/src/anemoi/datasets/zarr_versions/zarr3.py b/src/anemoi/datasets/zarr_versions/zarr3.py index cd2c7f76f..1181f0cdc 100644 --- a/src/anemoi/datasets/zarr_versions/zarr3.py +++ b/src/anemoi/datasets/zarr_versions/zarr3.py @@ -60,5 +60,31 @@ def __init__(self, url): DebugStore = zarr.storage.LoggingStore -def change_dtype_datetime64(a): - return a +def create_array(zarr_root, *args, **kwargs): + if "compressor" in kwargs and kwargs["compressor"] is None: + # compressor is deprecated, use compressors instead + kwargs.pop("compressor") + kwargs["compressors"] = () + return zarr_root.create_array(*args, **kwargs) + + +def change_dtype_datetime64(dtype): + # remove this flag (and the relevant code) when Zarr 3 supports datetime64 + # https://github.com/zarr-developers/zarr-python/issues/2616 + import numpy as np + + if dtype == "datetime64[s]": + dtype = np.dtype("int64") + return dtype + + +def cast_dtype_datetime64(array, dtype): + # remove this flag (and the relevant code) when Zarr 3 supports datetime64 + # https://github.com/zarr-developers/zarr-python/issues/2616 + import numpy as np + + if dtype == np.dtype("datetime64[s]"): + dtype = "int64" + array = array.astype(dtype) + + return array, dtype From 68cccb3d21348b18b0bd5c3582ba054a3d27bce6 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Thu, 26 Jun 2025 09:55:01 +0000 Subject: [PATCH 09/16] add cli options --- pyproject.toml | 7 +- src/anemoi/datasets/commands/create.py | 1 + src/anemoi/datasets/commands/init.py | 1 + src/anemoi/datasets/commands/inspect.py | 6 +- src/anemoi/datasets/create/__init__.py | 42 ++++++++-- .../datasets/create/{zarr.py => misc.py} | 0 src/anemoi/datasets/zarr_versions/zarr2.py | 4 + src/anemoi/datasets/zarr_versions/zarr3.py | 77 ++++++++++++++++++- tests/test_data_gridded.py | 28 ++++--- 9 files changed, 140 insertions(+), 26 deletions(-) rename src/anemoi/datasets/create/{zarr.py => misc.py} (100%) diff --git a/pyproject.toml b/pyproject.toml index 4742c42dc..edd06c053 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,11 +57,7 @@ dependencies = [ "pyyaml", "semantic-version", "tqdm", - # anemoi-datasets supports zarr 2 and zarr 3, but we still use only zarr 2: - # - we don't want to create zarr 3 datasets yet, as they will no be readable by zarr 2 - # - anemoi-inference needs zarr 2 for patching - # - anemoi-registry needs zarr 2 - "zarr<3", + "zarr", ] optional-dependencies.all = [ @@ -98,6 +94,7 @@ optional-dependencies.docs = [ optional-dependencies.remote = [ "boto3", + "obstore", "requests", ] diff --git a/src/anemoi/datasets/commands/create.py b/src/anemoi/datasets/commands/create.py index 86332cfcc..219a977cb 100644 --- a/src/anemoi/datasets/commands/create.py +++ b/src/anemoi/datasets/commands/create.py @@ -86,6 +86,7 @@ def add_arguments(self, command_parser: Any) -> None: group.add_argument("--threads", help="Use `n` parallel thread workers.", type=int, default=0) group.add_argument("--processes", help="Use `n` parallel process workers.", type=int, default=0) command_parser.add_argument("--trace", action="store_true") + command_parser.add_argument("--force-zarr3", action="store_true") def run(self, args: Any) -> None: """Execute the create command. diff --git a/src/anemoi/datasets/commands/init.py b/src/anemoi/datasets/commands/init.py index 0ca540b86..73152facb 100644 --- a/src/anemoi/datasets/commands/init.py +++ b/src/anemoi/datasets/commands/init.py @@ -63,6 +63,7 @@ def add_arguments(self, subparser: Any) -> None: subparser.add_argument("--cache", help="Location to store the downloaded data.", metavar="DIR") subparser.add_argument("--trace", action="store_true") + subparser.add_argument("--force-zarr3", action="store_true", help="Force the use of Zarr v3 format.") def run(self, args: Any) -> None: """Execute the command with the provided arguments. diff --git a/src/anemoi/datasets/commands/inspect.py b/src/anemoi/datasets/commands/inspect.py index 400cdcf98..112eb5407 100644 --- a/src/anemoi/datasets/commands/inspect.py +++ b/src/anemoi/datasets/commands/inspect.py @@ -655,7 +655,7 @@ def ready(self) -> bool: if "_build_flags" not in self.zarr: return False - build_flags = self.zarr["_build_flags"] + build_flags = self.zarr["_build_flags"][:] return all(build_flags) @property @@ -711,7 +711,7 @@ def build_flags(self) -> Optional[NDArray]: if "_build" not in self.zarr: return None build = self.zarr["_build"] - return build.get("flags") + return build.get("flags")[:] @property def build_lengths(self) -> Optional[NDArray]: @@ -719,7 +719,7 @@ def build_lengths(self) -> Optional[NDArray]: if "_build" not in self.zarr: return None build = self.zarr["_build"] - return build.get("lengths") + return build.get("lengths")[:] VERSIONS = { diff --git a/src/anemoi/datasets/create/__init__.py b/src/anemoi/datasets/create/__init__.py index d4e7113ce..5a2f7564f 100644 --- a/src/anemoi/datasets/create/__init__.py +++ b/src/anemoi/datasets/create/__init__.py @@ -174,10 +174,10 @@ def __init__(self, path: str): """ self.path = path - if zarr_2_or_3.version != 2: - raise ValueError( - f"Only zarr version 2 is supported when creating datasets, found version: {zarr.__version__}" - ) + # if zarr_2_or_3.version != 2: + # raise ValueError( + # f"Only zarr version 2 is supported when creating datasets, found version: {zarr.__version__}" + # ) _, ext = os.path.splitext(self.path) if ext != ".zarr": @@ -198,10 +198,9 @@ def add_dataset(self, mode: str = "r+", **kwargs: Any) -> zarr.Array: zarr.Array The added dataset. """ - import zarr z = zarr.open(self.path, mode=mode) - from .zarr import add_zarr_dataset + from .misc import add_zarr_dataset return add_zarr_dataset(zarr_root=z, **kwargs) @@ -451,7 +450,7 @@ def check_missing_dates(expected: list[np.datetime64]) -> None: """ import zarr - z = zarr.open(path, "r") + z = zarr.open(path, mode="r") missing_dates = z.attrs.get("missing_dates", []) missing_dates = sorted([np.datetime64(d) for d in missing_dates]) if missing_dates != expected: @@ -523,7 +522,7 @@ class HasRegistryMixin: @cached_property def registry(self) -> Any: """Get the registry.""" - from .zarr import ZarrBuiltRegistry + from .misc import ZarrBuiltRegistry return ZarrBuiltRegistry(self.path, use_threads=self.use_threads) @@ -587,6 +586,7 @@ def __init__( progress: Any = None, test: bool = False, cache: Optional[str] = None, + force_zarr3: bool = False, **kwargs: Any, ): """Initialize an Init instance. @@ -615,6 +615,32 @@ def __init__( if _path_readable(path) and not overwrite: raise Exception(f"{path} already exists. Use overwrite=True to overwrite.") + version = zarr_2_or_3.version + if not zarr_2_or_3.supports_datetime64(): + LOG.warning("⚠️" * 80) + LOG.warning(f"This version of Zarr ({zarr.__version__}) does not support datetime64.") + LOG.warning("⚠️" * 80) + + if version != 2: + + pytesting = "PYTEST_CURRENT_TEST" in os.environ + + if pytesting or force_zarr3: + LOG.warning("⚠️" * 80) + LOG.warning("Zarr version 3 is used, but this is an unsupported feature.") + LOG.warning("⚠️" * 80) + else: + LOG.warning("⚠️" * 80) + LOG.warning( + f"Only Zarr version 2 is supported when creating datasets, found version: {zarr.__version__}" + ) + LOG.warning("If you want to use Zarr version 3, please set --force-zarr3 option.") + LOG.warning("Please note that this is an unsupported feature.") + LOG.warning("⚠️" * 80) + raise ValueError( + f"Only Zarr version 2 is supported when creating datasets, found version: {zarr.__version__}" + ) + super().__init__(path, cache=cache) self.config = config self.check_name = check_name diff --git a/src/anemoi/datasets/create/zarr.py b/src/anemoi/datasets/create/misc.py similarity index 100% rename from src/anemoi/datasets/create/zarr.py rename to src/anemoi/datasets/create/misc.py diff --git a/src/anemoi/datasets/zarr_versions/zarr2.py b/src/anemoi/datasets/zarr_versions/zarr2.py index 7ae1e8c11..f6a6e1ec6 100644 --- a/src/anemoi/datasets/zarr_versions/zarr2.py +++ b/src/anemoi/datasets/zarr_versions/zarr2.py @@ -178,3 +178,7 @@ def change_dtype_datetime64(dtype): def cast_dtype_datetime64(array, dtype): return array, dtype + + +def supports_datetime64(): + return True diff --git a/src/anemoi/datasets/zarr_versions/zarr3.py b/src/anemoi/datasets/zarr_versions/zarr3.py index 1181f0cdc..ccdc631ef 100644 --- a/src/anemoi/datasets/zarr_versions/zarr3.py +++ b/src/anemoi/datasets/zarr_versions/zarr3.py @@ -60,12 +60,72 @@ def __init__(self, url): DebugStore = zarr.storage.LoggingStore +class PlanetaryComputerStore: + """We write our own Store to access catalogs on Planetary Computer, + as it requires some extra arguments to use xr.open_zarr. + """ + + def __init__(self, data_catalog_id: str) -> None: + """Initialize the PlanetaryComputerStore with a data catalog ID. + + Parameters + ---------- + data_catalog_id : str + The data catalog ID. + """ + super().__init__() + self.data_catalog_id = data_catalog_id + + import planetary_computer + import pystac_client + + catalog = pystac_client.Client.open( + "https://planetarycomputer.microsoft.com/api/stac/v1/", + modifier=planetary_computer.sign_inplace, + ) + collection = catalog.get_collection(self.data_catalog_id) + + asset = collection.assets["zarr-abfs"] + + if "xarray:storage_options" in asset.extra_fields: + store = { + "store": asset.href, + "storage_options": asset.extra_fields["xarray:storage_options"], + **asset.extra_fields["xarray:open_kwargs"], + } + else: + store = { + "filename_or_obj": asset.href, + **asset.extra_fields["xarray:open_kwargs"], + } + + self.store = store + + def create_array(zarr_root, *args, **kwargs): if "compressor" in kwargs and kwargs["compressor"] is None: # compressor is deprecated, use compressors instead kwargs.pop("compressor") kwargs["compressors"] = () - return zarr_root.create_array(*args, **kwargs) + + data = kwargs.pop("data", None) + if data is not None: + kwargs.setdefault("dtype", change_dtype_datetime64(data.dtype)) + kwargs.setdefault("shape", data.shape) + + try: + z = zarr_root.create_array(*args, **kwargs) + if data is not None: + z[:] = data + return z + except Exception: + LOG.exception("Failed to create array in Zarr store") + LOG.error( + "Failed to create array in Zarr store with args: %s, kwargs: %s", + args, + kwargs, + ) + raise def change_dtype_datetime64(dtype): @@ -88,3 +148,18 @@ def cast_dtype_datetime64(array, dtype): array = array.astype(dtype) return array, dtype + + +def supports_datetime64(): + store = zarr.storage.MemoryStore() + try: + zarr.create_array(store=store, shape=(10,), dtype="datetime64[s]") + return True + except KeyError: + # If a KeyError is raised, it means datetime64 is not supported + return False + + +if __name__ == "__main__": + print("Zarr version:", version) + print("Zarr supports datetime64:", supports_datetime64()) diff --git a/tests/test_data_gridded.py b/tests/test_data_gridded.py index e56c3f117..1a19ddf40 100644 --- a/tests/test_data_gridded.py +++ b/tests/test_data_gridded.py @@ -24,6 +24,7 @@ from anemoi.utils.dates import frequency_to_timedelta from anemoi.datasets import open_dataset +from anemoi.datasets.zarr_versions import zarr_2_or_3 VALUES = 20 @@ -144,24 +145,29 @@ def create_zarr( for e in range(ensembles): data[i, j, e] = _(date.astype(object), var, k, e, values) - root.create_dataset( + zarr_2_or_3.create_array( + root, "data", data=data, dtype=data.dtype, chunks=data.shape, compressor=None, ) - root.create_dataset( + # Store dates as ISO strings to avoid unsupported dtype in Zarr v3 + zarr_2_or_3.create_array( + root, "dates", - data=dates, + data=np.array([str(d) for d in dates], dtype="U32"), compressor=None, ) - root.create_dataset( + zarr_2_or_3.create_array( + root, "latitudes", data=np.array([x + values for x in range(values)]), compressor=None, ) - root.create_dataset( + zarr_2_or_3.create_array( + root, "longitudes", data=np.array([x + values for x in range(values)]), compressor=None, @@ -186,22 +192,26 @@ def create_zarr( root.attrs["missing_dates"] = [d.isoformat() for d in missing_dates] - root.create_dataset( + zarr_2_or_3.create_array( + root, "mean", data=np.mean(data, axis=0), compressor=None, ) - root.create_dataset( + zarr_2_or_3.create_array( + root, "stdev", data=np.std(data, axis=0), compressor=None, ) - root.create_dataset( + zarr_2_or_3.create_array( + root, "maximum", data=np.max(data, axis=0), compressor=None, ) - root.create_dataset( + zarr_2_or_3.create_array( + root, "minimum", data=np.min(data, axis=0), compressor=None, From b844c42d573ec73044cf7bdf184ccf36f7ee492b Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Thu, 26 Jun 2025 10:02:55 +0000 Subject: [PATCH 10/16] fix doc --- docs/datasets/building/sources/xarray-zarr.rst | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/docs/datasets/building/sources/xarray-zarr.rst b/docs/datasets/building/sources/xarray-zarr.rst index 2e225f863..0f9ce62c8 100644 --- a/docs/datasets/building/sources/xarray-zarr.rst +++ b/docs/datasets/building/sources/xarray-zarr.rst @@ -1,4 +1,5 @@ .. _xarray-zarr: + ############# xarray-zarr ############# @@ -18,7 +19,8 @@ it is necessary to use the :ref:`join ` operation to join separate lists containing 2D variables and 3D variables. If all vertical levels are desired, then it is acceptable to specify a single source. -Also, an ``xarray-zarr`` source uses the ``url`` keyword, and cannot be used for accessing local datasets. -For using local zarr datasets as sources, use instead :ref:`anemoi-dataset_source`. +Also, an ``xarray-zarr`` source uses the ``url`` keyword, and cannot be +used for accessing local datasets. For using local zarr datasets as +sources, use instead :ref:`anemoi-dataset_source`. See :ref:`create-cf-data` for more information. From deca4063abdeeed35aebf83fdc0edfeefd7ee81e Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Thu, 26 Jun 2025 16:42:29 +0000 Subject: [PATCH 11/16] feat: add planetary planetary source --- .../create/sources/planetary_computer.py | 44 +++++++++++++++++ .../create/sources/xarray_support/__init__.py | 28 +++-------- .../create/sources/xarray_support/field.py | 5 +- .../create/sources/xarray_support/flavour.py | 6 +-- .../create/sources/xarray_support/patch.py | 23 ++++++++- src/anemoi/datasets/data/stores.py | 3 -- tests/create/test_sources.py | 49 +++++++++++++++---- tests/xarray/test_zarr.py | 28 ----------- 8 files changed, 116 insertions(+), 70 deletions(-) create mode 100644 src/anemoi/datasets/create/sources/planetary_computer.py diff --git a/src/anemoi/datasets/create/sources/planetary_computer.py b/src/anemoi/datasets/create/sources/planetary_computer.py new file mode 100644 index 000000000..b710bcbbe --- /dev/null +++ b/src/anemoi/datasets/create/sources/planetary_computer.py @@ -0,0 +1,44 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +from . import source_registry +from .xarray import XarraySourceBase + + +@source_registry.register("planetary_computer") +class PlanetaryComputerSource(XarraySourceBase): + """An Xarray data source for the planetary_computer.""" + + emoji = "🪐" + + def __init__(self, context, data_catalog_id, version="v1", *args, **kwargs: dict): + + import planetary_computer + import pystac_client + + self.data_catalog_id = data_catalog_id + self.flavour = kwargs.pop("flavour", None) + self.patch = kwargs.pop("patch", None) + self.options = kwargs.pop("options", {}) + + catalog = pystac_client.Client.open( + f"https://planetarycomputer.microsoft.com/api/stac/{version}/", + modifier=planetary_computer.sign_inplace, + ) + collection = catalog.get_collection(self.data_catalog_id) + + asset = collection.assets["zarr-abfs"] + + if "xarray:storage_options" in asset.extra_fields: + self.options["storage_options"] = asset.extra_fields["xarray:storage_options"] + + self.options.update(asset.extra_fields["xarray:open_kwargs"]) + + super().__init__(context, url=asset.href, *args, **kwargs) diff --git a/src/anemoi/datasets/create/sources/xarray_support/__init__.py b/src/anemoi/datasets/create/sources/xarray_support/__init__.py index 4f4edb46f..665cfdad3 100644 --- a/src/anemoi/datasets/create/sources/xarray_support/__init__.py +++ b/src/anemoi/datasets/create/sources/xarray_support/__init__.py @@ -20,7 +20,6 @@ from earthkit.data.core.fieldlist import MultiFieldList from anemoi.datasets.create.sources.patterns import iterate_patterns -from anemoi.datasets.data.stores import name_to_zarr_store from ..legacy import legacy_source from .fieldlist import XarrayFieldList @@ -89,37 +88,22 @@ def load_one( The loaded dataset. """ - """ - We manage the S3 client ourselves, bypassing fsspec and s3fs layers, because sometimes something on the stack - zarr/fsspec/s3fs/boto3 (?) seem to flags files as missing when they actually are not (maybe when S3 reports some sort of - connection error). In that case, Zarr will silently fill the chunks that could not be downloaded with NaNs. - See https://github.com/pydata/xarray/issues/8842 - - We have seen this bug triggered when we run many clients in parallel, for example, when we create a new dataset using `xarray-zarr`. - """ - if options is None: options = {} context.trace(emoji, dataset, options, kwargs) - if isinstance(dataset, str) and ".zarr" in dataset: - data = xr.open_zarr(name_to_zarr_store(dataset), **options) - elif "planetarycomputer" in dataset: - store = name_to_zarr_store(dataset) - if "store" in store: - data = xr.open_zarr(**store) - if "filename_or_obj" in store: - data = xr.open_dataset(**store) - else: - data = xr.open_dataset(dataset, **options) + if isinstance(dataset, str) and dataset.endswith(".zarr"): + # If the dataset is a zarr store, we need to use the zarr engine + options["engine"] = "zarr" + + data = xr.open_dataset(dataset, **options) fs = XarrayFieldList.from_xarray(data, flavour=flavour, patch=patch) if len(dates) == 0: result = fs.sel(**kwargs) else: - print("dates", dates, kwargs) result = MultiFieldList([fs.sel(valid_datetime=date, **kwargs) for date in dates]) if len(result) == 0: @@ -130,7 +114,7 @@ def load_one( a = ["valid_datetime", k.metadata("valid_datetime", default=None)] for n in kwargs.keys(): a.extend([n, k.metadata(n, default=None)]) - print([str(x) for x in a]) + LOG.warning(f"{[str(x) for x in a]}") if i > 16: break diff --git a/src/anemoi/datasets/create/sources/xarray_support/field.py b/src/anemoi/datasets/create/sources/xarray_support/field.py index 663aeab54..9fdd93246 100644 --- a/src/anemoi/datasets/create/sources/xarray_support/field.py +++ b/src/anemoi/datasets/create/sources/xarray_support/field.py @@ -87,13 +87,10 @@ def __init__(self, owner: Any, selection: Any) -> None: coordinate = owner.by_name[coord_name] self._md[coord_name] = coordinate.normalise(extract_single_value(coord_value)) - # print(values.ndim, values.shape, selection.dims) # By now, the only dimensions should be latitude and longitude self._shape = tuple(list(self.selection.shape)[-2:]) if math.prod(self._shape) != math.prod(self.selection.shape): - print(self.selection.ndim, self.selection.shape) - print(self.selection) - raise ValueError("Invalid shape for selection") + raise ValueError(f"Invalid shape for selection {self._shape=}, {self.selection.shape=} {self.selection=}") @property def shape(self) -> Tuple[int, int]: diff --git a/src/anemoi/datasets/create/sources/xarray_support/flavour.py b/src/anemoi/datasets/create/sources/xarray_support/flavour.py index 4df374148..02b30d7bb 100644 --- a/src/anemoi/datasets/create/sources/xarray_support/flavour.py +++ b/src/anemoi/datasets/create/sources/xarray_support/flavour.py @@ -308,9 +308,9 @@ def _x_y_provided(self, x: Any, y: Any, variable: Any) -> Any: return self._grid_cache[(x.name, y.name, dim_vars)] grid_mapping = variable.attrs.get("grid_mapping", None) - if grid_mapping is not None: - print(f"grid_mapping: {grid_mapping}") - print(self.ds[grid_mapping]) + # if grid_mapping is not None: + # print(f"grid_mapping: {grid_mapping}") + # print(self.ds[grid_mapping]) if grid_mapping is None: LOG.warning(f"No 'grid_mapping' attribute provided for '{variable.name}'") diff --git a/src/anemoi/datasets/create/sources/xarray_support/patch.py b/src/anemoi/datasets/create/sources/xarray_support/patch.py index 29ea620dd..a84fccc14 100644 --- a/src/anemoi/datasets/create/sources/xarray_support/patch.py +++ b/src/anemoi/datasets/create/sources/xarray_support/patch.py @@ -61,9 +61,28 @@ def patch_coordinates(ds: xr.Dataset, coordinates: List[str]) -> Any: return ds +def patch_rename(ds: xr.Dataset, renames: dict[str, str]) -> Any: + """Rename variables in the dataset. + + Parameters + ---------- + ds : xr.Dataset + The dataset to patch. + renames : dict[str, str] + Mapping from old variable names to new variable names. + + Returns + ------- + Any + The patched dataset. + """ + return ds.rename(renames) + + PATCHES = { "attributes": patch_attributes, "coordinates": patch_coordinates, + "rename": patch_rename, } @@ -82,7 +101,9 @@ def patch_dataset(ds: xr.Dataset, patch: Dict[str, Dict[str, Any]]) -> Any: Any The patched dataset. """ - for what, values in patch.items(): + + ORDER = ["coordinates", "attributes", "rename"] + for what, values in sorted(patch.items(), key=lambda x: ORDER.index(x[0])): if what not in PATCHES: raise ValueError(f"Unknown patch type {what!r}") diff --git a/src/anemoi/datasets/data/stores.py b/src/anemoi/datasets/data/stores.py index e31d4cfb9..632bc36cd 100644 --- a/src/anemoi/datasets/data/stores.py +++ b/src/anemoi/datasets/data/stores.py @@ -214,9 +214,6 @@ def name_to_zarr_store(path_or_url: str) -> ReadOnlyStore: if len(bits) == 5 and (bits[1], bits[3], bits[4]) == ("s3", "amazonaws", "com"): s3_url = f"s3://{bits[0]}{parsed.path}" store = S3Store(s3_url, region=bits[2]) - elif store.startswith("https://planetarycomputer.microsoft.com/"): - data_catalog_id = store.rsplit("/", 1)[-1] - store = PlanetaryComputerStore(data_catalog_id).store else: store = HTTPStore(store) diff --git a/tests/create/test_sources.py b/tests/create/test_sources.py index 6e6098800..05281f16e 100644 --- a/tests/create/test_sources.py +++ b/tests/create/test_sources.py @@ -7,7 +7,6 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. -import logging import os import sys @@ -245,12 +244,44 @@ def test_kerchunk(get_test_data: callable) -> None: assert ds.shape == (4, 1, 1, 1038240) +@skip_if_offline +@skip_missing_packages("planetary_computer", "adlfs") +def test_planetary_computer_conus404() -> None: + """Test loading and validating the planetary_computer_conus404 dataset.""" + + config = { + "dates": { + "start": "2022-01-01", + "end": "2022-01-02", + "frequency": "1d", + }, + "input": { + "planetary_computer": { + "data_catalog_id": "conus404", + "param": ["Z"], + "level": [1], + "patch": { + "coordinates": ["bottom_top_stag"], + "rename": { + "bottom_top_stag": "level", + }, + "attributes": { + "lon": {"standard_name": "longitude", "long_name": "Longitude"}, + "lat": {"standard_name": "latitude", "long_name": "Latitude"}, + }, + }, + } + }, + } + + created = create_dataset(config=config, output=None) + ds = open_dataset(created) + assert ds.shape == (2, 1, 1, 1387505), ds.shape + + if __name__ == "__main__": - logging.basicConfig(level=logging.INFO) - test_kerchunk() - exit() - """Run all test functions that start with 'test_'.""" - for name, obj in list(globals().items()): - if name.startswith("test_") and callable(obj): - print(f"Running {name}...") - obj() + test_planetary_computer_conus404() + exit(0) + from anemoi.utils.testing import run_tests + + run_tests(globals()) diff --git a/tests/xarray/test_zarr.py b/tests/xarray/test_zarr.py index c0a8f8ed2..81d36b923 100644 --- a/tests/xarray/test_zarr.py +++ b/tests/xarray/test_zarr.py @@ -13,7 +13,6 @@ from anemoi.utils.testing import skip_missing_packages from anemoi.datasets.create.sources.xarray import XarrayFieldList -from anemoi.datasets.data.stores import name_to_zarr_store from anemoi.datasets.testing import assert_field_list @@ -133,33 +132,6 @@ def test_noaa_replay() -> None: ) -@skip_if_offline -@skip_missing_packages("planetary_computer", "adlfs") -def test_planetary_computer_conus404() -> None: - """Test loading and validating the planetary_computer_conus404 dataset.""" - url = "https://planetarycomputer.microsoft.com/api/stac/v1/collections/conus404" - ds = xr.open_zarr(**name_to_zarr_store(url)) - - flavour = { - "rules": { - "latitude": {"name": "lat"}, - "longitude": {"name": "lon"}, - "x": {"name": "west_east"}, - "y": {"name": "south_north"}, - "time": {"name": "time"}, - }, - } - - fs = XarrayFieldList.from_xarray(ds, flavour=flavour) - - assert_field_list( - fs, - 74634912, - "1979-10-01T00:00:00", - "2022-09-30T23:00:00", - ) - - if __name__ == "__main__": for name, obj in list(globals().items()): if name.startswith("test_") and callable(obj): From c72fe274e1a8bd082a0d000440424fd605407d71 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Thu, 26 Jun 2025 17:25:24 +0000 Subject: [PATCH 12/16] add tests --- src/anemoi/datasets/zarr_versions/zarr2.py | 46 ---------------------- src/anemoi/datasets/zarr_versions/zarr3.py | 42 -------------------- tests/xarray/test_zarr.py | 33 ++++++++++++++++ 3 files changed, 33 insertions(+), 88 deletions(-) diff --git a/src/anemoi/datasets/zarr_versions/zarr2.py b/src/anemoi/datasets/zarr_versions/zarr2.py index f6a6e1ec6..69980e31b 100644 --- a/src/anemoi/datasets/zarr_versions/zarr2.py +++ b/src/anemoi/datasets/zarr_versions/zarr2.py @@ -92,52 +92,6 @@ def __getitem__(self, key: str) -> bytes: return r.content -class PlanetaryComputerStore(ReadOnlyStore): - """We write our own Store to access catalogs on Planetary Computer, - as it requires some extra arguments to use xr.open_zarr. - """ - - def __init__(self, data_catalog_id: str) -> None: - """Initialize the PlanetaryComputerStore with a data catalog ID. - - Parameters - ---------- - data_catalog_id : str - The data catalog ID. - """ - super().__init__() - self.data_catalog_id = data_catalog_id - - import planetary_computer - import pystac_client - - catalog = pystac_client.Client.open( - "https://planetarycomputer.microsoft.com/api/stac/v1/", - modifier=planetary_computer.sign_inplace, - ) - collection = catalog.get_collection(self.data_catalog_id) - - asset = collection.assets["zarr-abfs"] - - if "xarray:storage_options" in asset.extra_fields: - store = { - "store": asset.href, - "storage_options": asset.extra_fields["xarray:storage_options"], - **asset.extra_fields["xarray:open_kwargs"], - } - else: - store = { - "filename_or_obj": asset.href, - **asset.extra_fields["xarray:open_kwargs"], - } - - self.store = store - - def __getitem__(self, key: str) -> bytes: - """Retrieve an item from the store.""" - raise NotImplementedError() - - class DebugStore(ReadOnlyStore): """A store to debug the zarr loading.""" diff --git a/src/anemoi/datasets/zarr_versions/zarr3.py b/src/anemoi/datasets/zarr_versions/zarr3.py index ccdc631ef..c7d18b121 100644 --- a/src/anemoi/datasets/zarr_versions/zarr3.py +++ b/src/anemoi/datasets/zarr_versions/zarr3.py @@ -60,48 +60,6 @@ def __init__(self, url): DebugStore = zarr.storage.LoggingStore -class PlanetaryComputerStore: - """We write our own Store to access catalogs on Planetary Computer, - as it requires some extra arguments to use xr.open_zarr. - """ - - def __init__(self, data_catalog_id: str) -> None: - """Initialize the PlanetaryComputerStore with a data catalog ID. - - Parameters - ---------- - data_catalog_id : str - The data catalog ID. - """ - super().__init__() - self.data_catalog_id = data_catalog_id - - import planetary_computer - import pystac_client - - catalog = pystac_client.Client.open( - "https://planetarycomputer.microsoft.com/api/stac/v1/", - modifier=planetary_computer.sign_inplace, - ) - collection = catalog.get_collection(self.data_catalog_id) - - asset = collection.assets["zarr-abfs"] - - if "xarray:storage_options" in asset.extra_fields: - store = { - "store": asset.href, - "storage_options": asset.extra_fields["xarray:storage_options"], - **asset.extra_fields["xarray:open_kwargs"], - } - else: - store = { - "filename_or_obj": asset.href, - **asset.extra_fields["xarray:open_kwargs"], - } - - self.store = store - - def create_array(zarr_root, *args, **kwargs): if "compressor" in kwargs and kwargs["compressor"] is None: # compressor is deprecated, use compressors instead diff --git a/tests/xarray/test_zarr.py b/tests/xarray/test_zarr.py index 81d36b923..8bd718101 100644 --- a/tests/xarray/test_zarr.py +++ b/tests/xarray/test_zarr.py @@ -132,6 +132,39 @@ def test_noaa_replay() -> None: ) +@skip_if_offline +@skip_missing_packages("s3fs") +def test_aws_s3() -> None: + """Test loading and validating an AWS S3 dataset.""" + url = "s3://aodn-cloud-optimised/model_sea_level_anomaly_gridded_realtime.zarr" + ds = xr.open_zarr(url, consolidated=True, storage_options={"anon": True}) + + fs = XarrayFieldList.from_xarray(ds) + + assert_field_list( + fs, + 400, + "2011-09-01T00:00:00", + "2011-12-12T00:00:00", + ) + + +@skip_if_offline +def test_aws_s3_https() -> None: + """Test loading and validating an AWS S3 dataset via HTTPS.""" + url = "https://aodn-cloud-optimised.s3.amazonaws.com/model_sea_level_anomaly_gridded_realtime.zarr" + ds = xr.open_zarr(url, consolidated=True) + + fs = XarrayFieldList.from_xarray(ds) + + assert_field_list( + fs, + 400, + "2011-09-01T00:00:00", + "2011-12-12T00:00:00", + ) + + if __name__ == "__main__": for name, obj in list(globals().items()): if name.startswith("test_") and callable(obj): From fc1dfa78a4ada78cf5b30f8f97e1000f4245a789 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Wed, 8 Oct 2025 18:39:57 +0000 Subject: [PATCH 13/16] use zarr3 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 8e873aff9..0f5158f2e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,7 +59,7 @@ dependencies = [ "ruamel-yaml", "semantic-version", "tqdm", - "zarr", + "zarr>=3", ] optional-dependencies.all = [ From 7331aa08fd19656f82bb8fb212ea189381348c1c Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Wed, 8 Oct 2025 18:53:18 +0000 Subject: [PATCH 14/16] update --- src/anemoi/datasets/create/__init__.py | 2 +- src/anemoi/datasets/data/stores.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/anemoi/datasets/create/__init__.py b/src/anemoi/datasets/create/__init__.py index e2e429883..b62982bac 100644 --- a/src/anemoi/datasets/create/__init__.py +++ b/src/anemoi/datasets/create/__init__.py @@ -153,7 +153,7 @@ def _path_readable(path: str) -> bool: import zarr try: - zarr.open(path, "r") + zarr.open(path, mode="r") return True except zarr_2_or_3.FileNotFoundException: return False diff --git a/src/anemoi/datasets/data/stores.py b/src/anemoi/datasets/data/stores.py index e21804aec..f0d47b8ba 100644 --- a/src/anemoi/datasets/data/stores.py +++ b/src/anemoi/datasets/data/stores.py @@ -106,7 +106,7 @@ def open_zarr(path: str, dont_fail: bool = False, cache: int = None) -> Any: class Zarr(Dataset): """A zarr dataset.""" - def __init__(self, path: str | zarr.hierarchy.Group) -> None: + def __init__(self, path: str | zarr_2_or_3.Group) -> None: """Initialize the Zarr dataset with a path or zarr group.""" if isinstance(path, zarr_2_or_3.Group): self.was_zarr = True @@ -346,7 +346,7 @@ def collect_input_sources(self, collected: set) -> None: class ZarrWithMissingDates(Zarr): """A zarr dataset with missing dates.""" - def __init__(self, path: str | zarr.hierarchy.Group) -> None: + def __init__(self, path: str | zarr_2_or_3.Group) -> None: """Initialize the ZarrWithMissingDates dataset with a path or zarr group.""" super().__init__(path) From 38f3094ec7ff3526454de31de1d61aa877ec59d7 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Thu, 9 Oct 2025 08:21:30 +0000 Subject: [PATCH 15/16] update with new version of zarr3 --- pyproject.toml | 2 +- src/anemoi/datasets/commands/copy.py | 27 ++-- src/anemoi/datasets/commands/create.py | 1 - src/anemoi/datasets/commands/init.py | 1 - src/anemoi/datasets/create/__init__.py | 39 +---- src/anemoi/datasets/create/misc.py | 10 +- src/anemoi/datasets/create/patch.py | 4 +- src/anemoi/datasets/data/misc.py | 4 +- src/anemoi/datasets/data/stores.py | 91 +++++++----- src/anemoi/datasets/zarr_versions/__init__.py | 23 --- src/anemoi/datasets/zarr_versions/zarr2.py | 138 ------------------ src/anemoi/datasets/zarr_versions/zarr3.py | 123 ---------------- tests/test_data.py | 84 ++++------- tests/test_data_gridded.py | 29 ++-- 14 files changed, 114 insertions(+), 462 deletions(-) delete mode 100644 src/anemoi/datasets/zarr_versions/__init__.py delete mode 100644 src/anemoi/datasets/zarr_versions/zarr2.py delete mode 100644 src/anemoi/datasets/zarr_versions/zarr3.py diff --git a/pyproject.toml b/pyproject.toml index 0f5158f2e..6f6164ae7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,7 +59,7 @@ dependencies = [ "ruamel-yaml", "semantic-version", "tqdm", - "zarr>=3", + "zarr>=3.1.3", ] optional-dependencies.all = [ diff --git a/src/anemoi/datasets/commands/copy.py b/src/anemoi/datasets/commands/copy.py index c79abf1fc..e4b11523b 100644 --- a/src/anemoi/datasets/commands/copy.py +++ b/src/anemoi/datasets/commands/copy.py @@ -20,7 +20,6 @@ from anemoi.utils.remote import TransferMethodNotImplementedError from anemoi.datasets.check import check_zarr -from anemoi.datasets.zarr_versions import zarr_2_or_3 from . import Command @@ -51,6 +50,8 @@ class ZarrCopier: Flag to resume copying an existing dataset. verbosity : int Verbosity level of logging. + nested : bool + Flag to use ZARR's nested directory backend. rechunk : str Rechunk size for the target data array. """ @@ -64,6 +65,7 @@ def __init__( overwrite: bool, resume: bool, verbosity: int, + nested: bool, rechunk: str, **kwargs: Any, ) -> None: @@ -85,6 +87,8 @@ def __init__( Flag to resume copying an existing dataset. verbosity : int Verbosity level of logging. + nested : bool + Flag to use ZARR's nested directory backend. rechunk : str Rechunk size for the target data array. **kwargs : Any @@ -97,6 +101,7 @@ def __init__( self.overwrite = overwrite self.resume = resume self.verbosity = verbosity + self.nested = nested self.rechunk = rechunk self.rechunking = rechunk.split(",") if rechunk else [] @@ -233,12 +238,10 @@ def copy_data(self, source: Any, target: Any, _copy: Any, verbosity: int) -> Non target_data = ( target["data"] if "data" in target - else zarr_2_or_3.create_array( - target, + else target.create_array( "data", shape=source_data.shape, chunks=self.data_chunks, - dtype=source_data.dtype, fill_value=source_data.fill_value, ) ) @@ -314,6 +317,7 @@ def copy_group(self, source: Any, target: Any, _copy: Any, verbosity: int) -> No verbosity : int Verbosity level of logging. """ + import zarr if self.verbosity > 0: LOG.info(f"Copying group {source} to {target}") @@ -339,7 +343,7 @@ def copy_group(self, source: Any, target: Any, _copy: Any, verbosity: int) -> No LOG.info(f"Skipping {name}") continue - if zarr_2_or_3.is_zarr_group(source[name]): + if isinstance(source[name], zarr.Group): group = target[name] if name in target else target.create_group(name) self.copy_group( source[name], @@ -397,13 +401,13 @@ def run(self) -> None: def target_exists() -> bool: try: - zarr.open(self.target, mode="r") + zarr.open(self._store(self.target), mode="r") return True - except ValueError: + except FileNotFoundError: return False def target_finished() -> bool: - target = zarr.open(self.target, mode="r") + target = zarr.open(self._store(self.target), mode="r") if "_copy" in target: done = sum(1 if x else 0 for x in target["_copy"]) todo = len(target["_copy"]) @@ -421,11 +425,11 @@ def target_finished() -> bool: def open_target() -> Any: if not target_exists(): - return zarr.open(self.target, mode="w") + return zarr.open(self._store(self.target, self.nested), mode="w") if self.overwrite: LOG.error("Target already exists, overwriting.") - return zarr.open(self.target, mode="w") + return zarr.open(self._store(self.target, self.nested), mode="w") if self.resume: if target_finished(): @@ -433,7 +437,7 @@ def open_target() -> Any: sys.exit(0) LOG.error("Target already exists, resuming copy.") - return zarr.open(self.target, mode=zarr_2_or_3.open_mode_append) + return zarr.open(self._store(self.target, self.nested), mode="w+") LOG.error("Target already exists, use either --overwrite or --resume.") sys.exit(1) @@ -489,6 +493,7 @@ def add_arguments(self, command_parser: Any) -> None: help="Verbosity level. 0 is silent, 1 is normal, 2 is verbose.", default=1, ) + command_parser.add_argument("--nested", action="store_true", help="Use ZARR's nested directpry backend.") command_parser.add_argument( "--rechunk", help="Rechunk the target data array. Rechunk size should be a diviser of the block size." ) diff --git a/src/anemoi/datasets/commands/create.py b/src/anemoi/datasets/commands/create.py index 7a9d8a559..3f6bbe7dd 100644 --- a/src/anemoi/datasets/commands/create.py +++ b/src/anemoi/datasets/commands/create.py @@ -86,7 +86,6 @@ def add_arguments(self, command_parser: Any) -> None: group.add_argument("--threads", help="Use `n` parallel thread workers.", type=int, default=0) group.add_argument("--processes", help="Use `n` parallel process workers.", type=int, default=0) command_parser.add_argument("--trace", action="store_true") - command_parser.add_argument("--force-zarr3", action="store_true") def run(self, args: Any) -> None: """Execute the create command. diff --git a/src/anemoi/datasets/commands/init.py b/src/anemoi/datasets/commands/init.py index 73152facb..0ca540b86 100644 --- a/src/anemoi/datasets/commands/init.py +++ b/src/anemoi/datasets/commands/init.py @@ -63,7 +63,6 @@ def add_arguments(self, subparser: Any) -> None: subparser.add_argument("--cache", help="Location to store the downloaded data.", metavar="DIR") subparser.add_argument("--trace", action="store_true") - subparser.add_argument("--force-zarr3", action="store_true", help="Force the use of Zarr v3 format.") def run(self, args: Any) -> None: """Execute the command with the provided arguments. diff --git a/src/anemoi/datasets/create/__init__.py b/src/anemoi/datasets/create/__init__.py index b62982bac..21f8ecc34 100644 --- a/src/anemoi/datasets/create/__init__.py +++ b/src/anemoi/datasets/create/__init__.py @@ -36,7 +36,6 @@ from anemoi.datasets.data.misc import as_first_date from anemoi.datasets.data.misc import as_last_date from anemoi.datasets.dates.groups import Groups -from anemoi.datasets.zarr_versions import zarr_2_or_3 from .check import DatasetName from .check import check_data_values @@ -155,7 +154,7 @@ def _path_readable(path: str) -> bool: try: zarr.open(path, mode="r") return True - except zarr_2_or_3.FileNotFoundException: + except FileNotFoundError: return False @@ -172,11 +171,6 @@ def __init__(self, path: str): """ self.path = path - # if zarr_2_or_3.version != 2: - # raise ValueError( - # f"Only zarr version 2 is supported when creating datasets, found version: {zarr.__version__}" - # ) - _, ext = os.path.splitext(self.path) if ext != ".zarr": raise ValueError(f"Unsupported extension={ext} for path={self.path}") @@ -213,7 +207,7 @@ def update_metadata(self, **kwargs: Any) -> None: import zarr LOG.debug(f"Updating metadata {kwargs}") - z = zarr.open(self.path, mode=zarr_2_or_3.open_mode_append) + z = zarr.open(self.path, mode="a") for k, v in kwargs.items(): if isinstance(v, np.datetime64): v = v.astype(datetime.datetime) @@ -584,7 +578,6 @@ def __init__( progress: Any = None, test: bool = False, cache: str | None = None, - force_zarr3: bool = False, **kwargs: Any, ): """Initialize an Init instance. @@ -609,38 +602,10 @@ def __init__( Whether this is a test. cache : Optional[str], optional The cache directory. - force_zarr3 : bool, optional - Whether to force the use of Zarr version 3. """ if _path_readable(path) and not overwrite: raise Exception(f"{path} already exists. Use overwrite=True to overwrite.") - version = zarr_2_or_3.version - if not zarr_2_or_3.supports_datetime64(): - LOG.warning("⚠️" * 80) - LOG.warning(f"This version of Zarr ({zarr.__version__}) does not support datetime64.") - LOG.warning("⚠️" * 80) - - if version != 2: - - pytesting = "PYTEST_CURRENT_TEST" in os.environ - - if pytesting or force_zarr3: - LOG.warning("⚠️" * 80) - LOG.warning("Zarr version 3 is used, but this is an unsupported feature.") - LOG.warning("⚠️" * 80) - else: - LOG.warning("⚠️" * 80) - LOG.warning( - f"Only Zarr version 2 is supported when creating datasets, found version: {zarr.__version__}" - ) - LOG.warning("If you want to use Zarr version 3, please set --force-zarr3 option.") - LOG.warning("Please note that this is an unsupported feature.") - LOG.warning("⚠️" * 80) - raise ValueError( - f"Only Zarr version 2 is supported when creating datasets, found version: {zarr.__version__}" - ) - super().__init__(path, cache=cache) self.config = config self.check_name = check_name diff --git a/src/anemoi/datasets/create/misc.py b/src/anemoi/datasets/create/misc.py index 111a72f12..55946a201 100644 --- a/src/anemoi/datasets/create/misc.py +++ b/src/anemoi/datasets/create/misc.py @@ -15,8 +15,6 @@ import zarr from numpy.typing import NDArray -from anemoi.datasets.zarr_versions import zarr_2_or_3 - from .synchronise import NoSynchroniser from .synchronise import Synchroniser @@ -73,11 +71,9 @@ def add_zarr_dataset( shape = array.shape if array is not None: - array, dtype = zarr_2_or_3.cast_dtype_datetime64(array, dtype) assert array.shape == shape, (array.shape, shape) - a = zarr_2_or_3.create_array( - zarr_root, + a = zarr_root.create_array( name, shape=shape, dtype=dtype, @@ -104,9 +100,7 @@ def add_zarr_dataset( else: raise ValueError(f"No fill_value for dtype={dtype}") - dtype = zarr_2_or_3.change_dtype_datetime64(dtype) - a = zarr_2_or_3.create_array( - zarr_root, + a = zarr_root.create_array( name, shape=shape, dtype=dtype, diff --git a/src/anemoi/datasets/create/patch.py b/src/anemoi/datasets/create/patch.py index 278b737ce..a24dca6bf 100755 --- a/src/anemoi/datasets/create/patch.py +++ b/src/anemoi/datasets/create/patch.py @@ -13,8 +13,6 @@ import zarr -from anemoi.datasets.zarr_versions import zarr_2_or_3 - LOG = logging.getLogger(__name__) @@ -135,7 +133,7 @@ def apply_patch(path: str, verbose: bool = True, dry_run: bool = False) -> None: try: attrs = zarr.open(path, mode="r").attrs.asdict() - except zarr_2_or_3.get_not_found_exception() as e: + except FileNotFoundError as e: LOG.error(f"Failed to open {path}") LOG.error(e) exit(0) diff --git a/src/anemoi/datasets/data/misc.py b/src/anemoi/datasets/data/misc.py index 61a6ea3cb..08004b408 100644 --- a/src/anemoi/datasets/data/misc.py +++ b/src/anemoi/datasets/data/misc.py @@ -22,8 +22,6 @@ from anemoi.utils.config import load_config as load_settings from numpy.typing import NDArray -from anemoi.datasets.zarr_versions import zarr_2_or_3 - if TYPE_CHECKING: from .dataset import Dataset @@ -368,7 +366,7 @@ def _open(a: str | PurePath | dict[str, Any] | list[Any] | tuple[Any, ...]) -> " if isinstance(a, Dataset): return a.mutate() - if isinstance(a, zarr_2_or_3.Group): + if isinstance(a, zarr.Group): return Zarr(a).mutate() if isinstance(a, str): diff --git a/src/anemoi/datasets/data/stores.py b/src/anemoi/datasets/data/stores.py index f0d47b8ba..a2310816f 100644 --- a/src/anemoi/datasets/data/stores.py +++ b/src/anemoi/datasets/data/stores.py @@ -21,7 +21,6 @@ from anemoi.utils.dates import frequency_to_timedelta from numpy.typing import NDArray -from ..zarr_versions import zarr_2_or_3 from . import MissingDateError from .dataset import Dataset from .dataset import FullIndex @@ -37,12 +36,54 @@ LOG = logging.getLogger(__name__) +class S3Store(zarr.storage.ObjectStore): + """We use our class to manage per bucket credentials""" + + def __init__(self, url): + + import boto3 + from anemoi.utils.remote.s3 import s3_options + from obstore.auth.boto3 import Boto3CredentialProvider + from obstore.store import from_url + + options = s3_options(url) + + credential_provider = Boto3CredentialProvider( + session=boto3.session.Session( + aws_access_key_id=options["aws_access_key_id"], + aws_secret_access_key=options["aws_secret_access_key"], + ), + ) + + objectstore = from_url( + url, + credential_provider=credential_provider, + endpoint=options["endpoint_url"], + ) + + super().__init__(objectstore, read_only=True) + + +class HTTPStore(zarr.storage.ObjectStore): + + def __init__(self, url): + + from obstore.store import from_url + + objectstore = from_url(url) + + super().__init__(objectstore, read_only=True) + + +DebugStore = zarr.storage.LoggingStore + + def name_to_zarr_store(path_or_url: str) -> Any: """Convert a path or URL to a zarr store.""" store = path_or_url if store.startswith("s3://"): - return zarr_2_or_3.S3Store(store) + return S3Store(store) if store.startswith("http://") or store.startswith("https://"): @@ -66,17 +107,12 @@ def name_to_zarr_store(path_or_url: str) -> Any: os.rename(path + ".tmp", path) return name_to_zarr_store(path) - bits = parsed.netloc.split(".") - if len(bits) == 5 and (bits[1], bits[3], bits[4]) == ("s3", "amazonaws", "com"): - s3_url = f"s3://{bits[0]}{parsed.path}" - store = zarr_2_or_3.S3Store(s3_url, region=bits[2]) - else: - store = zarr_2_or_3.HTTPStore(store) + return HTTPStore(store) return store -def open_zarr(path: str, dont_fail: bool = False, cache: int = None) -> Any: +def open_zarr(path: str, dont_fail: bool = False, cache: int = None) -> zarr.Group: """Open a zarr store from a path.""" try: store = name_to_zarr_store(path) @@ -90,15 +126,14 @@ def open_zarr(path: str, dont_fail: bool = False, cache: int = None) -> Any: "DEBUG_ZARR_LOADING is only implemented for DirectoryStore. " "Please disable it for other backends." ) - store = zarr_2_or_3.DirectoryStore(store) - store = zarr_2_or_3.DebugStore(store) + store = zarr.storage.DirectoryStore(store) + store = DebugStore(store) if cache is not None: - store = zarr_2_or_3.LRUStoreCache(store, max_size=cache) + store = zarr.LRUStoreCache(store, max_size=cache) return zarr.open(store, mode="r") - - except zarr_2_or_3.FileNotFoundException: + except FileNotFoundError: if not dont_fail: raise FileNotFoundError(f"Zarr store not found: {path}") @@ -106,9 +141,9 @@ def open_zarr(path: str, dont_fail: bool = False, cache: int = None) -> Any: class Zarr(Dataset): """A zarr dataset.""" - def __init__(self, path: str | zarr_2_or_3.Group) -> None: + def __init__(self, path: str | zarr.Group) -> None: """Initialize the Zarr dataset with a path or zarr group.""" - if isinstance(path, zarr_2_or_3.Group): + if isinstance(path, zarr.Group): self.was_zarr = True self.path = str(id(path)) self.z = path @@ -184,31 +219,17 @@ def dtype(self) -> np.dtype: @cached_property def dates(self) -> NDArray[np.datetime64]: """Return the dates of the dataset.""" - dates = self.z["dates"][:] - if not dates.dtype == np.dtype("datetime64[s]"): - # The datasets created with zarr3 will have the dates as int64 as long - # as zarr3 does not support datetime64 - LOG.warning("Converting dates to 'datetime64[s]'") - dates = dates.astype("datetime64[s]") - return dates + return self.z["dates"][:] # Convert to numpy @property def latitudes(self) -> NDArray[Any]: """Return the latitudes of the dataset.""" - try: - return self.z["latitudes"][:] - except AttributeError: - LOG.warning("No 'latitudes' in %r, trying 'latitude'", self) - return self.z["latitude"][:] + return self.z["latitudes"][:] @property def longitudes(self) -> NDArray[Any]: """Return the longitudes of the dataset.""" - try: - return self.z["longitudes"][:] - except AttributeError: - LOG.warning("No 'longitudes' in %r, trying 'longitude'", self) - return self.z["longitude"][:] + return self.z["longitudes"][:] @property def statistics(self) -> dict[str, NDArray[Any]]: @@ -346,7 +367,7 @@ def collect_input_sources(self, collected: set) -> None: class ZarrWithMissingDates(Zarr): """A zarr dataset with missing dates.""" - def __init__(self, path: str | zarr_2_or_3.Group) -> None: + def __init__(self, path: str | zarr.Group) -> None: """Initialize the ZarrWithMissingDates dataset with a path or zarr group.""" super().__init__(path) @@ -462,7 +483,7 @@ def zarr_lookup(name: str, fail: bool = True) -> str | None: LOG.info("Opening `%s` as `%s`", name, full) QUIET.add(name) return full - except zarr_2_or_3.FileNotFoundException: + except FileNotFoundError: pass if fail: diff --git a/src/anemoi/datasets/zarr_versions/__init__.py b/src/anemoi/datasets/zarr_versions/__init__.py deleted file mode 100644 index 39fd71d91..000000000 --- a/src/anemoi/datasets/zarr_versions/__init__.py +++ /dev/null @@ -1,23 +0,0 @@ -# (C) Copyright 2025 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - - -import zarr - -version = zarr.__version__.split(".")[0] - -if version == "2": - from . import zarr2 as zarr_2_or_3 - -elif version == "3": - from . import zarr3 as zarr_2_or_3 -else: - raise ImportError(f"Unsupported Zarr version: {zarr.__version__}. Supported versions are 2 and 3.") - -__all__ = ["zarr_2_or_3"] diff --git a/src/anemoi/datasets/zarr_versions/zarr2.py b/src/anemoi/datasets/zarr_versions/zarr2.py deleted file mode 100644 index 69980e31b..000000000 --- a/src/anemoi/datasets/zarr_versions/zarr2.py +++ /dev/null @@ -1,138 +0,0 @@ -# (C) Copyright 2025 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. -import logging -import warnings -from typing import Any -from typing import Optional - -import zarr - -LOG = logging.getLogger(__name__) - - -version = 2 - -FileNotFoundException = zarr.errors.PathNotFoundError -Group = zarr.hierarchy.Group -open_mode_append = "w+" - - -class ReadOnlyStore(zarr.storage.BaseStore): - """A base class for read-only stores.""" - - def __delitem__(self, key: str) -> None: - """Prevent deletion of items.""" - raise NotImplementedError() - - def __setitem__(self, key: str, value: bytes) -> None: - """Prevent setting of items.""" - raise NotImplementedError() - - def __len__(self) -> int: - """Return the number of items in the store.""" - raise NotImplementedError() - - def __iter__(self) -> iter: - """Return an iterator over the store.""" - raise NotImplementedError() - - -class S3Store(ReadOnlyStore): - """A read-only store for S3 resources.""" - - """We write our own S3Store because the one used by zarr (s3fs) - does not play well with fork(). We also get to control the s3 client - options using the anemoi configs. - """ - - def __init__(self, url: str, region: Optional[str] = None) -> None: - """Initialize the S3Store with a URL and optional region.""" - from anemoi.utils.remote.s3 import s3_client - - super().__init__() - - _, _, self.bucket, self.key = url.split("/", 3) - self.s3 = s3_client(self.bucket, region=region) - - # Version 2 - def __getitem__(self, key: str) -> bytes: - """Retrieve an item from the store.""" - try: - response = self.s3.get_object(Bucket=self.bucket, Key=self.key + "/" + key) - except self.s3.exceptions.NoSuchKey: - raise KeyError(key) - - return response["Body"].read() - - -class HTTPStore(ReadOnlyStore): - """A read-only store for HTTP(S) resources.""" - - def __init__(self, url: str) -> None: - """Initialize the HTTPStore with a URL.""" - super().__init__() - self.url = url - - def __getitem__(self, key: str) -> bytes: - """Retrieve an item from the store.""" - import requests - - r = requests.get(self.url + "/" + key) - - if r.status_code == 404: - raise KeyError(key) - - r.raise_for_status() - return r.content - - -class DebugStore(ReadOnlyStore): - """A store to debug the zarr loading.""" - - def __init__(self, store: Any) -> None: - super().__init__() - """Initialize the DebugStore with another store.""" - assert not isinstance(store, DebugStore) - self.store = store - - def __getitem__(self, key: str) -> bytes: - """Retrieve an item from the store and print debug information.""" - # print() - print("GET", key, self) - # traceback.print_stack(file=sys.stdout) - return self.store[key] - - def __len__(self) -> int: - """Return the number of items in the store.""" - return len(self.store) - - def __iter__(self) -> iter: - """Return an iterator over the store.""" - warnings.warn("DebugStore: iterating over the store") - return iter(self.store) - - def __contains__(self, key: str) -> bool: - """Check if the store contains a key.""" - return key in self.store - - -def create_array(zarr_root, *args, **kwargs): - return zarr_root.create_dataset(*args, **kwargs) - - -def change_dtype_datetime64(dtype): - return dtype - - -def cast_dtype_datetime64(array, dtype): - return array, dtype - - -def supports_datetime64(): - return True diff --git a/src/anemoi/datasets/zarr_versions/zarr3.py b/src/anemoi/datasets/zarr_versions/zarr3.py deleted file mode 100644 index c7d18b121..000000000 --- a/src/anemoi/datasets/zarr_versions/zarr3.py +++ /dev/null @@ -1,123 +0,0 @@ -# (C) Copyright 2025 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. -import logging - -import zarr - -LOG = logging.getLogger(__name__) - -version = 3 -FileNotFoundException = FileNotFoundError -Group = zarr.Group -open_mode_append = "a" - - -class S3Store(zarr.storage.ObjectStore): - """We use our class to manage per bucket credentials""" - - def __init__(self, url): - - import boto3 - from anemoi.utils.remote.s3 import s3_options - from obstore.auth.boto3 import Boto3CredentialProvider - from obstore.store import from_url - - options = s3_options(url) - - credential_provider = Boto3CredentialProvider( - session=boto3.session.Session( - aws_access_key_id=options["aws_access_key_id"], - aws_secret_access_key=options["aws_secret_access_key"], - ), - ) - - objectstore = from_url( - url, - credential_provider=credential_provider, - endpoint=options["endpoint_url"], - ) - - super().__init__(objectstore, read_only=True) - - -class HTTPStore(zarr.storage.ObjectStore): - - def __init__(self, url): - - from obstore.store import from_url - - objectstore = from_url(url) - - super().__init__(objectstore, read_only=True) - - -DebugStore = zarr.storage.LoggingStore - - -def create_array(zarr_root, *args, **kwargs): - if "compressor" in kwargs and kwargs["compressor"] is None: - # compressor is deprecated, use compressors instead - kwargs.pop("compressor") - kwargs["compressors"] = () - - data = kwargs.pop("data", None) - if data is not None: - kwargs.setdefault("dtype", change_dtype_datetime64(data.dtype)) - kwargs.setdefault("shape", data.shape) - - try: - z = zarr_root.create_array(*args, **kwargs) - if data is not None: - z[:] = data - return z - except Exception: - LOG.exception("Failed to create array in Zarr store") - LOG.error( - "Failed to create array in Zarr store with args: %s, kwargs: %s", - args, - kwargs, - ) - raise - - -def change_dtype_datetime64(dtype): - # remove this flag (and the relevant code) when Zarr 3 supports datetime64 - # https://github.com/zarr-developers/zarr-python/issues/2616 - import numpy as np - - if dtype == "datetime64[s]": - dtype = np.dtype("int64") - return dtype - - -def cast_dtype_datetime64(array, dtype): - # remove this flag (and the relevant code) when Zarr 3 supports datetime64 - # https://github.com/zarr-developers/zarr-python/issues/2616 - import numpy as np - - if dtype == np.dtype("datetime64[s]"): - dtype = "int64" - array = array.astype(dtype) - - return array, dtype - - -def supports_datetime64(): - store = zarr.storage.MemoryStore() - try: - zarr.create_array(store=store, shape=(10,), dtype="datetime64[s]") - return True - except KeyError: - # If a KeyError is raised, it means datetime64 is not supported - return False - - -if __name__ == "__main__": - print("Zarr version:", version) - print("Zarr supports datetime64:", supports_datetime64()) diff --git a/tests/test_data.py b/tests/test_data.py index 2bb3b4895..2541278a1 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -39,7 +39,6 @@ from anemoi.datasets.data.stores import Zarr from anemoi.datasets.data.subset import Subset from anemoi.datasets.testing import default_test_indexing -from anemoi.datasets.zarr_versions import zarr_2_or_3 VALUES = 10 @@ -107,7 +106,7 @@ def create_zarr( ensemble: int | None = None, grids: int | None = None, missing: bool = False, -) -> zarr_2_or_3.Group: +) -> zarr.Group: """Create a Zarr dataset. Parameters @@ -145,7 +144,7 @@ def create_zarr( dates.append(date) date += frequency - dates = np.array(dates, dtype="datetime64[s]") + dates = np.array(dates, dtype="datetime64") ensembles = ensemble if ensemble is not None else 1 values = grids if grids is not None else VALUES @@ -157,42 +156,27 @@ def create_zarr( for e in range(ensembles): data[i, j, e] = _(date.astype(object), var, k, e, values) - zarr_2_or_3.create_array( - root, + root.create_array( "data", - dtype=data.dtype, + data=data, chunks=data.shape, compressor=None, - shape=data.shape, - )[...] = data - - dates, dtype_ = zarr_2_or_3.cast_dtype_datetime64(dates, dates.dtype) - del dtype_ - zarr_2_or_3.create_array( - root, + ) + root.create_array( "dates", + data=dates, compressor=None, - dtype=dates.dtype, - shape=dates.shape, - )[...] = dates - - latitudes = np.array([x + values for x in range(values)]) - zarr_2_or_3.create_array( - root, + ) + root.create_array( "latitudes", + data=np.array([x + values for x in range(values)]), compressor=None, - dtype=latitudes.dtype, - shape=latitudes.shape, - )[...] = latitudes - - longitudes = np.array([x + values for x in range(values)]) - zarr_2_or_3.create_array( - root, + ) + root.create_array( "longitudes", + data=np.array([x + values for x in range(values)]), compressor=None, - dtype=longitudes.dtype, - shape=longitudes.shape, - )[...] = longitudes + ) root.attrs["frequency"] = frequency_to_string(frequency) root.attrs["resolution"] = resolution @@ -213,42 +197,26 @@ def create_zarr( root.attrs["missing_dates"] = [d.isoformat() for d in missing_dates] - zarr_2_or_3.create_array( - root, + root.create_array( "mean", + data=np.mean(data, axis=0), compressor=None, - shape=data.shape[1:], - dtype=data.dtype, - )[ - ... - ] = np.mean(data, axis=0) - zarr_2_or_3.create_array( - root, + ) + root.create_array( "stdev", + data=np.std(data, axis=0), compressor=None, - shape=data.shape[1:], - dtype=data.dtype, - )[ - ... - ] = np.std(data, axis=0) - zarr_2_or_3.create_array( - root, + ) + root.create_array( "maximum", + data=np.max(data, axis=0), compressor=None, - shape=data.shape[1:], - dtype=data.dtype, - )[ - ... - ] = np.max(data, axis=0) - zarr_2_or_3.create_array( - root, + ) + root.create_array( "minimum", + data=np.min(data, axis=0), compressor=None, - shape=data.shape[1:], - dtype=data.dtype, - )[ - ... - ] = np.min(data, axis=0) + ) return root diff --git a/tests/test_data_gridded.py b/tests/test_data_gridded.py index b5995de55..e53311f3b 100644 --- a/tests/test_data_gridded.py +++ b/tests/test_data_gridded.py @@ -21,7 +21,6 @@ from anemoi.utils.dates import frequency_to_timedelta from anemoi.datasets import open_dataset -from anemoi.datasets.zarr_versions import zarr_2_or_3 VALUES = 20 @@ -142,29 +141,23 @@ def create_zarr( for e in range(ensembles): data[i, j, e] = _(date.astype(object), var, k, e, values) - zarr_2_or_3.create_array( - root, + root.create_array( "data", data=data, - dtype=data.dtype, chunks=data.shape, compressor=None, ) - # Store dates as ISO strings to avoid unsupported dtype in Zarr v3 - zarr_2_or_3.create_array( - root, + root.create_array( "dates", - data=np.array([str(d) for d in dates], dtype="U32"), + data=dates, compressor=None, ) - zarr_2_or_3.create_array( - root, + root.create_array( "latitudes", data=np.array([x + values for x in range(values)]), compressor=None, ) - zarr_2_or_3.create_array( - root, + root.create_array( "longitudes", data=np.array([x + values for x in range(values)]), compressor=None, @@ -189,26 +182,22 @@ def create_zarr( root.attrs["missing_dates"] = [d.isoformat() for d in missing_dates] - zarr_2_or_3.create_array( - root, + root.create_array( "mean", data=np.mean(data, axis=0), compressor=None, ) - zarr_2_or_3.create_array( - root, + root.create_array( "stdev", data=np.std(data, axis=0), compressor=None, ) - zarr_2_or_3.create_array( - root, + root.create_array( "maximum", data=np.max(data, axis=0), compressor=None, ) - zarr_2_or_3.create_array( - root, + root.create_array( "minimum", data=np.min(data, axis=0), compressor=None, From e4a85b9e42455a0d449ad74a2b15366ecdbada17 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Thu, 9 Oct 2025 09:25:26 +0000 Subject: [PATCH 16/16] update with new version of zarr3 --- src/anemoi/datasets/create/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/anemoi/datasets/create/__init__.py b/src/anemoi/datasets/create/__init__.py index 21f8ecc34..d0bec0b02 100644 --- a/src/anemoi/datasets/create/__init__.py +++ b/src/anemoi/datasets/create/__init__.py @@ -1517,7 +1517,7 @@ def run(self) -> None: LOG.info(stats) - if not all(self.registry.get_flags()): + if not all(self.registry.get_flags(sync=False)): raise Exception(f"❗Zarr {self.path} is not fully built, not writing statistics into dataset.") for k in [