diff --git a/src/anemoi/datasets/data/misc.py b/src/anemoi/datasets/data/misc.py index b5523ef8..58b1e31e 100644 --- a/src/anemoi/datasets/data/misc.py +++ b/src/anemoi/datasets/data/misc.py @@ -339,13 +339,17 @@ def _concat_or_join(datasets: List["Dataset"], kwargs: Dict[str, Any]) -> Tuple[ return Concat(datasets), kwargs -def _open(a: Union[str, PurePath, Dict[str, Any], List[Any], Tuple[Any, ...]]) -> "Dataset": +def _open( + a: Union[str, PurePath, Dict[str, Any], List[Any], Tuple[Any, ...]], options: Optional[Dict[str, Any]] = None +) -> "Dataset": """Open a dataset from various input types. Parameters ---------- a : Union[str, PurePath, Dict[str, Any], List[Any], Tuple[Any, ...]] The input to open. + options : Optional[Dict[str, Any]] + Additional options for opening the dataset. Returns ------- @@ -372,19 +376,19 @@ def _open(a: Union[str, PurePath, Dict[str, Any], List[Any], Tuple[Any, ...]]) - return a.mutate() if isinstance(a, zarr.hierarchy.Group): - return Zarr(a).mutate() + return Zarr(a, options=options).mutate() if isinstance(a, str): - return Zarr(zarr_lookup(a)).mutate() + return Zarr(zarr_lookup(a), options=options).mutate() if isinstance(a, PurePath): - return _open(str(a)).mutate() + return _open(str(a), options=options).mutate() if isinstance(a, dict): - return _open_dataset(**a).mutate() + return _open_dataset(**a, options=options).mutate() if isinstance(a, (list, tuple)): - return _open_dataset(*a).mutate() + return _open_dataset(*a, options=options).mutate() raise NotImplementedError(f"Unsupported argument: {type(a)}") @@ -502,8 +506,9 @@ def _open_dataset(*args: Any, **kwargs: Any) -> "Dataset": The opened dataset. """ sets = [] + options = kwargs.pop("options", None) for a in args: - sets.append(_open(a)) + sets.append(_open(a, options=options)) if "observations" in kwargs: from .observations import observations_factory diff --git a/src/anemoi/datasets/data/stores.py b/src/anemoi/datasets/data/stores.py index 3c744248..b7a72404 100644 --- a/src/anemoi/datasets/data/stores.py +++ b/src/anemoi/datasets/data/stores.py @@ -12,6 +12,7 @@ import logging import os import tempfile +import threading import warnings from functools import cached_property from typing import Any @@ -25,6 +26,7 @@ import numpy as np import zarr from anemoi.utils.dates import frequency_to_timedelta +from anemoi.utils.humanize import bytes_to_human from numpy.typing import NDArray from . import MissingDateError @@ -45,6 +47,8 @@ class ReadOnlyStore(zarr.storage.BaseStore): """A base class for read-only stores.""" + _store_version = 2 + def __delitem__(self, key: str) -> None: """Prevent deletion of items.""" raise NotImplementedError() @@ -136,6 +140,86 @@ def __contains__(self, key: str) -> bool: return key in self.store +class CopyToSSDStore(ReadOnlyStore): + """A store that copies data to SSD before reading.""" + + def __init__(self, store: ReadOnlyStore, options: Optional[Dict[str, Any]] = None) -> None: + """Initialize the CopyToSSDStore with another store and options.""" + self.store = store + self.options = options or {} + self.total_size = 0 + self.copied_objects = 0 + self.reused_objects = 0 + self.key_cache = set() + self.path_cache = set() + self.lock = threading.Lock() + + self.tmpdir = tempfile.TemporaryDirectory( + prefix="anemoi-datasets-ssd-", + dir=self.options.get( + "ssd-path", + os.getenv("TMPDIR", "/tmp"), + ), + delete=self.options.get("ssd-delete", True), + ) + print("CopyToSSDStore: using temporary directory", self.tmpdir.name) + + def __del__(self) -> None: + print(f"CopyToSSDStore: total size copied: {bytes_to_human(self.total_size)}") + print(f"CopyToSSDStore: copied {self.copied_objects:,} objects, reused {self.reused_objects:,} objects") + + def __getitem__(self, key: str) -> bytes: + + with self.lock: + + path = os.path.join(self.tmpdir.name, key) + + if key in self.key_cache or os.path.exists(path): + self.key_cache.add(key) + self.reused_objects += 1 + return open(path, "rb").read() + + self.copied_objects += 1 + value = self.store[key] + + parent = os.path.dirname(path) + if parent not in self.path_cache: + os.makedirs(parent, exist_ok=True) + self.path_cache.add(parent) + + with open(path, "wb") as f: + f.write(value) + + self.total_size += len(value) + self.key_cache.add(key) + + return value + + def __len__(self) -> int: + """Return the number of items in the store.""" + return len(self.store) + + def __iter__(self) -> iter: + """Return an iterator over the store.""" + warnings.warn("DebugStore: iterating over the store") + return iter(self.store) + + def __contains__(self, key: str) -> bool: + """Check if the store contains a key.""" + + with self.lock: + + if key in self.key_cache: + return True + + path = os.path.join(self.tmpdir.name, key) + if os.path.exists(path): + self.key_cache.add(key) + return True + + return key in self.store + + def name_to_zarr_store(path_or_url: str) -> ReadOnlyStore: """Convert a path or URL to a zarr store.""" store = path_or_url @@ -170,27 +254,24 @@ def name_to_zarr_store(path_or_url: str) -> ReadOnlyStore: return store -def open_zarr(path: str, dont_fail: bool = False, cache: int = None) -> zarr.hierarchy.Group: +def open_zarr(path: str, dont_fail: bool = False, options=None) -> zarr.hierarchy.Group: """Open a zarr store from a path.""" + + options = options or {} + try: store = name_to_zarr_store(path) - if DEBUG_ZARR_LOADING: - if isinstance(store, str): - import os + if options.get("copy-to-ssd", False): + store = CopyToSSDStore(zarr.open(store, "r").store, options) - if not os.path.isdir(store): - raise NotImplementedError( - "DEBUG_ZARR_LOADING is only implemented for DirectoryStore. " - "Please disable it for other backends." - ) - store = zarr.storage.DirectoryStore(store) - store = DebugStore(store) + if DEBUG_ZARR_LOADING: + store = DebugStore( + zarr.open(store, "r").store, + ) - if cache is not None: - store = zarr.LRUStoreCache(store, max_size=cache) + return zarr.open(store, "r") - return zarr.convenience.open(store, "r") except zarr.errors.PathNotFoundError: if not dont_fail: raise zarr.errors.PathNotFoundError(path) @@ -199,16 +280,17 @@ def open_zarr(path: str, dont_fail: bool = False, cache: int = None) -> zarr.hie class Zarr(Dataset): """A zarr dataset.""" - def __init__(self, path: Union[str, zarr.hierarchy.Group]) -> None: + def __init__(self, path: Union[str, zarr.hierarchy.Group], options: Optional[Dict[str, Any]] = None) -> None: """Initialize the Zarr dataset with a path or zarr group.""" if isinstance(path, zarr.hierarchy.Group): self.was_zarr = True self.path = str(id(path)) self.z = path + assert not options, "Options are not supported for zarr groups" else: self.was_zarr = False self.path = str(path) - self.z = open_zarr(self.path) + self.z = open_zarr(self.path, options=options) # This seems to speed up the reading of the data a lot self.data = self.z.data diff --git a/tests/test_data.py b/tests/test_data.py index 07b35887..a482bf7a 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -57,7 +57,7 @@ def mockup_open_zarr(func: Callable) -> Callable: @wraps(func) def wrapper(*args, **kwargs): - with patch("zarr.convenience.open", zarr_from_str): + with patch("zarr.open", zarr_from_str): with patch("anemoi.datasets.data.stores.zarr_lookup", lambda name: name): return func(*args, **kwargs) diff --git a/tests/test_data_gridded.py b/tests/test_data_gridded.py index 5e473898..e56c3f11 100644 --- a/tests/test_data_gridded.py +++ b/tests/test_data_gridded.py @@ -44,7 +44,7 @@ def mockup_open_zarr(func: Callable) -> Callable: @wraps(func) def wrapper(*args, **kwargs): - with patch("zarr.convenience.open", zarr_from_str): + with patch("zarr.open", zarr_from_str): with patch("anemoi.datasets.data.stores.zarr_lookup", lambda name: name): return func(*args, **kwargs)