diff --git a/ci/minimum_versions.py b/ci/minimum_versions.py index 08808d002d9..21123bffcd6 100644 --- a/ci/minimum_versions.py +++ b/ci/minimum_versions.py @@ -30,6 +30,7 @@ "coveralls", "pip", "pytest", + "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-mypy-plugins", diff --git a/ci/requirements/all-but-dask.yml b/ci/requirements/all-but-dask.yml index 5f5db4a0f18..65780d91949 100644 --- a/ci/requirements/all-but-dask.yml +++ b/ci/requirements/all-but-dask.yml @@ -28,6 +28,7 @@ dependencies: - pip - pydap - pytest + - pytest-asyncio - pytest-cov - pytest-env - pytest-mypy-plugins diff --git a/ci/requirements/all-but-numba.yml b/ci/requirements/all-but-numba.yml index 712055a0ec2..23c38cc8267 100644 --- a/ci/requirements/all-but-numba.yml +++ b/ci/requirements/all-but-numba.yml @@ -41,6 +41,7 @@ dependencies: - pyarrow # pandas raises a deprecation warning without this, breaking doctests - pydap - pytest + - pytest-asyncio - pytest-cov - pytest-env - pytest-mypy-plugins diff --git a/ci/requirements/bare-min-and-scipy.yml b/ci/requirements/bare-min-and-scipy.yml index bb25af67651..d4a61586d82 100644 --- a/ci/requirements/bare-min-and-scipy.yml +++ b/ci/requirements/bare-min-and-scipy.yml @@ -7,6 +7,7 @@ dependencies: - coveralls - pip - pytest + - pytest-asyncio - pytest-cov - pytest-env - pytest-mypy-plugins diff --git a/ci/requirements/bare-minimum.yml b/ci/requirements/bare-minimum.yml index fafc1aa034a..777ff09b3e6 100644 --- a/ci/requirements/bare-minimum.yml +++ b/ci/requirements/bare-minimum.yml @@ -7,6 +7,7 @@ dependencies: - coveralls - pip - pytest + - pytest-asyncio - pytest-cov - pytest-env - pytest-mypy-plugins diff --git a/ci/requirements/environment-3.14.yml b/ci/requirements/environment-3.14.yml index 06c4df82663..d4d47d85536 100644 --- a/ci/requirements/environment-3.14.yml +++ b/ci/requirements/environment-3.14.yml @@ -37,6 +37,7 @@ dependencies: - pyarrow # pandas raises a deprecation warning without this, breaking doctests - pydap - pytest + - pytest-asyncio - pytest-cov - pytest-env - pytest-mypy-plugins diff --git a/ci/requirements/environment-windows-3.14.yml b/ci/requirements/environment-windows-3.14.yml index dd48add6b73..e86d57beb95 100644 --- a/ci/requirements/environment-windows-3.14.yml +++ b/ci/requirements/environment-windows-3.14.yml @@ -32,6 +32,7 @@ dependencies: - pyarrow # importing dask.dataframe raises an ImportError without this - pydap - pytest + - pytest-asyncio - pytest-cov - pytest-env - pytest-mypy-plugins diff --git a/ci/requirements/environment-windows.yml b/ci/requirements/environment-windows.yml index 3213ef687d3..7c0d4dd9231 100644 --- a/ci/requirements/environment-windows.yml +++ b/ci/requirements/environment-windows.yml @@ -32,6 +32,7 @@ dependencies: - pyarrow # importing dask.dataframe raises an ImportError without this - pydap - pytest + - pytest-asyncio - pytest-cov - pytest-env - pytest-mypy-plugins diff --git a/ci/requirements/environment.yml b/ci/requirements/environment.yml index cc33d8b4681..84441625e4c 100644 --- a/ci/requirements/environment.yml +++ b/ci/requirements/environment.yml @@ -39,6 +39,7 @@ dependencies: - pydap - pydap-server - pytest + - pytest-asyncio - pytest-cov - pytest-env - pytest-mypy-plugins diff --git a/ci/requirements/min-all-deps.yml b/ci/requirements/min-all-deps.yml index 9183433e801..add738630f1 100644 --- a/ci/requirements/min-all-deps.yml +++ b/ci/requirements/min-all-deps.yml @@ -40,6 +40,7 @@ dependencies: - pip - pydap=3.5.0 - pytest + - pytest-asyncio - pytest-cov - pytest-env - pytest-mypy-plugins diff --git a/doc/api-hidden.rst b/doc/api-hidden.rst index 5b9fa70d6b7..9f56fca1472 100644 --- a/doc/api-hidden.rst +++ b/doc/api-hidden.rst @@ -228,6 +228,7 @@ Variable.isnull Variable.item Variable.load + Variable.load_async Variable.max Variable.mean Variable.median diff --git a/doc/internals/how-to-add-new-backend.rst b/doc/internals/how-to-add-new-backend.rst index d3b5c3a9267..b5dfe3b5f8e 100644 --- a/doc/internals/how-to-add-new-backend.rst +++ b/doc/internals/how-to-add-new-backend.rst @@ -331,10 +331,12 @@ information on plugins. How to support lazy loading +++++++++++++++++++++++++++ -If you want to make your backend effective with big datasets, then you should -support lazy loading. -Basically, you shall replace the :py:class:`numpy.ndarray` inside the -variables with a custom class that supports lazy loading indexing. +If you want to make your backend effective with big datasets, then you should take advantage of xarray's +support for lazy loading and indexing. + +Basically, when your backend constructs the ``Variable`` objects, +you need to replace the :py:class:`numpy.ndarray` inside the +variables with a custom :py:class:`~xarray.backends.BackendArray` subclass that supports lazy loading and indexing. See the example below: .. code-block:: python @@ -345,25 +347,27 @@ See the example below: Where: -- :py:class:`~xarray.core.indexing.LazilyIndexedArray` is a class - provided by Xarray that manages the lazy loading. -- ``MyBackendArray`` shall be implemented by the backend and shall inherit +- :py:class:`~xarray.core.indexing.LazilyIndexedArray` is a wrapper class + provided by Xarray that manages the lazy loading and indexing. +- ``MyBackendArray`` should be implemented by the backend and must inherit from :py:class:`~xarray.backends.BackendArray`. BackendArray subclassing ^^^^^^^^^^^^^^^^^^^^^^^^ -The BackendArray subclass shall implement the following method and attributes: +The BackendArray subclass must implement the following method and attributes: -- the ``__getitem__`` method that takes in input an index and returns a - `NumPy `__ array -- the ``shape`` attribute +- the ``__getitem__`` method that takes an index as an input and returns a + `NumPy `__ array, +- the ``shape`` attribute, - the ``dtype`` attribute. -Xarray supports different type of :doc:`/user-guide/indexing`, that can be -grouped in three types of indexes +It may also optionally implement an additional ``async_getitem`` method. + +Xarray supports different types of :doc:`/user-guide/indexing`, that can be +grouped in three types of indexes: :py:class:`~xarray.core.indexing.BasicIndexer`, -:py:class:`~xarray.core.indexing.OuterIndexer` and +:py:class:`~xarray.core.indexing.OuterIndexer`, and :py:class:`~xarray.core.indexing.VectorizedIndexer`. This implies that the implementation of the method ``__getitem__`` can be tricky. In order to simplify this task, Xarray provides a helper function, @@ -419,8 +423,22 @@ input the ``key``, the array ``shape`` and the following parameters: For more details see :py:class:`~xarray.core.indexing.IndexingSupport` and :ref:`RST indexing`. +Async support +^^^^^^^^^^^^^ + +Backends can also optionally support loading data asynchronously via xarray's asynchronous loading methods +(e.g. ``~xarray.Dataset.load_async``). +To support async loading the ``BackendArray`` subclass must additionally implement the ``BackendArray.async_getitem`` method. + +Note that implementing this method is only necessary if you want to be able to load data from different xarray objects concurrently. +Even without this method your ``BackendArray`` implementation is still free to concurrently load chunks of data for a single ``Variable`` itself, +so long as it does so behind the synchronous ``__getitem__`` interface. + +Dask support +^^^^^^^^^^^^ + In order to support `Dask Distributed `__ and -:py:mod:`multiprocessing`, ``BackendArray`` subclass should be serializable +:py:mod:`multiprocessing`, the ``BackendArray`` subclass should be serializable either with :ref:`io.pickle` or `cloudpickle `__. That implies that all the reference to open files should be dropped. For diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 4df7c29c51c..19a47aba144 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -16,6 +16,9 @@ New Features Useful for cleaning up DataTree after time-based filtering operations (:issue:`10590`, :pull:`10598`). By `Alfonso Ladino `_. +- Added new asynchronous loading methods :py:meth:`Dataset.load_async`, :py:meth:`DataArray.load_async`, :py:meth:`Variable.load_async`. + Note that users are expected to limit concurrency themselves - xarray does not internally limit concurrency in any way. + (:issue:`10326`, :pull:`10327`) By `Tom Nicholas `_. - :py:meth:`DataTree.to_netcdf` can now write to a file-like object, or return bytes if called without a filepath. (:issue:`10570`) By `Matthew Willson `_. - Added exception handling for invalid files in :py:func:`open_mfdataset`. (:issue:`6736`) @@ -50,12 +53,10 @@ Deprecations Bug fixes ~~~~~~~~~ - - Fix Pydap Datatree backend testing. Testing now compares elements of (unordered) two sets (before, lists) (:pull:`10525`). By `Miguel Jimenez-Urias `_. - Fix ``KeyError`` when passing a ``dim`` argument different from the default to ``convert_calendar`` (:pull:`10544`). By `Eric Jansen `_. - - Fix transpose of boolean arrays read from disk. (:issue:`10536`) By `Deepak Cherian `_. - Fix detection of the ``h5netcdf`` backend. Xarray now selects ``h5netcdf`` if the default ``netCDF4`` engine is not available (:issue:`10401`, :pull:`10557`). diff --git a/pyproject.toml b/pyproject.toml index bc899596b4c..7426ff05518 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -80,6 +80,7 @@ dev = [ "pytest-mypy-plugins", "pytest-timeout", "pytest-xdist", + "pytest-asyncio", "ruff>=0.8.0", "sphinx", "sphinx_autosummary_accessors", diff --git a/xarray/backends/common.py b/xarray/backends/common.py index b94b11120ae..5e9be7571fb 100644 --- a/xarray/backends/common.py +++ b/xarray/backends/common.py @@ -311,10 +311,17 @@ def robust_getitem(array, key, catch=Exception, max_retries=6, initial_delay=500 class BackendArray(NdimSizeLenMixin, indexing.ExplicitlyIndexed): __slots__ = () + async def async_getitem(self, key: indexing.ExplicitIndexer) -> np.typing.ArrayLike: + raise NotImplementedError("Backend does not not support asynchronous loading") + def get_duck_array(self, dtype: np.typing.DTypeLike = None): key = indexing.BasicIndexer((slice(None),) * self.ndim) return self[key] # type: ignore[index] + async def async_get_duck_array(self, dtype: np.typing.DTypeLike = None): + key = indexing.BasicIndexer((slice(None),) * self.ndim) + return await self.async_getitem(key) + class AbstractDataStore: __slots__ = () diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py index 425f72c591c..65ba5ac044b 100644 --- a/xarray/backends/zarr.py +++ b/xarray/backends/zarr.py @@ -180,12 +180,23 @@ def encode_zarr_attr_value(value): return encoded +def has_zarr_async_index() -> bool: + try: + import zarr + + return hasattr(zarr.AsyncArray, "oindex") + except (ImportError, AttributeError): + return False + + class ZarrArrayWrapper(BackendArray): __slots__ = ("_array", "dtype", "shape") def __init__(self, zarr_array): # some callers attempt to evaluate an array if an `array` property exists on the object. # we prefix with _ to avoid this inference. + + # TODO type hint this? self._array = zarr_array self.shape = self._array.shape @@ -213,6 +224,33 @@ def _vindex(self, key): def _getitem(self, key): return self._array[key] + async def _async_getitem(self, key): + if not _zarr_v3(): + raise NotImplementedError( + "For lazy basic async indexing with zarr, zarr-python=>v3.0.0 is required" + ) + + async_array = self._array._async_array + return await async_array.getitem(key) + + async def _async_oindex(self, key): + if not has_zarr_async_index(): + raise NotImplementedError( + "For lazy orthogonal async indexing with zarr, zarr-python=>v3.1.2 is required" + ) + + async_array = self._array._async_array + return await async_array.oindex.getitem(key) + + async def _async_vindex(self, key): + if not has_zarr_async_index(): + raise NotImplementedError( + "For lazy vectorized async indexing with zarr, zarr-python=>v3.1.2 is required" + ) + + async_array = self._array._async_array + return await async_array.vindex.getitem(key) + def __getitem__(self, key): array = self._array if isinstance(key, indexing.BasicIndexer): @@ -228,6 +266,18 @@ def __getitem__(self, key): # if self.ndim == 0: # could possibly have a work-around for 0d data here + async def async_getitem(self, key): + array = self._array + if isinstance(key, indexing.BasicIndexer): + method = self._async_getitem + elif isinstance(key, indexing.VectorizedIndexer): + method = self._async_vindex + elif isinstance(key, indexing.OuterIndexer): + method = self._async_oindex + return await indexing.async_explicit_indexing_adapter( + key, array.shape, indexing.IndexingSupport.VECTORIZED, method + ) + def _determine_zarr_chunks(enc_chunks, var_chunks, ndim, name): """ diff --git a/xarray/coding/common.py b/xarray/coding/common.py index 0e8d7e1955e..79e5e7502b3 100644 --- a/xarray/coding/common.py +++ b/xarray/coding/common.py @@ -79,6 +79,9 @@ def __getitem__(self, key): def get_duck_array(self): return self.func(self.array.get_duck_array()) + async def async_get_duck_array(self): + return self.func(await self.array.async_get_duck_array()) + def __repr__(self) -> str: return f"{type(self).__name__}({self.array!r}, func={self.func!r}, dtype={self.dtype!r})" diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 98979ce05d7..b1833d3266f 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -1135,10 +1135,11 @@ def _dask_finalize(cls, results, name, func, *args, **kwargs) -> Self: return cls(variable, coords, name=name, indexes=indexes, fastpath=True) def load(self, **kwargs) -> Self: - """Manually trigger loading of this array's data from disk or a - remote source into memory and return this array. + """Trigger loading data into memory and return this dataarray. - Unlike compute, the original dataset is modified and returned. + Data will be computed and/or loaded from disk or a remote source. + + Unlike ``.compute``, the original dataarray is modified and returned. Normally, it should not be necessary to call this method in user code, because all xarray functions should either work on deferred data or @@ -1150,9 +1151,18 @@ def load(self, **kwargs) -> Self: **kwargs : dict Additional keyword arguments passed on to ``dask.compute``. + Returns + ------- + object : DataArray + Same object but with lazy data and coordinates as in-memory arrays. + See Also -------- dask.compute + DataArray.load_async + DataArray.compute + Dataset.load + Variable.load """ ds = self._to_temp_dataset().load(**kwargs) new = self._from_temp_dataset(ds) @@ -1160,11 +1170,49 @@ def load(self, **kwargs) -> Self: self._coords = new._coords return self + async def load_async(self, **kwargs) -> Self: + """Trigger and await asynchronous loading of data into memory and return this dataarray. + + Data will be computed and/or loaded from disk or a remote source. + + Unlike ``.compute``, the original dataarray is modified and returned. + + Only works when opening data lazily from IO storage backends which support lazy asynchronous loading. + Otherwise will raise a NotImplementedError. + + Note users are expected to limit concurrency themselves - xarray does not internally limit concurrency in any way. + + Parameters + ---------- + **kwargs : dict + Additional keyword arguments passed on to ``dask.compute``. + + Returns + ------- + object : Dataarray + Same object but with lazy data and coordinates as in-memory arrays. + + See Also + -------- + dask.compute + DataArray.compute + DataArray.load + Dataset.load_async + Variable.load_async + """ + temp_ds = self._to_temp_dataset() + ds = await temp_ds.load_async(**kwargs) + new = self._from_temp_dataset(ds) + self._variable = new._variable + self._coords = new._coords + return self + def compute(self, **kwargs) -> Self: - """Manually trigger loading of this array's data from disk or a - remote source into memory and return a new array. + """Trigger loading data into memory and return a new dataarray. + + Data will be computed and/or loaded from disk or a remote source. - Unlike load, the original is left unaltered. + Unlike ``.load``, the original dataarray is left unaltered. Normally, it should not be necessary to call this method in user code, because all xarray functions should either work on deferred data or @@ -1184,6 +1232,10 @@ def compute(self, **kwargs) -> Self: See Also -------- dask.compute + DataArray.load + DataArray.load_async + Dataset.compute + Variable.compute """ new = self.copy(deep=False) return new.load(**kwargs) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 0b1d9835cf5..f51853fc8fd 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import copy import datetime import io @@ -520,9 +521,11 @@ def dtypes(self) -> Frozen[Hashable, np.dtype]: ) def load(self, **kwargs) -> Self: - """Manually trigger loading and/or computation of this dataset's data - from disk or a remote source into memory and return this dataset. - Unlike compute, the original dataset is modified and returned. + """Trigger loading data into memory and return this dataset. + + Data will be computed and/or loaded from disk or a remote source. + + Unlike ``.compute``, the original dataset is modified and returned. Normally, it should not be necessary to call this method in user code, because all xarray functions should either work on deferred data or @@ -534,29 +537,93 @@ def load(self, **kwargs) -> Self: **kwargs : dict Additional keyword arguments passed on to ``dask.compute``. + Returns + ------- + object : Dataset + Same object but with lazy data variables and coordinates as in-memory arrays. + See Also -------- dask.compute + Dataset.compute + Dataset.load_async + DataArray.load + Variable.load """ # access .data to coerce everything to numpy or dask arrays - lazy_data = { + chunked_data = { k: v._data for k, v in self.variables.items() if is_chunked_array(v._data) } - if lazy_data: - chunkmanager = get_chunked_array_type(*lazy_data.values()) + if chunked_data: + chunkmanager = get_chunked_array_type(*chunked_data.values()) # evaluate all the chunked arrays simultaneously evaluated_data: tuple[np.ndarray[Any, Any], ...] = chunkmanager.compute( - *lazy_data.values(), **kwargs + *chunked_data.values(), **kwargs ) - for k, data in zip(lazy_data, evaluated_data, strict=False): + for k, data in zip(chunked_data, evaluated_data, strict=False): self.variables[k].data = data # load everything else sequentially - for k, v in self.variables.items(): - if k not in lazy_data: - v.load() + [v.load() for k, v in self.variables.items() if k not in chunked_data] + + return self + + async def load_async(self, **kwargs) -> Self: + """Trigger and await asynchronous loading of data into memory and return this dataset. + + Data will be computed and/or loaded from disk or a remote source. + + Unlike ``.compute``, the original dataset is modified and returned. + + Only works when opening data lazily from IO storage backends which support lazy asynchronous loading. + Otherwise will raise a NotImplementedError. + + Note users are expected to limit concurrency themselves - xarray does not internally limit concurrency in any way. + + Parameters + ---------- + **kwargs : dict + Additional keyword arguments passed on to ``dask.compute``. + + Returns + ------- + object : Dataset + Same object but with lazy data variables and coordinates as in-memory arrays. + + See Also + -------- + dask.compute + Dataset.compute + Dataset.load + DataArray.load_async + Variable.load_async + """ + # TODO refactor this to pull out the common chunked_data codepath + + # this blocks on chunked arrays but not on lazily indexed arrays + + # access .data to coerce everything to numpy or dask arrays + chunked_data = { + k: v._data for k, v in self.variables.items() if is_chunked_array(v._data) + } + if chunked_data: + chunkmanager = get_chunked_array_type(*chunked_data.values()) + + # evaluate all the chunked arrays simultaneously + evaluated_data: tuple[np.ndarray[Any, Any], ...] = chunkmanager.compute( + *chunked_data.values(), **kwargs + ) + + for k, data in zip(chunked_data, evaluated_data, strict=False): + self.variables[k].data = data + + # load everything else concurrently + coros = [ + v.load_async() for k, v in self.variables.items() if k not in chunked_data + ] + await asyncio.gather(*coros) return self @@ -695,9 +762,11 @@ def _dask_postpersist( ) def compute(self, **kwargs) -> Self: - """Manually trigger loading and/or computation of this dataset's data - from disk or a remote source into memory and return a new dataset. - Unlike load, the original dataset is left unaltered. + """Trigger loading data into memory and return a new dataset. + + Data will be computed and/or loaded from disk or a remote source. + + Unlike ``.load``, the original dataset is left unaltered. Normally, it should not be necessary to call this method in user code, because all xarray functions should either work on deferred data or @@ -717,6 +786,10 @@ def compute(self, **kwargs) -> Self: See Also -------- dask.compute + Dataset.load + Dataset.load_async + DataArray.compute + Variable.compute """ new = self.copy(deep=False) return new.load(**kwargs) diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index 7d7f9335cb2..16276b5b090 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -586,8 +586,10 @@ class ExplicitlyIndexedNDArrayMixin(NDArrayMixin, ExplicitlyIndexed): __slots__ = () def get_duck_array(self): - key = BasicIndexer((slice(None),) * self.ndim) - return self[key] + raise NotImplementedError + + async def async_get_duck_array(self): + raise NotImplementedError def _oindex_get(self, indexer: OuterIndexer): raise NotImplementedError( @@ -625,6 +627,22 @@ def vindex(self) -> IndexCallable: return IndexCallable(self._vindex_get, self._vindex_set) +class IndexingAdapter(ExplicitlyIndexedNDArrayMixin): + """Marker class for indexing adapters. + + These classes translate between Xarray's indexing semantics and the underlying array's + indexing semantics. + """ + + def get_duck_array(self): + key = BasicIndexer((slice(None),) * self.ndim) + return self[key] + + async def async_get_duck_array(self): + """These classes are applied to in-memory arrays, so specific async support isn't needed.""" + return self.get_duck_array() + + class ImplicitToExplicitIndexingAdapter(NDArrayMixin): """Wrap an array, converting tuples into the indicated explicit indexer.""" @@ -713,19 +731,25 @@ def shape(self) -> _Shape: return self._shape def get_duck_array(self): - if isinstance(self.array, ExplicitlyIndexedNDArrayMixin): - array = apply_indexer(self.array, self.key) - else: - # If the array is not an ExplicitlyIndexedNDArrayMixin, - # it may wrap a BackendArray so use its __getitem__ + from xarray.backends.common import BackendArray + + if isinstance(self.array, BackendArray): array = self.array[self.key] + else: + array = apply_indexer(self.array, self.key) + if isinstance(array, ExplicitlyIndexed): + array = array.get_duck_array() + return _wrap_numpy_scalars(array) - # self.array[self.key] is now a numpy array when - # self.array is a BackendArray subclass - # and self.key is BasicIndexer((slice(None, None, None),)) - # so we need the explicit check for ExplicitlyIndexed - if isinstance(array, ExplicitlyIndexed): - array = array.get_duck_array() + async def async_get_duck_array(self): + from xarray.backends.common import BackendArray + + if isinstance(self.array, BackendArray): + array = await self.array.async_getitem(self.key) + else: + array = apply_indexer(self.array, self.key) + if isinstance(array, ExplicitlyIndexed): + array = await array.async_get_duck_array() return _wrap_numpy_scalars(array) def transpose(self, order): @@ -789,18 +813,25 @@ def shape(self) -> _Shape: return np.broadcast(*self.key.tuple).shape def get_duck_array(self): - if isinstance(self.array, ExplicitlyIndexedNDArrayMixin): + from xarray.backends.common import BackendArray + + if isinstance(self.array, BackendArray): + array = self.array[self.key] + else: array = apply_indexer(self.array, self.key) + if isinstance(array, ExplicitlyIndexed): + array = array.get_duck_array() + return _wrap_numpy_scalars(array) + + async def async_get_duck_array(self): + from xarray.backends.common import BackendArray + + if isinstance(self.array, BackendArray): + array = await self.array.async_getitem(self.key) else: - # If the array is not an ExplicitlyIndexedNDArrayMixin, - # it may wrap a BackendArray so use its __getitem__ - array = self.array[self.key] - # self.array[self.key] is now a numpy array when - # self.array is a BackendArray subclass - # and self.key is BasicIndexer((slice(None, None, None),)) - # so we need the explicit check for ExplicitlyIndexed - if isinstance(array, ExplicitlyIndexed): - array = array.get_duck_array() + array = apply_indexer(self.array, self.key) + if isinstance(array, ExplicitlyIndexed): + array = await array.async_get_duck_array() return _wrap_numpy_scalars(array) def _updated_key(self, new_key: ExplicitIndexer): @@ -865,6 +896,9 @@ def _ensure_copied(self): def get_duck_array(self): return self.array.get_duck_array() + async def async_get_duck_array(self): + return await self.array.async_get_duck_array() + def _oindex_get(self, indexer: OuterIndexer): return type(self)(_wrap_numpy_scalars(self.array.oindex[indexer])) @@ -905,12 +939,17 @@ class MemoryCachedArray(ExplicitlyIndexedNDArrayMixin): def __init__(self, array): self.array = _wrap_numpy_scalars(as_indexable(array)) - def _ensure_cached(self): - self.array = as_indexable(self.array.get_duck_array()) - def get_duck_array(self): - self._ensure_cached() - return self.array.get_duck_array() + duck_array = self.array.get_duck_array() + # ensure the array object is cached in-memory + self.array = as_indexable(duck_array) + return duck_array + + async def async_get_duck_array(self): + duck_array = await self.array.async_get_duck_array() + # ensure the array object is cached in-memory + self.array = as_indexable(duck_array) + return duck_array def _oindex_get(self, indexer: OuterIndexer): return type(self)(_wrap_numpy_scalars(self.array.oindex[indexer])) @@ -1095,6 +1134,21 @@ def explicit_indexing_adapter( return result +async def async_explicit_indexing_adapter( + key: ExplicitIndexer, + shape: _Shape, + indexing_support: IndexingSupport, + raw_indexing_method: Callable[..., Any], +) -> Any: + raw_key, numpy_indices = decompose_indexer(key, shape, indexing_support) + result = await raw_indexing_method(raw_key.tuple) + if numpy_indices.tuple: + # index the loaded duck array + indexable = as_indexable(result) + result = apply_indexer(indexable, numpy_indices) + return result + + def apply_indexer(indexable, indexer: ExplicitIndexer): """Apply an indexer to an indexable object.""" if isinstance(indexer, VectorizedIndexer): @@ -1594,7 +1648,7 @@ def is_fancy_indexer(indexer: Any) -> bool: return True -class NumpyIndexingAdapter(ExplicitlyIndexedNDArrayMixin): +class NumpyIndexingAdapter(IndexingAdapter): """Wrap a NumPy array to use explicit indexing.""" __slots__ = ("array",) @@ -1673,7 +1727,7 @@ def __init__(self, array): self.array = array -class ArrayApiIndexingAdapter(ExplicitlyIndexedNDArrayMixin): +class ArrayApiIndexingAdapter(IndexingAdapter): """Wrap an array API array to use explicit indexing.""" __slots__ = ("array",) @@ -1738,7 +1792,7 @@ def _assert_not_chunked_indexer(idxr: tuple[Any, ...]) -> None: ) -class DaskIndexingAdapter(ExplicitlyIndexedNDArrayMixin): +class DaskIndexingAdapter(IndexingAdapter): """Wrap a dask array to support explicit indexing.""" __slots__ = ("array",) @@ -1814,7 +1868,7 @@ def transpose(self, order): return self.array.transpose(order) -class PandasIndexingAdapter(ExplicitlyIndexedNDArrayMixin): +class PandasIndexingAdapter(IndexingAdapter): """Wrap a pandas.Index to preserve dtypes and handle explicit indexing.""" __slots__ = ("_dtype", "array") @@ -2071,7 +2125,7 @@ def copy(self, deep: bool = True) -> Self: return type(self)(array, self._dtype, self.level) -class CoordinateTransformIndexingAdapter(ExplicitlyIndexedNDArrayMixin): +class CoordinateTransformIndexingAdapter(IndexingAdapter): """Wrap a CoordinateTransform as a lazy coordinate array. Supports explicit indexing (both outer and vectorized). diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 06d7218fe7c..cc502e17d2e 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -50,6 +50,7 @@ from xarray.namedarray.core import NamedArray, _raise_if_any_duplicate_dimensions from xarray.namedarray.parallelcompat import get_chunked_array_type from xarray.namedarray.pycompat import ( + async_to_duck_array, integer_types, is_0d_dask_array, is_chunked_array, @@ -975,9 +976,12 @@ def _replace( encoding = copy.copy(self._encoding) return type(self)(dims, data, attrs, encoding, fastpath=True) - def load(self, **kwargs): - """Manually trigger loading of this variable's data from disk or a - remote source into memory and return this variable. + def load(self, **kwargs) -> Self: + """Trigger loading data into memory and return this variable. + + Data will be computed and/or loaded from disk or a remote source. + + Unlike ``.compute``, the original variable is modified and returned. Normally, it should not be necessary to call this method in user code, because all xarray functions should either work on deferred data or @@ -988,17 +992,61 @@ def load(self, **kwargs): **kwargs : dict Additional keyword arguments passed on to ``dask.array.compute``. + Returns + ------- + object : Variable + Same object but with lazy data as an in-memory array. + See Also -------- dask.array.compute + Variable.compute + Variable.load_async + DataArray.load + Dataset.load """ self._data = to_duck_array(self._data, **kwargs) return self - def compute(self, **kwargs): - """Manually trigger loading of this variable's data from disk or a - remote source into memory and return a new variable. The original is - left unaltered. + async def load_async(self, **kwargs) -> Self: + """Trigger and await asynchronous loading of data into memory and return this variable. + + Data will be computed and/or loaded from disk or a remote source. + + Unlike ``.compute``, the original variable is modified and returned. + + Only works when opening data lazily from IO storage backends which support lazy asynchronous loading. + Otherwise will raise a NotImplementedError. + + Note users are expected to limit concurrency themselves - xarray does not internally limit concurrency in any way. + + Parameters + ---------- + **kwargs : dict + Additional keyword arguments passed on to ``dask.array.compute``. + + Returns + ------- + object : Variable + Same object but with lazy data as an in-memory array. + + See Also + -------- + dask.array.compute + Variable.load + Variable.compute + DataArray.load_async + Dataset.load_async + """ + self._data = await async_to_duck_array(self._data, **kwargs) + return self + + def compute(self, **kwargs) -> Self: + """Trigger loading data into memory and return a new variable. + + Data will be computed and/or loaded from disk or a remote source. + + The original variable is left unaltered. Normally, it should not be necessary to call this method in user code, because all xarray functions should either work on deferred data or @@ -1009,9 +1057,18 @@ def compute(self, **kwargs): **kwargs : dict Additional keyword arguments passed on to ``dask.array.compute``. + Returns + ------- + object : Variable + New object with the data as an in-memory array. + See Also -------- dask.array.compute + Variable.load + Variable.load_async + DataArray.compute + Dataset.compute """ new = self.copy(deep=False) return new.load(**kwargs) @@ -2700,6 +2757,10 @@ def load(self): # data is already loaded into memory for IndexVariable return self + async def load_async(self): + # data is already loaded into memory for IndexVariable + return self + # https://github.com/python/mypy/issues/1465 @Variable.data.setter # type: ignore[attr-defined] def data(self, data): diff --git a/xarray/namedarray/pycompat.py b/xarray/namedarray/pycompat.py index 68b6a7853bf..5832f7cc9e7 100644 --- a/xarray/namedarray/pycompat.py +++ b/xarray/namedarray/pycompat.py @@ -145,3 +145,17 @@ def to_duck_array(data: Any, **kwargs: dict[str, Any]) -> duckarray[_ShapeType, return data else: return np.asarray(data) # type: ignore[return-value] + + +async def async_to_duck_array( + data: Any, **kwargs: dict[str, Any] +) -> duckarray[_ShapeType, _DType]: + from xarray.core.indexing import ( + ExplicitlyIndexed, + ImplicitToExplicitIndexingAdapter, + ) + + if isinstance(data, ExplicitlyIndexed | ImplicitToExplicitIndexingAdapter): + return await data.async_get_duck_array() # type: ignore[union-attr, no-any-return] + else: + return to_duck_array(data, **kwargs) diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py index 787c01eaf62..3b4e49c64d8 100644 --- a/xarray/tests/__init__.py +++ b/xarray/tests/__init__.py @@ -132,6 +132,7 @@ def _importorskip( has_zarr, requires_zarr = _importorskip("zarr") has_zarr_v3, requires_zarr_v3 = _importorskip("zarr", "3.0.0") has_zarr_v3_dtypes, requires_zarr_v3_dtypes = _importorskip("zarr", "3.1.0") +has_zarr_v3_async_oindex, requires_zarr_v3_async_oindex = _importorskip("zarr", "3.1.2") if has_zarr_v3: import zarr @@ -140,10 +141,15 @@ def _importorskip( # installing from git main is giving me a lower version than the # most recently released zarr has_zarr_v3_dtypes = hasattr(zarr.core, "dtype") + has_zarr_v3_async_oindex = hasattr(zarr.AsyncArray, "oindex") requires_zarr_v3_dtypes = pytest.mark.skipif( not has_zarr_v3_dtypes, reason="requires zarr>3.1.0" ) + requires_zarr_v3_async_oindex = pytest.mark.skipif( + not has_zarr_v3_async_oindex, reason="requires zarr>3.1.1" + ) + has_fsspec, requires_fsspec = _importorskip("fsspec") has_iris, requires_iris = _importorskip("iris") diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index eb98df5229f..c336fe7bd0d 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import contextlib import gzip import itertools @@ -16,6 +17,7 @@ from collections import ChainMap from collections.abc import Generator, Iterator, Mapping from contextlib import ExitStack +from importlib import import_module from io import BytesIO from pathlib import Path from typing import TYPE_CHECKING, Any, Final, Literal, cast @@ -28,6 +30,7 @@ from pandas.errors import OutOfBoundsDatetime import xarray as xr +import xarray.testing as xrt from xarray import ( DataArray, Dataset, @@ -74,9 +77,11 @@ has_scipy, has_zarr, has_zarr_v3, + has_zarr_v3_async_oindex, has_zarr_v3_dtypes, mock, network, + parametrize_zarr_format, requires_cftime, requires_dask, requires_fsspec, @@ -348,6 +353,11 @@ def __getitem__(self, key): class NetCDF3Only: netcdf3_formats: tuple[T_NetcdfTypes, ...] = ("NETCDF3_CLASSIC", "NETCDF3_64BIT") + @pytest.mark.asyncio + @pytest.mark.skip(reason="NetCDF backends don't support async loading") + async def test_load_async(self) -> None: + pass + @requires_scipy def test_dtype_coercion_error(self) -> None: """Failing dtype coercion should lead to an error""" @@ -458,6 +468,7 @@ def test_roundtrip_test_data(self) -> None: assert_identical(expected, actual) def test_load(self) -> None: + # Note: please keep this in sync with test_load_async below as much as possible! expected = create_test_data() @contextlib.contextmanager @@ -490,6 +501,43 @@ def assert_loads(vars=None): actual = ds.load() assert_identical(expected, actual) + @pytest.mark.asyncio + async def test_load_async(self) -> None: + # Note: please keep this in sync with test_load above as much as possible! + + # Copied from `test_load` on the base test class, but won't work for netcdf + expected = create_test_data() + + @contextlib.contextmanager + def assert_loads(vars=None): + if vars is None: + vars = expected + with self.roundtrip(expected) as actual: + for k, v in actual.variables.items(): + # IndexVariables are eagerly loaded into memory + assert v._in_memory == (k in actual.dims) + yield actual + for k, v in actual.variables.items(): + if k in vars: + assert v._in_memory + assert_identical(expected, actual) + + with pytest.raises(AssertionError): + # make sure the contextmanager works! + with assert_loads() as ds: + pass + + with assert_loads() as ds: + await ds.load_async() + + with assert_loads(["var1", "dim1", "dim2"]) as ds: + await ds["var1"].load_async() + + # verify we can read data even after closing the file + with self.roundtrip(expected) as ds: + actual = await ds.load_async() + assert_identical(expected, actual) + def test_dataset_compute(self) -> None: expected = create_test_data() @@ -1521,6 +1569,11 @@ def test_indexing_roundtrip(self, indexer) -> None: class NetCDFBase(CFEncodedBase): """Tests for all netCDF3 and netCDF4 backends.""" + @pytest.mark.asyncio + @pytest.mark.skip(reason="NetCDF backends don't support async loading") + async def test_load_async(self) -> None: + await super().test_load_async() + @pytest.mark.skipif( ON_WINDOWS, reason="Windows does not allow modifying open files" ) @@ -2461,6 +2514,14 @@ def roundtrip( with self.open(store_target, **open_kwargs) as ds: yield ds + @pytest.mark.asyncio + @pytest.mark.skipif( + not has_zarr_v3, + reason="zarr-python <3 did not support async loading", + ) + async def test_load_async(self) -> None: + await super().test_load_async() + def test_roundtrip_bytes_with_fill_value(self): pytest.xfail("Broken by Zarr 3.0.7") @@ -3840,6 +3901,239 @@ def test_chunk_key_encoding_v2(self) -> None: # Verify chunks are preserved assert actual["var1"].encoding["chunks"] == (2, 2) + @pytest.mark.asyncio + @requires_zarr_v3 + async def test_async_load_multiple_variables(self) -> None: + target_class = zarr.AsyncArray + method_name = "getitem" + original_method = getattr(target_class, method_name) + + # the indexed coordinate variables is not lazy, so the create_test_dataset has 4 lazy variables in total + N_LAZY_VARS = 4 + + original = create_test_data() + with self.create_zarr_target() as store: + original.to_zarr(store, zarr_format=3, consolidated=False) + + with patch.object( + target_class, method_name, wraps=original_method, autospec=True + ) as mocked_meth: + # blocks upon loading the coordinate variables here + ds = xr.open_zarr(store, consolidated=False, chunks=None) + + # TODO we're not actually testing that these indexing methods are not blocking... + result_ds = await ds.load_async() + + mocked_meth.assert_called() + assert mocked_meth.call_count == N_LAZY_VARS + mocked_meth.assert_awaited() + + xrt.assert_identical(result_ds, ds.load()) + + @pytest.mark.asyncio + @requires_zarr_v3 + @pytest.mark.parametrize("cls_name", ["Variable", "DataArray", "Dataset"]) + async def test_concurrent_load_multiple_objects( + self, + cls_name, + ) -> None: + N_OBJECTS = 5 + N_LAZY_VARS = { + "Variable": 1, + "DataArray": 1, + "Dataset": 4, + } # specific to the create_test_data() used + + target_class = zarr.AsyncArray + method_name = "getitem" + original_method = getattr(target_class, method_name) + + original = create_test_data() + with self.create_zarr_target() as store: + original.to_zarr(store, consolidated=False, zarr_format=3) + + with patch.object( + target_class, method_name, wraps=original_method, autospec=True + ) as mocked_meth: + xr_obj = get_xr_obj(store, cls_name) + + # TODO we're not actually testing that these indexing methods are not blocking... + coros = [xr_obj.load_async() for _ in range(N_OBJECTS)] + results = await asyncio.gather(*coros) + + mocked_meth.assert_called() + assert mocked_meth.call_count == N_OBJECTS * N_LAZY_VARS[cls_name] + mocked_meth.assert_awaited() + + for result in results: + xrt.assert_identical(result, xr_obj.load()) + + @pytest.mark.asyncio + @requires_zarr_v3 + @pytest.mark.parametrize("cls_name", ["Variable", "DataArray", "Dataset"]) + @pytest.mark.parametrize( + "indexer, method, target_zarr_class", + [ + pytest.param({}, "sel", "zarr.AsyncArray", id="no-indexing-sel"), + pytest.param({}, "isel", "zarr.AsyncArray", id="no-indexing-isel"), + pytest.param({"dim2": 1.0}, "sel", "zarr.AsyncArray", id="basic-int-sel"), + pytest.param({"dim2": 2}, "isel", "zarr.AsyncArray", id="basic-int-isel"), + pytest.param( + {"dim2": slice(1.0, 3.0)}, + "sel", + "zarr.AsyncArray", + id="basic-slice-sel", + ), + pytest.param( + {"dim2": slice(1, 3)}, "isel", "zarr.AsyncArray", id="basic-slice-isel" + ), + pytest.param( + {"dim2": [1.0, 3.0]}, + "sel", + "zarr.core.indexing.AsyncOIndex", + id="outer-sel", + ), + pytest.param( + {"dim2": [1, 3]}, + "isel", + "zarr.core.indexing.AsyncOIndex", + id="outer-isel", + ), + pytest.param( + { + "dim1": xr.Variable(data=[2, 3], dims="points"), + "dim2": xr.Variable(data=[1.0, 2.0], dims="points"), + }, + "sel", + "zarr.core.indexing.AsyncVIndex", + id="vectorized-sel", + ), + pytest.param( + { + "dim1": xr.Variable(data=[2, 3], dims="points"), + "dim2": xr.Variable(data=[1, 3], dims="points"), + }, + "isel", + "zarr.core.indexing.AsyncVIndex", + id="vectorized-isel", + ), + ], + ) + async def test_indexing( + self, + cls_name, + method, + indexer, + target_zarr_class, + ) -> None: + if not has_zarr_v3_async_oindex and target_zarr_class in ( + "zarr.core.indexing.AsyncOIndex", + "zarr.core.indexing.AsyncVIndex", + ): + pytest.skip( + "current version of zarr does not support orthogonal or vectorized async indexing" + ) + + if cls_name == "Variable" and method == "sel": + pytest.skip("Variable doesn't have a .sel method") + + # Each type of indexing ends up calling a different zarr indexing method + # They all use a method named .getitem, but on a different internal zarr class + def _resolve_class_from_string(class_path: str) -> type[Any]: + """Resolve a string class path like 'zarr.AsyncArray' to the actual class.""" + module_path, class_name = class_path.rsplit(".", 1) + module = import_module(module_path) + return getattr(module, class_name) + + target_class = _resolve_class_from_string(target_zarr_class) + method_name = "getitem" + original_method = getattr(target_class, method_name) + + original = create_test_data() + with self.create_zarr_target() as store: + original.to_zarr(store, consolidated=False, zarr_format=3) + + with patch.object( + target_class, method_name, wraps=original_method, autospec=True + ) as mocked_meth: + xr_obj = get_xr_obj(store, cls_name) + + # TODO we're not actually testing that these indexing methods are not blocking... + result = await getattr(xr_obj, method)(**indexer).load_async() + + mocked_meth.assert_called() + mocked_meth.assert_awaited() + assert mocked_meth.call_count > 0 + + expected = getattr(xr_obj, method)(**indexer).load() + xrt.assert_identical(result, expected) + + @pytest.mark.asyncio + @pytest.mark.parametrize( + ("indexer", "expected_err_msg"), + [ + pytest.param( + {"dim2": 2}, + "basic async indexing", + marks=pytest.mark.skipif( + has_zarr_v3, + reason="current version of zarr has basic async indexing", + ), + ), # tests basic indexing + pytest.param( + {"dim2": [1, 3]}, + "orthogonal async indexing", + marks=pytest.mark.skipif( + has_zarr_v3_async_oindex, + reason="current version of zarr has async orthogonal indexing", + ), + ), # tests oindexing + pytest.param( + { + "dim1": xr.Variable(data=[2, 3], dims="points"), + "dim2": xr.Variable(data=[1, 3], dims="points"), + }, + "vectorized async indexing", + marks=pytest.mark.skipif( + has_zarr_v3_async_oindex, + reason="current version of zarr has async vectorized indexing", + ), + ), # tests vindexing + ], + ) + @parametrize_zarr_format + async def test_raise_on_older_zarr_version( + self, + indexer, + expected_err_msg, + zarr_format, + ): + """Test that trying to use async load with insufficiently new version of zarr raises a clear error""" + + original = create_test_data() + with self.create_zarr_target() as store: + original.to_zarr(store, consolidated=False, zarr_format=zarr_format) + + ds = xr.open_zarr(store, consolidated=False, chunks=None) + var = ds["var1"].variable + + with pytest.raises(NotImplementedError, match=expected_err_msg): + await var.isel(**indexer).load_async() + + +def get_xr_obj( + store: zarr.abc.store.Store, cls_name: Literal["Variable", "DataArray", "Dataset"] +): + ds = xr.open_zarr(store, consolidated=False, chunks=None) + + match cls_name: + case "Variable": + return ds["var1"].variable + case "DataArray": + return ds["var1"] + case "Dataset": + return ds + class NoConsolidatedMetadataSupportStore(WrapperStore): """ @@ -4090,7 +4384,7 @@ def test_zarr_version_deprecated() -> None: @requires_scipy -class TestScipyInMemoryData(CFEncodedBase, NetCDF3Only): +class TestScipyInMemoryData(NetCDF3Only, CFEncodedBase): engine: T_NetcdfEngine = "scipy" @contextlib.contextmanager @@ -4098,6 +4392,11 @@ def create_store(self): fobj = BytesIO() yield backends.ScipyDataStore(fobj, "w") + @pytest.mark.asyncio + @pytest.mark.skip(reason="NetCDF backends don't support async loading") + async def test_load_async(self) -> None: + await super().test_load_async() + def test_to_netcdf_explicit_engine(self) -> None: with pytest.warns( FutureWarning, @@ -4128,7 +4427,7 @@ def test_bytes_pickle(self) -> None: @requires_scipy -class TestScipyFileObject(CFEncodedBase, NetCDF3Only): +class TestScipyFileObject(NetCDF3Only, CFEncodedBase): # TODO: Consider consolidating some of these cases (e.g., # test_file_remains_open) with TestH5NetCDFFileObject engine: T_NetcdfEngine = "scipy" @@ -4197,7 +4496,7 @@ def test_create_default_indexes(self, tmp_path, create_default_indexes) -> None: @requires_scipy -class TestScipyFilePath(CFEncodedBase, NetCDF3Only): +class TestScipyFilePath(NetCDF3Only, CFEncodedBase): engine: T_NetcdfEngine = "scipy" @contextlib.contextmanager @@ -4234,7 +4533,7 @@ def test_nc4_scipy(self) -> None: @requires_netCDF4 -class TestNetCDF3ViaNetCDF4Data(CFEncodedBase, NetCDF3Only): +class TestNetCDF3ViaNetCDF4Data(NetCDF3Only, CFEncodedBase): engine: T_NetcdfEngine = "netcdf4" file_format: T_NetcdfTypes = "NETCDF3_CLASSIC" @@ -4255,7 +4554,7 @@ def test_encoding_kwarg_vlen_string(self) -> None: @requires_netCDF4 -class TestNetCDF4ClassicViaNetCDF4Data(CFEncodedBase, NetCDF3Only): +class TestNetCDF4ClassicViaNetCDF4Data(NetCDF3Only, CFEncodedBase): engine: T_NetcdfEngine = "netcdf4" file_format: T_NetcdfTypes = "NETCDF4_CLASSIC" @@ -4269,7 +4568,7 @@ def create_store(self): @requires_scipy_or_netCDF4 -class TestGenericNetCDFData(CFEncodedBase, NetCDF3Only): +class TestGenericNetCDFData(NetCDF3Only, CFEncodedBase): # verify that we can read and write netCDF3 files as long as we have scipy # or netCDF4-python installed file_format: T_NetcdfTypes = "NETCDF3_64BIT" diff --git a/xarray/tests/test_indexing.py b/xarray/tests/test_indexing.py index db4f6aaf0bd..67bb23cfe51 100644 --- a/xarray/tests/test_indexing.py +++ b/xarray/tests/test_indexing.py @@ -506,6 +506,25 @@ def test_sub_array(self) -> None: assert isinstance(child.array, indexing.NumpyIndexingAdapter) assert isinstance(wrapped.array, indexing.LazilyIndexedArray) + @pytest.mark.asyncio + async def test_async_wrapper(self) -> None: + original = indexing.LazilyIndexedArray(np.arange(10)) + wrapped = indexing.MemoryCachedArray(original) + await wrapped.async_get_duck_array() + assert_array_equal(wrapped, np.arange(10)) + assert isinstance(wrapped.array, indexing.NumpyIndexingAdapter) + + @pytest.mark.asyncio + async def test_async_sub_array(self) -> None: + original = indexing.LazilyIndexedArray(np.arange(10)) + wrapped = indexing.MemoryCachedArray(original) + child = wrapped[B[:5]] + assert isinstance(child, indexing.MemoryCachedArray) + await child.async_get_duck_array() + assert_array_equal(child, np.arange(5)) + assert isinstance(child.array, indexing.NumpyIndexingAdapter) + assert isinstance(wrapped.array, indexing.LazilyIndexedArray) + def test_setitem(self) -> None: original = np.arange(10) wrapped = indexing.MemoryCachedArray(original) diff --git a/xarray/tests/test_variable.py b/xarray/tests/test_variable.py index e2f4a3154f3..de77ec00c40 100644 --- a/xarray/tests/test_variable.py +++ b/xarray/tests/test_variable.py @@ -2897,18 +2897,34 @@ def setUp(self): self.d = np.random.random((10, 3)).astype(np.float64) self.cat = PandasExtensionArray(pd.Categorical(["a", "b"] * 5)) - def check_orthogonal_indexing(self, v): - assert np.allclose(v.isel(x=[8, 3], y=[2, 1]), self.d[[8, 3]][:, [2, 1]]) + async def check_orthogonal_indexing(self, v, load_async): + expected = self.d[[8, 3]][:, [2, 1]] - def check_vectorized_indexing(self, v): + if load_async: + result = await v.isel(x=[8, 3], y=[2, 1]).load_async() + else: + result = v.isel(x=[8, 3], y=[2, 1]) + + assert np.allclose(result, expected) + + async def check_vectorized_indexing(self, v, load_async): ind_x = Variable("z", [0, 2]) ind_y = Variable("z", [2, 1]) - assert np.allclose(v.isel(x=ind_x, y=ind_y), self.d[ind_x, ind_y]) + expected = self.d[ind_x, ind_y] + + if load_async: + result = await v.isel(x=ind_x, y=ind_y).load_async() + else: + result = v.isel(x=ind_x, y=ind_y).load() + + assert np.allclose(result, expected) - def test_NumpyIndexingAdapter(self): + @pytest.mark.asyncio + @pytest.mark.parametrize("load_async", [True, False]) + async def test_NumpyIndexingAdapter(self, load_async): v = Variable(dims=("x", "y"), data=NumpyIndexingAdapter(self.d)) - self.check_orthogonal_indexing(v) - self.check_vectorized_indexing(v) + await self.check_orthogonal_indexing(v, load_async) + await self.check_vectorized_indexing(v, load_async) # could not doubly wrapping with pytest.raises(TypeError, match=r"NumpyIndexingAdapter only wraps "): v = Variable( @@ -2923,54 +2939,62 @@ def test_extension_array_duck_indexed(self): lazy = Variable(dims=("x"), data=LazilyIndexedArray(self.cat)) assert (lazy[[0, 1, 5]] == ["a", "b", "b"]).all() - def test_LazilyIndexedArray(self): + @pytest.mark.asyncio + @pytest.mark.parametrize("load_async", [True, False]) + async def test_LazilyIndexedArray(self, load_async): v = Variable(dims=("x", "y"), data=LazilyIndexedArray(self.d)) - self.check_orthogonal_indexing(v) - self.check_vectorized_indexing(v) + await self.check_orthogonal_indexing(v, load_async) + await self.check_vectorized_indexing(v, load_async) # doubly wrapping v = Variable( dims=("x", "y"), data=LazilyIndexedArray(LazilyIndexedArray(self.d)), ) - self.check_orthogonal_indexing(v) + await self.check_orthogonal_indexing(v, load_async) # hierarchical wrapping v = Variable( dims=("x", "y"), data=LazilyIndexedArray(NumpyIndexingAdapter(self.d)) ) - self.check_orthogonal_indexing(v) + await self.check_orthogonal_indexing(v, load_async) - def test_CopyOnWriteArray(self): + @pytest.mark.asyncio + @pytest.mark.parametrize("load_async", [True, False]) + async def test_CopyOnWriteArray(self, load_async): v = Variable(dims=("x", "y"), data=CopyOnWriteArray(self.d)) - self.check_orthogonal_indexing(v) - self.check_vectorized_indexing(v) + await self.check_orthogonal_indexing(v, load_async) + await self.check_vectorized_indexing(v, load_async) # doubly wrapping v = Variable(dims=("x", "y"), data=CopyOnWriteArray(LazilyIndexedArray(self.d))) - self.check_orthogonal_indexing(v) - self.check_vectorized_indexing(v) + await self.check_orthogonal_indexing(v, load_async) + await self.check_vectorized_indexing(v, load_async) - def test_MemoryCachedArray(self): + @pytest.mark.asyncio + @pytest.mark.parametrize("load_async", [True, False]) + async def test_MemoryCachedArray(self, load_async): v = Variable(dims=("x", "y"), data=MemoryCachedArray(self.d)) - self.check_orthogonal_indexing(v) - self.check_vectorized_indexing(v) + await self.check_orthogonal_indexing(v, load_async) + await self.check_vectorized_indexing(v, load_async) # doubly wrapping v = Variable(dims=("x", "y"), data=CopyOnWriteArray(MemoryCachedArray(self.d))) - self.check_orthogonal_indexing(v) - self.check_vectorized_indexing(v) + await self.check_orthogonal_indexing(v, load_async) + await self.check_vectorized_indexing(v, load_async) @requires_dask - def test_DaskIndexingAdapter(self): + @pytest.mark.asyncio + @pytest.mark.parametrize("load_async", [True, False]) + async def test_DaskIndexingAdapter(self, load_async): import dask.array as da dask_array = da.asarray(self.d) v = Variable(dims=("x", "y"), data=DaskIndexingAdapter(dask_array)) - self.check_orthogonal_indexing(v) - self.check_vectorized_indexing(v) + await self.check_orthogonal_indexing(v, load_async) + await self.check_vectorized_indexing(v, load_async) # doubly wrapping v = Variable( dims=("x", "y"), data=CopyOnWriteArray(DaskIndexingAdapter(dask_array)) ) - self.check_orthogonal_indexing(v) - self.check_vectorized_indexing(v) + await self.check_orthogonal_indexing(v, load_async) + await self.check_vectorized_indexing(v, load_async) def test_clip(var):