Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ dependencies = [
"ruamel-yaml",
"semantic-version",
"tqdm",
"zarr<=2.18.4",
"zarr>=3.1.3",
]

optional-dependencies.all = [
Expand Down Expand Up @@ -96,6 +96,7 @@ optional-dependencies.docs = [

optional-dependencies.remote = [
"boto3",
"obstore",
"requests",
]

Expand Down
7 changes: 3 additions & 4 deletions src/anemoi/datasets/commands/copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,11 +238,10 @@ def copy_data(self, source: Any, target: Any, _copy: Any, verbosity: int) -> Non
target_data = (
target["data"]
if "data" in target
else target.create_dataset(
else target.create_array(
"data",
shape=source_data.shape,
chunks=self.data_chunks,
dtype=source_data.dtype,
fill_value=source_data.fill_value,
)
)
Expand Down Expand Up @@ -344,7 +343,7 @@ def copy_group(self, source: Any, target: Any, _copy: Any, verbosity: int) -> No
LOG.info(f"Skipping {name}")
continue

if isinstance(source[name], zarr.hierarchy.Group):
if isinstance(source[name], zarr.Group):
group = target[name] if name in target else target.create_group(name)
self.copy_group(
source[name],
Expand Down Expand Up @@ -404,7 +403,7 @@ def target_exists() -> bool:
try:
zarr.open(self._store(self.target), mode="r")
return True
except ValueError:
except FileNotFoundError:
return False

def target_finished() -> bool:
Expand Down
6 changes: 3 additions & 3 deletions src/anemoi/datasets/commands/inspect.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,7 +647,7 @@ def ready(self) -> bool:
if "_build_flags" not in self.zarr:
return False

build_flags = self.zarr["_build_flags"]
build_flags = self.zarr["_build_flags"][:]
return all(build_flags)

@property
Expand Down Expand Up @@ -703,15 +703,15 @@ def build_flags(self) -> NDArray | None:
if "_build" not in self.zarr:
return None
build = self.zarr["_build"]
return build.get("flags")
return build.get("flags")[:]

@property
def build_lengths(self) -> NDArray | None:
"""Get the build lengths for the dataset."""
if "_build" not in self.zarr:
return None
build = self.zarr["_build"]
return build.get("lengths")
return build.get("lengths")[:]


VERSIONS = {
Expand Down
13 changes: 6 additions & 7 deletions src/anemoi/datasets/create/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,9 @@ def _path_readable(path: str) -> bool:
import zarr

try:
zarr.open(path, "r")
zarr.open(path, mode="r")
return True
except zarr.errors.PathNotFoundError:
except FileNotFoundError:
return False


Expand Down Expand Up @@ -190,10 +190,9 @@ def add_dataset(self, mode: str = "r+", **kwargs: Any) -> zarr.Array:
zarr.Array
The added dataset.
"""
import zarr

z = zarr.open(self.path, mode=mode)
from .zarr import add_zarr_dataset
from .misc import add_zarr_dataset

return add_zarr_dataset(zarr_root=z, **kwargs)

Expand All @@ -208,7 +207,7 @@ def update_metadata(self, **kwargs: Any) -> None:
import zarr

LOG.debug(f"Updating metadata {kwargs}")
z = zarr.open(self.path, mode="w+")
z = zarr.open(self.path, mode="a")
for k, v in kwargs.items():
if isinstance(v, np.datetime64):
v = v.astype(datetime.datetime)
Expand Down Expand Up @@ -443,7 +442,7 @@ def check_missing_dates(expected: list[np.datetime64]) -> None:
"""
import zarr

z = zarr.open(path, "r")
z = zarr.open(path, mode="r")
missing_dates = z.attrs.get("missing_dates", [])
missing_dates = sorted([np.datetime64(d) for d in missing_dates])
if missing_dates != expected:
Expand Down Expand Up @@ -515,7 +514,7 @@ class HasRegistryMixin:
@cached_property
def registry(self) -> Any:
"""Get the registry."""
from .zarr import ZarrBuiltRegistry
from .misc import ZarrBuiltRegistry

return ZarrBuiltRegistry(self.path, use_threads=self.use_threads)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@
# nor does it submit to any jurisdiction.

import datetime
import logging
import shutil
from typing import Any

import numpy as np
import zarr
from numpy.typing import NDArray

LOG = logging.getLogger(__name__)
from .synchronise import NoSynchroniser
from .synchronise import Synchroniser


def add_zarr_dataset(
Expand Down Expand Up @@ -71,8 +71,9 @@ def add_zarr_dataset(
shape = array.shape

if array is not None:

assert array.shape == shape, (array.shape, shape)
a = zarr_root.create_dataset(
a = zarr_root.create_array(
name,
shape=shape,
dtype=dtype,
Expand All @@ -99,7 +100,7 @@ def add_zarr_dataset(
else:
raise ValueError(f"No fill_value for dtype={dtype}")

a = zarr_root.create_dataset(
a = zarr_root.create_array(
name,
shape=shape,
dtype=dtype,
Expand Down Expand Up @@ -131,27 +132,15 @@ def __init__(self, path: str, synchronizer_path: str | None = None, use_threads:
use_threads : bool
Whether to use thread-based synchronization.
"""
import zarr

assert isinstance(path, str), path
self.zarr_path = path

if use_threads:
self.synchronizer = zarr.ThreadSynchronizer()
self.synchronizer_path = None
else:
if synchronizer_path is None:
synchronizer_path = self.zarr_path + ".sync"
self.synchronizer_path = synchronizer_path
self.synchronizer = zarr.ProcessSynchronizer(self.synchronizer_path)
self.synchronizer = Synchroniser(synchronizer_path) if synchronizer_path else NoSynchroniser()

def clean(self) -> None:
"""Clean up the synchronizer path."""
if self.synchronizer_path is not None:
try:
shutil.rmtree(self.synchronizer_path)
except FileNotFoundError:
pass
self.synchronizer.clean()

_build = self.zarr_path + "/_build"
try:
Expand All @@ -161,9 +150,7 @@ def clean(self) -> None:

def _open_write(self) -> zarr.Group:
"""Open the Zarr store in write mode."""
import zarr

return zarr.open(self.zarr_path, mode="r+", synchronizer=self.synchronizer)
return zarr.open(self.zarr_path, mode="r+")

def _open_read(self, sync: bool = True) -> zarr.Group:
"""Open the Zarr store in read mode.
Expand All @@ -178,12 +165,7 @@ def _open_read(self, sync: bool = True) -> zarr.Group:
zarr.Group
The opened Zarr group.
"""
import zarr

if sync:
return zarr.open(self.zarr_path, mode="r", synchronizer=self.synchronizer)
else:
return zarr.open(self.zarr_path, mode="r")
return zarr.open(self.zarr_path, mode="r")

def new_dataset(self, *args, **kwargs) -> None:
"""Create a new dataset in the Zarr store.
Expand All @@ -195,9 +177,11 @@ def new_dataset(self, *args, **kwargs) -> None:
**kwargs
Keyword arguments for dataset creation.
"""
z = self._open_write()
zarr_root = z["_build"]
add_zarr_dataset(*args, zarr_root=zarr_root, overwrite=True, dimensions=("tmp",), **kwargs)
with self.synchronizer:
z = self._open_write()
zarr_root = z["_build"]
add_zarr_dataset(*args, zarr_root=zarr_root, overwrite=True, dimensions=("tmp",), **kwargs)
del z

def add_to_history(self, action: str, **kwargs) -> None:
"""Add an action to the history attribute of the Zarr store.
Expand All @@ -215,10 +199,12 @@ def add_to_history(self, action: str, **kwargs) -> None:
)
new.update(kwargs)

z = self._open_write()
history = z.attrs.get("history", [])
history.append(new)
z.attrs["history"] = history
with self.synchronizer:
z = self._open_write()
history = z.attrs.get("history", [])
history.append(new)
z.attrs["history"] = history
del z

def get_lengths(self) -> list[int]:
"""Get the lengths dataset.
Expand All @@ -228,8 +214,11 @@ def get_lengths(self) -> list[int]:
list[int]
The lengths dataset.
"""
z = self._open_read()
return list(z["_build"][self.name_lengths][:])
with self.synchronizer:
z = self._open_read()
lengths = list(z["_build"][self.name_lengths][:])
del z
return lengths

def get_flags(self, **kwargs) -> list[bool]:
"""Get the flags dataset.
Expand All @@ -244,8 +233,11 @@ def get_flags(self, **kwargs) -> list[bool]:
list[bool]
The flags dataset.
"""
z = self._open_read(**kwargs)
return list(z["_build"][self.name_flags][:])
with self.synchronizer:
z = self._open_read(**kwargs)
flags = list(z["_build"][self.name_flags][:])
del z
return flags

def get_flag(self, i: int) -> bool:
"""Get a specific flag.
Expand All @@ -260,8 +252,11 @@ def get_flag(self, i: int) -> bool:
bool
The flag value.
"""
z = self._open_read()
return z["_build"][self.name_flags][i]
with self.synchronizer:
z = self._open_read()
flag = z["_build"][self.name_flags][i]
del z
return flag

def set_flag(self, i: int, value: bool = True) -> None:
"""Set a specific flag.
Expand All @@ -273,11 +268,13 @@ def set_flag(self, i: int, value: bool = True) -> None:
value : bool
Value to set the flag to.
"""
z = self._open_write()
z.attrs["latest_write_timestamp"] = (
datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None).isoformat()
)
z["_build"][self.name_flags][i] = value
with self.synchronizer:
z = self._open_write()
z.attrs["latest_write_timestamp"] = (
datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None).isoformat()
)
z["_build"][self.name_flags][i] = value
del z

def ready(self) -> bool:
"""Check if all flags are set.
Expand Down Expand Up @@ -321,11 +318,13 @@ def add_provenance(self, name: str) -> None:
name : str
Name of the provenance attribute.
"""
z = self._open_write()
with self.synchronizer:
z = self._open_write()

if name in z.attrs:
return
if name in z.attrs:
return

from anemoi.utils.provenance import gather_provenance_info
from anemoi.utils.provenance import gather_provenance_info

z.attrs[name] = gather_provenance_info()
z.attrs[name] = gather_provenance_info()
del z
2 changes: 1 addition & 1 deletion src/anemoi/datasets/create/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def apply_patch(path: str, verbose: bool = True, dry_run: bool = False) -> None:

try:
attrs = zarr.open(path, mode="r").attrs.asdict()
except zarr.errors.PathNotFoundError as e:
except FileNotFoundError as e:
LOG.error(f"Failed to open {path}")
LOG.error(e)
exit(0)
Expand Down
Loading
Loading