diff --git a/benchmarks/benchmarks/readwrite.py b/benchmarks/benchmarks/readwrite.py index 825f84263..5e1541715 100644 --- a/benchmarks/benchmarks/readwrite.py +++ b/benchmarks/benchmarks/readwrite.py @@ -25,6 +25,7 @@ import sys import tempfile from pathlib import Path +from typing import TYPE_CHECKING import numpy as np import pooch @@ -35,6 +36,9 @@ from .utils import get_actualsize, get_peak_mem, sedate +if TYPE_CHECKING: + from collections.abc import Callable + PBMC_3K_URL = "https://falexwolf.de/data/pbmc3k_raw.h5ad" # PBMC_3K_PATH = Path(__file__).parent / "data/pbmc3k_raw.h5ad" @@ -43,93 +47,93 @@ # BM_43K_CSC_PATH = Path(__file__).parent.parent / "datasets/BM2_43k-cells_CSC.h5ad" -# class ZarrReadSuite: -# params = [] -# param_names = ["input_url"] - -# def setup(self, input_url): -# self.filepath = pooch.retrieve(url=input_url, known_hash=None) - -# def time_read_full(self, input_url): -# anndata.read_zarr(self.filepath) - -# def peakmem_read_full(self, input_url): -# anndata.read_zarr(self.filepath) - -# def mem_readfull_object(self, input_url): -# return anndata.read_zarr(self.filepath) +class TestSuite: + _urls = dict(pbmc3k=PBMC_3K_URL) + params = _urls.keys() + param_names = ["input_data"] + filepath: Path + read_func: Callable[[Path | str], anndata.AnnData] -# def track_read_full_memratio(self, input_url): -# mem_recording = memory_usage( -# (sedate(anndata.read_zarr, 0.005), (self.filepath,)), interval=0.001 -# ) -# adata = anndata.read_zarr(self.filepath) -# base_size = mem_recording[-1] - mem_recording[0] -# print(np.max(mem_recording) - np.min(mem_recording)) -# print(base_size) -# return (np.max(mem_recording) - np.min(mem_recording)) / base_size + def setup(self, input_data: str): + self.filepath = Path( + pooch.retrieve(url=self._urls[input_data], known_hash=None) + ) -# def peakmem_read_backed(self, input_url): -# anndata.read_zarr(self.filepath, backed="r") -# def mem_read_backed_object(self, input_url): -# return anndata.read_zarr(self.filepath, backed="r") +class ZarrMixin(TestSuite): + def setup(self, input_data: str): + super().setup(input_data) + zarr_path = self.filepath.with_suffix(".zarr") + anndata.read_h5ad(self.filepath).write_zarr(zarr_path) + self.filepath = zarr_path + @property + def read_func(self): + return anndata.read_zarr -class H5ADInMemorySizeSuite: - _urls = dict(pbmc3k=PBMC_3K_URL) - params = _urls.keys() - param_names = ["input_data"] - def setup(self, input_data: str): - self.filepath = pooch.retrieve(url=self._urls[input_data], known_hash=None) +class H5ADInMemorySizeSuite(TestSuite): + @property + def read_func(self): + return anndata.read_h5ad def track_in_memory_size(self, *_): - adata = anndata.read_h5ad(self.filepath) + adata = self.read_func(self.filepath) adata_size = sys.getsizeof(adata) return adata_size def track_actual_in_memory_size(self, *_): - adata = anndata.read_h5ad(self.filepath) + adata = self.read_func(self.filepath) adata_size = get_actualsize(adata) return adata_size -class H5ADReadSuite: - _urls = dict(pbmc3k=PBMC_3K_URL) - params = _urls.keys() - param_names = ["input_data"] +class ZarrInMemorySizeSuite(ZarrMixin, H5ADInMemorySizeSuite): + @property + def read_func(self): + return anndata.read_zarr - def setup(self, input_data: str): - self.filepath = pooch.retrieve(url=self._urls[input_data], known_hash=None) + +class H5ADReadSuite(TestSuite): + @property + def read_func(self): + return anndata.read_h5ad def time_read_full(self, *_): - anndata.read_h5ad(self.filepath) + self.read_func(self.filepath) def peakmem_read_full(self, *_): - anndata.read_h5ad(self.filepath) + self.read_func(self.filepath) def mem_readfull_object(self, *_): - return anndata.read_h5ad(self.filepath) + return self.read_func(self.filepath) def track_read_full_memratio(self, *_): mem_recording = memory_usage( - (sedate(anndata.read_h5ad, 0.005), (self.filepath,)), interval=0.001 + (sedate(self.read_func, 0.005), (self.filepath,)), interval=0.001 ) - # adata = anndata.read_h5ad(self.filepath) + # adata = self.read_func(self.filepath) base_size = mem_recording[-1] - mem_recording[0] print(np.max(mem_recording) - np.min(mem_recording)) print(base_size) return (np.max(mem_recording) - np.min(mem_recording)) / base_size + # causes benchmarking to break from: https://github.com/pympler/pympler/issues/151 + # def mem_read_backed_object(self, *_): + # return self.read_func(self.filepath, backed="r") + + +class BackedH5ADSuite(TestSuite): def peakmem_read_backed(self, *_): anndata.read_h5ad(self.filepath, backed="r") - # causes benchmarking to break from: https://github.com/pympler/pympler/issues/151 - # def mem_read_backed_object(self, *_): - # return anndata.read_h5ad(self.filepath, backed="r") + +class ZarrReadSuite(ZarrMixin, H5ADReadSuite): + @property + def read_func(self): + return anndata.read_zarr class H5ADWriteSuite: @@ -137,6 +141,10 @@ class H5ADWriteSuite: params = _urls.keys() param_names = ["input_data"] + @property + def write_func(self): + return anndata.write_h5ad + def setup(self, input_data: str): mem_recording, adata = memory_usage( ( @@ -155,31 +163,66 @@ def teardown(self, *_): self.tmpdir.cleanup() def time_write_full(self, *_): - self.adata.write_h5ad(self.writepth, compression=None) + self.write_func(self.writepth, compression=None) def peakmem_write_full(self, *_): - self.adata.write_h5ad(self.writepth) + self.write_func(self.writepth) def track_peakmem_write_full(self, *_): - return get_peak_mem((sedate(self.adata.write_h5ad), (self.writepth,))) + return get_peak_mem((sedate(self.write_func), (self.writepth, self.adata))) def time_write_compressed(self, *_): - self.adata.write_h5ad(self.writepth, compression="gzip") + self.write_func(self.adata, self.writepth, compression="gzip") def peakmem_write_compressed(self, *_): - self.adata.write_h5ad(self.writepth, compression="gzip") + self.write_func(self.adata, self.writepth, compression="gzip") def track_peakmem_write_compressed(self, *_): return get_peak_mem( - (sedate(self.adata.write_h5ad), (self.writepth,), {"compression": "gzip"}) + ( + sedate(self.write_func), + (self.writepth, self.adata), + {"compression": "gzip"}, + ) ) -class H5ADBackedWriteSuite(H5ADWriteSuite): - _urls = dict(pbmc3k=PBMC_3K_URL) - params = _urls.keys() - param_names = ["input_data"] +class ZarrWriteSizeSuite(H5ADWriteSuite): + write_func_str = "write_zarr" + + @property + def write_func(self): + return anndata.write_zarr + + def setup(self, input_data: str): + h5_path = Path(pooch.retrieve(self._urls[input_data], known_hash=None)) + zarr_path = h5_path.with_suffix(".zarr") + anndata.read_h5ad(h5_path).write_zarr(zarr_path) + + mem_recording, adata = memory_usage( + ( + sedate(anndata.read_zarr, 0.005), + (zarr_path,), + ), + retval=True, + interval=0.001, + ) + self.adata = adata + self.base_size = mem_recording[-1] - mem_recording[0] + self.tmpdir = tempfile.TemporaryDirectory() + self.writepth = Path(self.tmpdir.name) / "out.zarr" + + def track_peakmem_write_compressed(self, *_): + return get_peak_mem( + ( + sedate(self.write_func), + (self.writepth, self.adata), + {"compression": "gzip"}, + ) + ) + +class BackedH5ADWriteSuite(H5ADWriteSuite): def setup(self, input_data): mem_recording, adata = memory_usage( ( diff --git a/docs/api.md b/docs/api.md index 951786f81..1aeb60b66 100644 --- a/docs/api.md +++ b/docs/api.md @@ -154,6 +154,7 @@ Types used by the former: experimental.IOSpec experimental.Read experimental.Write + experimental.ReadAsync experimental.ReadCallback experimental.WriteCallback experimental.StorageType diff --git a/docs/conf.py b/docs/conf.py index 58a21cc3b..3e9f4f5ad 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -130,6 +130,7 @@ def setup(app: Sphinx): "anndata._types.WriteCallback": "anndata.experimental.WriteCallback", "anndata._types.Read": "anndata.experimental.Read", "anndata._types.Write": "anndata.experimental.Write", + "anndata._types.ReadAsync": "anndata.experimental.ReadAsync", "zarr.core.array.Array": "zarr.Array", "zarr.core.group.Group": "zarr.Group", "anndata.compat.DaskArray": "dask.array.Array", diff --git a/pyproject.toml b/pyproject.toml index fe344ab1e..da4315b95 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -87,6 +87,7 @@ test-full = [ "anndata[test,lazy]" ] test = [ "loompy>=3.0.5", "pytest>=8.2,<8.3.4", + "pytest-asyncio", "pytest-cov", "matplotlib", "scikit-learn", @@ -143,6 +144,7 @@ addopts = [ "--pyargs", "-ptesting.anndata._pytest", ] +asyncio_mode = "auto" filterwarnings = [ "ignore::anndata._warnings.OldFormatWarning", "ignore::anndata._warnings.ExperimentalFeatureWarning", diff --git a/src/anndata/_core/sparse_dataset.py b/src/anndata/_core/sparse_dataset.py index 556bcf372..11246a0a3 100644 --- a/src/anndata/_core/sparse_dataset.py +++ b/src/anndata/_core/sparse_dataset.py @@ -12,6 +12,7 @@ # - think about supporting the COO format from __future__ import annotations +import asyncio import warnings from abc import ABC from collections.abc import Iterable @@ -586,7 +587,12 @@ def append(self, sparse_matrix: CSMatrix | CSArray) -> None: append_data = sparse_matrix.data append_indices = sparse_matrix.indices if isinstance(sparse_matrix.data, ZarrArray) and not is_zarr_v2(): - data[orig_data_size:] = append_data[...] + from .._io.specs.methods import _iter_chunks_for_copy + + for chunk in _iter_chunks_for_copy(append_data, data): + data[(chunk.start + orig_data_size) : (chunk.stop + orig_data_size)] = ( + append_data[chunk] + ) else: data[orig_data_size:] = append_data # indptr @@ -598,12 +604,16 @@ def append(self, sparse_matrix: CSMatrix | CSArray) -> None: ) # indices - if isinstance(sparse_matrix.data, ZarrArray) and not is_zarr_v2(): - append_indices = append_indices[...] indices = self.group["indices"] orig_data_size = indices.shape[0] indices.resize((orig_data_size + sparse_matrix.indices.shape[0],)) - indices[orig_data_size:] = append_indices + if isinstance(sparse_matrix.data, ZarrArray) and not is_zarr_v2(): + for chunk in _iter_chunks_for_copy(append_indices, indices): + indices[ + (chunk.start + orig_data_size) : (chunk.stop + orig_data_size) + ] = append_indices[chunk] + else: + indices[orig_data_size:] = append_indices # Clear cached property for attr in ["_indptr", "_indices", "_data"]: @@ -652,6 +662,31 @@ def to_memory(self) -> CSMatrix | CSArray: mtx.indptr = self._indptr return mtx + async def to_memory_async(self) -> CSMatrix | CSArray: + format_class = get_memory_class( + self.format, use_sparray_in_io=settings.use_sparse_array_on_read + ) + mtx = format_class(self.shape, dtype=self.dtype) + mtx.indptr = self._indptr + if isinstance(self._data, ZarrArray) and not is_zarr_v2(): + await asyncio.gather( + *( + self.set_memory_async_from_zarr(mtx, attr) + for attr in ["data", "indices"] + ) + ) + else: + mtx.data = self._data[...] + mtx.indices = self._indices[...] + return mtx + + async def set_memory_async_from_zarr( + self, mtx: CSMatrix | CSArray, attr: Literal["indptr", "data", "indices"] + ) -> None: + setattr( + mtx, attr, await getattr(self, f"_{attr}")._async_array.getitem(()) + ) # TODO: better way to asyncify + class _CSRDataset(BaseCompressedSparseDataset, abc.CSRDataset): """Internal concrete version of :class:`anndata.abc.CSRDataset`.""" diff --git a/src/anndata/_io/h5ad.py b/src/anndata/_io/h5ad.py index 43a390ac0..47410b74b 100644 --- a/src/anndata/_io/h5ad.py +++ b/src/anndata/_io/h5ad.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import re from functools import partial from pathlib import Path @@ -24,8 +25,8 @@ _from_fixed_length_strings, ) from ..experimental import read_dispatched -from .specs import read_elem, write_elem -from .specs.registry import IOSpec, write_spec +from .specs.methods import sync_async_to_async +from .specs.registry import IOSpec, read_elem_async, write_elem_async, write_spec from .utils import ( H5PY_V3, _read_legacy_raw, @@ -81,39 +82,76 @@ def write_h5ad( f = f["/"] f.attrs.setdefault("encoding-type", "anndata") f.attrs.setdefault("encoding-version", "0.1.0") - if "X" in as_dense and isinstance( adata.X, CSMatrix | BaseCompressedSparseDataset ): - write_sparse_as_dense(f, "X", adata.X, dataset_kwargs=dataset_kwargs) + asyncio.run( + write_sparse_as_dense(f, "X", adata.X, dataset_kwargs=dataset_kwargs) + ) elif not (adata.isbacked and Path(adata.filename) == Path(filepath)): # If adata.isbacked, X should already be up to date - write_elem(f, "X", adata.X, dataset_kwargs=dataset_kwargs) + asyncio.run( + write_elem_async(f, "X", adata.X, dataset_kwargs=dataset_kwargs) + ) if "raw/X" in as_dense and isinstance( adata.raw.X, CSMatrix | BaseCompressedSparseDataset ): - write_sparse_as_dense( - f, "raw/X", adata.raw.X, dataset_kwargs=dataset_kwargs + asyncio.run( + write_sparse_as_dense( + f, "raw/X", adata.raw.X, dataset_kwargs=dataset_kwargs + ) + ) + asyncio.run( + write_elem_async( + f, "raw/var", adata.raw.var, dataset_kwargs=dataset_kwargs + ) ) - write_elem(f, "raw/var", adata.raw.var, dataset_kwargs=dataset_kwargs) - write_elem( - f, "raw/varm", dict(adata.raw.varm), dataset_kwargs=dataset_kwargs + asyncio.run( + write_elem_async( + f, "raw/varm", dict(adata.raw.varm), dataset_kwargs=dataset_kwargs + ) ) elif adata.raw is not None: - write_elem(f, "raw", adata.raw, dataset_kwargs=dataset_kwargs) - write_elem(f, "obs", adata.obs, dataset_kwargs=dataset_kwargs) - write_elem(f, "var", adata.var, dataset_kwargs=dataset_kwargs) - write_elem(f, "obsm", dict(adata.obsm), dataset_kwargs=dataset_kwargs) - write_elem(f, "varm", dict(adata.varm), dataset_kwargs=dataset_kwargs) - write_elem(f, "obsp", dict(adata.obsp), dataset_kwargs=dataset_kwargs) - write_elem(f, "varp", dict(adata.varp), dataset_kwargs=dataset_kwargs) - write_elem(f, "layers", dict(adata.layers), dataset_kwargs=dataset_kwargs) - write_elem(f, "uns", dict(adata.uns), dataset_kwargs=dataset_kwargs) + asyncio.run( + write_elem_async(f, "raw", adata.raw, dataset_kwargs=dataset_kwargs) + ) + + async def gather(): + await asyncio.gather( + *[ + write_elem_async( + f, "obs", adata.obs, dataset_kwargs=dataset_kwargs + ), + write_elem_async( + f, "var", adata.var, dataset_kwargs=dataset_kwargs + ), + write_elem_async( + f, "obsm", dict(adata.obsm), dataset_kwargs=dataset_kwargs + ), + write_elem_async( + f, "varm", dict(adata.varm), dataset_kwargs=dataset_kwargs + ), + write_elem_async( + f, "obsp", dict(adata.obsp), dataset_kwargs=dataset_kwargs + ), + write_elem_async( + f, "varp", dict(adata.varp), dataset_kwargs=dataset_kwargs + ), + write_elem_async( + f, "layers", dict(adata.layers), dataset_kwargs=dataset_kwargs + ), + write_elem_async( + f, "uns", dict(adata.uns), dataset_kwargs=dataset_kwargs + ), + ] + ) + + asyncio.run(gather()) @report_write_key_on_error @write_spec(IOSpec("array", "0.2.0")) -def write_sparse_as_dense( +async def write_sparse_as_dense( f: h5py.Group, key: str, value: CSMatrix | BaseCompressedSparseDataset, @@ -140,7 +178,7 @@ def write_sparse_as_dense( del f[key] -def read_h5ad_backed(filename: str | Path, mode: Literal["r", "r+"]) -> AnnData: +async def read_h5ad_backed(filename: str | Path, mode: Literal["r", "r+"]) -> AnnData: d = dict(filename=filename, filemode=mode) f = h5py.File(filename, mode) @@ -153,11 +191,11 @@ def read_h5ad_backed(filename: str | Path, mode: Literal["r", "r+"]) -> AnnData: else: for k in df_attributes: if k in f: # Backwards compat - d[k] = read_dataframe(f[k]) + d[k] = await read_dataframe(f[k]) - d.update({k: read_elem(f[k]) for k in attributes if k in f}) + d.update({k: await read_elem_async(f[k]) for k in attributes if k in f}) - d["raw"] = _read_raw(f, attrs={"var", "varm"}) + d["raw"] = await _read_raw(f, attrs={"var", "varm"}) adata = AnnData(**d) @@ -211,8 +249,8 @@ def read_h5ad( mode = backed if mode is True: mode = "r+" - assert mode in {"r", "r+"} - return read_h5ad_backed(filename, mode) + assert mode in {"r", "r+"}, mode + return asyncio.run(read_h5ad_backed(filename, mode)) if as_sparse_fmt not in (sparse.csr_matrix, sparse.csc_matrix): msg = "Dense formats can only be read to CSR or CSC matrices at this time." @@ -234,33 +272,36 @@ def read_h5ad( with h5py.File(filename, "r") as f: - def callback(func, elem_name: str, elem, iospec): + async def callback(func, elem_name: str, elem, iospec): if iospec.encoding_type == "anndata" or elem_name.endswith("/"): - return AnnData( - **{ - # This is covering up backwards compat in the anndata initializer - # In most cases we should be able to call `func(elen[k])` instead - k: read_dispatched(elem[k], callback) - for k in elem.keys() - if not k.startswith("raw.") - } + args = dict( + await asyncio.gather( + *( + # This is covering up backwards compat in the anndata initializer + # In most cases we should be able to call `func(elen[k])` instead + sync_async_to_async(k, read_dispatched(elem[k], callback)) + for k in elem.keys() + if not k.startswith("raw.") + ) + ) ) + return AnnData(**args) elif elem_name.startswith("/raw."): return None elif elem_name == "/X" and "X" in as_sparse: return rdasp(elem) elif elem_name == "/raw": - return _read_raw(f, as_sparse, rdasp) + return await _read_raw(f, as_sparse, rdasp) elif elem_name in {"/obs", "/var"}: # Backwards compat - return read_dataframe(elem) - return func(elem) + return await read_dataframe(elem) + return await func(elem) - adata = read_dispatched(f, callback=callback) + adata = asyncio.run(read_dispatched(f, callback=callback)) # Backwards compat (should figure out which version) if "raw.X" in f: - raw = AnnData(**_read_raw(f, as_sparse, rdasp)) + raw = AnnData(**asyncio.run(_read_raw(f, as_sparse, rdasp))) raw.obs_names = adata.obs_names adata.raw = raw @@ -271,7 +312,7 @@ def callback(func, elem_name: str, elem, iospec): return adata -def _read_raw( +async def _read_raw( f: h5py.File | AnnDataFileManager, as_sparse: Collection[str] = (), rdasp: Callable[[h5py.Dataset], CSMatrix] | None = None, @@ -282,12 +323,15 @@ def _read_raw( assert rdasp is not None, "must supply rdasp if as_sparse is supplied" raw = {} if "X" in attrs and "raw/X" in f: - read_x = rdasp if "raw/X" in as_sparse else read_elem - raw["X"] = read_x(f["raw/X"]) + raw["X"] = ( + (await read_elem_async(f["raw/X"])) + if "raw/X" not in as_sparse + else rdasp(f["raw/X"]) + ) for v in ("var", "varm"): if v in attrs and f"raw/{v}" in f: - raw[v] = read_elem(f[f"raw/{v}"]) - return _read_legacy_raw(f, raw, read_dataframe, read_elem, attrs=attrs) + raw[v] = await read_elem_async(f[f"raw/{v}"]) + return await _read_legacy_raw(f, raw, read_dataframe, read_elem_async, attrs=attrs) @report_read_key_on_error @@ -310,12 +354,12 @@ def read_dataframe_legacy(dataset: h5py.Dataset) -> pd.DataFrame: return df -def read_dataframe(group: h5py.Group | h5py.Dataset) -> pd.DataFrame: +async def read_dataframe(group: h5py.Group | h5py.Dataset) -> pd.DataFrame: """Backwards compat function""" if not isinstance(group, h5py.Group): return read_dataframe_legacy(group) else: - return read_elem(group) + return await read_elem_async(group) @report_read_key_on_error diff --git a/src/anndata/_io/specs/__init__.py b/src/anndata/_io/specs/__init__.py index 8fd9898a3..87349d836 100644 --- a/src/anndata/_io/specs/__init__.py +++ b/src/anndata/_io/specs/__init__.py @@ -9,16 +9,20 @@ Writer, get_spec, read_elem, + read_elem_async, read_elem_lazy, write_elem, + write_elem_async, ) __all__ = [ "methods", "lazy_methods", "write_elem", + "write_elem_async", "get_spec", "read_elem", + "read_elem_async", "read_elem_lazy", "Reader", "Writer", diff --git a/src/anndata/_io/specs/methods.py b/src/anndata/_io/specs/methods.py index 71ded5531..da367bf32 100644 --- a/src/anndata/_io/specs/methods.py +++ b/src/anndata/_io/specs/methods.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import warnings from collections.abc import Mapping from copy import copy @@ -18,8 +19,6 @@ import anndata as ad from anndata import AnnData, Raw from anndata._core import views -from anndata._core.index import _normalize_indices -from anndata._core.merge import intersect_keys from anndata._core.sparse_dataset import _CSCDataset, _CSRDataset, sparse_dataset from anndata._io.utils import H5PY_V3, check_key, zero_dim_array_as_scalar from anndata._warnings import OldFormatWarning @@ -42,12 +41,11 @@ from ..._settings import settings from ...compat import is_zarr_v2 -from .registry import _REGISTRY, IOSpec, read_elem, read_elem_partial +from .registry import _REGISTRY, IOSpec, read_elem if TYPE_CHECKING: from collections.abc import Callable, Iterator - from os import PathLike - from typing import Any, Literal + from typing import Any, Literal, TypeVar from numpy import typing as npt from numpy.typing import NDArray @@ -58,6 +56,9 @@ from .registry import Reader, Writer + T = TypeVar("T") + C = TypeVar("C") + #################### # Dask utils # #################### @@ -116,6 +117,12 @@ def wrapper( return wrapper +async def sync_async_to_async( + s: T, a: asyncio.Future[C] +) -> asyncio.Future[tuple[T, C]]: + return s, await a + + ################################ # Fallbacks / backwards compat # ################################ @@ -126,7 +133,7 @@ def wrapper( @_REGISTRY.register_read(H5File, IOSpec("", "")) @_REGISTRY.register_read(H5Group, IOSpec("", "")) @_REGISTRY.register_read(H5Array, IOSpec("", "")) -def read_basic( +async def read_basic( elem: H5File | H5Group | H5Array, *, _reader: Reader ) -> dict[str, InMemoryArrayOrScalarType] | npt.NDArray | CSMatrix | CSArray: from anndata._io import h5ad @@ -141,14 +148,21 @@ def read_basic( # Backwards compat sparse arrays if "h5sparse_format" in elem.attrs: return sparse_dataset(elem).to_memory() - return {k: _reader.read_elem(v) for k, v in dict(elem).items()} + return dict( + await asyncio.gather( + *( + sync_async_to_async(k, _reader.read_elem_async(v)) + for k, v in dict(elem).items() + ) + ) + ) elif isinstance(elem, h5py.Dataset): return h5ad.read_dataset(elem) # TODO: Handle legacy @_REGISTRY.register_read(ZarrGroup, IOSpec("", "")) @_REGISTRY.register_read(ZarrArray, IOSpec("", "")) -def read_basic_zarr( +async def read_basic_zarr( elem: ZarrGroup | ZarrArray, *, _reader: Reader ) -> dict[str, InMemoryArrayOrScalarType] | npt.NDArray | CSMatrix | CSArray: from anndata._io import zarr @@ -162,19 +176,16 @@ def read_basic_zarr( # Backwards compat sparse arrays if "h5sparse_format" in elem.attrs: return sparse_dataset(elem).to_memory() - return {k: _reader.read_elem(v) for k, v in dict(elem).items()} + return dict( + await asyncio.gather( + *( + sync_async_to_async(k, _reader.read_elem_async(v)) + for k, v in dict(elem).items() + ) + ) + ) elif isinstance(elem, ZarrArray): - return zarr.read_dataset(elem) # TODO: Handle legacy - - -# @_REGISTRY.register_read_partial(IOSpec("", "")) -# def read_basic_partial(elem, *, items=None, indices=(slice(None), slice(None))): -# if isinstance(elem, Mapping): -# return _read_partial(elem, items=items, indices=indices) -# elif indices != (slice(None), slice(None)): -# return elem[indices] -# else: -# return elem[()] + return await zarr.read_dataset(elem) # TODO: Handle legacy ########### @@ -192,80 +203,9 @@ def read_indices(group): return obs_idx, var_idx -def read_partial( - pth: PathLike, - *, - obs_idx=slice(None), - var_idx=slice(None), - X=True, - obs=None, - var=None, - obsm=None, - varm=None, - obsp=None, - varp=None, - layers=None, - uns=None, -) -> ad.AnnData: - result = {} - with h5py.File(pth, "r") as f: - obs_idx, var_idx = _normalize_indices((obs_idx, var_idx), *read_indices(f)) - result["obs"] = read_elem_partial( - f["obs"], items=obs, indices=(obs_idx, slice(None)) - ) - result["var"] = read_elem_partial( - f["var"], items=var, indices=(var_idx, slice(None)) - ) - if X: - result["X"] = read_elem_partial(f["X"], indices=(obs_idx, var_idx)) - else: - result["X"] = sparse.csr_matrix((len(result["obs"]), len(result["var"]))) - if "obsm" in f: - result["obsm"] = _read_partial( - f["obsm"], items=obsm, indices=(obs_idx, slice(None)) - ) - if "varm" in f: - result["varm"] = _read_partial( - f["varm"], items=varm, indices=(var_idx, slice(None)) - ) - if "obsp" in f: - result["obsp"] = _read_partial( - f["obsp"], items=obsp, indices=(obs_idx, obs_idx) - ) - if "varp" in f: - result["varp"] = _read_partial( - f["varp"], items=varp, indices=(var_idx, var_idx) - ) - if "layers" in f: - result["layers"] = _read_partial( - f["layers"], items=layers, indices=(obs_idx, var_idx) - ) - if "uns" in f: - result["uns"] = _read_partial(f["uns"], items=uns) - - return ad.AnnData(**result) - - -def _read_partial(group, *, items=None, indices=(slice(None), slice(None))): - if group is None: - return None - if items is None: - keys = intersect_keys((group,)) - else: - keys = intersect_keys((group, items)) - result = {} - for k in keys: - if isinstance(items, Mapping): - next_items = items.get(k, None) - else: - next_items = None - result[k] = read_elem_partial(group[k], items=next_items, indices=indices) - return result - - @_REGISTRY.register_write(ZarrGroup, AnnData, IOSpec("anndata", "0.1.0")) @_REGISTRY.register_write(H5Group, AnnData, IOSpec("anndata", "0.1.0")) -def write_anndata( +async def write_anndata( f: GroupStorageType, k: str, adata: AnnData, @@ -274,16 +214,30 @@ def write_anndata( dataset_kwargs: Mapping[str, Any] = MappingProxyType({}), ): g = f.require_group(k) - _writer.write_elem(g, "X", adata.X, dataset_kwargs=dataset_kwargs) - _writer.write_elem(g, "obs", adata.obs, dataset_kwargs=dataset_kwargs) - _writer.write_elem(g, "var", adata.var, dataset_kwargs=dataset_kwargs) - _writer.write_elem(g, "obsm", dict(adata.obsm), dataset_kwargs=dataset_kwargs) - _writer.write_elem(g, "varm", dict(adata.varm), dataset_kwargs=dataset_kwargs) - _writer.write_elem(g, "obsp", dict(adata.obsp), dataset_kwargs=dataset_kwargs) - _writer.write_elem(g, "varp", dict(adata.varp), dataset_kwargs=dataset_kwargs) - _writer.write_elem(g, "layers", dict(adata.layers), dataset_kwargs=dataset_kwargs) - _writer.write_elem(g, "uns", dict(adata.uns), dataset_kwargs=dataset_kwargs) - _writer.write_elem(g, "raw", adata.raw, dataset_kwargs=dataset_kwargs) + await asyncio.gather( + _writer.write_elem_async(g, "X", adata.X, dataset_kwargs=dataset_kwargs), + _writer.write_elem_async(g, "obs", adata.obs, dataset_kwargs=dataset_kwargs), + _writer.write_elem_async(g, "var", adata.var, dataset_kwargs=dataset_kwargs), + _writer.write_elem_async( + g, "obsm", dict(adata.obsm), dataset_kwargs=dataset_kwargs + ), + _writer.write_elem_async( + g, "varm", dict(adata.varm), dataset_kwargs=dataset_kwargs + ), + _writer.write_elem_async( + g, "obsp", dict(adata.obsp), dataset_kwargs=dataset_kwargs + ), + _writer.write_elem_async( + g, "varp", dict(adata.varp), dataset_kwargs=dataset_kwargs + ), + _writer.write_elem_async( + g, "layers", dict(adata.layers), dataset_kwargs=dataset_kwargs + ), + _writer.write_elem_async( + g, "uns", dict(adata.uns), dataset_kwargs=dataset_kwargs + ), + _writer.write_elem_async(g, "raw", adata.raw, dataset_kwargs=dataset_kwargs), + ) @_REGISTRY.register_read(H5Group, IOSpec("anndata", "0.1.0")) @@ -292,9 +246,8 @@ def write_anndata( @_REGISTRY.register_read(H5File, IOSpec("raw", "0.1.0")) @_REGISTRY.register_read(ZarrGroup, IOSpec("anndata", "0.1.0")) @_REGISTRY.register_read(ZarrGroup, IOSpec("raw", "0.1.0")) -def read_anndata(elem: GroupStorageType | H5File, *, _reader: Reader) -> AnnData: - d = {} - for k in [ +async def read_anndata(elem: GroupStorageType | H5File, *, _reader: Reader) -> AnnData: + elems = [ "X", "obs", "var", @@ -305,15 +258,22 @@ def read_anndata(elem: GroupStorageType | H5File, *, _reader: Reader) -> AnnData "layers", "uns", "raw", - ]: - if k in elem: - d[k] = _reader.read_elem(elem[k]) + ] + d = dict( + await asyncio.gather( + *( + sync_async_to_async(k, _reader.read_elem_async(elem[k])) + for k in elems + if k in elem + ) + ) + ) return AnnData(**d) @_REGISTRY.register_write(H5Group, Raw, IOSpec("raw", "0.1.0")) @_REGISTRY.register_write(ZarrGroup, Raw, IOSpec("raw", "0.1.0")) -def write_raw( +async def write_raw( f: GroupStorageType, k: str, raw: Raw, @@ -322,9 +282,13 @@ def write_raw( dataset_kwargs: Mapping[str, Any] = MappingProxyType({}), ): g = f.require_group(k) - _writer.write_elem(g, "X", raw.X, dataset_kwargs=dataset_kwargs) - _writer.write_elem(g, "var", raw.var, dataset_kwargs=dataset_kwargs) - _writer.write_elem(g, "varm", dict(raw.varm), dataset_kwargs=dataset_kwargs) + await asyncio.gather( + _writer.write_elem_async(g, "X", raw.X, dataset_kwargs=dataset_kwargs), + _writer.write_elem_async(g, "var", raw.var, dataset_kwargs=dataset_kwargs), + _writer.write_elem_async( + g, "varm", dict(raw.varm), dataset_kwargs=dataset_kwargs + ), + ) ######## @@ -334,17 +298,17 @@ def write_raw( @_REGISTRY.register_read(H5Array, IOSpec("null", "0.1.0")) @_REGISTRY.register_read(ZarrArray, IOSpec("null", "0.1.0")) -def read_null(_elem, _reader) -> None: +async def read_null(_elem, _reader) -> None: return None @_REGISTRY.register_write(H5Group, type(None), IOSpec("null", "0.1.0")) -def write_null_h5py(f, k, _v, _writer, dataset_kwargs=MappingProxyType({})): +async def write_null_h5py(f, k, _v, _writer, dataset_kwargs=MappingProxyType({})): f.create_dataset(k, data=h5py.Empty("f"), **dataset_kwargs) @_REGISTRY.register_write(ZarrGroup, type(None), IOSpec("null", "0.1.0")) -def write_null_zarr(f, k, _v, _writer, dataset_kwargs=MappingProxyType({})): +async def write_null_zarr(f, k, _v, _writer, dataset_kwargs=MappingProxyType({})): # zarr has no first-class null dataset if is_zarr_v2(): import zarr @@ -364,13 +328,22 @@ def write_null_zarr(f, k, _v, _writer, dataset_kwargs=MappingProxyType({})): @_REGISTRY.register_read(H5Group, IOSpec("dict", "0.1.0")) @_REGISTRY.register_read(ZarrGroup, IOSpec("dict", "0.1.0")) -def read_mapping(elem: GroupStorageType, *, _reader: Reader) -> dict[str, AxisStorable]: - return {k: _reader.read_elem(v) for k, v in dict(elem).items()} +async def read_mapping( + elem: GroupStorageType, *, _reader: Reader +) -> dict[str, AxisStorable]: + return dict( + await asyncio.gather( + *( + sync_async_to_async(k, _reader.read_elem_async(v)) + for k, v in dict(elem).items() + ) + ) + ) @_REGISTRY.register_write(H5Group, dict, IOSpec("dict", "0.1.0")) @_REGISTRY.register_write(ZarrGroup, dict, IOSpec("dict", "0.1.0")) -def write_mapping( +async def write_mapping( f: GroupStorageType, k: str, v: dict[str, AxisStorable], @@ -379,8 +352,12 @@ def write_mapping( dataset_kwargs: Mapping[str, Any] = MappingProxyType({}), ): g = f.require_group(k) - for sub_k, sub_v in v.items(): - _writer.write_elem(g, sub_k, sub_v, dataset_kwargs=dataset_kwargs) + await asyncio.gather( + *[ + _writer.write_elem_async(g, sub_k, sub_v, dataset_kwargs=dataset_kwargs) + for sub_k, sub_v in v.items() + ] + ) ############## @@ -390,7 +367,7 @@ def write_mapping( @_REGISTRY.register_write(H5Group, list, IOSpec("array", "0.2.0")) @_REGISTRY.register_write(ZarrGroup, list, IOSpec("array", "0.2.0")) -def write_list( +async def write_list( f: GroupStorageType, k: str, elem: list[AxisStorable], @@ -398,7 +375,7 @@ def write_list( _writer: Writer, dataset_kwargs: Mapping[str, Any] = MappingProxyType({}), ): - _writer.write_elem(f, k, np.array(elem), dataset_kwargs=dataset_kwargs) + await _writer.write_elem_async(f, k, np.array(elem), dataset_kwargs=dataset_kwargs) # TODO: Is this the right behavior for MaskedArrays? @@ -412,7 +389,7 @@ def write_list( @_REGISTRY.register_write(ZarrGroup, ZarrArray, IOSpec("array", "0.2.0")) @_REGISTRY.register_write(ZarrGroup, H5Array, IOSpec("array", "0.2.0")) @zero_dim_array_as_scalar -def write_basic( +async def write_basic( f: GroupStorageType, k: str, elem: views.ArrayView | np.ndarray | h5py.Dataset | np.ma.MaskedArray | ZarrArray, @@ -429,7 +406,7 @@ def write_basic( f.create_array(k, shape=elem.shape, dtype=dtype, **dataset_kwargs) # see https://github.com/zarr-developers/zarr-python/discussions/2712 if isinstance(elem, ZarrArray): - f[k][...] = elem[...] + await f[k]._async_array.setitem(Ellipsis, elem[...]) else: f[k][...] = elem @@ -457,7 +434,7 @@ def _iter_chunks_for_copy( @_REGISTRY.register_write(H5Group, H5Array, IOSpec("array", "0.2.0")) @_REGISTRY.register_write(H5Group, ZarrArray, IOSpec("array", "0.2.0")) -def write_chunked_dense_array_to_group( +async def write_chunked_dense_array_to_group( f: GroupStorageType, k: str, elem: ArrayStorageType, @@ -474,9 +451,17 @@ def write_chunked_dense_array_to_group( dtype = dataset_kwargs.get("dtype", elem.dtype) kwargs = {**dataset_kwargs, "dtype": dtype} dest = f.create_dataset(k, shape=elem.shape, **kwargs) - - for chunk in _iter_chunks_for_copy(elem, dest): - dest[chunk] = elem[chunk] + chunk_iter = _iter_chunks_for_copy(elem, dest) + if isinstance(dest, ZarrArray) and not is_zarr_v2(): + await asyncio.gather( + *( + dest._async_array.setitem(chunk, elem[chunk]) + for chunk in _iter_chunks_for_copy(elem, dest) + ) + ) + else: + for chunk in chunk_iter: + dest[chunk] = elem[chunk] _REGISTRY.register_write(H5Group, CupyArray, IOSpec("array", "0.2.0"))( @@ -488,7 +473,7 @@ def write_chunked_dense_array_to_group( @_REGISTRY.register_write(ZarrGroup, DaskArray, IOSpec("array", "0.2.0")) -def write_basic_dask_zarr( +async def write_basic_dask_zarr( f: ZarrGroup, k: str, elem: DaskArray, @@ -508,7 +493,7 @@ def write_basic_dask_zarr( # Adding this separately because h5py isn't serializable # https://github.com/pydata/xarray/issues/4242 @_REGISTRY.register_write(H5Group, DaskArray, IOSpec("array", "0.2.0")) -def write_basic_dask_h5( +async def write_basic_dask_h5( f: H5Group, k: str, elem: DaskArray, @@ -530,30 +515,16 @@ def write_basic_dask_h5( @_REGISTRY.register_read(H5Array, IOSpec("array", "0.2.0")) @_REGISTRY.register_read(ZarrArray, IOSpec("array", "0.2.0")) @_REGISTRY.register_read(ZarrArray, IOSpec("string-array", "0.2.0")) -def read_array(elem: ArrayStorageType, *, _reader: Reader) -> npt.NDArray: +async def read_array(elem: ArrayStorageType, *, _reader: Reader) -> npt.NDArray: + if not is_zarr_v2() and isinstance(elem, ZarrArray): + return await elem._async_array.getitem(()) return elem[()] -@_REGISTRY.register_read_partial(H5Array, IOSpec("array", "0.2.0")) -@_REGISTRY.register_read_partial(ZarrArray, IOSpec("string-array", "0.2.0")) -def read_array_partial(elem, *, items=None, indices=(slice(None, None))): - return elem[indices] - - -@_REGISTRY.register_read_partial(ZarrArray, IOSpec("array", "0.2.0")) -def read_zarr_array_partial(elem, *, items=None, indices=(slice(None, None))): - return elem.oindex[indices] - - # arrays of strings @_REGISTRY.register_read(H5Array, IOSpec("string-array", "0.2.0")) -def read_string_array(d: H5Array, *, _reader: Reader): - return read_array(d.asstr(), _reader=_reader) - - -@_REGISTRY.register_read_partial(H5Array, IOSpec("string-array", "0.2.0")) -def read_string_array_partial(d, items=None, indices=slice(None)): - return read_array_partial(d.asstr(), items=items, indices=indices) +async def read_string_array(d: H5Array, *, _reader: Reader): + return await read_array(d.asstr(), _reader=_reader) @_REGISTRY.register_write( @@ -565,7 +536,7 @@ def read_string_array_partial(d, items=None, indices=slice(None)): @_REGISTRY.register_write(H5Group, (np.ndarray, "U"), IOSpec("string-array", "0.2.0")) @_REGISTRY.register_write(H5Group, (np.ndarray, "O"), IOSpec("string-array", "0.2.0")) @zero_dim_array_as_scalar -def write_vlen_string_array( +async def write_vlen_string_array( f: H5Group, k: str, elem: np.ndarray, @@ -587,7 +558,7 @@ def write_vlen_string_array( @_REGISTRY.register_write(ZarrGroup, (np.ndarray, "U"), IOSpec("string-array", "0.2.0")) @_REGISTRY.register_write(ZarrGroup, (np.ndarray, "O"), IOSpec("string-array", "0.2.0")) @zero_dim_array_as_scalar -def write_vlen_string_array_zarr( +async def write_vlen_string_array_zarr( f: ZarrGroup, k: str, elem: np.ndarray, @@ -634,7 +605,7 @@ def write_vlen_string_array_zarr( compressor=compressor, **dataset_kwargs, ) - f[k][:] = elem + await f[k]._async_array.setitem(slice(None), elem) ############### @@ -655,8 +626,13 @@ def _to_hdf5_vlen_strings(value: np.ndarray) -> np.ndarray: @_REGISTRY.register_read(H5Array, IOSpec("rec-array", "0.2.0")) @_REGISTRY.register_read(ZarrArray, IOSpec("rec-array", "0.2.0")) -def read_recarray(d: ArrayStorageType, *, _reader: Reader) -> np.recarray | npt.NDArray: - value = d[()] +async def read_recarray( + d: ArrayStorageType, *, _reader: Reader +) -> np.recarray | npt.NDArray: + if not is_zarr_v2() and isinstance(d, ZarrArray): + value = await d._async_array.getitem(()) + else: + value = d[()] dtype = value.dtype value = _from_fixed_length_strings(value) if H5PY_V3: @@ -666,7 +642,7 @@ def read_recarray(d: ArrayStorageType, *, _reader: Reader) -> np.recarray | npt. @_REGISTRY.register_write(H5Group, (np.ndarray, "V"), IOSpec("rec-array", "0.2.0")) @_REGISTRY.register_write(H5Group, np.recarray, IOSpec("rec-array", "0.2.0")) -def write_recarray( +async def write_recarray( f: H5Group, k: str, elem: np.ndarray | np.recarray, @@ -679,7 +655,7 @@ def write_recarray( @_REGISTRY.register_write(ZarrGroup, (np.ndarray, "V"), IOSpec("rec-array", "0.2.0")) @_REGISTRY.register_write(ZarrGroup, np.recarray, IOSpec("rec-array", "0.2.0")) -def write_recarray_zarr( +async def write_recarray_zarr( f: ZarrGroup, k: str, elem: np.ndarray | np.recarray, @@ -695,7 +671,7 @@ def write_recarray_zarr( else: # TODO: zarr’s on-disk format v3 doesn’t support this dtype f.create_array(k, shape=elem.shape, dtype=elem.dtype, **dataset_kwargs) - f[k][...] = elem + await f[k]._async_array.setitem(Ellipsis, elem) ################# @@ -703,7 +679,7 @@ def write_recarray_zarr( ################# -def write_sparse_compressed( +async def write_sparse_compressed( f: GroupStorageType, key: str, value: CSMatrix | CSArray, @@ -720,7 +696,7 @@ def write_sparse_compressed( # Allow resizing for hdf5 if isinstance(f, H5Group): dataset_kwargs = dict(maxshape=(None,), **dataset_kwargs) - + awaitables = [] for attr_name in ["data", "indices", "indptr"]: attr = getattr(value, attr_name) dtype = indptr_dtype if attr_name == "indptr" else attr.dtype @@ -733,7 +709,9 @@ def write_sparse_compressed( attr_name, shape=attr.shape, dtype=dtype, **dataset_kwargs ) # see https://github.com/zarr-developers/zarr-python/discussions/2712 - arr[...] = attr[...] + awaitables.append(arr._async_array.setitem(Ellipsis, attr[...])) + if len(awaitables) > 0: + await asyncio.gather(*awaitables) write_csr = partial(write_sparse_compressed, fmt="csr") @@ -774,7 +752,7 @@ def write_sparse_compressed( @_REGISTRY.register_write(H5Group, _CSCDataset, IOSpec("", "0.1.0")) @_REGISTRY.register_write(ZarrGroup, _CSRDataset, IOSpec("", "0.1.0")) @_REGISTRY.register_write(ZarrGroup, _CSCDataset, IOSpec("", "0.1.0")) -def write_sparse_dataset( +async def write_sparse_dataset( f: GroupStorageType, k: str, elem: _CSCDataset | _CSRDataset, @@ -782,7 +760,7 @@ def write_sparse_dataset( _writer: Writer, dataset_kwargs: Mapping[str, Any] = MappingProxyType({}), ): - write_sparse_compressed( + await write_sparse_compressed( f, k, elem._to_backed(), @@ -809,8 +787,10 @@ def write_sparse_dataset( @_REGISTRY.register_write( ZarrGroup, (DaskArray, CupyCSCMatrix), IOSpec("csc_matrix", "0.1.0") ) -def write_cupy_dask_sparse(f, k, elem, _writer, dataset_kwargs=MappingProxyType({})): - _writer.write_elem( +async def write_cupy_dask_sparse( + f, k, elem, _writer, dataset_kwargs=MappingProxyType({}) +): + await _writer.write_elem_async( f, k, elem.map_blocks(lambda x: x.get(), dtype=elem.dtype, meta=elem._meta.get()), @@ -830,7 +810,7 @@ def write_cupy_dask_sparse(f, k, elem, _writer, dataset_kwargs=MappingProxyType( @_REGISTRY.register_write( ZarrGroup, (DaskArray, sparse.csc_matrix), IOSpec("csc_matrix", "0.1.0") ) -def write_dask_sparse( +async def write_dask_sparse( f: GroupStorageType, k: str, elem: DaskArray, @@ -862,7 +842,7 @@ def chunk_slice(start: int, stop: int) -> tuple[slice | None, slice | None]: chunk_start = 0 chunk_stop = axis_chunks[0] - _writer.write_elem( + await _writer.write_elem_async( f, k, as_int64_indices(elem[chunk_slice(chunk_start, chunk_stop)].compute()), @@ -882,16 +862,8 @@ def chunk_slice(start: int, stop: int) -> tuple[slice | None, slice | None]: @_REGISTRY.register_read(H5Group, IOSpec("csr_matrix", "0.1.0")) @_REGISTRY.register_read(ZarrGroup, IOSpec("csc_matrix", "0.1.0")) @_REGISTRY.register_read(ZarrGroup, IOSpec("csr_matrix", "0.1.0")) -def read_sparse(elem: GroupStorageType, *, _reader: Reader) -> CSMatrix | CSArray: - return sparse_dataset(elem).to_memory() - - -@_REGISTRY.register_read_partial(H5Group, IOSpec("csc_matrix", "0.1.0")) -@_REGISTRY.register_read_partial(H5Group, IOSpec("csr_matrix", "0.1.0")) -@_REGISTRY.register_read_partial(ZarrGroup, IOSpec("csc_matrix", "0.1.0")) -@_REGISTRY.register_read_partial(ZarrGroup, IOSpec("csr_matrix", "0.1.0")) -def read_sparse_partial(elem, *, items=None, indices=(slice(None), slice(None))): - return sparse_dataset(elem)[indices] +async def read_sparse(elem: GroupStorageType, *, _reader: Reader) -> CSMatrix | CSArray: + return await sparse_dataset(elem).to_memory_async() ################# @@ -907,7 +879,7 @@ def read_sparse_partial(elem, *, items=None, indices=(slice(None), slice(None))) @_REGISTRY.register_write( ZarrGroup, views.AwkwardArrayView, IOSpec("awkward-array", "0.1.0") ) -def write_awkward( +async def write_awkward( f: GroupStorageType, k: str, v: views.AwkwardArrayView | AwkArray, @@ -924,18 +896,29 @@ def write_awkward( form, length, container = ak.to_buffers(ak.to_packed(v)) group.attrs["length"] = length group.attrs["form"] = form.to_json() - for k, v in container.items(): - _writer.write_elem(group, k, v, dataset_kwargs=dataset_kwargs) + await asyncio.gather( + *[ + _writer.write_elem_async(group, k, v, dataset_kwargs=dataset_kwargs) + for k, v in container.items() + ] + ) @_REGISTRY.register_read(H5Group, IOSpec("awkward-array", "0.1.0")) @_REGISTRY.register_read(ZarrGroup, IOSpec("awkward-array", "0.1.0")) -def read_awkward(elem: GroupStorageType, *, _reader: Reader) -> AwkArray: +async def read_awkward(elem: GroupStorageType, *, _reader: Reader) -> AwkArray: from anndata.compat import awkward as ak form = _read_attr(elem.attrs, "form") length = _read_attr(elem.attrs, "length") - container = {k: _reader.read_elem(elem[k]) for k in elem.keys()} + container = dict( + await asyncio.gather( + *( + sync_async_to_async(k, _reader.read_elem_async(elem[k])) + for k in elem.keys() + ) + ) + ) return ak.from_buffers(form, int(length), container) @@ -949,7 +932,7 @@ def read_awkward(elem: GroupStorageType, *, _reader: Reader) -> AwkArray: @_REGISTRY.register_write(H5Group, pd.DataFrame, IOSpec("dataframe", "0.2.0")) @_REGISTRY.register_write(ZarrGroup, views.DataFrameView, IOSpec("dataframe", "0.2.0")) @_REGISTRY.register_write(ZarrGroup, pd.DataFrame, IOSpec("dataframe", "0.2.0")) -def write_dataframe( +async def write_dataframe( f: GroupStorageType, key: str, df: views.DataFrameView | pd.DataFrame, @@ -988,47 +971,36 @@ def write_dataframe( # ._values is "the best" array representation. It's the true array backing the # object, where `.values` is always a np.ndarray and .array is always a pandas # array. - _writer.write_elem( - group, index_name, df.index._values, dataset_kwargs=dataset_kwargs - ) - for colname, series in df.items(): - # TODO: this should write the "true" representation of the series (i.e. the underlying array or ndarray depending) - _writer.write_elem( + awaitables = [ + _writer.write_elem_async( + group, index_name, df.index._values, dataset_kwargs=dataset_kwargs + ) + ] + # TODO: this should write the "true" representation of the series (i.e. the underlying array or ndarray depending) + awaitables += [ + _writer.write_elem_async( group, colname, series._values, dataset_kwargs=dataset_kwargs ) + for colname, series in df.items() + ] + await asyncio.gather(*awaitables) @_REGISTRY.register_read(H5Group, IOSpec("dataframe", "0.2.0")) @_REGISTRY.register_read(ZarrGroup, IOSpec("dataframe", "0.2.0")) -def read_dataframe(elem: GroupStorageType, *, _reader: Reader) -> pd.DataFrame: +async def read_dataframe(elem: GroupStorageType, *, _reader: Reader) -> pd.DataFrame: columns = list(_read_attr(elem.attrs, "column-order")) idx_key = _read_attr(elem.attrs, "_index") df = pd.DataFrame( - {k: _reader.read_elem(elem[k]) for k in columns}, - index=_reader.read_elem(elem[idx_key]), - columns=columns if len(columns) else None, - ) - if idx_key != "_index": - df.index.name = idx_key - return df - - -# TODO: Figure out what indices is allowed to be at each element -@_REGISTRY.register_read_partial(H5Group, IOSpec("dataframe", "0.2.0")) -@_REGISTRY.register_read_partial(ZarrGroup, IOSpec("dataframe", "0.2.0")) -def read_dataframe_partial( - elem, *, items=None, indices=(slice(None, None), slice(None, None)) -): - if items is not None: - columns = [ - col for col in _read_attr(elem.attrs, "column-order") if col in items - ] - else: - columns = list(_read_attr(elem.attrs, "column-order")) - idx_key = _read_attr(elem.attrs, "_index") - df = pd.DataFrame( - {k: read_elem_partial(elem[k], indices=indices[0]) for k in columns}, - index=read_elem_partial(elem[idx_key], indices=indices[0]), + dict( + await asyncio.gather( + *( + sync_async_to_async(k, read_series(elem[k], _reader)) + for k in columns + ) + ) + ), + index=await _reader.read_elem_async(elem[idx_key]), columns=columns if len(columns) else None, ) if idx_key != "_index": @@ -1041,12 +1013,21 @@ def read_dataframe_partial( @_REGISTRY.register_read(H5Group, IOSpec("dataframe", "0.1.0")) @_REGISTRY.register_read(ZarrGroup, IOSpec("dataframe", "0.1.0")) -def read_dataframe_0_1_0(elem: GroupStorageType, *, _reader: Reader) -> pd.DataFrame: +async def read_dataframe_0_1_0( + elem: GroupStorageType, *, _reader: Reader +) -> pd.DataFrame: columns = _read_attr(elem.attrs, "column-order") idx_key = _read_attr(elem.attrs, "_index") df = pd.DataFrame( - {k: read_series(elem[k]) for k in columns}, - index=read_series(elem[idx_key]), + dict( + await asyncio.gather( + *( + sync_async_to_async(k, read_series(elem[k], _reader)) + for k in columns + ) + ) + ), + index=await read_series(elem[idx_key], _reader), columns=columns if len(columns) else None, ) if idx_key != "_index": @@ -1054,7 +1035,9 @@ def read_dataframe_0_1_0(elem: GroupStorageType, *, _reader: Reader) -> pd.DataF return df -def read_series(dataset: h5py.Dataset) -> np.ndarray | pd.Categorical: +async def read_series( + dataset: h5py.Dataset, _reader: Reader +) -> np.ndarray | pd.Categorical: # For reading older dataframes if "categories" in dataset.attrs: if isinstance(dataset, ZarrArray): @@ -1065,25 +1048,16 @@ def read_series(dataset: h5py.Dataset) -> np.ndarray | pd.Categorical: else: parent = dataset.parent categories_dset = parent[_read_attr(dataset.attrs, "categories")] - categories = read_elem(categories_dset) - ordered = bool(_read_attr(categories_dset.attrs, "ordered", default=False)) - return pd.Categorical.from_codes( - read_elem(dataset), categories, ordered=ordered + categories, codes = await asyncio.gather( + *( + _reader.read_elem_async(categories_dset), + _reader.read_elem_async(dataset), + ) ) + ordered = bool(_read_attr(categories_dset.attrs, "ordered", default=False)) + return pd.Categorical.from_codes(codes, categories, ordered=ordered) else: - return read_elem(dataset) - - -@_REGISTRY.register_read_partial(H5Group, IOSpec("dataframe", "0.1.0")) -@_REGISTRY.register_read_partial(ZarrGroup, IOSpec("dataframe", "0.1.0")) -def read_partial_dataframe_0_1_0( - elem, *, items=None, indices=(slice(None), slice(None)) -): - if items is None: - items = slice(None) - else: - items = list(items) - return read_elem(elem)[items].iloc[indices[0]] + return await _reader.read_elem_async(dataset) ############### @@ -1093,7 +1067,7 @@ def read_partial_dataframe_0_1_0( @_REGISTRY.register_write(H5Group, pd.Categorical, IOSpec("categorical", "0.2.0")) @_REGISTRY.register_write(ZarrGroup, pd.Categorical, IOSpec("categorical", "0.2.0")) -def write_categorical( +async def write_categorical( f: GroupStorageType, k: str, v: pd.Categorical, @@ -1104,29 +1078,25 @@ def write_categorical( g = f.require_group(k) g.attrs["ordered"] = bool(v.ordered) - _writer.write_elem(g, "codes", v.codes, dataset_kwargs=dataset_kwargs) - _writer.write_elem( - g, "categories", v.categories._values, dataset_kwargs=dataset_kwargs + await asyncio.gather( + _writer.write_elem_async(g, "codes", v.codes, dataset_kwargs=dataset_kwargs), + _writer.write_elem_async( + g, "categories", v.categories._values, dataset_kwargs=dataset_kwargs + ), ) @_REGISTRY.register_read(H5Group, IOSpec("categorical", "0.2.0")) @_REGISTRY.register_read(ZarrGroup, IOSpec("categorical", "0.2.0")) -def read_categorical(elem: GroupStorageType, *, _reader: Reader) -> pd.Categorical: - return pd.Categorical.from_codes( - codes=_reader.read_elem(elem["codes"]), - categories=_reader.read_elem(elem["categories"]), - ordered=bool(_read_attr(elem.attrs, "ordered")), +async def read_categorical( + elem: GroupStorageType, *, _reader: Reader +) -> pd.Categorical: + codes, categories = await asyncio.gather( + *(_reader.read_elem_async(elem[k]) for k in ["codes", "categories"]) ) - - -@_REGISTRY.register_read_partial(H5Group, IOSpec("categorical", "0.2.0")) -@_REGISTRY.register_read_partial(ZarrGroup, IOSpec("categorical", "0.2.0")) -def read_partial_categorical(elem, *, items=None, indices=(slice(None),)): + ordered = bool(_read_attr(elem.attrs, "ordered")) return pd.Categorical.from_codes( - codes=read_elem_partial(elem["codes"], indices=indices), - categories=read_elem(elem["categories"]), - ordered=bool(_read_attr(elem.attrs, "ordered")), + codes=codes, categories=categories, ordered=ordered ) @@ -1153,7 +1123,7 @@ def read_partial_categorical(elem, *, items=None, indices=(slice(None),)): @_REGISTRY.register_write( ZarrGroup, pd.arrays.StringArray, IOSpec("nullable-string-array", "0.1.0") ) -def write_nullable( +async def write_nullable( f: GroupStorageType, k: str, v: pd.arrays.IntegerArray | pd.arrays.BooleanArray | pd.arrays.StringArray, @@ -1178,11 +1148,13 @@ def write_nullable( if isinstance(v, pd.arrays.StringArray) else v.to_numpy(na_value=0, dtype=v.dtype.numpy_dtype) ) - _writer.write_elem(g, "values", values, dataset_kwargs=dataset_kwargs) - _writer.write_elem(g, "mask", v.isna(), dataset_kwargs=dataset_kwargs) + await asyncio.gather( + _writer.write_elem_async(g, "values", values, dataset_kwargs=dataset_kwargs), + _writer.write_elem_async(g, "mask", v.isna(), dataset_kwargs=dataset_kwargs), + ) -def _read_nullable( +async def _read_nullable( elem: GroupStorageType, *, _reader: Reader, @@ -1191,10 +1163,10 @@ def _read_nullable( [NDArray[np.number], NDArray[np.bool_]], pd.api.extensions.ExtensionArray ], ) -> pd.api.extensions.ExtensionArray: - return array_type( - _reader.read_elem(elem["values"]), - mask=_reader.read_elem(elem["mask"]), + values, mask = await asyncio.gather( + *(_reader.read_elem_async(elem[k]) for k in ["values", "mask"]) ) + return array_type(values, mask=mask) def _string_array( @@ -1235,9 +1207,11 @@ def _string_array( @_REGISTRY.register_read(H5Array, IOSpec("numeric-scalar", "0.2.0")) @_REGISTRY.register_read(ZarrArray, IOSpec("numeric-scalar", "0.2.0")) -def read_scalar(elem: ArrayStorageType, *, _reader: Reader) -> np.number: +async def read_scalar(elem: ArrayStorageType, *, _reader: Reader) -> np.number: # TODO: `item` ensures the return is in fact a scalar (needed after zarr v3 which now returns a 1 elem array) # https://github.com/zarr-developers/zarr-python/issues/2713 + if not is_zarr_v2() and isinstance(elem, ZarrArray): + return (await elem._async_array.getitem(())).item() return elem[()].item() @@ -1257,7 +1231,7 @@ def _remove_scalar_compression_args(dataset_kwargs: Mapping[str, Any]) -> dict: return dataset_kwargs -def write_scalar_zarr( +async def write_scalar_zarr( f: ZarrGroup, key: str, value, @@ -1290,7 +1264,7 @@ def write_scalar_zarr( a[...] = np.array(value) -def write_hdf5_scalar( +async def write_hdf5_scalar( f: H5Group, key: str, value, @@ -1324,12 +1298,14 @@ def write_hdf5_scalar( @_REGISTRY.register_read(H5Array, IOSpec("string", "0.2.0")) -def read_hdf5_string(elem: H5Array, *, _reader: Reader) -> str: +async def read_hdf5_string(elem: H5Array, *, _reader: Reader) -> str: return elem.asstr()[()] @_REGISTRY.register_read(ZarrArray, IOSpec("string", "0.2.0")) -def read_zarr_string(elem: ZarrArray, *, _reader: Reader) -> str: +async def read_zarr_string(elem: ZarrArray, *, _reader: Reader) -> str: + if not is_zarr_v2() and isinstance(elem, ZarrArray): + return str(await elem._async_array.getitem(())) return str(elem[()]) @@ -1339,7 +1315,7 @@ def read_zarr_string(elem: ZarrArray, *, _reader: Reader) -> str: @_REGISTRY.register_write(H5Group, np.str_, IOSpec("string", "0.2.0")) @_REGISTRY.register_write(H5Group, str, IOSpec("string", "0.2.0")) -def write_string( +async def write_string( f: H5Group, k: str, v: np.str_ | str, @@ -1357,7 +1333,7 @@ def write_string( # @_REGISTRY.register_write(np.bytes_, IOSpec("bytes", "0.2.0")) # @_REGISTRY.register_write(bytes, IOSpec("bytes", "0.2.0")) -# def write_string(f, k, v, dataset_kwargs): +# async def write_string(f, k, v, dataset_kwargs): # if "compression" in dataset_kwargs: # dataset_kwargs = dict(dataset_kwargs) # dataset_kwargs.pop("compression") diff --git a/src/anndata/_io/specs/registry.py b/src/anndata/_io/specs/registry.py index aeda7775c..9a61277c7 100644 --- a/src/anndata/_io/specs/registry.py +++ b/src/anndata/_io/specs/registry.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import inspect import warnings from collections.abc import Mapping @@ -9,7 +10,14 @@ from typing import TYPE_CHECKING, Generic, TypeVar from anndata._io.utils import report_read_key_on_error, report_write_key_on_error -from anndata._types import Read, ReadLazy, _ReadInternal, _ReadLazyInternal +from anndata._types import ( + Read, + ReadAsync, + ReadLazy, + _ReadAsyncInternal, + _ReadInternal, + _ReadLazyInternal, +) from anndata.compat import DaskArray, ZarrGroup, _read_attr, is_zarr_v2 if TYPE_CHECKING: @@ -71,8 +79,10 @@ def _from_read_parts( def write_spec(spec: IOSpec): def decorator(func: W) -> W: @wraps(func) - def wrapper(g: GroupStorageType, k: str, *args, **kwargs): - result = func(g, k, *args, **kwargs) + async def wrapper(g: GroupStorageType, k: str, *args, **kwargs): + result = await func(g, k, *args, **kwargs) + if k not in g and isinstance(g, ZarrGroup) and not is_zarr_v2(): + g.require_group(k) g[k].attrs.setdefault("encoding-type", spec.encoding_type) g[k].attrs.setdefault("encoding-version", spec.encoding_version) return result @@ -82,14 +92,13 @@ def wrapper(g: GroupStorageType, k: str, *args, **kwargs): return decorator -_R = TypeVar("_R", _ReadInternal, _ReadLazyInternal) -R = TypeVar("R", Read, ReadLazy) +_R = TypeVar("_R", _ReadInternal, _ReadAsyncInternal, _ReadLazyInternal) +R = TypeVar("R", Read, ReadAsync, ReadLazy) class IORegistry(Generic[_R, R]): def __init__(self): self.read: dict[tuple[type, IOSpec, frozenset[str]], _R] = {} - self.read_partial: dict[tuple[type, IOSpec, frozenset[str]], Callable] = {} self.write: dict[ tuple[type, type | tuple[type, str], frozenset[str]], _WriteInternal ] = {} @@ -181,29 +190,6 @@ def has_read( ) -> bool: return (src_type, spec, modifiers) in self.read - def register_read_partial( - self, - src_type: type, - spec: IOSpec | Mapping[str, str], - modifiers: Iterable[str] = frozenset(), - ): - spec = proc_spec(spec) - modifiers = frozenset(modifiers) - - def _register(func): - self.read_partial[(src_type, spec, modifiers)] = func - return func - - return _register - - def get_partial_read( - self, src_type: type, spec: IOSpec, modifiers: frozenset[str] = frozenset() - ): - if (src_type, spec, modifiers) in self.read_partial: - return self.read_partial[(src_type, spec, modifiers)] - name = "read_partial" - raise IORegistryError._from_read_parts(name, self.read_partial, src_type, spec) - def get_spec(self, elem: Any) -> IOSpec: if isinstance(elem, DaskArray): if (typ_meta := (DaskArray, type(elem._meta))) in self.write_specs: @@ -214,7 +200,9 @@ def get_spec(self, elem: Any) -> IOSpec: return self.write_specs[type(elem)] -_REGISTRY: IORegistry[_ReadInternal, Read] = IORegistry() +_REGISTRY: IORegistry[_ReadInternal | _ReadAsyncInternal, Read | ReadAsync] = ( + IORegistry() +) _LAZY_REGISTRY: IORegistry[_ReadLazyInternal, ReadLazy] = IORegistry() @@ -263,13 +251,15 @@ def _iter_patterns( class Reader: def __init__( - self, registry: IORegistry, callback: ReadCallback | None = None + self, + registry: IORegistry, + callback: ReadCallback | None = None, ) -> None: self.registry = registry self.callback = callback @report_read_key_on_error - def read_elem( + async def read_elem_async( self, elem: StorageType, modifiers: frozenset[str] = frozenset(), @@ -277,12 +267,12 @@ def read_elem( """Read an element from a store. See exported function for more details.""" iospec = get_spec(elem) - read_func: Read = self.registry.get_read( + read_func: ReadAsync = self.registry.get_read( type(elem), iospec, modifiers, reader=self ) if self.callback is None: - return read_func(elem) - return self.callback(read_func, elem.name, elem, iospec=iospec) + return await read_func(elem) + return await self.callback(read_func, elem.name, elem, iospec=iospec) class LazyReader(Reader): @@ -333,7 +323,7 @@ def find_write_func( return self.registry.get_write(dest_type, type(elem), modifiers, writer=self) @report_write_key_on_error - def write_elem( + async def write_elem_async( self, store: GroupStorageType, k: str, @@ -361,9 +351,7 @@ def write_elem( if k == "/": if isinstance(store, ZarrGroup) and not is_zarr_v2(): - import asyncio - - asyncio.run(store.store.clear()) + await store.store.clear() else: store.clear() elif k in store: @@ -372,8 +360,8 @@ def write_elem( write_func = self.find_write_func(dest_type, elem, modifiers) if self.callback is None: - return write_func(store, k, elem, dataset_kwargs=dataset_kwargs) - return self.callback( + return await write_func(store, k, elem, dataset_kwargs=dataset_kwargs) + return await self.callback( write_func, store, k, @@ -395,7 +383,22 @@ def read_elem(elem: StorageType) -> RWAble: elem The stored element. """ - return Reader(_REGISTRY).read_elem(elem) + return asyncio.run(read_elem_async(elem)) + + +async def read_elem_async(elem: StorageType) -> RWAble: + """ + Read an element from a store asynchronously. + + Assumes that the element is encoded using the anndata encoding. This function will + determine the encoded type using the encoding metadata stored in elem's attributes. + + Params + ------ + elem + The stored element. + """ + return await Reader(_REGISTRY).read_elem_async(elem) def read_elem_lazy( @@ -501,19 +504,33 @@ def write_elem( Keyword arguments to pass to the stores dataset creation function. E.g. for zarr this would be `chunks`, `compressor`. """ - Writer(_REGISTRY).write_elem(store, k, elem, dataset_kwargs=dataset_kwargs) + return asyncio.run(write_elem_async(store, k, elem, dataset_kwargs=dataset_kwargs)) -# TODO: If all items would be read, just call normal read method -def read_elem_partial( - elem, +async def write_elem_async( + store: GroupStorageType, + k: str, + elem: RWAble, *, - items=None, - indices=(slice(None), slice(None)), - modifiers: frozenset[str] = frozenset(), -): - """Read part of an element from an on disk store.""" - read_partial = _REGISTRY.get_partial_read( - type(elem), get_spec(elem), frozenset(modifiers) + dataset_kwargs: Mapping[str, Any] = MappingProxyType({}), +) -> None: + """ + Write an element to a storage group using anndata encoding. + + Params + ------ + store + The group to write to. + k + The key to write to in the group. Note that absolute paths will be written + from the root. + elem + The element to write. Typically an in-memory object, e.g. an AnnData, pandas + dataframe, scipy sparse matrix, etc. + dataset_kwargs + Keyword arguments to pass to the stores dataset creation function. + E.g. for zarr this would be `chunks`, `compressor`. + """ + return await Writer(_REGISTRY).write_elem_async( + store, k, elem, dataset_kwargs=dataset_kwargs ) - return read_partial(elem, items=items, indices=indices) diff --git a/src/anndata/_io/utils.py b/src/anndata/_io/utils.py index 12b86bb27..75ee1166e 100644 --- a/src/anndata/_io/utils.py +++ b/src/anndata/_io/utils.py @@ -1,5 +1,6 @@ from __future__ import annotations +import inspect from functools import WRAPPER_ASSIGNMENTS, wraps from itertools import pairwise from typing import TYPE_CHECKING, cast @@ -198,25 +199,46 @@ def report_read_key_on_error(func): >>> z["X"] = np.array([1, 2, 3]) >>> read_arr(z["X"]) # doctest: +SKIP """ + if inspect.iscoroutinefunction(func): + + @wraps(func) + async def func_wrapper(*args, **kwargs): + from anndata._io.specs import Reader + + # Figure out signature (method vs function) by going through args + for arg in args: + if not isinstance(arg, Reader): + store = cast("Storage", arg) + break + else: + msg = "No element found in args." + raise ValueError(msg) + try: + return await func(*args, **kwargs) + except Exception as e: + path, key = _get_display_path(store).rsplit("/", 1) + add_key_note(e, store, path or "/", key, "read") + raise + else: - @wraps(func) - def func_wrapper(*args, **kwargs): - from anndata._io.specs import Reader - - # Figure out signature (method vs function) by going through args - for arg in args: - if not isinstance(arg, Reader): - store = cast("Storage", arg) - break - else: - msg = "No element found in args." - raise ValueError(msg) - try: - return func(*args, **kwargs) - except Exception as e: - path, key = _get_display_path(store).rsplit("/", 1) - add_key_note(e, store, path or "/", key, "read") - raise + @wraps(func) + def func_wrapper(*args, **kwargs): + from anndata._io.specs import Reader + + # Figure out signature (method vs function) by going through args + for arg in args: + if not isinstance(arg, Reader): + store = cast("Storage", arg) + break + else: + msg = "No element found in args." + raise ValueError(msg) + try: + return func(*args, **kwargs) + except Exception as e: + path, key = _get_display_path(store).rsplit("/", 1) + add_key_note(e, store, path or "/", key, "read") + raise return func_wrapper @@ -237,7 +259,7 @@ def report_write_key_on_error(func): """ @wraps(func) - def func_wrapper(*args, **kwargs): + async def func_wrapper(*args, **kwargs): from anndata._io.specs import Writer # Figure out signature (method vs function) by going through args @@ -249,7 +271,7 @@ def func_wrapper(*args, **kwargs): msg = "No element found in args." raise ValueError(msg) try: - return func(*args, **kwargs) + return await func(*args, **kwargs) except Exception as e: path = _get_display_path(store) add_key_note(e, store, path, key, "writ") @@ -263,7 +285,7 @@ def func_wrapper(*args, **kwargs): # ------------------------------------------------------------------------------- -def _read_legacy_raw( +async def _read_legacy_raw( f: ZarrGroup | H5Group, modern_raw, # TODO: type read_df: Callable, @@ -284,11 +306,11 @@ def _read_legacy_raw( raw = {} if "X" in attrs and "raw.X" in f: - raw["X"] = read_attr(f["raw.X"]) + raw["X"] = await read_attr(f["raw.X"]) if "var" in attrs and "raw.var" in f: - raw["var"] = read_df(f["raw.var"]) # Backwards compat + raw["var"] = await read_df(f["raw.var"]) # Backwards compat if "varm" in attrs and "raw.varm" in f: - raw["varm"] = read_attr(f["raw.varm"]) + raw["varm"] = await read_attr(f["raw.varm"]) return raw @@ -298,7 +320,7 @@ def zero_dim_array_as_scalar(func: _WriteInternal): """ @wraps(func, assigned=WRAPPER_ASSIGNMENTS + ("__defaults__", "__kwdefaults__")) - def func_wrapper( + async def func_wrapper( f: StorageType, k: str, elem: ContravariantRWAble, @@ -307,8 +329,10 @@ def func_wrapper( dataset_kwargs: Mapping[str, Any], ): if elem.shape == (): - _writer.write_elem(f, k, elem[()], dataset_kwargs=dataset_kwargs) + await _writer.write_elem_async( + f, k, elem[()], dataset_kwargs=dataset_kwargs + ) else: - func(f, k, elem, _writer=_writer, dataset_kwargs=dataset_kwargs) + await func(f, k, elem, _writer=_writer, dataset_kwargs=dataset_kwargs) return func_wrapper diff --git a/src/anndata/_io/zarr.py b/src/anndata/_io/zarr.py index 4718a8940..579a2afd6 100644 --- a/src/anndata/_io/zarr.py +++ b/src/anndata/_io/zarr.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio from pathlib import Path from typing import TYPE_CHECKING, TypeVar from warnings import warn @@ -14,7 +15,8 @@ from .._warnings import OldFormatWarning from ..compat import _clean_uns, _from_fixed_length_strings, is_zarr_v2 from ..experimental import read_dispatched, write_dispatched -from .specs import read_elem +from .specs import read_elem_async +from .specs.methods import sync_async_to_async from .utils import _read_legacy_raw, report_read_key_on_error if TYPE_CHECKING: @@ -45,16 +47,18 @@ def write_zarr( f.attrs.setdefault("encoding-type", "anndata") f.attrs.setdefault("encoding-version", "0.1.0") - def callback(func, s, k: str, elem, dataset_kwargs, iospec): + async def callback(func, s, k: str, elem, dataset_kwargs, iospec): if ( chunks is not None and not isinstance(elem, sparse.spmatrix) and k.lstrip("/") == "X" ): dataset_kwargs = dict(dataset_kwargs, chunks=chunks) - func(s, k, elem, dataset_kwargs=dataset_kwargs) + await func(s, k, elem, dataset_kwargs=dataset_kwargs) - write_dispatched(f, "/", adata, callback=callback, dataset_kwargs=ds_kwargs) + asyncio.run( + write_dispatched(f, "/", adata, callback=callback, dataset_kwargs=ds_kwargs) + ) if is_zarr_v2(): zarr.convenience.consolidate_metadata(f.store) else: @@ -79,29 +83,42 @@ def read_zarr(store: str | Path | MutableMapping | zarr.Group) -> AnnData: f = zarr.open(store, mode="r") # Read with handling for backwards compat - def callback(func, elem_name: str, elem, iospec): + async def callback(func, elem_name: str, elem, iospec): if iospec.encoding_type == "anndata" or elem_name.endswith("/"): - return AnnData( - **{ - k: read_dispatched(v, callback) - for k, v in dict(elem).items() - if not k.startswith("raw.") - } + args = dict( + await asyncio.gather( + *( + # This is covering up backwards compat in the anndata initializer + # In most cases we should be able to call `func(elen[k])` instead + sync_async_to_async(k, read_dispatched(elem[k], callback)) + for k in elem.keys() + if not k.startswith("raw.") + ) + ) ) + return AnnData(**args) elif elem_name.startswith("/raw."): return None elif elem_name in {"/obs", "/var"}: - return read_dataframe(elem) + return await read_dataframe(elem) elif elem_name == "/raw": # Backwards compat - return _read_legacy_raw(f, func(elem), read_dataframe, func) - return func(elem) + return await _read_legacy_raw( + f, await func(elem), read_dataframe, read_elem_async + ) + return await func(elem) - adata = read_dispatched(f, callback=callback) + adata = asyncio.run(read_dispatched(f, callback=callback)) # Backwards compat (should figure out which version) if "raw.X" in f: - raw = AnnData(**_read_legacy_raw(f, adata.raw, read_dataframe, read_elem)) + raw = AnnData( + **asyncio.run( + asyncio.gather( + _read_legacy_raw(f, adata.raw, read_dataframe, read_elem_async) + ) + ) + ) raw.obs_names = adata.obs_names adata.raw = raw @@ -113,9 +130,12 @@ def callback(func, elem_name: str, elem, iospec): @report_read_key_on_error -def read_dataset(dataset: zarr.Array): +async def read_dataset(dataset: zarr.Array): """Legacy method for reading datasets without encoding_type.""" - value = dataset[...] + if is_zarr_v2(): + value = dataset[...] + else: + value = await dataset._async_array.getitem(()) if not hasattr(value, "dtype"): return value elif isinstance(value.dtype, str): @@ -131,7 +151,7 @@ def read_dataset(dataset: zarr.Array): @report_read_key_on_error -def read_dataframe_legacy(dataset: zarr.Array) -> pd.DataFrame: +async def read_dataframe_legacy(dataset: zarr.Array) -> pd.DataFrame: """Reads old format of dataframes""" # NOTE: Likely that categoricals need to be removed from uns warn( @@ -139,18 +159,22 @@ def read_dataframe_legacy(dataset: zarr.Array) -> pd.DataFrame: "Consider rewriting it.", OldFormatWarning, ) - df = pd.DataFrame(_from_fixed_length_strings(dataset[()])) + if is_zarr_v2(): + data = dataset[...] + else: + data = await dataset._async_array.getitem(()) + df = pd.DataFrame(_from_fixed_length_strings(data)) df.set_index(df.columns[0], inplace=True) return df @report_read_key_on_error -def read_dataframe(group: zarr.Group | zarr.Array) -> pd.DataFrame: +async def read_dataframe(group: zarr.Group | zarr.Array) -> pd.DataFrame: # Fast paths if isinstance(group, zarr.Array): - return read_dataframe_legacy(group) + return await read_dataframe_legacy(group) else: - return read_elem(group) + return await read_elem_async(group) def open_write_group( diff --git a/src/anndata/_types.py b/src/anndata/_types.py index 56105c57c..447ac76b8 100644 --- a/src/anndata/_types.py +++ b/src/anndata/_types.py @@ -6,12 +6,7 @@ from typing import TYPE_CHECKING, Literal, Protocol, TypeVar -from .compat import ( - H5Array, - H5Group, - ZarrArray, - ZarrGroup, -) +from .compat import H5Array, H5Group, ZarrArray, ZarrGroup from .typing import RWAble if TYPE_CHECKING: @@ -31,6 +26,7 @@ "GroupStorageType", "StorageType", "_ReadInternal", + "_ReadAsyncInternal", "_ReadLazyInternal", "_WriteInternal", ] @@ -52,6 +48,10 @@ class _ReadInternal(Protocol[SCon, CovariantRWAble]): def __call__(self, elem: SCon, *, _reader: Reader) -> CovariantRWAble: ... +class _ReadAsyncInternal(Protocol[SCon, CovariantRWAble]): + async def __call__(self, elem: SCon, *, _reader: Reader) -> CovariantRWAble: ... + + class _ReadLazyInternal(Protocol[SCon]): def __call__( self, elem: SCon, *, _reader: LazyReader, chunks: tuple[int, ...] | None = None @@ -59,7 +59,7 @@ def __call__( class Read(Protocol[SCon, CovariantRWAble]): - def __call__(self, elem: SCon) -> CovariantRWAble: + async def __call__(self, elem: SCon) -> CovariantRWAble: """Low-level reading function for an element. Parameters @@ -73,6 +73,21 @@ def __call__(self, elem: SCon) -> CovariantRWAble: ... +class ReadAsync(Protocol[SCon, CovariantRWAble]): + async def __call__(self, elem: SCon) -> CovariantRWAble: + """Low-level reading function for an element asynchronously. + + Parameters + ---------- + elem + The element to read from. + Returns + ------- + The element read from the store. + """ + ... + + class ReadLazy(Protocol[SCon]): def __call__( self, elem: SCon, *, chunks: tuple[int, ...] | None = None @@ -105,7 +120,7 @@ def __call__( class Write(Protocol[ContravariantRWAble]): - def __call__( + async def __call__( self, f: StorageType, k: str, @@ -130,10 +145,10 @@ def __call__( class ReadCallback(Protocol[SCo, InvariantRWAble]): - def __call__( + async def __call__( self, /, - read_func: Read[SCo, InvariantRWAble], + read_func: ReadAsync[SCo, InvariantRWAble], elem_name: str, elem: StorageType, *, @@ -155,13 +170,13 @@ def __call__( Returns ------- - The element read from the store. + The element read from the store. """ ... class WriteCallback(Protocol[InvariantRWAble]): - def __call__( + async def __call__( self, /, write_func: Write[InvariantRWAble], diff --git a/src/anndata/experimental/__init__.py b/src/anndata/experimental/__init__.py index 2c233c6b6..8f6d0fb77 100644 --- a/src/anndata/experimental/__init__.py +++ b/src/anndata/experimental/__init__.py @@ -4,7 +4,7 @@ from typing import TYPE_CHECKING from .._io.specs import IOSpec, read_elem_lazy -from .._types import Read, ReadCallback, StorageType, Write, WriteCallback +from .._types import Read, ReadAsync, ReadCallback, StorageType, Write, WriteCallback from ..utils import module_get_attr_redirect from ._dispatch_io import read_dispatched, write_dispatched from .backed import read_lazy @@ -47,6 +47,7 @@ def __getattr__(attr_name: str) -> Any: "IOSpec", "concat_on_disk", "Read", + "ReadAsync", "read_lazy", "Write", "ReadCallback", diff --git a/src/anndata/experimental/_dispatch_io.py b/src/anndata/experimental/_dispatch_io.py index 53f94c453..4d7605818 100644 --- a/src/anndata/experimental/_dispatch_io.py +++ b/src/anndata/experimental/_dispatch_io.py @@ -16,7 +16,7 @@ from anndata.typing import RWAble -def read_dispatched( +async def read_dispatched( elem: StorageType, callback: ReadCallback, ) -> RWAble: @@ -39,10 +39,10 @@ def read_dispatched( reader = Reader(_REGISTRY, callback=callback) - return reader.read_elem(elem) + return await reader.read_elem_async(elem) -def write_dispatched( +async def write_dispatched( store: GroupStorageType, key: str, elem: RWAble, @@ -74,4 +74,4 @@ def write_dispatched( writer = Writer(_REGISTRY, callback=callback) - writer.write_elem(store, key, elem, dataset_kwargs=dataset_kwargs) + await writer.write_elem_async(store, key, elem, dataset_kwargs=dataset_kwargs) diff --git a/src/anndata/experimental/backed/_io.py b/src/anndata/experimental/backed/_io.py index e64919dba..2208a5ca6 100644 --- a/src/anndata/experimental/backed/_io.py +++ b/src/anndata/experimental/backed/_io.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import typing import warnings from pathlib import Path @@ -7,7 +8,8 @@ import h5py -from anndata._io.specs.registry import read_elem_lazy +from anndata._io.specs.methods import sync_async_to_async +from anndata._io.specs.registry import read_elem_async, read_elem_lazy from anndata._types import AnnDataElem from testing.anndata._doctest import doctest_needs @@ -114,7 +116,9 @@ def read_lazy( else: f = h5py.File(store, mode="r") - def callback(func: Read, /, elem_name: str, elem: StorageType, *, iospec: IOSpec): + async def callback( + func: Read, /, elem_name: str, elem: StorageType, *, iospec: IOSpec + ): if iospec.encoding_type in {"anndata", "raw"} or elem_name.endswith("/"): iter_object = ( dict(elem).items() @@ -128,7 +132,15 @@ def callback(func: Read, /, elem_name: str, elem: StorageType, *, iospec: IOSpec is not None # need to do this instead of `k in elem` to prevent unnecessary metadata accesses ) ) - return AnnData(**{k: read_dispatched(v, callback) for k, v in iter_object}) + args = dict( + await asyncio.gather( + *( + sync_async_to_async(k, read_dispatched(v, callback=callback)) + for k, v in iter_object + ) + ) + ) + return AnnData(**args) elif ( iospec.encoding_type in { @@ -145,14 +157,19 @@ def callback(func: Read, /, elem_name: str, elem: StorageType, *, iospec: IOSpec return read_elem_lazy(elem, use_range_index=not load_annotation_index) return read_elem_lazy(elem) elif iospec.encoding_type in {"awkward-array"}: - return read_dispatched(elem, None) + return await read_elem_async(elem) elif iospec.encoding_type == "dict": - return { - k: read_dispatched(v, callback=callback) for k, v in dict(elem).items() - } - return func(elem) + return dict( + await asyncio.gather( + *( + sync_async_to_async(k, read_dispatched(v, callback=callback)) + for k, v in dict(elem).items() + ) + ) + ) + return await func(elem) with settings.override(check_uniqueness=load_annotation_index): - adata = read_dispatched(f, callback=callback) + adata = asyncio.run(read_dispatched(f, callback=callback)) return adata diff --git a/src/anndata/experimental/merge.py b/src/anndata/experimental/merge.py index 1cb447f23..e62c79076 100644 --- a/src/anndata/experimental/merge.py +++ b/src/anndata/experimental/merge.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import os import shutil from collections.abc import Mapping @@ -26,6 +27,8 @@ ) from .._core.sparse_dataset import BaseCompressedSparseDataset, sparse_dataset from .._io.specs import read_elem, write_elem +from .._io.specs.methods import sync_async_to_async +from .._io.specs.registry import read_elem_async from ..compat import H5Array, H5Group, ZarrArray, ZarrGroup, _map_cat_to_str from . import read_dispatched @@ -142,19 +145,26 @@ def read_as_backed(group: ZarrGroup | H5Group): BaseCompressedSparseDataset, Array or EAGER_TYPES are encountered. """ - def callback(func, elem_name: str, elem, iospec): + async def callback(func, elem_name: str, elem, iospec): if iospec.encoding_type in SPARSE_MATRIX: return sparse_dataset(elem) elif iospec.encoding_type in EAGER_TYPES: - return read_elem(elem) + return await read_elem_async(elem) elif iospec.encoding_type == "array": return elem elif iospec.encoding_type == "dict": - return {k: read_as_backed(v) for k, v in dict(elem).items()} + return dict( + await asyncio.gather( + *( + sync_async_to_async(k, read_dispatched(v, callback=callback)) + for k, v in dict(elem).items() + ) + ) + ) else: - return func(elem) + return await func(elem) - return read_dispatched(group, callback=callback) + return asyncio.run(read_dispatched(group, callback=callback)) def _df_index(df: ZarrGroup | H5Group) -> pd.Index: diff --git a/src/anndata/io.py b/src/anndata/io.py index 55430a699..6c5027519 100644 --- a/src/anndata/io.py +++ b/src/anndata/io.py @@ -11,7 +11,7 @@ read_text, read_umi_tools, ) -from ._io.specs import read_elem, write_elem +from ._io.specs import read_elem, read_elem_async, write_elem, write_elem_async from ._io.write import write_csvs, write_loom from ._io.zarr import read_zarr, write_zarr @@ -30,6 +30,8 @@ "write_loom", "write_zarr", "write_elem", + "write_elem_async", "read_elem", + "read_elem_async", "sparse_dataset", ] diff --git a/tests/test_backed_sparse.py b/tests/test_backed_sparse.py index a48af8071..a2627914f 100644 --- a/tests/test_backed_sparse.py +++ b/tests/test_backed_sparse.py @@ -13,15 +13,9 @@ import anndata as ad from anndata._core.anndata import AnnData from anndata._core.sparse_dataset import sparse_dataset -from anndata._io.specs.registry import read_elem_lazy +from anndata._io.specs.registry import read_elem_lazy, write_elem_async from anndata._io.zarr import open_write_group -from anndata.compat import ( - CSArray, - CSMatrix, - DaskArray, - ZarrGroup, - is_zarr_v2, -) +from anndata.compat import CSArray, CSMatrix, DaskArray, ZarrGroup, is_zarr_v2 from anndata.experimental import read_dispatched from anndata.tests.helpers import AccessTrackingStore, assert_equal, subset_func @@ -55,54 +49,62 @@ def diskfmt(request): @pytest.fixture -def ondisk_equivalent_adata( +async def ondisk_equivalent_adata( tmp_path: Path, diskfmt: Literal["h5ad", "zarr"] ) -> tuple[AnnData, AnnData, AnnData, AnnData]: csr_path = tmp_path / f"csr.{diskfmt}" csc_path = tmp_path / f"csc.{diskfmt}" dense_path = tmp_path / f"dense.{diskfmt}" - write = lambda x, pth, **kwargs: getattr(x, f"write_{diskfmt}")(pth, **kwargs) - csr_mem = ad.AnnData(X=sparse.random(M, N, format="csr", density=0.1)) csc_mem = ad.AnnData(X=csr_mem.X.tocsc()) dense_mem = ad.AnnData(X=csr_mem.X.toarray()) - write(csr_mem, csr_path) - write(csc_mem, csc_path) + await write_elem_async( + open_write_group(csr_path) + if diskfmt == "zarr" + else h5py.File(csr_path, mode="w"), + "/", + csr_mem, + ) + await write_elem_async( + open_write_group(csc_path) + if diskfmt == "zarr" + else h5py.File(csc_path, mode="w"), + "/", + csc_mem, + ) # write(csr_mem, dense_path, as_dense="X") - write(dense_mem, dense_path) - if diskfmt == "h5ad": - csr_disk = ad.read_h5ad(csr_path, backed="r") - csc_disk = ad.read_h5ad(csc_path, backed="r") - dense_disk = ad.read_h5ad(dense_path, backed="r") - else: + await write_elem_async( + open_write_group(dense_path) + if diskfmt == "zarr" + else h5py.File(dense_path, mode="w"), + "/", + dense_mem, + ) - def read_zarr_backed(path): - path = str(path) + async def read_backed(path): + path = str(path) - f = zarr.open(path, mode="r") + f = ( + h5py.File(path, mode="r") + if diskfmt == "h5ad" + else zarr.open(path, mode="r") + ) - # Read with handling for backwards compat - def callback(func, elem_name, elem, iospec): - if iospec.encoding_type == "anndata" or elem_name.endswith("/"): - return AnnData( - **{ - k: read_dispatched(v, callback) - for k, v in dict(elem).items() - } - ) - if iospec.encoding_type in {"csc_matrix", "csr_matrix"}: - return sparse_dataset(elem) - return func(elem) + # Read with handling for backwards compat + async def callback(func, elem_name, elem, iospec): + if iospec.encoding_type in {"csc_matrix", "csr_matrix"}: + return sparse_dataset(elem) + return await func(elem) - adata = read_dispatched(f, callback=callback) + adata = await read_dispatched(f, callback=callback) - return adata + return adata - csr_disk = read_zarr_backed(csr_path) - csc_disk = read_zarr_backed(csc_path) - dense_disk = read_zarr_backed(dense_path) + csr_disk = await read_backed(csr_path) + csc_disk = await read_backed(csc_path) + dense_disk = await read_backed(dense_path) return csr_mem, csr_disk, csc_disk, dense_disk @@ -110,7 +112,7 @@ def callback(func, elem_name, elem, iospec): @pytest.mark.parametrize( "empty_mask", [[], np.zeros(M, dtype=bool)], ids=["empty_list", "empty_bool_mask"] ) -def test_empty_backed_indexing( +async def test_empty_backed_indexing( ondisk_equivalent_adata: tuple[AnnData, AnnData, AnnData, AnnData], empty_mask, ): @@ -126,7 +128,7 @@ def test_empty_backed_indexing( # assert_equal(csr_mem.X[empty_mask, empty_mask], csc_disk.X[empty_mask, empty_mask]) -def test_backed_indexing( +async def test_backed_indexing( ondisk_equivalent_adata: tuple[AnnData, AnnData, AnnData, AnnData], subset_func, subset_func2, @@ -144,7 +146,7 @@ def test_backed_indexing( assert_equal(csr_mem[:, var_idx].X, dense_disk[:, var_idx].X) -def test_backed_ellipsis_indexing( +async def test_backed_ellipsis_indexing( ondisk_equivalent_adata: tuple[AnnData, AnnData, AnnData, AnnData], ellipsis_index: tuple[EllipsisType | slice, ...] | EllipsisType, equivalent_ellipsis_index: tuple[slice, slice], @@ -200,7 +202,7 @@ def make_one_elem_mask(size: int) -> np.ndarray: ], ids=["randomized", "alternating_15", "alternating_5", "one_group", "one_elem"], ) -def test_consecutive_bool( +async def test_consecutive_bool( mocker: MockerFixture, ondisk_equivalent_adata: tuple[AnnData, AnnData, AnnData, AnnData], make_bool_mask: Callable[[int], np.ndarray], @@ -635,7 +637,7 @@ def test_anndata_sparse_compat(tmp_path: Path, diskfmt: Literal["h5ad", "zarr"]) assert_equal(adata.X, base) -def test_backed_sizeof( +async def test_backed_sizeof( ondisk_equivalent_adata: tuple[AnnData, AnnData, AnnData, AnnData], ): csr_mem, csr_disk, csc_disk, _ = ondisk_equivalent_adata diff --git a/tests/test_io_dispatched.py b/tests/test_io_dispatched.py index c34b5ea3d..cd850987e 100644 --- a/tests/test_io_dispatched.py +++ b/tests/test_io_dispatched.py @@ -17,33 +17,33 @@ from pathlib import Path -def test_read_dispatched_w_regex(tmp_path: Path): - def read_only_axis_dfs(func, elem_name: str, elem, iospec): +async def test_read_dispatched_w_regex(tmp_path: Path): + async def read_only_axis_dfs(func, elem_name: str, elem, iospec): if iospec.encoding_type == "anndata": - return func(elem) + return await func(elem) elif re.match(r"^/((obs)|(var))?(/.*)?$", elem_name): - return func(elem) + return await func(elem) else: return None adata = gen_adata((1000, 100)) z = open_write_group(tmp_path) - ad.io.write_elem(z, "/", adata) + await ad.io.write_elem_async(z, "/", adata) # TODO: see https://github.com/zarr-developers/zarr-python/issues/2716 if not is_zarr_v2() and isinstance(z, ZarrGroup): z = zarr.open(z.store) expected = ad.AnnData(obs=adata.obs, var=adata.var) - actual = read_dispatched(z, read_only_axis_dfs) + actual = await read_dispatched(z, read_only_axis_dfs) assert_equal(expected, actual) -def test_read_dispatched_dask(tmp_path: Path): +async def test_read_dispatched_dask(tmp_path: Path): import dask.array as da - def read_as_dask_array(func, elem_name: str, elem, iospec): + async def read_as_dask_array(func, elem_name: str, elem, iospec): if iospec.encoding_type in { "dataframe", "csr_matrix", @@ -51,45 +51,49 @@ def read_as_dask_array(func, elem_name: str, elem, iospec): "awkward-array", }: # Preventing recursing inside of these types - return ad.io.read_elem(elem) + return await func(elem) elif iospec.encoding_type == "array": return da.from_zarr(elem) else: - return func(elem) + return await func(elem) adata = gen_adata((1000, 100)) z = open_write_group(tmp_path) - ad.io.write_elem(z, "/", adata) + await ad.io.write_elem_async(z, "/", adata) # TODO: see https://github.com/zarr-developers/zarr-python/issues/2716 if not is_zarr_v2() and isinstance(z, ZarrGroup): z = zarr.open(z.store) - dask_adata = read_dispatched(z, read_as_dask_array) + dask_adata = await read_dispatched(z, read_as_dask_array) assert isinstance(dask_adata.layers["array"], da.Array) assert isinstance(dask_adata.obsm["array"], da.Array) assert isinstance(dask_adata.uns["nested"]["nested_further"]["array"], da.Array) - expected = ad.io.read_elem(z) + expected = await ad.io.read_elem_async(z) actual = dask_adata.to_memory(copy=False) assert_equal(expected, actual) -def test_read_dispatched_null_case(tmp_path: Path): +async def test_read_dispatched_null_case(tmp_path: Path): adata = gen_adata((100, 100)) z = open_write_group(tmp_path) ad.io.write_elem(z, "/", adata) # TODO: see https://github.com/zarr-developers/zarr-python/issues/2716 if not is_zarr_v2() and isinstance(z, ZarrGroup): z = zarr.open(z.store) - expected = ad.io.read_elem(z) - actual = read_dispatched(z, lambda _, __, x, **___: ad.io.read_elem(x)) + expected = await ad.io.read_elem_async(z) + + async def callback(_, __, x, **___): + return await ad.io.read_elem_async(x) + + actual = await read_dispatched(z, callback) assert_equal(expected, actual) -def test_write_dispatched_chunks(tmp_path: Path): +async def test_write_dispatched_chunks(tmp_path: Path): from itertools import chain, repeat def determine_chunks(elem_shape, specified_chunks): @@ -98,7 +102,7 @@ def determine_chunks(elem_shape, specified_chunks): adata = gen_adata((1000, 100)) - def write_chunked(func, store, k, elem, dataset_kwargs, iospec): + async def write_chunked(func, store, k, elem, dataset_kwargs, iospec): M, N = 13, 42 def set_copy(d, **kwargs): @@ -123,7 +127,7 @@ def set_copy(d, **kwargs): chunks = (N,) else: chunks = dataset_kwargs.get("chunks", ()) - func( + await func( store, k, elem, @@ -132,11 +136,11 @@ def set_copy(d, **kwargs): ), ) else: - func(store, k, elem, dataset_kwargs=dataset_kwargs) + await func(store, k, elem, dataset_kwargs=dataset_kwargs) z = open_write_group(tmp_path) - write_dispatched(z, "/", adata, callback=write_chunked) + await write_dispatched(z, "/", adata, callback=write_chunked) def check_chunking(k: str, v: ZarrGroup | zarr.Array): if ( @@ -167,7 +171,7 @@ def visititems( visititems(z, check_chunking) -def test_io_dispatched_keys(tmp_path: Path): +async def test_io_dispatched_keys(tmp_path: Path): h5ad_write_keys = [] zarr_write_keys = [] h5ad_read_keys = [] @@ -176,33 +180,33 @@ def test_io_dispatched_keys(tmp_path: Path): h5ad_path = tmp_path / "test.h5ad" zarr_path = tmp_path / "test.zarr" - def h5ad_writer(func, store, k, elem, dataset_kwargs, iospec): + async def h5ad_writer(func, store, k, elem, dataset_kwargs, iospec): h5ad_write_keys.append(k if is_zarr_v2() else k.strip("/")) - func(store, k, elem, dataset_kwargs=dataset_kwargs) + await func(store, k, elem, dataset_kwargs=dataset_kwargs) - def zarr_writer(func, store, k, elem, dataset_kwargs, iospec): + async def zarr_writer(func, store, k, elem, dataset_kwargs, iospec): zarr_write_keys.append( k if is_zarr_v2() else f"{store.name.strip('/')}/{k.strip('/')}".strip("/") ) - func(store, k, elem, dataset_kwargs=dataset_kwargs) + await func(store, k, elem, dataset_kwargs=dataset_kwargs) - def h5ad_reader(func, elem_name: str, elem, iospec): + async def h5ad_reader(func, elem_name: str, elem, iospec): h5ad_read_keys.append(elem_name if is_zarr_v2() else elem_name.strip("/")) return func(elem) - def zarr_reader(func, elem_name: str, elem, iospec): + async def zarr_reader(func, elem_name: str, elem, iospec): zarr_read_keys.append(elem_name if is_zarr_v2() else elem_name.strip("/")) return func(elem) adata = gen_adata((50, 100)) with h5py.File(h5ad_path, "w") as f: - write_dispatched(f, "/", adata, callback=h5ad_writer) - _ = read_dispatched(f, h5ad_reader) + await write_dispatched(f, "/", adata, callback=h5ad_writer) + _ = await read_dispatched(f, h5ad_reader) f = open_write_group(zarr_path) - write_dispatched(f, "/", adata, callback=zarr_writer) - _ = read_dispatched(f, zarr_reader) + await write_dispatched(f, "/", adata, callback=zarr_writer) + _ = await read_dispatched(f, zarr_reader) assert sorted(h5ad_read_keys) == sorted(zarr_read_keys) assert sorted(h5ad_write_keys) == sorted(zarr_write_keys) diff --git a/tests/test_io_partial.py b/tests/test_io_partial.py deleted file mode 100644 index dc3582cdd..000000000 --- a/tests/test_io_partial.py +++ /dev/null @@ -1,96 +0,0 @@ -from __future__ import annotations - -import warnings -from importlib.util import find_spec -from pathlib import Path - -import h5py -import numpy as np -import pytest -import zarr -from scipy.sparse import csr_matrix - -from anndata import AnnData -from anndata._io.specs.registry import read_elem_partial -from anndata.io import read_elem, write_h5ad, write_zarr - -X = np.array([[1.0, 0.0, 3.0], [4.0, 0.0, 6.0], [0.0, 8.0, 0.0]], dtype="float32") -X_check = np.array([[4.0, 0.0], [0.0, 8.0]], dtype="float32") - -WRITER = dict(h5ad=write_h5ad, zarr=write_zarr) -READER = dict(h5ad=h5py.File, zarr=zarr.open) - - -@pytest.mark.parametrize("typ", [np.asarray, csr_matrix]) -@pytest.mark.parametrize("accessor", ["h5ad", "zarr"]) -def test_read_partial_X(tmp_path, typ, accessor): - adata = AnnData(X=typ(X)) - - path = Path(tmp_path) / ("test_tp_X." + accessor) - - WRITER[accessor](path, adata) - - store = READER[accessor](path, mode="r") - if accessor == "zarr": - X_part = read_elem_partial(store["X"], indices=([1, 2], [0, 1])) - else: - # h5py doesn't allow fancy indexing across multiple dimensions - X_part = read_elem_partial(store["X"], indices=([1, 2],)) - X_part = X_part[:, [0, 1]] - store.close() - - assert np.all(X_check == X_part) - - -@pytest.mark.skipif(not find_spec("scanpy"), reason="Scanpy is not installed") -@pytest.mark.parametrize("accessor", ["h5ad", "zarr"]) -def test_read_partial_adata(tmp_path, accessor): - with warnings.catch_warnings(): - warnings.filterwarnings( - "ignore", message=r"Importing read_.* from `anndata` is deprecated" - ) - import scanpy as sc - - adata = sc.datasets.pbmc68k_reduced() - - path = Path(tmp_path) / ("test_rp." + accessor) - - WRITER[accessor](path, adata) - - storage = READER[accessor](path, mode="r") - - obs_idx = [1, 2] - var_idx = [0, 3] - adata_sbs = adata[obs_idx, var_idx] - - if accessor == "zarr": - part = read_elem_partial(storage["X"], indices=(obs_idx, var_idx)) - else: - # h5py doesn't allow fancy indexing across multiple dimensions - part = read_elem_partial(storage["X"], indices=(obs_idx,)) - part = part[:, var_idx] - assert np.all(part == adata_sbs.X) - - part = read_elem_partial(storage["obs"], indices=(obs_idx,)) - assert np.all(part.keys() == adata_sbs.obs.keys()) - assert np.all(part.index == adata_sbs.obs.index) - - part = read_elem_partial(storage["var"], indices=(var_idx,)) - assert np.all(part.keys() == adata_sbs.var.keys()) - assert np.all(part.index == adata_sbs.var.index) - - for key in storage["obsm"].keys(): - part = read_elem_partial(storage["obsm"][key], indices=(obs_idx,)) - assert np.all(part == adata_sbs.obsm[key]) - - for key in storage["varm"].keys(): - part = read_elem_partial(storage["varm"][key], indices=(var_idx,)) - np.testing.assert_equal(part, adata_sbs.varm[key]) - - for key in storage["obsp"].keys(): - part = read_elem_partial(storage["obsp"][key], indices=(obs_idx, obs_idx)) - part = part.toarray() - assert np.all(part == adata_sbs.obsp[key]) - - # check uns just in case - np.testing.assert_equal(read_elem(storage["uns"]).keys(), adata.uns.keys())