Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions src/anemoi/datasets/data/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,13 +339,17 @@ def _concat_or_join(datasets: List["Dataset"], kwargs: Dict[str, Any]) -> Tuple[
return Concat(datasets), kwargs


def _open(a: Union[str, PurePath, Dict[str, Any], List[Any], Tuple[Any, ...]]) -> "Dataset":
def _open(
a: Union[str, PurePath, Dict[str, Any], List[Any], Tuple[Any, ...]], options: Optional[Dict[str, Any]] = None
) -> "Dataset":
"""Open a dataset from various input types.

Parameters
----------
a : Union[str, PurePath, Dict[str, Any], List[Any], Tuple[Any, ...]]
The input to open.
options : Optional[Dict[str, Any]]
Additional options for opening the dataset.

Returns
-------
Expand All @@ -372,19 +376,19 @@ def _open(a: Union[str, PurePath, Dict[str, Any], List[Any], Tuple[Any, ...]]) -
return a.mutate()

if isinstance(a, zarr.hierarchy.Group):
return Zarr(a).mutate()
return Zarr(a, options=options).mutate()

if isinstance(a, str):
return Zarr(zarr_lookup(a)).mutate()
return Zarr(zarr_lookup(a), options=options).mutate()

if isinstance(a, PurePath):
return _open(str(a)).mutate()
return _open(str(a), options=options).mutate()

if isinstance(a, dict):
return _open_dataset(**a).mutate()
return _open_dataset(**a, options=options).mutate()

if isinstance(a, (list, tuple)):
return _open_dataset(*a).mutate()
return _open_dataset(*a, options=options).mutate()

raise NotImplementedError(f"Unsupported argument: {type(a)}")

Expand Down Expand Up @@ -502,8 +506,9 @@ def _open_dataset(*args: Any, **kwargs: Any) -> "Dataset":
The opened dataset.
"""
sets = []
options = kwargs.pop("options", None)
for a in args:
sets.append(_open(a))
sets.append(_open(a, options=options))

if "observations" in kwargs:
from .observations import observations_factory
Expand Down
114 changes: 98 additions & 16 deletions src/anemoi/datasets/data/stores.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import logging
import os
import tempfile
import threading
import warnings
from functools import cached_property
from typing import Any
Expand All @@ -25,6 +26,7 @@
import numpy as np
import zarr
from anemoi.utils.dates import frequency_to_timedelta
from anemoi.utils.humanize import bytes_to_human
from numpy.typing import NDArray

from . import MissingDateError
Expand All @@ -45,6 +47,8 @@
class ReadOnlyStore(zarr.storage.BaseStore):
"""A base class for read-only stores."""

_store_version = 2

def __delitem__(self, key: str) -> None:
"""Prevent deletion of items."""
raise NotImplementedError()
Expand Down Expand Up @@ -136,6 +140,86 @@ def __contains__(self, key: str) -> bool:
return key in self.store


class CopyToSSDStore(ReadOnlyStore):
"""A store that copies data to SSD before reading."""

def __init__(self, store: ReadOnlyStore, options: Optional[Dict[str, Any]] = None) -> None:
"""Initialize the CopyToSSDStore with another store and options."""
self.store = store
self.options = options or {}
self.total_size = 0
self.copied_objects = 0
self.reused_objects = 0
self.key_cache = set()
self.path_cache = set()
self.lock = threading.Lock()

self.tmpdir = tempfile.TemporaryDirectory(
prefix="anemoi-datasets-ssd-",
dir=self.options.get(
"ssd-path",
os.getenv("TMPDIR", "/tmp"),
),
delete=self.options.get("ssd-delete", True),
)
print("CopyToSSDStore: using temporary directory", self.tmpdir.name)

def __del__(self) -> None:
print(f"CopyToSSDStore: total size copied: {bytes_to_human(self.total_size)}")
print(f"CopyToSSDStore: copied {self.copied_objects:,} objects, reused {self.reused_objects:,} objects")

def __getitem__(self, key: str) -> bytes:

with self.lock:

path = os.path.join(self.tmpdir.name, key)

if key in self.key_cache or os.path.exists(path):
self.key_cache.add(key)
self.reused_objects += 1
return open(path, "rb").read()

self.copied_objects += 1
value = self.store[key]

parent = os.path.dirname(path)
if parent not in self.path_cache:
os.makedirs(parent, exist_ok=True)
self.path_cache.add(parent)

with open(path, "wb") as f:
f.write(value)

self.total_size += len(value)
self.key_cache.add(key)

return value

def __len__(self) -> int:
"""Return the number of items in the store."""
return len(self.store)

def __iter__(self) -> iter:
"""Return an iterator over the store."""
warnings.warn("DebugStore: iterating over the store")
return iter(self.store)

def __contains__(self, key: str) -> bool:
"""Check if the store contains a key."""

with self.lock:

if key in self.key_cache:
return True

path = os.path.join(self.tmpdir.name, key)
if os.path.exists(path):
self.key_cache.add(key)
return True

return key in self.store


def name_to_zarr_store(path_or_url: str) -> ReadOnlyStore:
"""Convert a path or URL to a zarr store."""
store = path_or_url
Expand Down Expand Up @@ -170,27 +254,24 @@ def name_to_zarr_store(path_or_url: str) -> ReadOnlyStore:
return store


def open_zarr(path: str, dont_fail: bool = False, cache: int = None) -> zarr.hierarchy.Group:
def open_zarr(path: str, dont_fail: bool = False, options=None) -> zarr.hierarchy.Group:
"""Open a zarr store from a path."""

options = options or {}

try:
store = name_to_zarr_store(path)

if DEBUG_ZARR_LOADING:
if isinstance(store, str):
import os
if options.get("copy-to-ssd", False):
store = CopyToSSDStore(zarr.open(store, "r").store, options)

if not os.path.isdir(store):
raise NotImplementedError(
"DEBUG_ZARR_LOADING is only implemented for DirectoryStore. "
"Please disable it for other backends."
)
store = zarr.storage.DirectoryStore(store)
store = DebugStore(store)
if DEBUG_ZARR_LOADING:
store = DebugStore(
zarr.open(store, "r").store,
)

if cache is not None:
store = zarr.LRUStoreCache(store, max_size=cache)
return zarr.open(store, "r")

return zarr.convenience.open(store, "r")
except zarr.errors.PathNotFoundError:
if not dont_fail:
raise zarr.errors.PathNotFoundError(path)
Expand All @@ -199,16 +280,17 @@ def open_zarr(path: str, dont_fail: bool = False, cache: int = None) -> zarr.hie
class Zarr(Dataset):
"""A zarr dataset."""

def __init__(self, path: Union[str, zarr.hierarchy.Group]) -> None:
def __init__(self, path: Union[str, zarr.hierarchy.Group], options: Optional[Dict[str, Any]] = None) -> None:
"""Initialize the Zarr dataset with a path or zarr group."""
if isinstance(path, zarr.hierarchy.Group):
self.was_zarr = True
self.path = str(id(path))
self.z = path
assert not options, "Options are not supported for zarr groups"
else:
self.was_zarr = False
self.path = str(path)
self.z = open_zarr(self.path)
self.z = open_zarr(self.path, options=options)

# This seems to speed up the reading of the data a lot
self.data = self.z.data
Expand Down
2 changes: 1 addition & 1 deletion tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def mockup_open_zarr(func: Callable) -> Callable:

@wraps(func)
def wrapper(*args, **kwargs):
with patch("zarr.convenience.open", zarr_from_str):
with patch("zarr.open", zarr_from_str):
with patch("anemoi.datasets.data.stores.zarr_lookup", lambda name: name):
return func(*args, **kwargs)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_data_gridded.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def mockup_open_zarr(func: Callable) -> Callable:

@wraps(func)
def wrapper(*args, **kwargs):
with patch("zarr.convenience.open", zarr_from_str):
with patch("zarr.open", zarr_from_str):
with patch("anemoi.datasets.data.stores.zarr_lookup", lambda name: name):
return func(*args, **kwargs)

Expand Down
Loading