From 4479b94cf4366ae3181d34ecff876c7031bfe8f0 Mon Sep 17 00:00:00 2001 From: Ilia Kats Date: Wed, 9 Apr 2025 16:14:56 +0200 Subject: [PATCH 01/43] move core xarray support to _core it's being used in many places in _core already --- src/anndata/_core/merge.py | 2 +- src/anndata/_core/storage.py | 10 +------ .../backed/_xarray.py => _core/xarray.py} | 25 +++++++----------- src/anndata/_io/specs/lazy_methods.py | 20 +++++++------- src/anndata/_io/specs/registry.py | 2 +- src/anndata/compat/__init__.py | 22 ++++++++++++++++ src/anndata/experimental/backed/_compat.py | 26 +------------------ .../experimental/backed/_lazy_arrays.py | 4 +-- 8 files changed, 47 insertions(+), 64 deletions(-) rename src/anndata/{experimental/backed/_xarray.py => _core/xarray.py} (89%) diff --git a/src/anndata/_core/merge.py b/src/anndata/_core/merge.py index ce67b8142..d37081bf3 100644 --- a/src/anndata/_core/merge.py +++ b/src/anndata/_core/merge.py @@ -43,7 +43,7 @@ from pandas.api.extensions import ExtensionDtype from anndata._types import Join_T - from anndata.experimental.backed._compat import DataArray, Dataset2D + from .._compat import XArray, Dataset2D T = TypeVar("T") diff --git a/src/anndata/_core/storage.py b/src/anndata/_core/storage.py index b8c85ab6e..91422ec07 100644 --- a/src/anndata/_core/storage.py +++ b/src/anndata/_core/storage.py @@ -27,17 +27,9 @@ def coerce_array( allow_df: bool = False, allow_array_like: bool = False, ): - try: - from anndata.experimental.backed._compat import Dataset2D - except ImportError: - - class Dataset2D: - @staticmethod - def __repr__(): - return "mock anndata.experimental.backed._xarray." - """Coerce arrays stored in layers/X, and aligned arrays ({obs,var}{m,p}).""" from ..typing import ArrayDataStructureTypes + from .xarray import Dataset2D # If value is a scalar and we allow that, return it if allow_array_like and np.isscalar(value): diff --git a/src/anndata/experimental/backed/_xarray.py b/src/anndata/_core/xarray.py similarity index 89% rename from src/anndata/experimental/backed/_xarray.py rename to src/anndata/_core/xarray.py index e5420a45b..e96868d10 100644 --- a/src/anndata/experimental/backed/_xarray.py +++ b/src/anndata/_core/xarray.py @@ -4,36 +4,29 @@ import pandas as pd -from ..._core.anndata import AnnData, _gen_dataframe -from ..._core.file_backing import to_memory -from ..._core.index import _subset -from ..._core.views import as_view - -try: - from xarray import Dataset -except ImportError: - - class Dataset: - def __repr__(self) -> str: - return "mock Dataset" +from .anndata import AnnData, _gen_dataframe +from .file_backing import to_memory +from .index import _subset +from .views import as_view +from .._compat import XDataset if TYPE_CHECKING: from collections.abc import Hashable, Iterable from typing import Any, Literal - from ..._core.index import Index - from ._compat import xarray as xr + from .index import Index + from .._compat import XArray -def get_index_dim(ds: xr.DataArray) -> Hashable: +def get_index_dim(ds: XArray) -> Hashable: if len(ds.sizes) != 1: msg = f"xarray Dataset should not have more than 1 dims, found {len(ds.sizes)} {ds.sizes}, {ds}" raise ValueError(msg) return next(iter(ds.indexes.keys())) -class Dataset2D(Dataset): +class Dataset2D(XDataset): """ A wrapper class meant to enable working with lazy dataframe data. We do not guarantee the stability of this API beyond that guaranteed diff --git a/src/anndata/_io/specs/lazy_methods.py b/src/anndata/_io/specs/lazy_methods.py index 76e2614d1..14a8ff842 100644 --- a/src/anndata/_io/specs/lazy_methods.py +++ b/src/anndata/_io/specs/lazy_methods.py @@ -21,10 +21,10 @@ from collections.abc import Generator, Mapping, Sequence from typing import Literal, ParamSpec, TypeVar - from anndata.experimental.backed._compat import DataArray, Dataset2D from anndata.experimental.backed._lazy_arrays import CategoricalArray, MaskedArray - from ...compat import CSArray, CSMatrix, H5File + from ..._compat import CSArray, CSMatrix, H5File, XArray + from ..._core.xarray import Dataset2D from .registry import LazyDataStructures, LazyReader BlockInfo = Mapping[ @@ -220,19 +220,19 @@ def _gen_xarray_dict_iterator_from_elems( elem_dict: dict[str, LazyDataStructures], dim_name: str, index: np.NDArray, -) -> Generator[tuple[str, DataArray], None, None]: - from anndata.experimental.backed._compat import DataArray - from anndata.experimental.backed._compat import xarray as xr +) -> Generator[tuple[str, XArray], None, None]: + from ..._compat import XArray + from ..._compat import xarray as xr from anndata.experimental.backed._lazy_arrays import CategoricalArray, MaskedArray for k, v in elem_dict.items(): if isinstance(v, DaskArray) and k != dim_name: - data_array = DataArray(v, coords=[index], dims=[dim_name], name=k) + data_array = XArray(v, coords=[index], dims=[dim_name], name=k) elif isinstance(v, CategoricalArray | MaskedArray) and k != dim_name: variable = xr.Variable( data=xr.core.indexing.LazilyIndexedArray(v), dims=[dim_name] ) - data_array = DataArray( + data_array = XArray( variable, coords=[index], dims=[dim_name], @@ -243,7 +243,7 @@ def _gen_xarray_dict_iterator_from_elems( }, ) elif k == dim_name: - data_array = DataArray( + data_array = XArray( index, coords=[index], dims=[dim_name], name=dim_name ) else: @@ -263,7 +263,7 @@ def read_dataframe( _reader: LazyReader, use_range_index: bool = False, ) -> Dataset2D: - from anndata.experimental.backed._compat import DataArray, Dataset2D + from ..._compat import XArray, Dataset2D elem_dict = { k: _reader.read_elem(elem[k]) @@ -282,7 +282,7 @@ def read_dataframe( _gen_xarray_dict_iterator_from_elems(elem_dict, dim_name, index) ) if use_range_index: - elem_xarray_dict[DUMMY_RANGE_INDEX_KEY] = DataArray( + elem_xarray_dict[DUMMY_RANGE_INDEX_KEY] = XArray( index, coords=[index], dims=[DUMMY_RANGE_INDEX_KEY], diff --git a/src/anndata/_io/specs/registry.py b/src/anndata/_io/specs/registry.py index de158f2b1..3e40b2c66 100644 --- a/src/anndata/_io/specs/registry.py +++ b/src/anndata/_io/specs/registry.py @@ -24,7 +24,7 @@ WriteCallback, _WriteInternal, ) - from anndata.experimental.backed._compat import Dataset2D + from ..._compat import Dataset2D from anndata.experimental.backed._lazy_arrays import CategoricalArray, MaskedArray from anndata.typing import RWAble diff --git a/src/anndata/compat/__init__.py b/src/anndata/compat/__init__.py index 3cb719f90..aa09917ea 100644 --- a/src/anndata/compat/__init__.py +++ b/src/anndata/compat/__init__.py @@ -99,6 +99,28 @@ class DaskArray: def __repr__(): return "mock dask.array.core.Array" +if find_spec("xarray") or TYPE_CHECKING: + import xarray + from xarray import DataArray as XArray, Dataset as XDataset + from xarray.backends import BackendArray as XBackendArray + from xarray.backends.zarr import ZarrArrayWrapper as XZarrArrayWrapper +else: + xarray = None + class XArray: + def __repr__(self) -> str: + return "mock DataArray" + + class XDataset: + def __repr__(self) -> str: + return "mock Dataset" + + class XZarrArrayWrapper: + def __repr__(self) -> str: + return "mock ZarrArrayWrapper" + + class XBackendArray: + def __repr__(self) -> str: + return "mock BackendArray" # https://github.com/scverse/anndata/issues/1749 def is_cupy_importable() -> bool: diff --git a/src/anndata/experimental/backed/_compat.py b/src/anndata/experimental/backed/_compat.py index 7ea06e93b..3d7338fcd 100644 --- a/src/anndata/experimental/backed/_compat.py +++ b/src/anndata/experimental/backed/_compat.py @@ -3,31 +3,7 @@ from importlib.util import find_spec from typing import TYPE_CHECKING -if find_spec("xarray") or TYPE_CHECKING: - import xarray - from xarray import DataArray - from xarray.backends import BackendArray - from xarray.backends.zarr import ZarrArrayWrapper - - -else: - - class DataArray: - def __repr__(self) -> str: - return "mock DataArray" - - xarray = None - - class ZarrArrayWrapper: - def __repr__(self) -> str: - return "mock ZarrArrayWrapper" - - class BackendArray: - def __repr__(self) -> str: - return "mock BackendArray" - - -from ._xarray import Dataset, Dataset2D # noqa: F401 +from ..._core.xarray import Dataset2D if TYPE_CHECKING: from anndata import AnnData diff --git a/src/anndata/experimental/backed/_lazy_arrays.py b/src/anndata/experimental/backed/_lazy_arrays.py index 70b5ac9b6..3131ac953 100644 --- a/src/anndata/experimental/backed/_lazy_arrays.py +++ b/src/anndata/experimental/backed/_lazy_arrays.py @@ -11,8 +11,8 @@ from anndata.compat import H5Array, ZarrArray from ..._settings import settings -from ._compat import BackendArray, DataArray, ZarrArrayWrapper -from ._compat import xarray as xr +from ..._compat import XBackendArray, XArray, XZarrArrayWrapper +from ..._compat import xarray as xr if TYPE_CHECKING: from pathlib import Path From 098beb07cb45c08d5da2ca392fdecd4061a9089f Mon Sep 17 00:00:00 2001 From: Ilia Kats Date: Wed, 9 Apr 2025 16:23:35 +0200 Subject: [PATCH 02/43] allow XArray Datasets in obs, var, obsm, and varm --- src/anndata/_core/aligned_mapping.py | 7 +++++-- src/anndata/_core/xarray.py | 8 ++++++-- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/src/anndata/_core/aligned_mapping.py b/src/anndata/_core/aligned_mapping.py index b5649d0f4..e7d609174 100644 --- a/src/anndata/_core/aligned_mapping.py +++ b/src/anndata/_core/aligned_mapping.py @@ -11,7 +11,7 @@ import pandas as pd from .._warnings import ExperimentalFeatureWarning, ImplicitModificationWarning -from ..compat import AwkArray, CSArray, CSMatrix, CupyArray +from ..compat import AwkArray, CSArray, CSMatrix, CupyArray, XDataset from ..utils import ( axis_len, convert_to_dict, @@ -75,8 +75,11 @@ def _validate_value(self, val: Value, key: str) -> Value: ExperimentalFeatureWarning, # stacklevel=3, ) - if isinstance(val, np.ndarray | CupyArray) and len(val.shape) == 1: + elif isinstance(val, np.ndarray | CupyArray) and len(val.shape) == 1: val = val.reshape((val.shape[0], 1)) + elif isinstance(val, XDataset): + from .xarray import Dataset2D + val = Dataset2D(data_vars=val.data_vars, coords=val.coords, attrs=val.attrs) for i, axis in enumerate(self.axes): if self.parent.shape[axis] == axis_len(val, i): continue diff --git a/src/anndata/_core/xarray.py b/src/anndata/_core/xarray.py index e96868d10..488bc1c3e 100644 --- a/src/anndata/_core/xarray.py +++ b/src/anndata/_core/xarray.py @@ -9,14 +9,14 @@ from .index import _subset from .views import as_view -from .._compat import XDataset +from ..compat import XDataset if TYPE_CHECKING: from collections.abc import Hashable, Iterable from typing import Any, Literal from .index import Index - from .._compat import XArray + from ..compat import XArray def get_index_dim(ds: XArray) -> Hashable: @@ -128,6 +128,10 @@ def _gen_dataframe_xr( ): return anno +@_gen_dataframe.register(XDataset) +def _gen_dataframe_xdataset(anno: Dataset, index_names: Iterable[str], *, source: Literal["X", "shape"], attr: Literal["obs", "var"], length: int | None=None): + return Dataset2D(anno) + @AnnData._remove_unused_categories.register(Dataset2D) @staticmethod From 8a603534f5d2339c1b3f1c23a1cab14ead9676bd Mon Sep 17 00:00:00 2001 From: Ilia Kats Date: Wed, 9 Apr 2025 16:57:19 +0200 Subject: [PATCH 03/43] reorganize to register all singledispatches at anndata import --- src/anndata/_core/aligned_df.py | 17 ++++++++++ src/anndata/_core/aligned_mapping.py | 2 +- src/anndata/_core/anndata.py | 7 +++++ src/anndata/_core/file_backing.py | 5 +++ src/anndata/_core/index.py | 9 ++++++ src/anndata/_core/merge.py | 3 +- src/anndata/_core/storage.py | 5 ++- src/anndata/_core/views.py | 5 +++ src/anndata/_core/xarray.py | 46 ---------------------------- src/anndata/typing.py | 3 +- 10 files changed, 52 insertions(+), 50 deletions(-) diff --git a/src/anndata/_core/aligned_df.py b/src/anndata/_core/aligned_df.py index efb0ab9bc..2aa83b37f 100644 --- a/src/anndata/_core/aligned_df.py +++ b/src/anndata/_core/aligned_df.py @@ -7,6 +7,8 @@ import pandas as pd from pandas.api.types import is_string_dtype +from ..compat import XDataset +from .xarray import Dataset2D from .._warnings import ImplicitModificationWarning @@ -119,3 +121,18 @@ def _mk_df_error( "({actual} {what}s instead of {expected})" ) return ValueError(msg) + +@_gen_dataframe.register(Dataset2D) +def _gen_dataframe_xr( + anno: Dataset2D, + index_names: Iterable[str], + *, + source: Literal["X", "shape"], + attr: Literal["obs", "var"], + length: int | None = None, +): + return anno + +@_gen_dataframe.register(XDataset) +def _gen_dataframe_xdataset(anno: Dataset, index_names: Iterable[str], *, source: Literal["X", "shape"], attr: Literal["obs", "var"], length: int | None=None): + return Dataset2D(anno) diff --git a/src/anndata/_core/aligned_mapping.py b/src/anndata/_core/aligned_mapping.py index e7d609174..606a2b94b 100644 --- a/src/anndata/_core/aligned_mapping.py +++ b/src/anndata/_core/aligned_mapping.py @@ -12,6 +12,7 @@ from .._warnings import ExperimentalFeatureWarning, ImplicitModificationWarning from ..compat import AwkArray, CSArray, CSMatrix, CupyArray, XDataset +from .xarray import Dataset2D from ..utils import ( axis_len, convert_to_dict, @@ -78,7 +79,6 @@ def _validate_value(self, val: Value, key: str) -> Value: elif isinstance(val, np.ndarray | CupyArray) and len(val.shape) == 1: val = val.reshape((val.shape[0], 1)) elif isinstance(val, XDataset): - from .xarray import Dataset2D val = Dataset2D(data_vars=val.data_vars, coords=val.coords, attrs=val.attrs) for i, axis in enumerate(self.axes): if self.parent.shape[axis] == axis_len(val, i): diff --git a/src/anndata/_core/anndata.py b/src/anndata/_core/anndata.py index 8421056aa..2e7cc93f3 100644 --- a/src/anndata/_core/anndata.py +++ b/src/anndata/_core/anndata.py @@ -42,6 +42,7 @@ from .raw import Raw from .sparse_dataset import BaseCompressedSparseDataset, sparse_dataset from .storage import coerce_array +from .xarray import Dataset2D from .views import ( DictView, _resolve_idxs, @@ -2078,6 +2079,12 @@ def _get_and_delete_multicol_field(self, a, key_multicol): getattr(self, a).drop(keys, axis=1, inplace=True) return values +@AnnData._remove_unused_categories.register(Dataset2D) +@staticmethod +def _remove_unused_categories_xr( + df_full: Dataset2D, df_sub: Dataset2D, uns: dict[str, Any] +): + pass # this is handled automatically by the categorical arrays themselves i.e., they dedup upon access. def _check_2d_shape(X): """\ diff --git a/src/anndata/_core/file_backing.py b/src/anndata/_core/file_backing.py index 45275e651..b51d14511 100644 --- a/src/anndata/_core/file_backing.py +++ b/src/anndata/_core/file_backing.py @@ -9,6 +9,7 @@ import h5py from ..compat import AwkArray, DaskArray, ZarrArray, ZarrGroup +from .xarray import Dataset2D from .sparse_dataset import BaseCompressedSparseDataset if TYPE_CHECKING: @@ -161,6 +162,10 @@ def _(x: AwkArray, *, copy: bool = False): else: return x +@to_memory.register(Dataset2D) +def _(x: Dataset2D, copy: bool = False): + return x.to_memory(copy=copy) + @singledispatch def filename(x): diff --git a/src/anndata/_core/index.py b/src/anndata/_core/index.py index f6435b64a..b2d6d94c6 100644 --- a/src/anndata/_core/index.py +++ b/src/anndata/_core/index.py @@ -11,6 +11,7 @@ from scipy.sparse import issparse from ..compat import AwkArray, CSArray, CSMatrix, DaskArray +from .xarray import Dataset2D if TYPE_CHECKING: from ..compat import Index, Index1D @@ -209,6 +210,14 @@ def _subset_awkarray(a: AwkArray, subset_idx: Index): subset_idx = np.ix_(*subset_idx) return a[subset_idx] +@_subset.register(Dataset2D) +def _(a: Dataset2D, subset_idx: Index): + key = get_index_dim(a) + # xarray seems to have some code looking for a second entry in tuples + if isinstance(subset_idx, tuple) and len(subset_idx) == 1: + subset_idx = subset_idx[0] + return a.isel(**{key: subset_idx}) + # Registration for SparseDataset occurs in sparse_dataset.py @_subset.register(h5py.Dataset) diff --git a/src/anndata/_core/merge.py b/src/anndata/_core/merge.py index d37081bf3..b3a7b872f 100644 --- a/src/anndata/_core/merge.py +++ b/src/anndata/_core/merge.py @@ -43,7 +43,8 @@ from pandas.api.extensions import ExtensionDtype from anndata._types import Join_T - from .._compat import XArray, Dataset2D + from ..compat import XArray + from .xarray import Dataset2D T = TypeVar("T") diff --git a/src/anndata/_core/storage.py b/src/anndata/_core/storage.py index 91422ec07..51292b8ed 100644 --- a/src/anndata/_core/storage.py +++ b/src/anndata/_core/storage.py @@ -10,6 +10,8 @@ from anndata.compat import CSArray, CSMatrix from .._warnings import ImplicitModificationWarning +from .xarray import Dataset2D +from ..compat import XDataset from ..utils import ( ensure_df_homogeneous, join_english, @@ -29,13 +31,14 @@ def coerce_array( ): """Coerce arrays stored in layers/X, and aligned arrays ({obs,var}{m,p}).""" from ..typing import ArrayDataStructureTypes - from .xarray import Dataset2D # If value is a scalar and we allow that, return it if allow_array_like and np.isscalar(value): return value # If value is one of the allowed types, return it array_data_structure_types = get_args(ArrayDataStructureTypes) + if isinstance(value, XDataset): + value = Dataset2D(value) if isinstance(value, (*array_data_structure_types, Dataset2D)): if isinstance(value, np.matrix): msg = f"{name} should not be a np.matrix, use np.ndarray instead." diff --git a/src/anndata/_core/views.py b/src/anndata/_core/views.py index 6fd341690..29e815cf4 100644 --- a/src/anndata/_core/views.py +++ b/src/anndata/_core/views.py @@ -22,6 +22,7 @@ DaskArray, ZappyArray, ) +from .xarray import Dataset2D from .access import ElementRef if TYPE_CHECKING: @@ -361,6 +362,10 @@ def as_view_cupy_csr(mtx, view_args): def as_view_cupy_csc(mtx, view_args): return CupySparseCSCView(mtx, view_args=view_args) +@as_view.register(Dataset2D) +def _(a: Dataset2D, view_args): + return a + try: import weakref diff --git a/src/anndata/_core/xarray.py b/src/anndata/_core/xarray.py index 488bc1c3e..fe7b4a4f9 100644 --- a/src/anndata/_core/xarray.py +++ b/src/anndata/_core/xarray.py @@ -4,11 +4,6 @@ import pandas as pd -from .anndata import AnnData, _gen_dataframe -from .file_backing import to_memory -from .index import _subset -from .views import as_view - from ..compat import XDataset if TYPE_CHECKING: @@ -101,44 +96,3 @@ def columns(self) -> pd.Index: """ columns_list = list(self.keys()) return pd.Index(columns_list) - - -@_subset.register(Dataset2D) -def _(a: Dataset2D, subset_idx: Index): - key = get_index_dim(a) - # xarray seems to have some code looking for a second entry in tuples - if isinstance(subset_idx, tuple) and len(subset_idx) == 1: - subset_idx = subset_idx[0] - return a.isel(**{key: subset_idx}) - - -@as_view.register(Dataset2D) -def _(a: Dataset2D, view_args): - return a - - -@_gen_dataframe.register(Dataset2D) -def _gen_dataframe_xr( - anno: Dataset2D, - index_names: Iterable[str], - *, - source: Literal["X", "shape"], - attr: Literal["obs", "var"], - length: int | None = None, -): - return anno - -@_gen_dataframe.register(XDataset) -def _gen_dataframe_xdataset(anno: Dataset, index_names: Iterable[str], *, source: Literal["X", "shape"], attr: Literal["obs", "var"], length: int | None=None): - return Dataset2D(anno) - - -@AnnData._remove_unused_categories.register(Dataset2D) -@staticmethod -def _remove_unused_categories_xr( - df_full: Dataset2D, df_sub: Dataset2D, uns: dict[str, Any] -): - pass # this is handled automatically by the categorical arrays themselves i.e., they dedup upon access. - - -to_memory.register(Dataset2D, Dataset2D.to_memory) diff --git a/src/anndata/typing.py b/src/anndata/typing.py index f0cf974b4..ab82926ca 100644 --- a/src/anndata/typing.py +++ b/src/anndata/typing.py @@ -18,6 +18,7 @@ H5Array, ZappyArray, ZarrArray, + XArray ) from .compat import Index as _Index @@ -45,7 +46,7 @@ | CupyArray | CupySparseMatrix ) -ArrayDataStructureTypes: TypeAlias = XDataType | AwkArray +ArrayDataStructureTypes: TypeAlias = XDataType | AwkArray | XArray InMemoryArrayOrScalarType: TypeAlias = ( From 3ce1624fa7483fcf9e5c325c53df2232e7dfda4b Mon Sep 17 00:00:00 2001 From: Ilia Kats Date: Thu, 10 Apr 2025 10:25:48 +0200 Subject: [PATCH 04/43] fix remaining bugs to make tests pass --- src/anndata/_core/index.py | 8 +++--- src/anndata/_core/merge.py | 21 ++++++++-------- src/anndata/_core/storage.py | 4 +-- src/anndata/_core/xarray.py | 25 +++++++++++-------- src/anndata/_io/specs/lazy_methods.py | 12 ++++----- src/anndata/_io/specs/registry.py | 2 +- .../experimental/backed/_lazy_arrays.py | 14 +++++------ src/anndata/tests/helpers.py | 6 +++-- tests/lazy/test_concat.py | 2 +- 9 files changed, 48 insertions(+), 46 deletions(-) diff --git a/src/anndata/_core/index.py b/src/anndata/_core/index.py index b2d6d94c6..1b04a198c 100644 --- a/src/anndata/_core/index.py +++ b/src/anndata/_core/index.py @@ -10,7 +10,7 @@ import pandas as pd from scipy.sparse import issparse -from ..compat import AwkArray, CSArray, CSMatrix, DaskArray +from ..compat import AwkArray, CSArray, CSMatrix, DaskArray, XArray from .xarray import Dataset2D if TYPE_CHECKING: @@ -45,8 +45,6 @@ def _normalize_index( # noqa: PLR0911, PLR0912 | pd.Index, index: pd.Index, ) -> slice | int | np.ndarray: # ndarray of int or bool - from ..experimental.backed._compat import DataArray - # TODO: why is this here? All tests pass without it and it seems at the minimum not strict enough. if not isinstance(index, pd.RangeIndex) and index.dtype in (np.float64, np.int64): msg = f"Don’t call _normalize_index with non-categorical/string names and non-range index {index}" @@ -113,7 +111,7 @@ def name_idx(i): ) raise KeyError(msg) return positions # np.ndarray[int] - elif isinstance(indexer, DataArray): + elif isinstance(indexer, XArray): if isinstance(indexer.data, DaskArray): return indexer.data.compute() return indexer.data @@ -212,7 +210,7 @@ def _subset_awkarray(a: AwkArray, subset_idx: Index): @_subset.register(Dataset2D) def _(a: Dataset2D, subset_idx: Index): - key = get_index_dim(a) + key = a.index_dim # xarray seems to have some code looking for a second entry in tuples if isinstance(subset_idx, tuple) and len(subset_idx) == 1: subset_idx = subset_idx[0] diff --git a/src/anndata/_core/merge.py b/src/anndata/_core/merge.py index b3a7b872f..29087387b 100644 --- a/src/anndata/_core/merge.py +++ b/src/anndata/_core/merge.py @@ -1125,7 +1125,7 @@ def concat_Xs(adatas, reindexers, axis, fill_value): def make_dask_col_from_extension_dtype( - col: DataArray, *, use_only_object_dtype: bool = False + col: XArray, *, use_only_object_dtype: bool = False ) -> DaskArray: """ Creates dask arrays from :class:`pandas.api.extensions.ExtensionArray` dtype :class:`xarray.DataArray`s. @@ -1149,8 +1149,7 @@ def make_dask_col_from_extension_dtype( maybe_open_h5, ) from anndata.experimental import read_elem_lazy - from anndata.experimental.backed._compat import DataArray - from anndata.experimental.backed._compat import xarray as xr + from anndata.compat import XArray, xarray as xr base_path_or_zarr_group = col.attrs.get("base_path_or_zarr_group") elem_name = col.attrs.get("elem_name") @@ -1171,7 +1170,7 @@ def get_chunk(block_info=None): variable = xr.Variable( data=xr.core.indexing.LazilyIndexedArray(v), dims=dims ) - data_array = DataArray( + data_array = XArray( variable, coords=coords, dims=dims, @@ -1253,16 +1252,18 @@ def concat_dataset2d_on_annot_axis( Concatenated :class:`~anndata.experimental.backed._xarray.Dataset2D` """ from anndata._io.specs.lazy_methods import DUMMY_RANGE_INDEX_KEY - from anndata.experimental.backed._compat import Dataset2D - from anndata.experimental.backed._compat import xarray as xr + from anndata._core.xarray import Dataset2D + from anndata.compat import xarray as xr annotations_re_indexed = [] for a in make_xarray_extension_dtypes_dask(annotations): - old_key = next(iter(a.coords.keys())) + old_key = a.index_dim + if "indexing_key" not in a.attrs: + a.attrs["indexing_key"] = old_key # First create a dummy index a.coords[DS_CONCAT_DUMMY_INDEX_NAME] = ( old_key, - pd.RangeIndex(a[a.attrs["indexing_key"]].shape[0]).astype("str"), + pd.RangeIndex(a.shape[0]).astype("str"), ) # Set all the dimensions to this new dummy index a = a.swap_dims({old_key: DS_CONCAT_DUMMY_INDEX_NAME}) @@ -1501,8 +1502,8 @@ def concat( # noqa: PLR0912, PLR0913, PLR0915 {'a': 1, 'b': 2, 'c': {'c.a': 3, 'c.b': 4, 'c.c': 5}} """ - from anndata.experimental.backed._compat import Dataset2D - from anndata.experimental.backed._compat import xarray as xr + from anndata._core.xarray import Dataset2D + from anndata.compat import xarray as xr # Argument normalization merge = resolve_merge_strategy(merge) diff --git a/src/anndata/_core/storage.py b/src/anndata/_core/storage.py index 51292b8ed..b601adbab 100644 --- a/src/anndata/_core/storage.py +++ b/src/anndata/_core/storage.py @@ -37,8 +37,8 @@ def coerce_array( return value # If value is one of the allowed types, return it array_data_structure_types = get_args(ArrayDataStructureTypes) - if isinstance(value, XDataset): - value = Dataset2D(value) + if isinstance(value, XDataset) and not isinstance(value, Dataset2D): + value = Dataset2D(value.data_vars, value.coords, value.attrs) if isinstance(value, (*array_data_structure_types, Dataset2D)): if isinstance(value, np.matrix): msg = f"{name} should not be a np.matrix, use np.ndarray instead." diff --git a/src/anndata/_core/xarray.py b/src/anndata/_core/xarray.py index fe7b4a4f9..a36acb9e6 100644 --- a/src/anndata/_core/xarray.py +++ b/src/anndata/_core/xarray.py @@ -14,13 +14,6 @@ from ..compat import XArray -def get_index_dim(ds: XArray) -> Hashable: - if len(ds.sizes) != 1: - msg = f"xarray Dataset should not have more than 1 dims, found {len(ds.sizes)} {ds.sizes}, {ds}" - raise ValueError(msg) - return next(iter(ds.indexes.keys())) - - class Dataset2D(XDataset): """ A wrapper class meant to enable working with lazy dataframe data. @@ -32,6 +25,17 @@ class Dataset2D(XDataset): __slots__ = () + @property + def index_dim(self) -> str: + if len(self.sizes) != 1: + msg = f"xarray Dataset should not have more than 1 dims, found {len(self.sizes)} {self.sizes}, {self}" + raise ValueError(msg) + return next(iter(self.coords.keys())) + + @property + def xr_index(self) -> XArray: + return self[self.index_dim] + @property def index(self) -> pd.Index: """:attr:`~anndata.AnnData` internally looks for :attr:`~pandas.DataFrame.index` so this ensures usability @@ -40,8 +44,7 @@ def index(self) -> pd.Index: ------- The index of the of the dataframe as resolved from :attr:`~xarray.Dataset.coords`. """ - coord = get_index_dim(self) - return self.indexes[coord] + return self.indexes[self.index_dim] @index.setter def index(self, val) -> None: @@ -56,7 +59,7 @@ def shape(self) -> tuple[int, int]: ------- The (2D) shape of the dataframe resolved from :attr:`~xarray.Dataset.sizes`. """ - return (self.sizes[get_index_dim(self)], len(self)) + return (self.sizes[self.index_dim], len(self)) @property def iloc(self): @@ -72,7 +75,7 @@ def __init__(self, ds): self._ds = ds def __getitem__(self, idx): - coord = get_index_dim(self._ds) + coord = self._ds.index_dim return self._ds.isel(**{coord: idx}) return IlocGetter(self) diff --git a/src/anndata/_io/specs/lazy_methods.py b/src/anndata/_io/specs/lazy_methods.py index 14a8ff842..7ab2576b1 100644 --- a/src/anndata/_io/specs/lazy_methods.py +++ b/src/anndata/_io/specs/lazy_methods.py @@ -13,7 +13,8 @@ import anndata as ad from anndata._core.file_backing import filename, get_elem_name from anndata.abc import CSCDataset, CSRDataset -from anndata.compat import DaskArray, H5Array, H5Group, ZarrArray, ZarrGroup +from anndata.compat import DaskArray, H5Array, H5Group, ZarrArray, ZarrGroup, XArray +from anndata._core.xarray import Dataset2D from .registry import _LAZY_REGISTRY, IOSpec @@ -23,8 +24,7 @@ from anndata.experimental.backed._lazy_arrays import CategoricalArray, MaskedArray - from ..._compat import CSArray, CSMatrix, H5File, XArray - from ..._core.xarray import Dataset2D + from ...compat import CSArray, CSMatrix, H5File from .registry import LazyDataStructures, LazyReader BlockInfo = Mapping[ @@ -221,8 +221,8 @@ def _gen_xarray_dict_iterator_from_elems( dim_name: str, index: np.NDArray, ) -> Generator[tuple[str, XArray], None, None]: - from ..._compat import XArray - from ..._compat import xarray as xr + from ...compat import XArray + from ...compat import xarray as xr from anndata.experimental.backed._lazy_arrays import CategoricalArray, MaskedArray for k, v in elem_dict.items(): @@ -263,8 +263,6 @@ def read_dataframe( _reader: LazyReader, use_range_index: bool = False, ) -> Dataset2D: - from ..._compat import XArray, Dataset2D - elem_dict = { k: _reader.read_elem(elem[k]) for k in [*elem.attrs["column-order"], elem.attrs["_index"]] diff --git a/src/anndata/_io/specs/registry.py b/src/anndata/_io/specs/registry.py index 3e40b2c66..0ac5ece3d 100644 --- a/src/anndata/_io/specs/registry.py +++ b/src/anndata/_io/specs/registry.py @@ -24,7 +24,7 @@ WriteCallback, _WriteInternal, ) - from ..._compat import Dataset2D + from ...compat import Dataset2D from anndata.experimental.backed._lazy_arrays import CategoricalArray, MaskedArray from anndata.typing import RWAble diff --git a/src/anndata/experimental/backed/_lazy_arrays.py b/src/anndata/experimental/backed/_lazy_arrays.py index 3131ac953..0aa367ce9 100644 --- a/src/anndata/experimental/backed/_lazy_arrays.py +++ b/src/anndata/experimental/backed/_lazy_arrays.py @@ -11,8 +11,8 @@ from anndata.compat import H5Array, ZarrArray from ..._settings import settings -from ..._compat import XBackendArray, XArray, XZarrArrayWrapper -from ..._compat import xarray as xr +from ...compat import XBackendArray, XArray, XZarrArrayWrapper +from ...compat import xarray as xr if TYPE_CHECKING: from pathlib import Path @@ -27,7 +27,7 @@ K = TypeVar("K", H5Array, ZarrArray) -class ZarrOrHDF5Wrapper(ZarrArrayWrapper, Generic[K]): +class ZarrOrHDF5Wrapper(XZarrArrayWrapper, Generic[K]): def __init__(self, array: K): self.chunks = array.chunks if isinstance(array, ZarrArray): @@ -48,7 +48,7 @@ def __getitem__(self, key: xr.core.indexing.ExplicitIndexer): ) -class CategoricalArray(BackendArray, Generic[K]): +class CategoricalArray(XBackendArray, Generic[K]): """ A wrapper class meant to enable working with lazy categorical data. We do not guarantee the stability of this API beyond that guaranteed @@ -103,7 +103,7 @@ def dtype(self): return pd.CategoricalDtype(categories=self.categories, ordered=self._ordered) -class MaskedArray(BackendArray, Generic[K]): +class MaskedArray(XBackendArray, Generic[K]): """ A wrapper class meant to enable working with lazy masked data. We do not guarantee the stability of this API beyond that guaranteed @@ -168,12 +168,12 @@ def dtype(self): raise RuntimeError(msg) -@_subset.register(DataArray) +@_subset.register(XArray) def _subset_masked(a: DataArray, subset_idx: Index): return a[subset_idx] -@as_view.register(DataArray) +@as_view.register(XArray) def _view_pd_boolean_array(a: DataArray, view_args): return a diff --git a/src/anndata/tests/helpers.py b/src/anndata/tests/helpers.py index 6f2bc09b8..33a36c3ad 100644 --- a/src/anndata/tests/helpers.py +++ b/src/anndata/tests/helpers.py @@ -25,6 +25,8 @@ from anndata._core.views import ArrayView from anndata.compat import ( AwkArray, + XArray, + XDataset, CSArray, CSMatrix, CupyArray, @@ -290,8 +292,8 @@ def gen_adata( # noqa: PLR0913 var_dtypes: Collection[ np.dtype | pd.api.extensions.ExtensionDtype ] = DEFAULT_COL_TYPES, - obsm_types: Collection[type] = (*DEFAULT_KEY_TYPES, AwkArray), - varm_types: Collection[type] = (*DEFAULT_KEY_TYPES, AwkArray), + obsm_types: Collection[type] = DEFAULT_KEY_TYPES + (AwkArray, XArray, XDataset), + varm_types: Collection[type] = DEFAULT_KEY_TYPES + (AwkArray, XArray, XDataset), layers_types: Collection[type] = DEFAULT_KEY_TYPES, random_state: np.random.Generator | None = None, sparse_fmt: Literal["csr", "csc"] = "csr", diff --git a/tests/lazy/test_concat.py b/tests/lazy/test_concat.py index f04db4046..8f9e70004 100644 --- a/tests/lazy/test_concat.py +++ b/tests/lazy/test_concat.py @@ -254,7 +254,7 @@ def test_concat_data_subsetting( join: Join_T, index: slice | NDArray | Literal["a"] | None, ): - from anndata.experimental.backed._compat import Dataset2D + from anndata._core.xarray import Dataset2D remote_concatenated = ad.concat([adata_remote, adata_remote], join=join) if index is not None: From 02c835503ad3f2a79e345b5f084d72447afe4e7e Mon Sep 17 00:00:00 2001 From: Ilia Kats Date: Thu, 10 Apr 2025 17:55:08 +0200 Subject: [PATCH 05/43] add in-memory xarray Dataset to tests and make tests pass --- src/anndata/_core/merge.py | 113 ++++++++++++++++++++++------------- src/anndata/_core/xarray.py | 7 ++- src/anndata/tests/helpers.py | 49 +++++++-------- tests/lazy/conftest.py | 4 ++ tests/lazy/test_concat.py | 4 +- tests/lazy/test_read.py | 4 +- tests/test_backed_hdf5.py | 3 +- tests/test_base.py | 4 +- tests/test_concatenate.py | 8 +-- tests/test_io_conversion.py | 8 +-- tests/test_io_dispatched.py | 12 ++-- tests/test_io_elementwise.py | 15 ++--- tests/test_io_warnings.py | 4 +- tests/test_readwrite.py | 26 ++++---- tests/test_x.py | 4 +- 15 files changed, 149 insertions(+), 116 deletions(-) diff --git a/src/anndata/_core/merge.py b/src/anndata/_core/merge.py index 29087387b..e31413c0e 100644 --- a/src/anndata/_core/merge.py +++ b/src/anndata/_core/merge.py @@ -18,6 +18,7 @@ from natsort import natsorted from packaging.version import Version from scipy import sparse +from pandas.api.types import is_extension_array_dtype from anndata._core.file_backing import to_memory from anndata._warnings import ExperimentalFeatureWarning @@ -35,6 +36,7 @@ from ..utils import asarray, axis_len, warn_once from .anndata import AnnData from .index import _subset, make_slice +from .xarray import Dataset2D if TYPE_CHECKING: from collections.abc import Collection, Generator, Iterable, Sequence @@ -44,7 +46,6 @@ from anndata._types import Join_T from ..compat import XArray - from .xarray import Dataset2D T = TypeVar("T") @@ -208,6 +209,10 @@ def equal_awkward(a, b) -> bool: return ak.almost_equal(a, b) +@equal.register(Dataset2D) +def equal_dataset2d(a, b) -> bool: + return a.equals(b) + def as_sparse(x, *, use_sparse_array: bool = False) -> CSMatrix | CSArray: if not isinstance(x, CSMatrix | CSArray): @@ -449,7 +454,7 @@ def _merge_nested( vals = [d[k] for d in ds if k in d] if len(vals) == 0: return MissingVal - elif all(isinstance(v, Mapping) for v in vals): + elif all(isinstance(v, Mapping) and not isinstance(v, Dataset2D) for v in vals): new_map = merge_nested(vals, keys_join, value_join) if len(new_map) == 0: return MissingVal @@ -557,6 +562,8 @@ def apply(self, el, *, axis, fill_value=None): # noqa: PLR0911 return self._apply_to_dask_array(el, axis=axis, fill_value=fill_value) elif isinstance(el, CupyArray): return self._apply_to_cupy_array(el, axis=axis, fill_value=fill_value) + elif isinstance(el, Dataset2D): + return self._apply_to_dataset2d(el, axis=axis, fill_value=fill_value) else: return self._apply_to_array(el, axis=axis, fill_value=fill_value) @@ -719,6 +726,28 @@ def _apply_to_awkward(self, el: AwkArray, *, axis, fill_value=None): el = ak.pad_none(el, 1, axis=axis) # axis == 0 return el[self.idx] + def _apply_to_dataset2d(self, el: Dataset2D, *, axis, fill_value=None): + if fill_value is None: + fill_value = np.nan + index_dim = el.index_dim + if axis == 0: + # Dataset.reindex() can't handle ExtensionArrays + extension_arrays = {col: arr for col, arr in el.items() if is_extension_array_dtype(arr)} + el = el.drop_vars(extension_arrays.keys()) + el = el.reindex({index_dim: self.new_idx}, method=None, fill_value=fill_value) + for col, arr in extension_arrays.items(): + el[col] = (index_dim, pd.Series(arr, index=self.old_idx).reindex(self.new_idx, fill_value=fill_value)) + return el + else: + cols = el.columns + tokeep = cols[cols.isin(self.new_idx)] + el = el[tokeep.to_list()] + newcols = self.new_idx[~self.new_idx.isin(cols)] + for col in newcols: + el[col] = (el.index_dim, np.broadcast_to(fill_value, el.shape[0])) + return el + + @property def idx(self): return self.old_idx.get_indexer(self.new_idx) @@ -1153,44 +1182,47 @@ def make_dask_col_from_extension_dtype( base_path_or_zarr_group = col.attrs.get("base_path_or_zarr_group") elem_name = col.attrs.get("elem_name") - dims = col.dims - coords = col.coords.copy() - with maybe_open_h5(base_path_or_zarr_group, elem_name) as f: - maybe_chunk_size = get_chunksize(read_elem_lazy(f)) - chunk_size = ( - compute_chunk_layout_for_axis_size( - 1000 if maybe_chunk_size is None else maybe_chunk_size[0], col.shape[0] - ), - ) - - def get_chunk(block_info=None): - # reopening is important to get around h5py's unserializable lock in processes + if base_path_or_zarr_group is not None and elem_name is not None: # lazy, backed by store + dims = col.dims + coords = col.coords.copy() with maybe_open_h5(base_path_or_zarr_group, elem_name) as f: - v = read_elem_lazy(f) - variable = xr.Variable( - data=xr.core.indexing.LazilyIndexedArray(v), dims=dims - ) - data_array = XArray( - variable, - coords=coords, - dims=dims, - ) - idx = tuple( - slice(start, stop) for start, stop in block_info[None]["array-location"] + maybe_chunk_size = get_chunksize(read_elem_lazy(f)) + chunk_size = ( + compute_chunk_layout_for_axis_size( + 1000 if maybe_chunk_size is None else maybe_chunk_size[0], col.shape[0] + ), ) - chunk = np.array(data_array.data[idx].array) - return chunk - if col.dtype in ("category", "string") or use_only_object_dtype: - dtype = "object" - else: - dtype = col.dtype.numpy_dtype - return da.map_blocks( - get_chunk, - chunks=chunk_size, - meta=np.array([], dtype=dtype), - dtype=dtype, - ) + def get_chunk(block_info=None): + # reopening is important to get around h5py's unserializable lock in processes + with maybe_open_h5(base_path_or_zarr_group, elem_name) as f: + v = read_elem_lazy(f) + variable = xr.Variable( + data=xr.core.indexing.LazilyIndexedArray(v), dims=dims + ) + data_array = XArray( + variable, + coords=coords, + dims=dims, + ) + idx = tuple( + slice(start, stop) for start, stop in block_info[None]["array-location"] + ) + chunk = np.array(data_array.data[idx].array) + return chunk + + if col.dtype == "category" or col.dtype == "string" or use_only_object_dtype: + dtype = "object" + else: + dtype = col.dtype.numpy_dtype + return da.map_blocks( + get_chunk, + chunks=chunk_size, + meta=np.array([], dtype=dtype), + dtype=dtype, + ) + else: # in-memory + return da.from_array(col.values, chunks=-1) def make_xarray_extension_dtypes_dask( @@ -1275,11 +1307,10 @@ def concat_dataset2d_on_annot_axis( # Concat along the dummy index ds = Dataset2D( xr.concat(annotations_re_indexed, join=join, dim=DS_CONCAT_DUMMY_INDEX_NAME), - attrs={"indexing_key": f"true_{DS_CONCAT_DUMMY_INDEX_NAME}"}, ) ds.coords[DS_CONCAT_DUMMY_INDEX_NAME] = pd.RangeIndex( ds.coords[DS_CONCAT_DUMMY_INDEX_NAME].shape[0] - ).astype("str") + ) # Drop any lingering dimensions (swap doesn't delete) ds = ds.drop_dims(d for d in ds.dims if d != DS_CONCAT_DUMMY_INDEX_NAME) # Create a new true index and then delete the columns resulting from the concatenation for each index. @@ -1290,8 +1321,8 @@ def concat_dataset2d_on_annot_axis( ) # prevent duplicate values index.coords[DS_CONCAT_DUMMY_INDEX_NAME] = ds.coords[DS_CONCAT_DUMMY_INDEX_NAME] - ds[f"true_{DS_CONCAT_DUMMY_INDEX_NAME}"] = index - for key in {a.attrs["indexing_key"] for a in annotations_re_indexed}: + ds.coords[DS_CONCAT_DUMMY_INDEX_NAME] = index + for key in set(a.attrs["indexing_key"] for a in annotations_re_indexed): del ds[key] if DUMMY_RANGE_INDEX_KEY in ds: del ds[DUMMY_RANGE_INDEX_KEY] diff --git a/src/anndata/_core/xarray.py b/src/anndata/_core/xarray.py index a36acb9e6..117cd4651 100644 --- a/src/anndata/_core/xarray.py +++ b/src/anndata/_core/xarray.py @@ -97,5 +97,8 @@ def columns(self) -> pd.Index: ------- :class:`pandas.Index` that represents the "columns." """ - columns_list = list(self.keys()) - return pd.Index(columns_list) + columns = set(self.keys()) + index_key = self.attrs.get("indexing_key", None) + if index_key is not None: + columns.discard(index_key) + return pd.Index(columns) diff --git a/src/anndata/tests/helpers.py b/src/anndata/tests/helpers.py index 33a36c3ad..4b7fbc1d4 100644 --- a/src/anndata/tests/helpers.py +++ b/src/anndata/tests/helpers.py @@ -60,32 +60,6 @@ ) -# Give this to gen_adata when dask array support is expected. -GEN_ADATA_DASK_ARGS = dict( - obsm_types=( - sparse.csr_matrix, - np.ndarray, - pd.DataFrame, - DaskArray, - sparse.csr_array, - ), - varm_types=( - sparse.csr_matrix, - np.ndarray, - pd.DataFrame, - DaskArray, - sparse.csr_array, - ), - layers_types=( - sparse.csr_matrix, - np.ndarray, - pd.DataFrame, - DaskArray, - sparse.csr_array, - ), -) - - DEFAULT_KEY_TYPES = ( sparse.csr_matrix, np.ndarray, @@ -106,6 +80,19 @@ ) +# Give this to gen_adata when dask array support is expected. +GEN_ADATA_DASK_ARGS = dict( + obsm_types=(*DEFAULT_KEY_TYPES, DaskArray), + varm_types=(*DEFAULT_KEY_TYPES, DaskArray), + layers_types=(*DEFAULT_KEY_TYPES, DaskArray), +) + +GEN_ADATA_NO_XARRAY_ARGS = dict( + obsm_types=(*DEFAULT_KEY_TYPES, AwkArray), + varm_types=(*DEFAULT_KEY_TYPES, AwkArray) +) + + def gen_vstr_recarray(m, n, dtype=None): size = m * n lengths = np.random.randint(3, 5, size) @@ -292,8 +279,8 @@ def gen_adata( # noqa: PLR0913 var_dtypes: Collection[ np.dtype | pd.api.extensions.ExtensionDtype ] = DEFAULT_COL_TYPES, - obsm_types: Collection[type] = DEFAULT_KEY_TYPES + (AwkArray, XArray, XDataset), - varm_types: Collection[type] = DEFAULT_KEY_TYPES + (AwkArray, XArray, XDataset), + obsm_types: Collection[type] = DEFAULT_KEY_TYPES + (AwkArray, XDataset), + varm_types: Collection[type] = DEFAULT_KEY_TYPES + (AwkArray, XDataset), layers_types: Collection[type] = DEFAULT_KEY_TYPES, random_state: np.random.Generator | None = None, sparse_fmt: Literal["csr", "csc"] = "csr", @@ -325,6 +312,7 @@ def gen_adata( # noqa: PLR0913 (csr, csc) """ import dask.array as da + import xarray as xr if random_state is None: random_state = np.random.default_rng() @@ -349,6 +337,7 @@ def gen_adata( # noqa: PLR0913 df=gen_typed_df(M, obs_names, dtypes=obs_dtypes), awk_2d_ragged=gen_awkward((M, None)), da=da.random.random((M, 50)), + xdataset=xr.Dataset.from_dataframe(gen_typed_df(M, obs_names, dtypes=obs_dtypes)) ) obsm = {k: v for k, v in obsm.items() if type(v) in obsm_types} obsm = maybe_add_sparse_array( @@ -364,6 +353,7 @@ def gen_adata( # noqa: PLR0913 df=gen_typed_df(N, var_names, dtypes=var_dtypes), awk_2d_ragged=gen_awkward((N, None)), da=da.random.random((N, 50)), + xdataset=xr.Dataset.from_dataframe(gen_typed_df(N, var_names, dtypes=var_dtypes)) ) varm = {k: v for k, v in varm.items() if type(v) in varm_types} varm = maybe_add_sparse_array( @@ -743,6 +733,9 @@ def assert_equal_extension_array( _elem_name=elem_name, ) +@assert_equal.register(XArray) +def assert_equal_xarray(a: XArray, b: object, *, exact: bool=False, elem_name: str | None = None): + report_name(a.equals)(b, _elem_name=elem_name) @assert_equal.register(Raw) def assert_equal_raw( diff --git a/tests/lazy/conftest.py b/tests/lazy/conftest.py index 4a153c25f..d5f7a69fb 100644 --- a/tests/lazy/conftest.py +++ b/tests/lazy/conftest.py @@ -14,6 +14,8 @@ from anndata.experimental import read_lazy from anndata.tests.helpers import ( DEFAULT_COL_TYPES, + DEFAULT_KEY_TYPES, + AwkArray, AccessTrackingStore, as_dense_dask_array, gen_adata, @@ -92,6 +94,8 @@ def adata_remote_orig_with_path( mtx_format, obs_dtypes=(*DEFAULT_COL_TYPES, pd.StringDtype), var_dtypes=(*DEFAULT_COL_TYPES, pd.StringDtype), + obsm_types=(*DEFAULT_KEY_TYPES, AwkArray), + varm_types=(*DEFAULT_KEY_TYPES, AwkArray), ) orig.raw = orig.copy() with ad.settings.override(allow_write_nullable_strings=True): diff --git a/tests/lazy/test_concat.py b/tests/lazy/test_concat.py index 8f9e70004..4c1205dc1 100644 --- a/tests/lazy/test_concat.py +++ b/tests/lazy/test_concat.py @@ -11,7 +11,7 @@ import anndata as ad from anndata._core.file_backing import to_memory from anndata.experimental import read_lazy -from anndata.tests.helpers import assert_equal, gen_adata +from anndata.tests.helpers import assert_equal, gen_adata, GEN_ADATA_NO_XARRAY_ARGS from .conftest import ANNDATA_ELEMS, get_key_trackers_for_columns_on_axis @@ -312,7 +312,7 @@ def with_elem_in_memory(adata: AnnData, attr: str, key: str | None) -> AnnData: def test_concat_bad_mixed_types(tmp_path: Path): - orig = gen_adata((100, 200), np.array) + orig = gen_adata((100, 200), np.array, **GEN_ADATA_NO_XARRAY_ARGS) orig.write_zarr(tmp_path) remote = read_lazy(tmp_path) orig.obsm["df"] = orig.obsm["array"] diff --git a/tests/lazy/test_read.py b/tests/lazy/test_read.py index 3755e3023..a63956dda 100644 --- a/tests/lazy/test_read.py +++ b/tests/lazy/test_read.py @@ -7,7 +7,7 @@ from anndata.compat import DaskArray from anndata.experimental import read_lazy -from anndata.tests.helpers import AccessTrackingStore, assert_equal, gen_adata +from anndata.tests.helpers import AccessTrackingStore, assert_equal, gen_adata, GEN_ADATA_NO_XARRAY_ARGS from .conftest import ANNDATA_ELEMS @@ -144,7 +144,7 @@ def test_view_of_view_to_memory(adata_remote: AnnData, adata_orig: AnnData): def test_unconsolidated(tmp_path: Path, mtx_format): - adata = gen_adata((1000, 1000), mtx_format) + adata = gen_adata((1000, 1000), mtx_format, **GEN_ADATA_NO_XARRAY_ARGS) orig_pth = tmp_path / "orig.zarr" adata.write_zarr(orig_pth) (orig_pth / ".zmetadata").unlink() diff --git a/tests/test_backed_hdf5.py b/tests/test_backed_hdf5.py index ecf2ef03e..75239adee 100644 --- a/tests/test_backed_hdf5.py +++ b/tests/test_backed_hdf5.py @@ -13,6 +13,7 @@ from anndata.compat import CSArray, CSMatrix from anndata.tests.helpers import ( GEN_ADATA_DASK_ARGS, + GEN_ADATA_NO_XARRAY_ARGS, as_dense_dask_array, assert_equal, gen_adata, @@ -196,7 +197,7 @@ def test_backed_raw(tmp_path): def test_backed_raw_subset(tmp_path, array_type, subset_func, subset_func2): backed_pth = tmp_path / "backed.h5ad" final_pth = tmp_path / "final.h5ad" - mem_adata = gen_adata((10, 10), X_type=array_type) + mem_adata = gen_adata((10, 10), X_type=array_type, **GEN_ADATA_NO_XARRAY_ARGS) mem_adata.raw = mem_adata obs_idx = subset_func(mem_adata.obs_names) var_idx = subset_func2(mem_adata.var_names) diff --git a/tests/test_base.py b/tests/test_base.py index f41a21049..11fed10f6 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -15,7 +15,7 @@ import anndata as ad from anndata import AnnData, ImplicitModificationWarning from anndata._settings import settings -from anndata.tests.helpers import assert_equal, gen_adata, get_multiindex_columns_df +from anndata.tests.helpers import assert_equal, gen_adata, get_multiindex_columns_df, GEN_ADATA_NO_XARRAY_ARGS if TYPE_CHECKING: from pathlib import Path @@ -724,7 +724,7 @@ def assert_eq_not_id(a, b): def test_to_memory_no_copy(): - adata = gen_adata((3, 5)) + adata = gen_adata((3, 5), **GEN_ADATA_NO_XARRAY_ARGS) mem = adata.to_memory() assert mem.X is adata.X diff --git a/tests/test_concatenate.py b/tests/test_concatenate.py index f78b50bd2..44b713bf6 100644 --- a/tests/test_concatenate.py +++ b/tests/test_concatenate.py @@ -21,7 +21,7 @@ from anndata import AnnData, Raw, concat from anndata._core import merge from anndata._core.index import _subset -from anndata.compat import AwkArray, CSArray, CSMatrix, CupySparseMatrix, DaskArray +from anndata.compat import AwkArray, CSArray, CSMatrix, CupySparseMatrix, DaskArray, XDataset from anndata.tests import helpers from anndata.tests.helpers import ( BASE_MATRIX_PARAMS, @@ -497,19 +497,19 @@ def get_obs_els(adata): adata1.obsm = { k: v for k, v in adata1.obsm.items() - if not isinstance(v, pd.DataFrame | AwkArray) + if not isinstance(v, pd.DataFrame | AwkArray | XDataset) } adata2 = gen_adata((10, 5)) adata2.obsm = { k: v[:, : v.shape[1] // 2] for k, v in adata2.obsm.items() - if not isinstance(v, pd.DataFrame | AwkArray) + if not isinstance(v, pd.DataFrame | AwkArray | XDataset) } adata3 = gen_adata((7, 3)) adata3.obsm = { k: v[:, : v.shape[1] // 3] for k, v in adata3.obsm.items() - if not isinstance(v, pd.DataFrame | AwkArray) + if not isinstance(v, pd.DataFrame | AwkArray | XDataset) } # remove AwkArrays from adata.var, as outer joins are not yet implemented for them for tmp_ad in [adata1, adata2, adata3]: diff --git a/tests/test_io_conversion.py b/tests/test_io_conversion.py index 763c89233..60acf748a 100644 --- a/tests/test_io_conversion.py +++ b/tests/test_io_conversion.py @@ -11,7 +11,7 @@ import anndata as ad from anndata.compat import CSMatrix -from anndata.tests.helpers import assert_equal, gen_adata +from anndata.tests.helpers import assert_equal, gen_adata, GEN_ADATA_NO_XARRAY_ARGS @pytest.fixture( @@ -39,7 +39,7 @@ def test_sparse_to_dense_disk(tmp_path, mtx_format, to_convert): mem_pth = tmp_path / "orig.h5ad" dense_from_mem_pth = tmp_path / "dense_mem.h5ad" dense_from_disk_pth = tmp_path / "dense_disk.h5ad" - mem = gen_adata((50, 50), mtx_format) + mem = gen_adata((50, 50), mtx_format, **GEN_ADATA_NO_XARRAY_ARGS) mem.raw = mem.copy() mem.write_h5ad(mem_pth) @@ -66,7 +66,7 @@ def test_sparse_to_dense_disk(tmp_path, mtx_format, to_convert): def test_sparse_to_dense_inplace(tmp_path, spmtx_format): pth = tmp_path / "adata.h5ad" - orig = gen_adata((50, 50), spmtx_format) + orig = gen_adata((50, 50), spmtx_format, **GEN_ADATA_NO_XARRAY_ARGS) orig.raw = orig.copy() orig.write(pth) backed = ad.read_h5ad(pth, backed="r+") @@ -97,7 +97,7 @@ def test_sparse_to_dense_errors(tmp_path): def test_dense_to_sparse_memory(tmp_path, spmtx_format, to_convert): dense_path = tmp_path / "dense.h5ad" - orig = gen_adata((50, 50), np.array) + orig = gen_adata((50, 50), np.array, **GEN_ADATA_NO_XARRAY_ARGS) orig.raw = orig.copy() orig.write_h5ad(dense_path) assert not isinstance(orig.X, CSMatrix) diff --git a/tests/test_io_dispatched.py b/tests/test_io_dispatched.py index f421b8523..a0bdc7ed4 100644 --- a/tests/test_io_dispatched.py +++ b/tests/test_io_dispatched.py @@ -10,7 +10,7 @@ from anndata._io.zarr import open_write_group from anndata.compat import CSArray, CSMatrix, ZarrGroup, is_zarr_v2 from anndata.experimental import read_dispatched, write_dispatched -from anndata.tests.helpers import assert_equal, gen_adata +from anndata.tests.helpers import assert_equal, gen_adata, GEN_ADATA_NO_XARRAY_ARGS if TYPE_CHECKING: from collections.abc import Callable @@ -26,7 +26,7 @@ def read_only_axis_dfs(func, elem_name: str, elem, iospec): else: return None - adata = gen_adata((1000, 100)) + adata = gen_adata((1000, 100), **GEN_ADATA_NO_XARRAY_ARGS) z = open_write_group(tmp_path) ad.io.write_elem(z, "/", adata) @@ -57,7 +57,7 @@ def read_as_dask_array(func, elem_name: str, elem, iospec): else: return func(elem) - adata = gen_adata((1000, 100)) + adata = gen_adata((1000, 100), **GEN_ADATA_NO_XARRAY_ARGS) z = open_write_group(tmp_path) ad.io.write_elem(z, "/", adata) # TODO: see https://github.com/zarr-developers/zarr-python/issues/2716 @@ -77,7 +77,7 @@ def read_as_dask_array(func, elem_name: str, elem, iospec): def test_read_dispatched_null_case(tmp_path: Path): - adata = gen_adata((100, 100)) + adata = gen_adata((100, 100), **GEN_ADATA_NO_XARRAY_ARGS) z = open_write_group(tmp_path) ad.io.write_elem(z, "/", adata) # TODO: see https://github.com/zarr-developers/zarr-python/issues/2716 @@ -96,7 +96,7 @@ def determine_chunks(elem_shape, specified_chunks): chunk_iterator = chain(specified_chunks, repeat(None)) return tuple(e if c is None else c for e, c in zip(elem_shape, chunk_iterator)) - adata = gen_adata((1000, 100)) + adata = gen_adata((1000, 100), **GEN_ADATA_NO_XARRAY_ARGS) def write_chunked(func, store, k, elem, dataset_kwargs, iospec): M, N = 13, 42 @@ -194,7 +194,7 @@ def zarr_reader(func, elem_name: str, elem, iospec): zarr_read_keys.append(elem_name if is_zarr_v2() else elem_name.strip("/")) return func(elem) - adata = gen_adata((50, 100)) + adata = gen_adata((50, 100), **GEN_ADATA_NO_XARRAY_ARGS) with h5py.File(h5ad_path, "w") as f: write_dispatched(f, "/", adata, callback=h5ad_writer) diff --git a/tests/test_io_elementwise.py b/tests/test_io_elementwise.py index 9c019cb0f..8be9d302b 100644 --- a/tests/test_io_elementwise.py +++ b/tests/test_io_elementwise.py @@ -29,6 +29,7 @@ as_dense_cupy_dask_array, assert_equal, gen_adata, + GEN_ADATA_NO_XARRAY_ARGS ) if TYPE_CHECKING: @@ -123,7 +124,7 @@ def create_sparse_store( pytest.param(True, "numeric-scalar", id="py_bool"), pytest.param(1.0, "numeric-scalar", id="py_float"), pytest.param({"a": 1}, "dict", id="py_dict"), - pytest.param(gen_adata((3, 2)), "anndata", id="anndata"), + pytest.param(gen_adata((3, 2), **GEN_ADATA_NO_XARRAY_ARGS), "anndata", id="anndata"), pytest.param( sparse.random(5, 3, format="csr", density=0.5), "csr_matrix", @@ -428,7 +429,7 @@ def test_write_indptr_dtype_override(store, sparse_format): def test_io_spec_raw(store): - adata = gen_adata((3, 2)) + adata = gen_adata((3, 2), **GEN_ADATA_NO_XARRAY_ARGS) adata.raw = adata.copy() write_elem(store, "adata", adata) @@ -440,7 +441,7 @@ def test_io_spec_raw(store): def test_write_anndata_to_root(store): - adata = gen_adata((3, 2)) + adata = gen_adata((3, 2), **GEN_ADATA_NO_XARRAY_ARGS) write_elem(store, "/", adata) # TODO: see https://github.com/zarr-developers/zarr-python/issues/2716 @@ -460,7 +461,7 @@ def test_write_anndata_to_root(store): ], ) def test_read_iospec_not_found(store, attribute, value): - adata = gen_adata((3, 2)) + adata = gen_adata((3, 2), **GEN_ADATA_NO_XARRAY_ARGS) write_elem(store, "/", adata) store["obs"].attrs.update({attribute: value}) @@ -527,7 +528,7 @@ def _(store, key, adata): "value", [ pytest.param({"a": 1}, id="dict"), - pytest.param(gen_adata((3, 2)), id="anndata"), + pytest.param(gen_adata((3, 2), **GEN_ADATA_NO_XARRAY_ARGS), id="anndata"), pytest.param(sparse.random(5, 3, format="csr", density=0.5), id="csr_matrix"), pytest.param(sparse.random(5, 3, format="csc", density=0.5), id="csc_matrix"), pytest.param(pd.DataFrame({"a": [1, 2, 3]}), id="dataframe"), @@ -578,7 +579,7 @@ def test_write_to_root(store, value): def test_read_zarr_from_group(tmp_path, consolidated): # https://github.com/scverse/anndata/issues/1056 pth = tmp_path / "test.zarr" - adata = gen_adata((3, 2)) + adata = gen_adata((3, 2), **GEN_ADATA_NO_XARRAY_ARGS) z = open_write_group(pth) write_elem(z, "table/table", adata) @@ -628,7 +629,7 @@ def test_io_pd_cow(store, copy_on_write): pytest.xfail("copy_on_write option is not available in pandas < 2") # https://github.com/zarr-developers/numcodecs/issues/514 with pd.option_context("mode.copy_on_write", copy_on_write): - orig = gen_adata((3, 2)) + orig = gen_adata((3, 2), **GEN_ADATA_NO_XARRAY_ARGS) write_elem(store, "adata", orig) from_store = read_elem(store["adata"]) assert_equal(orig, from_store) diff --git a/tests/test_io_warnings.py b/tests/test_io_warnings.py index 0e3848168..43c39ba45 100644 --- a/tests/test_io_warnings.py +++ b/tests/test_io_warnings.py @@ -10,7 +10,7 @@ from packaging.version import Version import anndata as ad -from anndata.tests.helpers import gen_adata +from anndata.tests.helpers import gen_adata, GEN_ADATA_NO_XARRAY_ARGS @pytest.mark.skipif(not find_spec("scanpy"), reason="Scanpy is not installed") @@ -44,7 +44,7 @@ def test_old_format_warning_thrown(): def test_old_format_warning_not_thrown(tmp_path): pth = tmp_path / "current.h5ad" - adata = gen_adata((20, 10)) + adata = gen_adata((20, 10), **GEN_ADATA_NO_XARRAY_ARGS) adata.write_h5ad(pth) with warnings.catch_warnings(record=True) as record: diff --git a/tests/test_readwrite.py b/tests/test_readwrite.py index 0a4bcc336..5b19c6e3f 100644 --- a/tests/test_readwrite.py +++ b/tests/test_readwrite.py @@ -29,7 +29,7 @@ _read_attr, is_zarr_v2, ) -from anndata.tests.helpers import as_dense_dask_array, assert_equal, gen_adata +from anndata.tests.helpers import as_dense_dask_array, assert_equal, gen_adata, GEN_ADATA_NO_XARRAY_ARGS if TYPE_CHECKING: from typing import Literal @@ -87,7 +87,7 @@ def dataset_kwargs(request): @pytest.fixture def rw(backing_h5ad): M, N = 100, 101 - orig = gen_adata((M, N)) + orig = gen_adata((M, N), **GEN_ADATA_NO_XARRAY_ARGS) orig.write(backing_h5ad) curr = ad.read_h5ad(backing_h5ad) return curr, orig @@ -255,7 +255,7 @@ def test_readwrite_equivalent_h5ad_zarr(tmp_path, typ): zarr_pth = tmp_path / "adata.zarr" M, N = 100, 101 - adata = gen_adata((M, N), X_type=typ) + adata = gen_adata((M, N), X_type=typ, **GEN_ADATA_NO_XARRAY_ARGS) adata.raw = adata.copy() adata.write_h5ad(h5ad_pth) @@ -286,7 +286,7 @@ def store_context(path: Path): ], ) def test_read_full_io_error(tmp_path, name, read, write): - adata = gen_adata((4, 3)) + adata = gen_adata((4, 3), **GEN_ADATA_NO_XARRAY_ARGS) path = tmp_path / name write(adata, path) with store_context(path) as store: @@ -325,7 +325,7 @@ def test_read_full_io_error(tmp_path, name, read, write): def test_hdf5_compression_opts(tmp_path, compression, compression_opts): # https://github.com/scverse/anndata/issues/497 pth = Path(tmp_path) / "adata.h5ad" - adata = gen_adata((10, 8)) + adata = gen_adata((10, 8), **GEN_ADATA_NO_XARRAY_ARGS) kwargs = {} if compression is not None: kwargs["compression"] = compression @@ -362,7 +362,7 @@ def check_compressed(key, value): def test_zarr_compression(tmp_path, zarr_write_format): ad.settings.zarr_write_format = zarr_write_format pth = str(Path(tmp_path) / "adata.zarr") - adata = gen_adata((10, 8)) + adata = gen_adata((10, 8), **GEN_ADATA_NO_XARRAY_ARGS) if zarr_write_format == 2 or is_zarr_v2(): from numcodecs import Blosc @@ -415,7 +415,7 @@ def check_compressed(value, key): def test_changed_obs_var_names(tmp_path, diskfmt): filepth = tmp_path / f"test.{diskfmt}" - orig = gen_adata((10, 10)) + orig = gen_adata((10, 10), **GEN_ADATA_NO_XARRAY_ARGS) orig.obs_names.name = "obs" orig.var_names.name = "var" modified = orig.copy() @@ -751,7 +751,7 @@ def test_zarr_chunk_X(tmp_path): import zarr zarr_pth = Path(tmp_path) / "test.zarr" - adata = gen_adata((100, 100), X_type=np.array) + adata = gen_adata((100, 100), X_type=np.array, **GEN_ADATA_NO_XARRAY_ARGS) adata.write_zarr(zarr_pth, chunks=(10, 10)) z = zarr.open(str(zarr_pth)) # As of v2.3.2 zarr won’t take a Path @@ -879,13 +879,13 @@ def test_backwards_compat_zarr(): def test_adata_in_uns(tmp_path, diskfmt, roundtrip): pth = tmp_path / f"adatas_in_uns.{diskfmt}" - orig = gen_adata((4, 5)) + orig = gen_adata((4, 5), **GEN_ADATA_NO_XARRAY_ARGS) orig.uns["adatas"] = { - "a": gen_adata((1, 2)), - "b": gen_adata((12, 8)), + "a": gen_adata((1, 2), **GEN_ADATA_NO_XARRAY_ARGS), + "b": gen_adata((12, 8), **GEN_ADATA_NO_XARRAY_ARGS), } - another_one = gen_adata((2, 5)) - another_one.raw = gen_adata((2, 7)) + another_one = gen_adata((2, 5), **GEN_ADATA_NO_XARRAY_ARGS) + another_one.raw = gen_adata((2, 7), **GEN_ADATA_NO_XARRAY_ARGS) orig.uns["adatas"]["b"].uns["another_one"] = another_one curr = roundtrip(orig, pth) diff --git a/tests/test_x.py b/tests/test_x.py index 4c0b62516..9a6282baa 100644 --- a/tests/test_x.py +++ b/tests/test_x.py @@ -10,7 +10,7 @@ import anndata as ad from anndata import AnnData from anndata._warnings import ImplicitModificationWarning -from anndata.tests.helpers import assert_equal, gen_adata +from anndata.tests.helpers import assert_equal, gen_adata, GEN_ADATA_NO_XARRAY_ARGS from anndata.utils import asarray UNLABELLED_ARRAY_TYPES = [ @@ -156,7 +156,7 @@ def test_io_missing_X(tmp_path, diskfmt): write = lambda obj, pth: getattr(obj, f"write_{diskfmt}")(pth) read = lambda pth: getattr(ad, f"read_{diskfmt}")(pth) - adata = gen_adata((20, 30)) + adata = gen_adata((20, 30), **GEN_ADATA_NO_XARRAY_ARGS) del adata.X write(adata, file_pth) From 3862745784035bc2c11017e13618be9f60761dd1 Mon Sep 17 00:00:00 2001 From: Ilia Kats Date: Fri, 11 Apr 2025 14:59:50 +0200 Subject: [PATCH 06/43] add tests for var/obs as in-memory XDatasets and make them pass --- src/anndata/_core/aligned_df.py | 4 +- src/anndata/_core/aligned_mapping.py | 1 + src/anndata/_core/anndata.py | 16 +++--- src/anndata/_core/merge.py | 19 ++++--- src/anndata/_core/xarray.py | 27 +++++++++- src/anndata/experimental/backed/_io.py | 2 +- src/anndata/tests/helpers.py | 7 +++ tests/test_base.py | 11 ++-- tests/test_concatenate.py | 74 +++++++++++++++++--------- 9 files changed, 109 insertions(+), 52 deletions(-) diff --git a/src/anndata/_core/aligned_df.py b/src/anndata/_core/aligned_df.py index 2aa83b37f..071b2a3fb 100644 --- a/src/anndata/_core/aligned_df.py +++ b/src/anndata/_core/aligned_df.py @@ -109,8 +109,8 @@ def _mk_df_error( expected: int, actual: int, ): + what = "row" if attr == "obs" else "column" if source == "X": - what = "row" if attr == "obs" else "column" msg = ( f"Observations annot. `{attr}` must have as many rows as `X` has {what}s " f"({expected}), but has {actual} rows." @@ -118,7 +118,7 @@ def _mk_df_error( else: msg = ( f"`shape` is inconsistent with `{attr}` " - "({actual} {what}s instead of {expected})" + f"({actual} {what}s instead of {expected})" ) return ValueError(msg) diff --git a/src/anndata/_core/aligned_mapping.py b/src/anndata/_core/aligned_mapping.py index 606a2b94b..309e76b22 100644 --- a/src/anndata/_core/aligned_mapping.py +++ b/src/anndata/_core/aligned_mapping.py @@ -278,6 +278,7 @@ def _validate_value(self, val: Value, key: str) -> Value: else: msg = "Index.equals and pd.testing.assert_index_equal disagree" raise AssertionError(msg) + val.index.name = self.dim_names.name # this is consistent with AnnData.obsm.setter and AnnData.varm.setter return super()._validate_value(val, key) @property diff --git a/src/anndata/_core/anndata.py b/src/anndata/_core/anndata.py index 2e7cc93f3..68cf55f73 100644 --- a/src/anndata/_core/anndata.py +++ b/src/anndata/_core/anndata.py @@ -56,7 +56,7 @@ from zarr.storage import StoreLike - from ..compat import Index1D + from ..compat import Index1D, XDataset from ..typing import XDataType from .aligned_mapping import AxisArraysView, LayersView, PairwiseArraysView from .index import Index @@ -747,10 +747,8 @@ def n_vars(self) -> int: """Number of variables/features.""" return len(self.var_names) - def _set_dim_df(self, value: pd.DataFrame, attr: Literal["obs", "var"]): - if not isinstance(value, pd.DataFrame): - msg = f"Can only assign pd.DataFrame to {attr}." - raise ValueError(msg) + def _set_dim_df(self, value: pd.DataFrame | XDataset, attr: Literal["obs", "var"]): + value = _gen_dataframe(value, [f"{attr}_names", f"{'row' if attr == 'obs' else 'col'}_names"], source="shape", attr=attr, length=self.n_obs if attr == "obs" else self.n_vars) raise_value_error_if_multiindex_columns(value, attr) value_idx = self._prep_dim_index(value.index, attr) if self.is_view: @@ -805,12 +803,12 @@ def _set_dim_index(self, value: pd.Index, attr: str): v.index = value @property - def obs(self) -> pd.DataFrame: + def obs(self) -> pd.DataFrame | Dataset2D: """One-dimensional annotation of observations (`pd.DataFrame`).""" return self._obs @obs.setter - def obs(self, value: pd.DataFrame): + def obs(self, value: pd.DataFrame | XDataset): self._set_dim_df(value, "obs") @obs.deleter @@ -828,12 +826,12 @@ def obs_names(self, names: Sequence[str]): self._set_dim_index(names, "obs") @property - def var(self) -> pd.DataFrame: + def var(self) -> pd.DataFrame | Dataset2D: """One-dimensional annotation of variables/ features (`pd.DataFrame`).""" return self._var @var.setter - def var(self, value: pd.DataFrame): + def var(self, value: pd.DataFrame | XDataset): self._set_dim_df(value, "var") @var.deleter diff --git a/src/anndata/_core/merge.py b/src/anndata/_core/merge.py index e31413c0e..025913549 100644 --- a/src/anndata/_core/merge.py +++ b/src/anndata/_core/merge.py @@ -1126,7 +1126,11 @@ def _resolve_axis( def axis_indices(adata: AnnData, axis: Literal["obs", 0, "var", 1]) -> pd.Index: """Helper function to get adata.{dim}_names.""" _, axis_name = _resolve_axis(axis) - return getattr(adata, f"{axis_name}_names") + attr = getattr(adata, axis_name) + if isinstance(attr, Dataset2D) and "indexing_key" in attr.attrs: + return attr[attr.attrs["indexing_key"]].to_index() + else: + return attr.index # TODO: Resolve https://github.com/scverse/anndata/issues/678 and remove this function @@ -1295,7 +1299,7 @@ def concat_dataset2d_on_annot_axis( # First create a dummy index a.coords[DS_CONCAT_DUMMY_INDEX_NAME] = ( old_key, - pd.RangeIndex(a.shape[0]).astype("str"), + pd.RangeIndex(a.shape[0]), ) # Set all the dimensions to this new dummy index a = a.swap_dims({old_key: DS_CONCAT_DUMMY_INDEX_NAME}) @@ -1599,8 +1603,9 @@ def concat( # noqa: PLR0912, PLR0913, PLR0915 else: concat_annot = concat_dataset2d_on_annot_axis(annotations, join) concat_indices.name = DS_CONCAT_DUMMY_INDEX_NAME + concat_annot.index = concat_indices if label is not None: - concat_annot[label] = label_col + concat_annot[label] = label_col if not isinstance(concat_annot, Dataset2D) else (DS_CONCAT_DUMMY_INDEX_NAME, label_col) # Annotation for other axis alt_annotations = [getattr(a, alt_axis_name) for a in adatas] @@ -1622,13 +1627,11 @@ def concat( # noqa: PLR0912, PLR0913, PLR0915 ) ) annotations_with_only_dask = [ - a.rename({a.attrs["indexing_key"]: "merge_index"}) + a.rename({a.true_index_dim: "merge_index"}) for a in annotations_with_only_dask ] - alt_annot = Dataset2D( - xr.merge(annotations_with_only_dask, join=join, compat="override"), - attrs={"indexing_key": "merge_index"}, - ) + alt_annot = Dataset2D(xr.merge(annotations_with_only_dask, join=join, compat="override")) + alt_annot.true_index_dim = "merge_index" X = concat_Xs(adatas, reindexers, axis=axis, fill_value=fill_value) diff --git a/src/anndata/_core/xarray.py b/src/anndata/_core/xarray.py index 117cd4651..96d080591 100644 --- a/src/anndata/_core/xarray.py +++ b/src/anndata/_core/xarray.py @@ -32,6 +32,18 @@ def index_dim(self) -> str: raise ValueError(msg) return next(iter(self.coords.keys())) + @property + def true_index_dim(self) -> str: + index_dim = self.attrs.get("indexing_key", None) + return index_dim if index_dim is not None else self.index_dim + + @true_index_dim.setter + def true_index_dim(self, val: str): + if val not in self.dims: + if val not in self.data_vars: + raise ValueError(f"Unknown variable `{val}`.") + self.attrs["indexing_key"] = val + @property def xr_index(self) -> XArray: return self[self.index_dim] @@ -48,8 +60,13 @@ def index(self) -> pd.Index: @index.setter def index(self, val) -> None: - coord = get_index_dim(self) - self.coords[coord] = val + index_dim = self.index_dim + self.coords[index_dim] = (index_dim, val) + if isinstance(val, pd.Index) and val.name is not None and val.name != index_dim: + self.update(self.rename({self.index_dim: val.name})) + del self.coords[index_dim] + if "indexing_key" in self.attrs: + del self.attrs["indexing_key"] @property def shape(self) -> tuple[int, int]: @@ -80,6 +97,12 @@ def __getitem__(self, idx): return IlocGetter(self) + def __getitem__(self, idx) -> Dataset2D: + ret = super().__getitem__(idx) + if idx == []: # empty XDataset + ret.coords[self.index_dim] = self.xr_index + return ret + def to_memory(self, *, copy=False) -> pd.DataFrame: df = self.to_dataframe() index_key = self.attrs.get("indexing_key", None) diff --git a/src/anndata/experimental/backed/_io.py b/src/anndata/experimental/backed/_io.py index 83de3fde1..896d3ec74 100644 --- a/src/anndata/experimental/backed/_io.py +++ b/src/anndata/experimental/backed/_io.py @@ -141,7 +141,7 @@ def callback(func: Read, /, elem_name: str, elem: StorageType, *, iospec: IOSpec } or "nullable" in iospec.encoding_type ): - if iospec.encoding_type == "dataframe" and elem_name in {"/obs", "/var"}: + if "dataframe" == iospec.encoding_type and (elem_name[:4] in {"/obs", "/var"} or elem_name[:8] in {"/raw/obs", "/raw/var"}): return read_elem_lazy(elem, use_range_index=not load_annotation_index) return read_elem_lazy(elem) elif iospec.encoding_type in {"awkward-array"}: diff --git a/src/anndata/tests/helpers.py b/src/anndata/tests/helpers.py index 4b7fbc1d4..ec7fa6b8a 100644 --- a/src/anndata/tests/helpers.py +++ b/src/anndata/tests/helpers.py @@ -279,6 +279,8 @@ def gen_adata( # noqa: PLR0913 var_dtypes: Collection[ np.dtype | pd.api.extensions.ExtensionDtype ] = DEFAULT_COL_TYPES, + obs_xdataset: bool = False, + var_xdataset: bool = False, obsm_types: Collection[type] = DEFAULT_KEY_TYPES + (AwkArray, XDataset), varm_types: Collection[type] = DEFAULT_KEY_TYPES + (AwkArray, XDataset), layers_types: Collection[type] = DEFAULT_KEY_TYPES, @@ -326,6 +328,11 @@ def gen_adata( # noqa: PLR0913 obs.rename(columns=dict(cat="obs_cat"), inplace=True) var.rename(columns=dict(cat="var_cat"), inplace=True) + if obs_xdataset: + obs = XDataset.from_dataframe(obs) + if var_xdataset: + var = XDataset.from_dataframe(var) + if X_type is None: X = None else: diff --git a/tests/test_base.py b/tests/test_base.py index 11fed10f6..dfc37895e 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -276,11 +276,12 @@ def test_setting_index_names_error(attr): @pytest.mark.parametrize("dim", ["obs", "var"]) -def test_setting_dim_index(dim): +@pytest.mark.parametrize(("obs_xdataset", "var_xdataset"), [(False, False), (True, True)]) +def test_setting_dim_index(dim, obs_xdataset, var_xdataset): index_attr = f"{dim}_names" mapping_attr = f"{dim}m" - orig = gen_adata((5, 5)) + orig = gen_adata((5, 5), obs_xdataset=obs_xdataset, var_xdataset=var_xdataset) orig.raw = orig.copy() curr = orig.copy() view = orig[:, :] @@ -514,12 +515,10 @@ def test_set_obs(): adata = AnnData(np.array([[1, 2, 3], [4, 5, 6]])) adata.obs = pd.DataFrame(dict(a=[3, 4])) - assert adata.obs_names.tolist() == [0, 1] + assert adata.obs_names.tolist() == ["0", "1"] - with pytest.raises(ValueError, match="but this AnnData has shape"): + with pytest.raises(ValueError, match="`shape` is inconsistent with `obs`"): adata.obs = pd.DataFrame(dict(a=[3, 4, 5])) - with pytest.raises(ValueError, match="Can only assign pd.DataFrame"): - adata.obs = dict(a=[1, 2]) def test_multicol(): diff --git a/tests/test_concatenate.py b/tests/test_concatenate.py index 44b713bf6..7327c8545 100644 --- a/tests/test_concatenate.py +++ b/tests/test_concatenate.py @@ -140,6 +140,20 @@ def fix_known_differences( orig = orig.copy() result = result.copy() + for attrname in ("obs", "var"): + if isinstance(getattr(result, attrname), XDataset): + for adata in (orig, result): + df = getattr(adata, attrname).to_dataframe() + df.index.name = "index" + setattr(adata, attrname, df) + resattr = getattr(result, attrname) + origattr = getattr(orig, attrname) + for colname, col in resattr.items(): + # concatenation of XDatasets happens via Dask arrays and those don't know about Pandas Extension arrays + # so categoricals and nullable arrays are all converted to other dtypes + if col.dtype != origattr[colname].dtype and pd.api.types.is_extension_array_dtype(origattr[colname].dtype): + resattr[colname] = col.astype(origattr[colname].dtype) + result.strings_to_categoricals() # Should this be implicit in concatenation? # TODO @@ -162,8 +176,9 @@ def fix_known_differences( return orig, result -def test_concat_interface_errors(): - adatas = [gen_adata((5, 10)), gen_adata((5, 10))] +@pytest.mark.parametrize(("obs_xdataset", "var_xdataset"), [(False, False), (True, True)]) +def test_concat_interface_errors(obs_xdataset, var_xdataset): + adatas = [gen_adata((5, 10), obs_xdataset=obs_xdataset, var_xdataset=var_xdataset), gen_adata((5, 10), obs_xdataset=obs_xdataset, var_xdataset=var_xdataset)] with pytest.raises(ValueError, match="`axis` must be.*0, 1, 'obs', or 'var'"): concat(adatas, axis=3) @@ -181,8 +196,12 @@ def test_concat_interface_errors(): (lambda x, **kwargs: x[0].concatenate(x[1:], **kwargs), True), ], ) -def test_concatenate_roundtrip(join_type, array_type, concat_func, backwards_compat): - adata = gen_adata((100, 10), X_type=array_type, **GEN_ADATA_DASK_ARGS) +@pytest.mark.parametrize(("obs_xdataset", "var_xdataset"), [(False, False), (True, True)]) +def test_concatenate_roundtrip(join_type, array_type, concat_func, backwards_compat, obs_xdataset, var_xdataset): + adata = gen_adata((100, 10), X_type=array_type, obs_xdataset=obs_xdataset, var_xdataset=var_xdataset, **GEN_ADATA_DASK_ARGS) + + if backwards_compat and (obs_xdataset or var_xdataset): + pytest.xfail("https://github.com/pydata/xarray/issues/10218") remaining = adata.obs_names subsets = [] @@ -1161,11 +1180,12 @@ def test_concatenate_uns(unss, merge_strategy, result, value_gen): assert_equal(merged, result, elem_name="uns") -def test_transposed_concat(array_type, axis_name, join_type, merge_strategy): +@pytest.mark.parametrize(("obs_xdataset", "var_xdataset"), [(False, False), (True, True)]) +def test_transposed_concat(array_type, axis_name, join_type, merge_strategy, obs_xdataset, var_xdataset): axis, axis_name = merge._resolve_axis(axis_name) alt_axis = 1 - axis - lhs = gen_adata((10, 10), X_type=array_type, **GEN_ADATA_DASK_ARGS) - rhs = gen_adata((10, 12), X_type=array_type, **GEN_ADATA_DASK_ARGS) + lhs = gen_adata((10, 10), X_type=array_type, obs_xdataset=obs_xdataset, var_xdataset=var_xdataset, **GEN_ADATA_DASK_ARGS) + rhs = gen_adata((10, 12), X_type=array_type, obs_xdataset=obs_xdataset, var_xdataset=var_xdataset, **GEN_ADATA_DASK_ARGS) a = concat([lhs, rhs], axis=axis, join=join_type, merge=merge_strategy) b = concat([lhs.T, rhs.T], axis=alt_axis, join=join_type, merge=merge_strategy).T @@ -1173,13 +1193,14 @@ def test_transposed_concat(array_type, axis_name, join_type, merge_strategy): assert_equal(a, b) -def test_batch_key(axis_name): +@pytest.mark.parametrize(("obs_xdataset", "var_xdataset"), [(False, False), (True, True)]) +def test_batch_key(axis_name, obs_xdataset, var_xdataset): """Test that concat only adds a label if the key is provided""" get_annot = attrgetter(axis_name) - lhs = gen_adata((10, 10), **GEN_ADATA_DASK_ARGS) - rhs = gen_adata((10, 12), **GEN_ADATA_DASK_ARGS) + lhs = gen_adata((10, 10), obs_xdataset=obs_xdataset, var_xdataset=var_xdataset, **GEN_ADATA_DASK_ARGS) + rhs = gen_adata((10, 12), obs_xdataset=obs_xdataset, var_xdataset=var_xdataset, **GEN_ADATA_DASK_ARGS) # There is probably a prettier way to do this annot = get_annot(concat([lhs, rhs], axis=axis_name)) @@ -1200,10 +1221,11 @@ def test_batch_key(axis_name): ) == ["batch"] -def test_concat_categories_from_mapping(): +@pytest.mark.parametrize(("obs_xdataset", "var_xdataset"), [(False, False), (True, True)]) +def test_concat_categories_from_mapping(obs_xdataset, var_xdataset): mapping = { - "a": gen_adata((10, 10)), - "b": gen_adata((10, 10)), + "a": gen_adata((10, 10), obs_xdataset=obs_xdataset, var_xdataset=var_xdataset), + "b": gen_adata((10, 10), obs_xdataset=obs_xdataset, var_xdataset=var_xdataset), } keys = list(mapping.keys()) adatas = list(mapping.values()) @@ -1339,11 +1361,12 @@ def test_bool_promotion(): assert result.obs["bool"].dtype == np.dtype(bool) -def test_concat_names(axis_name): +@pytest.mark.parametrize(("obs_xdataset", "var_xdataset"), [(False, False), (True, True)]) +def test_concat_names(axis_name, obs_xdataset, var_xdataset): get_annot = attrgetter(axis_name) - lhs = gen_adata((10, 10)) - rhs = gen_adata((10, 10)) + lhs = gen_adata((10, 10), obs_xdataset=obs_xdataset, var_xdataset=var_xdataset) + rhs = gen_adata((10, 10), obs_xdataset=obs_xdataset, var_xdataset=var_xdataset) assert not get_annot(concat([lhs, rhs], axis=axis_name)).index.is_unique assert get_annot( @@ -1376,13 +1399,14 @@ def expected_shape( @pytest.mark.parametrize( "shape", [pytest.param((8, 0), id="no_var"), pytest.param((0, 10), id="no_obs")] ) -def test_concat_size_0_axis(axis_name, join_type, merge_strategy, shape): +@pytest.mark.parametrize(("obs_xdataset", "var_xdataset"), [(False, False), (True, True)]) +def test_concat_size_0_axis(axis_name, join_type, merge_strategy, shape, obs_xdataset, var_xdataset): """Regression test for https://github.com/scverse/anndata/issues/526""" axis, axis_name = merge._resolve_axis(axis_name) alt_axis = 1 - axis col_dtypes = (*DEFAULT_COL_TYPES, pd.StringDtype) - a = gen_adata((5, 7), obs_dtypes=col_dtypes, var_dtypes=col_dtypes) - b = gen_adata(shape, obs_dtypes=col_dtypes, var_dtypes=col_dtypes) + a = gen_adata((5, 7), obs_dtypes=col_dtypes, var_dtypes=col_dtypes, obs_xdataset=obs_xdataset, var_xdataset=var_xdataset) + b = gen_adata(shape, obs_dtypes=col_dtypes, var_dtypes=col_dtypes, obs_xdataset=obs_xdataset, var_xdataset=var_xdataset) expected_size = expected_shape(a, b, axis=axis, join=join_type) @@ -1441,9 +1465,10 @@ def test_concat_size_0_axis(axis_name, join_type, merge_strategy, shape): @pytest.mark.parametrize("elem", ["sparse", "array", "df", "da"]) @pytest.mark.parametrize("axis", ["obs", "var"]) -def test_concat_outer_aligned_mapping(elem, axis): - a = gen_adata((5, 5), **GEN_ADATA_DASK_ARGS) - b = gen_adata((3, 5), **GEN_ADATA_DASK_ARGS) +@pytest.mark.parametrize(("obs_xdataset", "var_xdataset"), [(False, False), (True, True)]) +def test_concat_outer_aligned_mapping(elem, axis, obs_xdataset, var_xdataset): + a = gen_adata((5, 5), obs_xdataset=obs_xdataset, var_xdataset=var_xdataset, **GEN_ADATA_DASK_ARGS) + b = gen_adata((3, 5), obs_xdataset=obs_xdataset, var_xdataset=var_xdataset, **GEN_ADATA_DASK_ARGS) del getattr(b, f"{axis}m")[elem] concated = concat({"a": a, "b": b}, join="outer", label="group", axis=axis) @@ -1469,8 +1494,9 @@ def test_concatenate_size_0_axis(): b.concatenate([a]).shape == (10, 0) -def test_concat_null_X(): - adatas_orig = {k: gen_adata((20, 10)) for k in list("abc")} +@pytest.mark.parametrize(("obs_xdataset", "var_xdataset"), [(False, False), (True, True)]) +def test_concat_null_X(obs_xdataset, var_xdataset): + adatas_orig = {k: gen_adata((20, 10), obs_xdataset=obs_xdataset, var_xdataset=var_xdataset) for k in list("abc")} adatas_no_X = {} for k, v in adatas_orig.items(): v = v.copy() From 95922039ebc3c0e534d9dedff7bbcf7ed444f661 Mon Sep 17 00:00:00 2001 From: Ilia Kats Date: Fri, 11 Apr 2025 16:05:26 +0200 Subject: [PATCH 07/43] move fake index handling into Dataset2D as much as possible --- src/anndata/_core/merge.py | 13 +++++++------ src/anndata/_core/xarray.py | 8 ++++++++ src/anndata/_io/specs/lazy_methods.py | 3 ++- 3 files changed, 17 insertions(+), 7 deletions(-) diff --git a/src/anndata/_core/merge.py b/src/anndata/_core/merge.py index 025913549..c1a681297 100644 --- a/src/anndata/_core/merge.py +++ b/src/anndata/_core/merge.py @@ -1127,8 +1127,8 @@ def axis_indices(adata: AnnData, axis: Literal["obs", 0, "var", 1]) -> pd.Index: """Helper function to get adata.{dim}_names.""" _, axis_name = _resolve_axis(axis) attr = getattr(adata, axis_name) - if isinstance(attr, Dataset2D) and "indexing_key" in attr.attrs: - return attr[attr.attrs["indexing_key"]].to_index() + if isinstance(attr, Dataset2D): + return attr.true_index else: return attr.index @@ -1294,8 +1294,7 @@ def concat_dataset2d_on_annot_axis( annotations_re_indexed = [] for a in make_xarray_extension_dtypes_dask(annotations): old_key = a.index_dim - if "indexing_key" not in a.attrs: - a.attrs["indexing_key"] = old_key + is_fake_index = old_key != a.true_index_dim # First create a dummy index a.coords[DS_CONCAT_DUMMY_INDEX_NAME] = ( old_key, @@ -1307,6 +1306,8 @@ def concat_dataset2d_on_annot_axis( old_coord = a.coords[old_key] del a.coords[old_key] a[old_key] = old_coord + if not is_fake_index: + a.true_index_dim = old_key annotations_re_indexed.append(a) # Concat along the dummy index ds = Dataset2D( @@ -1320,13 +1321,13 @@ def concat_dataset2d_on_annot_axis( # Create a new true index and then delete the columns resulting from the concatenation for each index. # This includes the dummy column (which is neither a dimension nor a true indexing column) index = xr.concat( - [a[a.attrs["indexing_key"]] for a in annotations_re_indexed], + [a.true_xr_index for a in annotations_re_indexed], dim=DS_CONCAT_DUMMY_INDEX_NAME, ) # prevent duplicate values index.coords[DS_CONCAT_DUMMY_INDEX_NAME] = ds.coords[DS_CONCAT_DUMMY_INDEX_NAME] ds.coords[DS_CONCAT_DUMMY_INDEX_NAME] = index - for key in set(a.attrs["indexing_key"] for a in annotations_re_indexed): + for key in set(true_index for a in annotations_re_indexed if (true_index := a.true_index_dim) != a.index_dim): del ds[key] if DUMMY_RANGE_INDEX_KEY in ds: del ds[DUMMY_RANGE_INDEX_KEY] diff --git a/src/anndata/_core/xarray.py b/src/anndata/_core/xarray.py index 96d080591..130b59f8c 100644 --- a/src/anndata/_core/xarray.py +++ b/src/anndata/_core/xarray.py @@ -68,6 +68,14 @@ def index(self, val) -> None: if "indexing_key" in self.attrs: del self.attrs["indexing_key"] + @property + def true_xr_index(self) -> XArray: + return self[self.true_index_dim] + + @property + def true_index(self) -> pd.Index: + return self.true_xr_index.to_index() + @property def shape(self) -> tuple[int, int]: """:attr:`~anndata.AnnData` internally looks for :attr:`~pandas.DataFrame.shape` so this ensures usability diff --git a/src/anndata/_io/specs/lazy_methods.py b/src/anndata/_io/specs/lazy_methods.py index 7ab2576b1..58a83ff41 100644 --- a/src/anndata/_io/specs/lazy_methods.py +++ b/src/anndata/_io/specs/lazy_methods.py @@ -286,9 +286,10 @@ def read_dataframe( dims=[DUMMY_RANGE_INDEX_KEY], name=DUMMY_RANGE_INDEX_KEY, ) + ds = Dataset2D(elem_xarray_dict) # We ensure the indexing_key attr always points to the true index # so that the roundtrip works even for the `use_range_index` `True` case - ds = Dataset2D(elem_xarray_dict, attrs={"indexing_key": elem.attrs["_index"]}) + ds.true_index_dim = "_index" return ds From 1a0b3f92b65a2f385a044fb3f2008ba3eb164a78 Mon Sep 17 00:00:00 2001 From: Ilia Kats Date: Fri, 11 Apr 2025 17:40:49 +0200 Subject: [PATCH 08/43] add release notes --- docs/release-notes/1966.feature.md | 1 + 1 file changed, 1 insertion(+) create mode 100755 docs/release-notes/1966.feature.md diff --git a/docs/release-notes/1966.feature.md b/docs/release-notes/1966.feature.md new file mode 100755 index 000000000..a702d9d47 --- /dev/null +++ b/docs/release-notes/1966.feature.md @@ -0,0 +1 @@ +Allow xarray Datasets to be used for obs/var/obsm/varm. {user}`ilia-kats` From 3eb33bb8c3c695ac56c0964fa29453e446e7af58 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 11 Apr 2025 15:43:26 +0000 Subject: [PATCH 09/43] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/anndata/_core/aligned_df.py | 15 +- src/anndata/_core/aligned_mapping.py | 5 +- src/anndata/_core/anndata.py | 12 +- src/anndata/_core/file_backing.py | 3 +- src/anndata/_core/index.py | 1 + src/anndata/_core/merge.py | 53 +++++-- src/anndata/_core/storage.py | 2 +- src/anndata/_core/views.py | 3 +- src/anndata/_core/xarray.py | 6 +- src/anndata/_io/specs/lazy_methods.py | 11 +- src/anndata/_io/specs/registry.py | 3 +- src/anndata/compat/__init__.py | 6 +- src/anndata/experimental/backed/_compat.py | 1 - src/anndata/experimental/backed/_io.py | 5 +- .../experimental/backed/_lazy_arrays.py | 2 +- src/anndata/tests/helpers.py | 21 ++- src/anndata/typing.py | 2 +- tests/lazy/conftest.py | 2 +- tests/lazy/test_concat.py | 2 +- tests/lazy/test_read.py | 7 +- tests/test_base.py | 11 +- tests/test_concatenate.py | 141 ++++++++++++++---- tests/test_io_conversion.py | 2 +- tests/test_io_dispatched.py | 2 +- tests/test_io_elementwise.py | 6 +- tests/test_io_warnings.py | 2 +- tests/test_readwrite.py | 7 +- tests/test_x.py | 2 +- 28 files changed, 251 insertions(+), 84 deletions(-) diff --git a/src/anndata/_core/aligned_df.py b/src/anndata/_core/aligned_df.py index 071b2a3fb..5fa52130a 100644 --- a/src/anndata/_core/aligned_df.py +++ b/src/anndata/_core/aligned_df.py @@ -7,10 +7,10 @@ import pandas as pd from pandas.api.types import is_string_dtype -from ..compat import XDataset -from .xarray import Dataset2D from .._warnings import ImplicitModificationWarning +from ..compat import XDataset +from .xarray import Dataset2D if TYPE_CHECKING: from collections.abc import Iterable @@ -122,6 +122,7 @@ def _mk_df_error( ) return ValueError(msg) + @_gen_dataframe.register(Dataset2D) def _gen_dataframe_xr( anno: Dataset2D, @@ -133,6 +134,14 @@ def _gen_dataframe_xr( ): return anno + @_gen_dataframe.register(XDataset) -def _gen_dataframe_xdataset(anno: Dataset, index_names: Iterable[str], *, source: Literal["X", "shape"], attr: Literal["obs", "var"], length: int | None=None): +def _gen_dataframe_xdataset( + anno: Dataset, + index_names: Iterable[str], + *, + source: Literal["X", "shape"], + attr: Literal["obs", "var"], + length: int | None = None, +): return Dataset2D(anno) diff --git a/src/anndata/_core/aligned_mapping.py b/src/anndata/_core/aligned_mapping.py index 309e76b22..095d29c16 100644 --- a/src/anndata/_core/aligned_mapping.py +++ b/src/anndata/_core/aligned_mapping.py @@ -24,6 +24,7 @@ from .index import _subset from .storage import coerce_array from .views import as_view, view_update +from .xarray import Dataset2D if TYPE_CHECKING: from collections.abc import Callable, Iterable, Iterator, Mapping @@ -278,7 +279,9 @@ def _validate_value(self, val: Value, key: str) -> Value: else: msg = "Index.equals and pd.testing.assert_index_equal disagree" raise AssertionError(msg) - val.index.name = self.dim_names.name # this is consistent with AnnData.obsm.setter and AnnData.varm.setter + val.index.name = ( + self.dim_names.name + ) # this is consistent with AnnData.obsm.setter and AnnData.varm.setter return super()._validate_value(val, key) @property diff --git a/src/anndata/_core/anndata.py b/src/anndata/_core/anndata.py index 68cf55f73..d3873ae8b 100644 --- a/src/anndata/_core/anndata.py +++ b/src/anndata/_core/anndata.py @@ -42,12 +42,12 @@ from .raw import Raw from .sparse_dataset import BaseCompressedSparseDataset, sparse_dataset from .storage import coerce_array -from .xarray import Dataset2D from .views import ( DictView, _resolve_idxs, as_view, ) +from .xarray import Dataset2D if TYPE_CHECKING: from collections.abc import Iterable @@ -748,7 +748,13 @@ def n_vars(self) -> int: return len(self.var_names) def _set_dim_df(self, value: pd.DataFrame | XDataset, attr: Literal["obs", "var"]): - value = _gen_dataframe(value, [f"{attr}_names", f"{'row' if attr == 'obs' else 'col'}_names"], source="shape", attr=attr, length=self.n_obs if attr == "obs" else self.n_vars) + value = _gen_dataframe( + value, + [f"{attr}_names", f"{'row' if attr == 'obs' else 'col'}_names"], + source="shape", + attr=attr, + length=self.n_obs if attr == "obs" else self.n_vars, + ) raise_value_error_if_multiindex_columns(value, attr) value_idx = self._prep_dim_index(value.index, attr) if self.is_view: @@ -2077,6 +2083,7 @@ def _get_and_delete_multicol_field(self, a, key_multicol): getattr(self, a).drop(keys, axis=1, inplace=True) return values + @AnnData._remove_unused_categories.register(Dataset2D) @staticmethod def _remove_unused_categories_xr( @@ -2084,6 +2091,7 @@ def _remove_unused_categories_xr( ): pass # this is handled automatically by the categorical arrays themselves i.e., they dedup upon access. + def _check_2d_shape(X): """\ Check shape of array or sparse matrix. diff --git a/src/anndata/_core/file_backing.py b/src/anndata/_core/file_backing.py index b51d14511..dd2435638 100644 --- a/src/anndata/_core/file_backing.py +++ b/src/anndata/_core/file_backing.py @@ -9,8 +9,8 @@ import h5py from ..compat import AwkArray, DaskArray, ZarrArray, ZarrGroup -from .xarray import Dataset2D from .sparse_dataset import BaseCompressedSparseDataset +from .xarray import Dataset2D if TYPE_CHECKING: from collections.abc import Iterator @@ -162,6 +162,7 @@ def _(x: AwkArray, *, copy: bool = False): else: return x + @to_memory.register(Dataset2D) def _(x: Dataset2D, copy: bool = False): return x.to_memory(copy=copy) diff --git a/src/anndata/_core/index.py b/src/anndata/_core/index.py index 1b04a198c..12c266e24 100644 --- a/src/anndata/_core/index.py +++ b/src/anndata/_core/index.py @@ -208,6 +208,7 @@ def _subset_awkarray(a: AwkArray, subset_idx: Index): subset_idx = np.ix_(*subset_idx) return a[subset_idx] + @_subset.register(Dataset2D) def _(a: Dataset2D, subset_idx: Index): key = a.index_dim diff --git a/src/anndata/_core/merge.py b/src/anndata/_core/merge.py index c1a681297..821f43090 100644 --- a/src/anndata/_core/merge.py +++ b/src/anndata/_core/merge.py @@ -17,8 +17,8 @@ import scipy from natsort import natsorted from packaging.version import Version -from scipy import sparse from pandas.api.types import is_extension_array_dtype +from scipy import sparse from anndata._core.file_backing import to_memory from anndata._warnings import ExperimentalFeatureWarning @@ -45,6 +45,7 @@ from pandas.api.extensions import ExtensionDtype from anndata._types import Join_T + from ..compat import XArray T = TypeVar("T") @@ -209,6 +210,7 @@ def equal_awkward(a, b) -> bool: return ak.almost_equal(a, b) + @equal.register(Dataset2D) def equal_dataset2d(a, b) -> bool: return a.equals(b) @@ -732,11 +734,20 @@ def _apply_to_dataset2d(self, el: Dataset2D, *, axis, fill_value=None): index_dim = el.index_dim if axis == 0: # Dataset.reindex() can't handle ExtensionArrays - extension_arrays = {col: arr for col, arr in el.items() if is_extension_array_dtype(arr)} + extension_arrays = { + col: arr for col, arr in el.items() if is_extension_array_dtype(arr) + } el = el.drop_vars(extension_arrays.keys()) - el = el.reindex({index_dim: self.new_idx}, method=None, fill_value=fill_value) + el = el.reindex( + {index_dim: self.new_idx}, method=None, fill_value=fill_value + ) for col, arr in extension_arrays.items(): - el[col] = (index_dim, pd.Series(arr, index=self.old_idx).reindex(self.new_idx, fill_value=fill_value)) + el[col] = ( + index_dim, + pd.Series(arr, index=self.old_idx).reindex( + self.new_idx, fill_value=fill_value + ), + ) return el else: cols = el.columns @@ -747,7 +758,6 @@ def _apply_to_dataset2d(self, el: Dataset2D, *, axis, fill_value=None): el[col] = (el.index_dim, np.broadcast_to(fill_value, el.shape[0])) return el - @property def idx(self): return self.old_idx.get_indexer(self.new_idx) @@ -1181,19 +1191,23 @@ def make_dask_col_from_extension_dtype( get_chunksize, maybe_open_h5, ) + from anndata.compat import XArray + from anndata.compat import xarray as xr from anndata.experimental import read_elem_lazy - from anndata.compat import XArray, xarray as xr base_path_or_zarr_group = col.attrs.get("base_path_or_zarr_group") elem_name = col.attrs.get("elem_name") - if base_path_or_zarr_group is not None and elem_name is not None: # lazy, backed by store + if ( + base_path_or_zarr_group is not None and elem_name is not None + ): # lazy, backed by store dims = col.dims coords = col.coords.copy() with maybe_open_h5(base_path_or_zarr_group, elem_name) as f: maybe_chunk_size = get_chunksize(read_elem_lazy(f)) chunk_size = ( compute_chunk_layout_for_axis_size( - 1000 if maybe_chunk_size is None else maybe_chunk_size[0], col.shape[0] + 1000 if maybe_chunk_size is None else maybe_chunk_size[0], + col.shape[0], ), ) @@ -1210,7 +1224,8 @@ def get_chunk(block_info=None): dims=dims, ) idx = tuple( - slice(start, stop) for start, stop in block_info[None]["array-location"] + slice(start, stop) + for start, stop in block_info[None]["array-location"] ) chunk = np.array(data_array.data[idx].array) return chunk @@ -1225,7 +1240,7 @@ def get_chunk(block_info=None): meta=np.array([], dtype=dtype), dtype=dtype, ) - else: # in-memory + else: # in-memory return da.from_array(col.values, chunks=-1) @@ -1287,8 +1302,8 @@ def concat_dataset2d_on_annot_axis( ------- Concatenated :class:`~anndata.experimental.backed._xarray.Dataset2D` """ - from anndata._io.specs.lazy_methods import DUMMY_RANGE_INDEX_KEY from anndata._core.xarray import Dataset2D + from anndata._io.specs.lazy_methods import DUMMY_RANGE_INDEX_KEY from anndata.compat import xarray as xr annotations_re_indexed = [] @@ -1327,7 +1342,11 @@ def concat_dataset2d_on_annot_axis( # prevent duplicate values index.coords[DS_CONCAT_DUMMY_INDEX_NAME] = ds.coords[DS_CONCAT_DUMMY_INDEX_NAME] ds.coords[DS_CONCAT_DUMMY_INDEX_NAME] = index - for key in set(true_index for a in annotations_re_indexed if (true_index := a.true_index_dim) != a.index_dim): + for key in set( + true_index + for a in annotations_re_indexed + if (true_index := a.true_index_dim) != a.index_dim + ): del ds[key] if DUMMY_RANGE_INDEX_KEY in ds: del ds[DUMMY_RANGE_INDEX_KEY] @@ -1606,7 +1625,11 @@ def concat( # noqa: PLR0912, PLR0913, PLR0915 concat_indices.name = DS_CONCAT_DUMMY_INDEX_NAME concat_annot.index = concat_indices if label is not None: - concat_annot[label] = label_col if not isinstance(concat_annot, Dataset2D) else (DS_CONCAT_DUMMY_INDEX_NAME, label_col) + concat_annot[label] = ( + label_col + if not isinstance(concat_annot, Dataset2D) + else (DS_CONCAT_DUMMY_INDEX_NAME, label_col) + ) # Annotation for other axis alt_annotations = [getattr(a, alt_axis_name) for a in adatas] @@ -1631,7 +1654,9 @@ def concat( # noqa: PLR0912, PLR0913, PLR0915 a.rename({a.true_index_dim: "merge_index"}) for a in annotations_with_only_dask ] - alt_annot = Dataset2D(xr.merge(annotations_with_only_dask, join=join, compat="override")) + alt_annot = Dataset2D( + xr.merge(annotations_with_only_dask, join=join, compat="override") + ) alt_annot.true_index_dim = "merge_index" X = concat_Xs(adatas, reindexers, axis=axis, fill_value=fill_value) diff --git a/src/anndata/_core/storage.py b/src/anndata/_core/storage.py index b601adbab..76914a82a 100644 --- a/src/anndata/_core/storage.py +++ b/src/anndata/_core/storage.py @@ -10,13 +10,13 @@ from anndata.compat import CSArray, CSMatrix from .._warnings import ImplicitModificationWarning -from .xarray import Dataset2D from ..compat import XDataset from ..utils import ( ensure_df_homogeneous, join_english, raise_value_error_if_multiindex_columns, ) +from .xarray import Dataset2D if TYPE_CHECKING: from typing import Any diff --git a/src/anndata/_core/views.py b/src/anndata/_core/views.py index 29e815cf4..d3254ec79 100644 --- a/src/anndata/_core/views.py +++ b/src/anndata/_core/views.py @@ -22,8 +22,8 @@ DaskArray, ZappyArray, ) -from .xarray import Dataset2D from .access import ElementRef +from .xarray import Dataset2D if TYPE_CHECKING: from collections.abc import Callable, Iterable, KeysView, Sequence @@ -362,6 +362,7 @@ def as_view_cupy_csr(mtx, view_args): def as_view_cupy_csc(mtx, view_args): return CupySparseCSCView(mtx, view_args=view_args) + @as_view.register(Dataset2D) def _(a: Dataset2D, view_args): return a diff --git a/src/anndata/_core/xarray.py b/src/anndata/_core/xarray.py index 130b59f8c..c6468701c 100644 --- a/src/anndata/_core/xarray.py +++ b/src/anndata/_core/xarray.py @@ -7,10 +7,6 @@ from ..compat import XDataset if TYPE_CHECKING: - from collections.abc import Hashable, Iterable - from typing import Any, Literal - - from .index import Index from ..compat import XArray @@ -107,7 +103,7 @@ def __getitem__(self, idx): def __getitem__(self, idx) -> Dataset2D: ret = super().__getitem__(idx) - if idx == []: # empty XDataset + if idx == []: # empty XDataset ret.coords[self.index_dim] = self.xr_index return ret diff --git a/src/anndata/_io/specs/lazy_methods.py b/src/anndata/_io/specs/lazy_methods.py index 58a83ff41..192bea0f2 100644 --- a/src/anndata/_io/specs/lazy_methods.py +++ b/src/anndata/_io/specs/lazy_methods.py @@ -12,9 +12,9 @@ import anndata as ad from anndata._core.file_backing import filename, get_elem_name -from anndata.abc import CSCDataset, CSRDataset -from anndata.compat import DaskArray, H5Array, H5Group, ZarrArray, ZarrGroup, XArray from anndata._core.xarray import Dataset2D +from anndata.abc import CSCDataset, CSRDataset +from anndata.compat import DaskArray, H5Array, H5Group, XArray, ZarrArray, ZarrGroup from .registry import _LAZY_REGISTRY, IOSpec @@ -221,9 +221,10 @@ def _gen_xarray_dict_iterator_from_elems( dim_name: str, index: np.NDArray, ) -> Generator[tuple[str, XArray], None, None]: + from anndata.experimental.backed._lazy_arrays import CategoricalArray, MaskedArray + from ...compat import XArray from ...compat import xarray as xr - from anndata.experimental.backed._lazy_arrays import CategoricalArray, MaskedArray for k, v in elem_dict.items(): if isinstance(v, DaskArray) and k != dim_name: @@ -243,9 +244,7 @@ def _gen_xarray_dict_iterator_from_elems( }, ) elif k == dim_name: - data_array = XArray( - index, coords=[index], dims=[dim_name], name=dim_name - ) + data_array = XArray(index, coords=[index], dims=[dim_name], name=dim_name) else: msg = f"Could not read {k}: {v} from into xarray Dataset2D" raise ValueError(msg) diff --git a/src/anndata/_io/specs/registry.py b/src/anndata/_io/specs/registry.py index 0ac5ece3d..85ff9bd13 100644 --- a/src/anndata/_io/specs/registry.py +++ b/src/anndata/_io/specs/registry.py @@ -24,10 +24,11 @@ WriteCallback, _WriteInternal, ) - from ...compat import Dataset2D from anndata.experimental.backed._lazy_arrays import CategoricalArray, MaskedArray from anndata.typing import RWAble + from ...compat import Dataset2D + T = TypeVar("T") W = TypeVar("W", bound=_WriteInternal) LazyDataStructures = DaskArray | Dataset2D | CategoricalArray | MaskedArray diff --git a/src/anndata/compat/__init__.py b/src/anndata/compat/__init__.py index aa09917ea..8e9281fcf 100644 --- a/src/anndata/compat/__init__.py +++ b/src/anndata/compat/__init__.py @@ -99,13 +99,16 @@ class DaskArray: def __repr__(): return "mock dask.array.core.Array" + if find_spec("xarray") or TYPE_CHECKING: import xarray - from xarray import DataArray as XArray, Dataset as XDataset + from xarray import DataArray as XArray + from xarray import Dataset as XDataset from xarray.backends import BackendArray as XBackendArray from xarray.backends.zarr import ZarrArrayWrapper as XZarrArrayWrapper else: xarray = None + class XArray: def __repr__(self) -> str: return "mock DataArray" @@ -122,6 +125,7 @@ class XBackendArray: def __repr__(self) -> str: return "mock BackendArray" + # https://github.com/scverse/anndata/issues/1749 def is_cupy_importable() -> bool: try: diff --git a/src/anndata/experimental/backed/_compat.py b/src/anndata/experimental/backed/_compat.py index 3d7338fcd..0657a4be3 100644 --- a/src/anndata/experimental/backed/_compat.py +++ b/src/anndata/experimental/backed/_compat.py @@ -1,6 +1,5 @@ from __future__ import annotations -from importlib.util import find_spec from typing import TYPE_CHECKING from ..._core.xarray import Dataset2D diff --git a/src/anndata/experimental/backed/_io.py b/src/anndata/experimental/backed/_io.py index 896d3ec74..b58001eee 100644 --- a/src/anndata/experimental/backed/_io.py +++ b/src/anndata/experimental/backed/_io.py @@ -141,7 +141,10 @@ def callback(func: Read, /, elem_name: str, elem: StorageType, *, iospec: IOSpec } or "nullable" in iospec.encoding_type ): - if "dataframe" == iospec.encoding_type and (elem_name[:4] in {"/obs", "/var"} or elem_name[:8] in {"/raw/obs", "/raw/var"}): + if iospec.encoding_type == "dataframe" and ( + elem_name[:4] in {"/obs", "/var"} + or elem_name[:8] in {"/raw/obs", "/raw/var"} + ): return read_elem_lazy(elem, use_range_index=not load_annotation_index) return read_elem_lazy(elem) elif iospec.encoding_type in {"awkward-array"}: diff --git a/src/anndata/experimental/backed/_lazy_arrays.py b/src/anndata/experimental/backed/_lazy_arrays.py index 0aa367ce9..7b6b3514e 100644 --- a/src/anndata/experimental/backed/_lazy_arrays.py +++ b/src/anndata/experimental/backed/_lazy_arrays.py @@ -11,7 +11,7 @@ from anndata.compat import H5Array, ZarrArray from ..._settings import settings -from ...compat import XBackendArray, XArray, XZarrArrayWrapper +from ...compat import XArray, XBackendArray, XZarrArrayWrapper from ...compat import xarray as xr if TYPE_CHECKING: diff --git a/src/anndata/tests/helpers.py b/src/anndata/tests/helpers.py index ec7fa6b8a..0246dcc2b 100644 --- a/src/anndata/tests/helpers.py +++ b/src/anndata/tests/helpers.py @@ -25,8 +25,6 @@ from anndata._core.views import ArrayView from anndata.compat import ( AwkArray, - XArray, - XDataset, CSArray, CSMatrix, CupyArray, @@ -34,6 +32,8 @@ CupyCSRMatrix, CupySparseMatrix, DaskArray, + XArray, + XDataset, ZarrArray, is_zarr_v2, ) @@ -88,8 +88,7 @@ ) GEN_ADATA_NO_XARRAY_ARGS = dict( - obsm_types=(*DEFAULT_KEY_TYPES, AwkArray), - varm_types=(*DEFAULT_KEY_TYPES, AwkArray) + obsm_types=(*DEFAULT_KEY_TYPES, AwkArray), varm_types=(*DEFAULT_KEY_TYPES, AwkArray) ) @@ -344,7 +343,9 @@ def gen_adata( # noqa: PLR0913 df=gen_typed_df(M, obs_names, dtypes=obs_dtypes), awk_2d_ragged=gen_awkward((M, None)), da=da.random.random((M, 50)), - xdataset=xr.Dataset.from_dataframe(gen_typed_df(M, obs_names, dtypes=obs_dtypes)) + xdataset=xr.Dataset.from_dataframe( + gen_typed_df(M, obs_names, dtypes=obs_dtypes) + ), ) obsm = {k: v for k, v in obsm.items() if type(v) in obsm_types} obsm = maybe_add_sparse_array( @@ -360,7 +361,9 @@ def gen_adata( # noqa: PLR0913 df=gen_typed_df(N, var_names, dtypes=var_dtypes), awk_2d_ragged=gen_awkward((N, None)), da=da.random.random((N, 50)), - xdataset=xr.Dataset.from_dataframe(gen_typed_df(N, var_names, dtypes=var_dtypes)) + xdataset=xr.Dataset.from_dataframe( + gen_typed_df(N, var_names, dtypes=var_dtypes) + ), ) varm = {k: v for k, v in varm.items() if type(v) in varm_types} varm = maybe_add_sparse_array( @@ -740,10 +743,14 @@ def assert_equal_extension_array( _elem_name=elem_name, ) + @assert_equal.register(XArray) -def assert_equal_xarray(a: XArray, b: object, *, exact: bool=False, elem_name: str | None = None): +def assert_equal_xarray( + a: XArray, b: object, *, exact: bool = False, elem_name: str | None = None +): report_name(a.equals)(b, _elem_name=elem_name) + @assert_equal.register(Raw) def assert_equal_raw( a: Raw, b: object, *, exact: bool = False, elem_name: str | None = None diff --git a/src/anndata/typing.py b/src/anndata/typing.py index ab82926ca..fd11cc2d1 100644 --- a/src/anndata/typing.py +++ b/src/anndata/typing.py @@ -16,9 +16,9 @@ CupySparseMatrix, DaskArray, H5Array, + XArray, ZappyArray, ZarrArray, - XArray ) from .compat import Index as _Index diff --git a/tests/lazy/conftest.py b/tests/lazy/conftest.py index d5f7a69fb..6e181c70b 100644 --- a/tests/lazy/conftest.py +++ b/tests/lazy/conftest.py @@ -15,8 +15,8 @@ from anndata.tests.helpers import ( DEFAULT_COL_TYPES, DEFAULT_KEY_TYPES, - AwkArray, AccessTrackingStore, + AwkArray, as_dense_dask_array, gen_adata, gen_typed_df, diff --git a/tests/lazy/test_concat.py b/tests/lazy/test_concat.py index 4c1205dc1..152d7e50c 100644 --- a/tests/lazy/test_concat.py +++ b/tests/lazy/test_concat.py @@ -11,7 +11,7 @@ import anndata as ad from anndata._core.file_backing import to_memory from anndata.experimental import read_lazy -from anndata.tests.helpers import assert_equal, gen_adata, GEN_ADATA_NO_XARRAY_ARGS +from anndata.tests.helpers import GEN_ADATA_NO_XARRAY_ARGS, assert_equal, gen_adata from .conftest import ANNDATA_ELEMS, get_key_trackers_for_columns_on_axis diff --git a/tests/lazy/test_read.py b/tests/lazy/test_read.py index a63956dda..f6d8d75f3 100644 --- a/tests/lazy/test_read.py +++ b/tests/lazy/test_read.py @@ -7,7 +7,12 @@ from anndata.compat import DaskArray from anndata.experimental import read_lazy -from anndata.tests.helpers import AccessTrackingStore, assert_equal, gen_adata, GEN_ADATA_NO_XARRAY_ARGS +from anndata.tests.helpers import ( + GEN_ADATA_NO_XARRAY_ARGS, + AccessTrackingStore, + assert_equal, + gen_adata, +) from .conftest import ANNDATA_ELEMS diff --git a/tests/test_base.py b/tests/test_base.py index dfc37895e..1ff0db903 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -15,7 +15,12 @@ import anndata as ad from anndata import AnnData, ImplicitModificationWarning from anndata._settings import settings -from anndata.tests.helpers import assert_equal, gen_adata, get_multiindex_columns_df, GEN_ADATA_NO_XARRAY_ARGS +from anndata.tests.helpers import ( + GEN_ADATA_NO_XARRAY_ARGS, + assert_equal, + gen_adata, + get_multiindex_columns_df, +) if TYPE_CHECKING: from pathlib import Path @@ -276,7 +281,9 @@ def test_setting_index_names_error(attr): @pytest.mark.parametrize("dim", ["obs", "var"]) -@pytest.mark.parametrize(("obs_xdataset", "var_xdataset"), [(False, False), (True, True)]) +@pytest.mark.parametrize( + ("obs_xdataset", "var_xdataset"), [(False, False), (True, True)] +) def test_setting_dim_index(dim, obs_xdataset, var_xdataset): index_attr = f"{dim}_names" mapping_attr = f"{dim}m" diff --git a/tests/test_concatenate.py b/tests/test_concatenate.py index 7327c8545..4a62b48e1 100644 --- a/tests/test_concatenate.py +++ b/tests/test_concatenate.py @@ -21,7 +21,14 @@ from anndata import AnnData, Raw, concat from anndata._core import merge from anndata._core.index import _subset -from anndata.compat import AwkArray, CSArray, CSMatrix, CupySparseMatrix, DaskArray, XDataset +from anndata.compat import ( + AwkArray, + CSArray, + CSMatrix, + CupySparseMatrix, + DaskArray, + XDataset, +) from anndata.tests import helpers from anndata.tests.helpers import ( BASE_MATRIX_PARAMS, @@ -151,7 +158,11 @@ def fix_known_differences( for colname, col in resattr.items(): # concatenation of XDatasets happens via Dask arrays and those don't know about Pandas Extension arrays # so categoricals and nullable arrays are all converted to other dtypes - if col.dtype != origattr[colname].dtype and pd.api.types.is_extension_array_dtype(origattr[colname].dtype): + if col.dtype != origattr[ + colname + ].dtype and pd.api.types.is_extension_array_dtype( + origattr[colname].dtype + ): resattr[colname] = col.astype(origattr[colname].dtype) result.strings_to_categoricals() # Should this be implicit in concatenation? @@ -176,9 +187,14 @@ def fix_known_differences( return orig, result -@pytest.mark.parametrize(("obs_xdataset", "var_xdataset"), [(False, False), (True, True)]) +@pytest.mark.parametrize( + ("obs_xdataset", "var_xdataset"), [(False, False), (True, True)] +) def test_concat_interface_errors(obs_xdataset, var_xdataset): - adatas = [gen_adata((5, 10), obs_xdataset=obs_xdataset, var_xdataset=var_xdataset), gen_adata((5, 10), obs_xdataset=obs_xdataset, var_xdataset=var_xdataset)] + adatas = [ + gen_adata((5, 10), obs_xdataset=obs_xdataset, var_xdataset=var_xdataset), + gen_adata((5, 10), obs_xdataset=obs_xdataset, var_xdataset=var_xdataset), + ] with pytest.raises(ValueError, match="`axis` must be.*0, 1, 'obs', or 'var'"): concat(adatas, axis=3) @@ -196,9 +212,19 @@ def test_concat_interface_errors(obs_xdataset, var_xdataset): (lambda x, **kwargs: x[0].concatenate(x[1:], **kwargs), True), ], ) -@pytest.mark.parametrize(("obs_xdataset", "var_xdataset"), [(False, False), (True, True)]) -def test_concatenate_roundtrip(join_type, array_type, concat_func, backwards_compat, obs_xdataset, var_xdataset): - adata = gen_adata((100, 10), X_type=array_type, obs_xdataset=obs_xdataset, var_xdataset=var_xdataset, **GEN_ADATA_DASK_ARGS) +@pytest.mark.parametrize( + ("obs_xdataset", "var_xdataset"), [(False, False), (True, True)] +) +def test_concatenate_roundtrip( + join_type, array_type, concat_func, backwards_compat, obs_xdataset, var_xdataset +): + adata = gen_adata( + (100, 10), + X_type=array_type, + obs_xdataset=obs_xdataset, + var_xdataset=var_xdataset, + **GEN_ADATA_DASK_ARGS, + ) if backwards_compat and (obs_xdataset or var_xdataset): pytest.xfail("https://github.com/pydata/xarray/issues/10218") @@ -1180,12 +1206,28 @@ def test_concatenate_uns(unss, merge_strategy, result, value_gen): assert_equal(merged, result, elem_name="uns") -@pytest.mark.parametrize(("obs_xdataset", "var_xdataset"), [(False, False), (True, True)]) -def test_transposed_concat(array_type, axis_name, join_type, merge_strategy, obs_xdataset, var_xdataset): +@pytest.mark.parametrize( + ("obs_xdataset", "var_xdataset"), [(False, False), (True, True)] +) +def test_transposed_concat( + array_type, axis_name, join_type, merge_strategy, obs_xdataset, var_xdataset +): axis, axis_name = merge._resolve_axis(axis_name) alt_axis = 1 - axis - lhs = gen_adata((10, 10), X_type=array_type, obs_xdataset=obs_xdataset, var_xdataset=var_xdataset, **GEN_ADATA_DASK_ARGS) - rhs = gen_adata((10, 12), X_type=array_type, obs_xdataset=obs_xdataset, var_xdataset=var_xdataset, **GEN_ADATA_DASK_ARGS) + lhs = gen_adata( + (10, 10), + X_type=array_type, + obs_xdataset=obs_xdataset, + var_xdataset=var_xdataset, + **GEN_ADATA_DASK_ARGS, + ) + rhs = gen_adata( + (10, 12), + X_type=array_type, + obs_xdataset=obs_xdataset, + var_xdataset=var_xdataset, + **GEN_ADATA_DASK_ARGS, + ) a = concat([lhs, rhs], axis=axis, join=join_type, merge=merge_strategy) b = concat([lhs.T, rhs.T], axis=alt_axis, join=join_type, merge=merge_strategy).T @@ -1193,14 +1235,26 @@ def test_transposed_concat(array_type, axis_name, join_type, merge_strategy, obs assert_equal(a, b) -@pytest.mark.parametrize(("obs_xdataset", "var_xdataset"), [(False, False), (True, True)]) +@pytest.mark.parametrize( + ("obs_xdataset", "var_xdataset"), [(False, False), (True, True)] +) def test_batch_key(axis_name, obs_xdataset, var_xdataset): """Test that concat only adds a label if the key is provided""" get_annot = attrgetter(axis_name) - lhs = gen_adata((10, 10), obs_xdataset=obs_xdataset, var_xdataset=var_xdataset, **GEN_ADATA_DASK_ARGS) - rhs = gen_adata((10, 12), obs_xdataset=obs_xdataset, var_xdataset=var_xdataset, **GEN_ADATA_DASK_ARGS) + lhs = gen_adata( + (10, 10), + obs_xdataset=obs_xdataset, + var_xdataset=var_xdataset, + **GEN_ADATA_DASK_ARGS, + ) + rhs = gen_adata( + (10, 12), + obs_xdataset=obs_xdataset, + var_xdataset=var_xdataset, + **GEN_ADATA_DASK_ARGS, + ) # There is probably a prettier way to do this annot = get_annot(concat([lhs, rhs], axis=axis_name)) @@ -1221,7 +1275,9 @@ def test_batch_key(axis_name, obs_xdataset, var_xdataset): ) == ["batch"] -@pytest.mark.parametrize(("obs_xdataset", "var_xdataset"), [(False, False), (True, True)]) +@pytest.mark.parametrize( + ("obs_xdataset", "var_xdataset"), [(False, False), (True, True)] +) def test_concat_categories_from_mapping(obs_xdataset, var_xdataset): mapping = { "a": gen_adata((10, 10), obs_xdataset=obs_xdataset, var_xdataset=var_xdataset), @@ -1361,7 +1417,9 @@ def test_bool_promotion(): assert result.obs["bool"].dtype == np.dtype(bool) -@pytest.mark.parametrize(("obs_xdataset", "var_xdataset"), [(False, False), (True, True)]) +@pytest.mark.parametrize( + ("obs_xdataset", "var_xdataset"), [(False, False), (True, True)] +) def test_concat_names(axis_name, obs_xdataset, var_xdataset): get_annot = attrgetter(axis_name) @@ -1399,14 +1457,30 @@ def expected_shape( @pytest.mark.parametrize( "shape", [pytest.param((8, 0), id="no_var"), pytest.param((0, 10), id="no_obs")] ) -@pytest.mark.parametrize(("obs_xdataset", "var_xdataset"), [(False, False), (True, True)]) -def test_concat_size_0_axis(axis_name, join_type, merge_strategy, shape, obs_xdataset, var_xdataset): +@pytest.mark.parametrize( + ("obs_xdataset", "var_xdataset"), [(False, False), (True, True)] +) +def test_concat_size_0_axis( + axis_name, join_type, merge_strategy, shape, obs_xdataset, var_xdataset +): """Regression test for https://github.com/scverse/anndata/issues/526""" axis, axis_name = merge._resolve_axis(axis_name) alt_axis = 1 - axis col_dtypes = (*DEFAULT_COL_TYPES, pd.StringDtype) - a = gen_adata((5, 7), obs_dtypes=col_dtypes, var_dtypes=col_dtypes, obs_xdataset=obs_xdataset, var_xdataset=var_xdataset) - b = gen_adata(shape, obs_dtypes=col_dtypes, var_dtypes=col_dtypes, obs_xdataset=obs_xdataset, var_xdataset=var_xdataset) + a = gen_adata( + (5, 7), + obs_dtypes=col_dtypes, + var_dtypes=col_dtypes, + obs_xdataset=obs_xdataset, + var_xdataset=var_xdataset, + ) + b = gen_adata( + shape, + obs_dtypes=col_dtypes, + var_dtypes=col_dtypes, + obs_xdataset=obs_xdataset, + var_xdataset=var_xdataset, + ) expected_size = expected_shape(a, b, axis=axis, join=join_type) @@ -1465,10 +1539,22 @@ def test_concat_size_0_axis(axis_name, join_type, merge_strategy, shape, obs_xda @pytest.mark.parametrize("elem", ["sparse", "array", "df", "da"]) @pytest.mark.parametrize("axis", ["obs", "var"]) -@pytest.mark.parametrize(("obs_xdataset", "var_xdataset"), [(False, False), (True, True)]) +@pytest.mark.parametrize( + ("obs_xdataset", "var_xdataset"), [(False, False), (True, True)] +) def test_concat_outer_aligned_mapping(elem, axis, obs_xdataset, var_xdataset): - a = gen_adata((5, 5), obs_xdataset=obs_xdataset, var_xdataset=var_xdataset, **GEN_ADATA_DASK_ARGS) - b = gen_adata((3, 5), obs_xdataset=obs_xdataset, var_xdataset=var_xdataset, **GEN_ADATA_DASK_ARGS) + a = gen_adata( + (5, 5), + obs_xdataset=obs_xdataset, + var_xdataset=var_xdataset, + **GEN_ADATA_DASK_ARGS, + ) + b = gen_adata( + (3, 5), + obs_xdataset=obs_xdataset, + var_xdataset=var_xdataset, + **GEN_ADATA_DASK_ARGS, + ) del getattr(b, f"{axis}m")[elem] concated = concat({"a": a, "b": b}, join="outer", label="group", axis=axis) @@ -1494,9 +1580,14 @@ def test_concatenate_size_0_axis(): b.concatenate([a]).shape == (10, 0) -@pytest.mark.parametrize(("obs_xdataset", "var_xdataset"), [(False, False), (True, True)]) +@pytest.mark.parametrize( + ("obs_xdataset", "var_xdataset"), [(False, False), (True, True)] +) def test_concat_null_X(obs_xdataset, var_xdataset): - adatas_orig = {k: gen_adata((20, 10), obs_xdataset=obs_xdataset, var_xdataset=var_xdataset) for k in list("abc")} + adatas_orig = { + k: gen_adata((20, 10), obs_xdataset=obs_xdataset, var_xdataset=var_xdataset) + for k in list("abc") + } adatas_no_X = {} for k, v in adatas_orig.items(): v = v.copy() diff --git a/tests/test_io_conversion.py b/tests/test_io_conversion.py index 60acf748a..a1a778f62 100644 --- a/tests/test_io_conversion.py +++ b/tests/test_io_conversion.py @@ -11,7 +11,7 @@ import anndata as ad from anndata.compat import CSMatrix -from anndata.tests.helpers import assert_equal, gen_adata, GEN_ADATA_NO_XARRAY_ARGS +from anndata.tests.helpers import GEN_ADATA_NO_XARRAY_ARGS, assert_equal, gen_adata @pytest.fixture( diff --git a/tests/test_io_dispatched.py b/tests/test_io_dispatched.py index a0bdc7ed4..5a642fff3 100644 --- a/tests/test_io_dispatched.py +++ b/tests/test_io_dispatched.py @@ -10,7 +10,7 @@ from anndata._io.zarr import open_write_group from anndata.compat import CSArray, CSMatrix, ZarrGroup, is_zarr_v2 from anndata.experimental import read_dispatched, write_dispatched -from anndata.tests.helpers import assert_equal, gen_adata, GEN_ADATA_NO_XARRAY_ARGS +from anndata.tests.helpers import GEN_ADATA_NO_XARRAY_ARGS, assert_equal, gen_adata if TYPE_CHECKING: from collections.abc import Callable diff --git a/tests/test_io_elementwise.py b/tests/test_io_elementwise.py index 8be9d302b..861e37760 100644 --- a/tests/test_io_elementwise.py +++ b/tests/test_io_elementwise.py @@ -24,12 +24,12 @@ from anndata.experimental import read_elem_lazy from anndata.io import read_elem, write_elem from anndata.tests.helpers import ( + GEN_ADATA_NO_XARRAY_ARGS, as_cupy, as_cupy_sparse_dask_array, as_dense_cupy_dask_array, assert_equal, gen_adata, - GEN_ADATA_NO_XARRAY_ARGS ) if TYPE_CHECKING: @@ -124,7 +124,9 @@ def create_sparse_store( pytest.param(True, "numeric-scalar", id="py_bool"), pytest.param(1.0, "numeric-scalar", id="py_float"), pytest.param({"a": 1}, "dict", id="py_dict"), - pytest.param(gen_adata((3, 2), **GEN_ADATA_NO_XARRAY_ARGS), "anndata", id="anndata"), + pytest.param( + gen_adata((3, 2), **GEN_ADATA_NO_XARRAY_ARGS), "anndata", id="anndata" + ), pytest.param( sparse.random(5, 3, format="csr", density=0.5), "csr_matrix", diff --git a/tests/test_io_warnings.py b/tests/test_io_warnings.py index 43c39ba45..219263478 100644 --- a/tests/test_io_warnings.py +++ b/tests/test_io_warnings.py @@ -10,7 +10,7 @@ from packaging.version import Version import anndata as ad -from anndata.tests.helpers import gen_adata, GEN_ADATA_NO_XARRAY_ARGS +from anndata.tests.helpers import GEN_ADATA_NO_XARRAY_ARGS, gen_adata @pytest.mark.skipif(not find_spec("scanpy"), reason="Scanpy is not installed") diff --git a/tests/test_readwrite.py b/tests/test_readwrite.py index 5b19c6e3f..6c2ea5fc3 100644 --- a/tests/test_readwrite.py +++ b/tests/test_readwrite.py @@ -29,7 +29,12 @@ _read_attr, is_zarr_v2, ) -from anndata.tests.helpers import as_dense_dask_array, assert_equal, gen_adata, GEN_ADATA_NO_XARRAY_ARGS +from anndata.tests.helpers import ( + GEN_ADATA_NO_XARRAY_ARGS, + as_dense_dask_array, + assert_equal, + gen_adata, +) if TYPE_CHECKING: from typing import Literal diff --git a/tests/test_x.py b/tests/test_x.py index 9a6282baa..d7da59a0c 100644 --- a/tests/test_x.py +++ b/tests/test_x.py @@ -10,7 +10,7 @@ import anndata as ad from anndata import AnnData from anndata._warnings import ImplicitModificationWarning -from anndata.tests.helpers import assert_equal, gen_adata, GEN_ADATA_NO_XARRAY_ARGS +from anndata.tests.helpers import GEN_ADATA_NO_XARRAY_ARGS, assert_equal, gen_adata from anndata.utils import asarray UNLABELLED_ARRAY_TYPES = [ From 801902c9b7c7582ffb1a2565b5db367fdd1861f0 Mon Sep 17 00:00:00 2001 From: Ilia Kats Date: Fri, 11 Apr 2025 17:53:50 +0200 Subject: [PATCH 10/43] fix linter errors --- src/anndata/_core/aligned_df.py | 2 +- src/anndata/_core/file_backing.py | 2 +- src/anndata/_core/merge.py | 8 ++------ src/anndata/_core/xarray.py | 3 ++- src/anndata/experimental/backed/_lazy_arrays.py | 9 ++++----- src/anndata/tests/helpers.py | 4 ++-- 6 files changed, 12 insertions(+), 16 deletions(-) diff --git a/src/anndata/_core/aligned_df.py b/src/anndata/_core/aligned_df.py index 5fa52130a..ca1fa7912 100644 --- a/src/anndata/_core/aligned_df.py +++ b/src/anndata/_core/aligned_df.py @@ -137,7 +137,7 @@ def _gen_dataframe_xr( @_gen_dataframe.register(XDataset) def _gen_dataframe_xdataset( - anno: Dataset, + anno: XDataset, index_names: Iterable[str], *, source: Literal["X", "shape"], diff --git a/src/anndata/_core/file_backing.py b/src/anndata/_core/file_backing.py index dd2435638..0e1dbf336 100644 --- a/src/anndata/_core/file_backing.py +++ b/src/anndata/_core/file_backing.py @@ -164,7 +164,7 @@ def _(x: AwkArray, *, copy: bool = False): @to_memory.register(Dataset2D) -def _(x: Dataset2D, copy: bool = False): +def _(x: Dataset2D, *, copy: bool = False): return x.to_memory(copy=copy) diff --git a/src/anndata/_core/merge.py b/src/anndata/_core/merge.py index 821f43090..576d69efd 100644 --- a/src/anndata/_core/merge.py +++ b/src/anndata/_core/merge.py @@ -1230,7 +1230,7 @@ def get_chunk(block_info=None): chunk = np.array(data_array.data[idx].array) return chunk - if col.dtype == "category" or col.dtype == "string" or use_only_object_dtype: + if col.dtype == "category" or col.dtype == "string" or use_only_object_dtype: # noqa PLR1714 dtype = "object" else: dtype = col.dtype.numpy_dtype @@ -1342,11 +1342,7 @@ def concat_dataset2d_on_annot_axis( # prevent duplicate values index.coords[DS_CONCAT_DUMMY_INDEX_NAME] = ds.coords[DS_CONCAT_DUMMY_INDEX_NAME] ds.coords[DS_CONCAT_DUMMY_INDEX_NAME] = index - for key in set( - true_index - for a in annotations_re_indexed - if (true_index := a.true_index_dim) != a.index_dim - ): + for key in {true_index for a in annotations_re_indexed if (true_index := a.true_index_dim) != a.index_dim}: del ds[key] if DUMMY_RANGE_INDEX_KEY in ds: del ds[DUMMY_RANGE_INDEX_KEY] diff --git a/src/anndata/_core/xarray.py b/src/anndata/_core/xarray.py index c6468701c..b9fbf2e16 100644 --- a/src/anndata/_core/xarray.py +++ b/src/anndata/_core/xarray.py @@ -37,7 +37,8 @@ def true_index_dim(self) -> str: def true_index_dim(self, val: str): if val not in self.dims: if val not in self.data_vars: - raise ValueError(f"Unknown variable `{val}`.") + msg = f"Unknown variable `{val}`." + raise ValueError(msg) self.attrs["indexing_key"] = val @property diff --git a/src/anndata/experimental/backed/_lazy_arrays.py b/src/anndata/experimental/backed/_lazy_arrays.py index 7b6b3514e..942fd5f3a 100644 --- a/src/anndata/experimental/backed/_lazy_arrays.py +++ b/src/anndata/experimental/backed/_lazy_arrays.py @@ -11,7 +11,7 @@ from anndata.compat import H5Array, ZarrArray from ..._settings import settings -from ...compat import XArray, XBackendArray, XZarrArrayWrapper +from ...compat import XBackendArray, XArray, XZarrArrayWrapper from ...compat import xarray as xr if TYPE_CHECKING: @@ -31,8 +31,7 @@ class ZarrOrHDF5Wrapper(XZarrArrayWrapper, Generic[K]): def __init__(self, array: K): self.chunks = array.chunks if isinstance(array, ZarrArray): - super().__init__(array) - return + return super().__init__(array) self._array = array self.shape = self._array.shape self.dtype = self._array.dtype @@ -169,12 +168,12 @@ def dtype(self): @_subset.register(XArray) -def _subset_masked(a: DataArray, subset_idx: Index): +def _subset_masked(a: XArray, subset_idx: Index): return a[subset_idx] @as_view.register(XArray) -def _view_pd_boolean_array(a: DataArray, view_args): +def _view_pd_boolean_array(a: XArray, view_args): return a diff --git a/src/anndata/tests/helpers.py b/src/anndata/tests/helpers.py index 0246dcc2b..c07742e9f 100644 --- a/src/anndata/tests/helpers.py +++ b/src/anndata/tests/helpers.py @@ -280,8 +280,8 @@ def gen_adata( # noqa: PLR0913 ] = DEFAULT_COL_TYPES, obs_xdataset: bool = False, var_xdataset: bool = False, - obsm_types: Collection[type] = DEFAULT_KEY_TYPES + (AwkArray, XDataset), - varm_types: Collection[type] = DEFAULT_KEY_TYPES + (AwkArray, XDataset), + obsm_types: Collection[type] = (*DEFAULT_KEY_TYPES, AwkArray, XDataset), + varm_types: Collection[type] = (*DEFAULT_KEY_TYPES, AwkArray, XDataset), layers_types: Collection[type] = DEFAULT_KEY_TYPES, random_state: np.random.Generator | None = None, sparse_fmt: Literal["csr", "csc"] = "csr", From 4b7867e25e96b0b3c8fd4ee507422e945f1eba30 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 11 Apr 2025 15:54:07 +0000 Subject: [PATCH 11/43] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/anndata/_core/merge.py | 8 ++++++-- src/anndata/experimental/backed/_lazy_arrays.py | 2 +- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/src/anndata/_core/merge.py b/src/anndata/_core/merge.py index 576d69efd..ec44d9e4e 100644 --- a/src/anndata/_core/merge.py +++ b/src/anndata/_core/merge.py @@ -1230,7 +1230,7 @@ def get_chunk(block_info=None): chunk = np.array(data_array.data[idx].array) return chunk - if col.dtype == "category" or col.dtype == "string" or use_only_object_dtype: # noqa PLR1714 + if col.dtype == "category" or col.dtype == "string" or use_only_object_dtype: # noqa PLR1714 dtype = "object" else: dtype = col.dtype.numpy_dtype @@ -1342,7 +1342,11 @@ def concat_dataset2d_on_annot_axis( # prevent duplicate values index.coords[DS_CONCAT_DUMMY_INDEX_NAME] = ds.coords[DS_CONCAT_DUMMY_INDEX_NAME] ds.coords[DS_CONCAT_DUMMY_INDEX_NAME] = index - for key in {true_index for a in annotations_re_indexed if (true_index := a.true_index_dim) != a.index_dim}: + for key in { + true_index + for a in annotations_re_indexed + if (true_index := a.true_index_dim) != a.index_dim + }: del ds[key] if DUMMY_RANGE_INDEX_KEY in ds: del ds[DUMMY_RANGE_INDEX_KEY] diff --git a/src/anndata/experimental/backed/_lazy_arrays.py b/src/anndata/experimental/backed/_lazy_arrays.py index 942fd5f3a..a95257d93 100644 --- a/src/anndata/experimental/backed/_lazy_arrays.py +++ b/src/anndata/experimental/backed/_lazy_arrays.py @@ -11,7 +11,7 @@ from anndata.compat import H5Array, ZarrArray from ..._settings import settings -from ...compat import XBackendArray, XArray, XZarrArrayWrapper +from ...compat import XArray, XBackendArray, XZarrArrayWrapper from ...compat import xarray as xr if TYPE_CHECKING: From bc88f12fbce6af78aba50f053e61d00f04852c9d Mon Sep 17 00:00:00 2001 From: Ilia Kats Date: Fri, 11 Apr 2025 17:56:02 +0200 Subject: [PATCH 12/43] more linter fixes --- src/anndata/experimental/backed/_lazy_arrays.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/anndata/experimental/backed/_lazy_arrays.py b/src/anndata/experimental/backed/_lazy_arrays.py index a95257d93..8356d14f0 100644 --- a/src/anndata/experimental/backed/_lazy_arrays.py +++ b/src/anndata/experimental/backed/_lazy_arrays.py @@ -31,7 +31,8 @@ class ZarrOrHDF5Wrapper(XZarrArrayWrapper, Generic[K]): def __init__(self, array: K): self.chunks = array.chunks if isinstance(array, ZarrArray): - return super().__init__(array) + super().__init__(array) + return self._array = array self.shape = self._array.shape self.dtype = self._array.dtype From 5344e8027d75d1de58e10de0adc5a40e6744cfbe Mon Sep 17 00:00:00 2001 From: Ilia Kats Date: Fri, 11 Apr 2025 18:05:02 +0200 Subject: [PATCH 13/43] add xarray to test dependencies --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 3c76dfc63..3ea7008ec 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -100,6 +100,7 @@ test = [ "boltons", "scanpy>=1.10", "httpx", # For data downloading + "xarray>=2024.06.0", "dask[distributed]", "awkward>=2.3", "pyarrow", From 80563bb841941ead5b04cf1f7b2ed1ff1d670f93 Mon Sep 17 00:00:00 2001 From: Ilia Kats Date: Fri, 11 Apr 2025 18:12:20 +0200 Subject: [PATCH 14/43] fix docs --- docs/api.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/api.md b/docs/api.md index ba6de634d..358968b3e 100644 --- a/docs/api.md +++ b/docs/api.md @@ -180,7 +180,7 @@ Types used by the former: experimental.StorageType experimental.backed._lazy_arrays.MaskedArray experimental.backed._lazy_arrays.CategoricalArray - experimental.backed._xarray.Dataset2D + _core.xarray.Dataset2D ``` (extensions-api)= From 938f0b35dc89d72ce0b00b63d51607e9df11e740 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 11 Apr 2025 18:14:12 +0000 Subject: [PATCH 15/43] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/anndata/_core/aligned_mapping.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/anndata/_core/aligned_mapping.py b/src/anndata/_core/aligned_mapping.py index 095d29c16..e883032a1 100644 --- a/src/anndata/_core/aligned_mapping.py +++ b/src/anndata/_core/aligned_mapping.py @@ -12,7 +12,6 @@ from .._warnings import ExperimentalFeatureWarning, ImplicitModificationWarning from ..compat import AwkArray, CSArray, CSMatrix, CupyArray, XDataset -from .xarray import Dataset2D from ..utils import ( axis_len, convert_to_dict, From 01386a48fd471ef6f6dd914532186a29896a6b82 Mon Sep 17 00:00:00 2001 From: Ilia Kats Date: Mon, 14 Apr 2025 09:28:39 +0200 Subject: [PATCH 16/43] bump min. xarray version to 2024.10.0 --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 3ea7008ec..7fb5c5dbe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -100,7 +100,7 @@ test = [ "boltons", "scanpy>=1.10", "httpx", # For data downloading - "xarray>=2024.06.0", + "xarray>=2024.10.0", "dask[distributed]", "awkward>=2.3", "pyarrow", @@ -110,7 +110,7 @@ gpu = [ "cupy" ] cu12 = [ "cupy-cuda12x" ] cu11 = [ "cupy-cuda11x" ] # requests and aiohttp needed for zarr remote data -lazy = [ "xarray>=2024.06.0", "aiohttp", "requests", "anndata[dask]" ] +lazy = [ "xarray>=2024.10.0", "aiohttp", "requests", "anndata[dask]" ] # https://github.com/dask/dask/issues/11290 # https://github.com/dask/dask/issues/11752 dask = [ "dask[array]>=2023.5.1,!=2024.8.*,!=2024.9.*,<2025.2.0" ] From ff21784657f62b39bb8e774932f389207ac16ff2 Mon Sep 17 00:00:00 2001 From: Ilia Kats Date: Mon, 14 Apr 2025 09:31:30 +0200 Subject: [PATCH 17/43] bump min pandas version for tests to 2.1.0 to satisfy CI --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 7fb5c5dbe..8509094f3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -101,6 +101,7 @@ test = [ "scanpy>=1.10", "httpx", # For data downloading "xarray>=2024.10.0", + "pandas>=2.1.0", "dask[distributed]", "awkward>=2.3", "pyarrow", From 522d6799127a43d487cf22fd28f9447f844e46f8 Mon Sep 17 00:00:00 2001 From: Ilia Kats Date: Mon, 14 Apr 2025 09:55:54 +0200 Subject: [PATCH 18/43] fix min-deps.py script it can now handle the same dependency being specified multiple times with different min. versions. --- ci/scripts/min-deps.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/ci/scripts/min-deps.py b/ci/scripts/min-deps.py index b2e5fa716..40fd81736 100755 --- a/ci/scripts/min-deps.py +++ b/ci/scripts/min-deps.py @@ -55,7 +55,7 @@ def min_dep(req: Requirement) -> Requirement: elif spec.operator == "==": min_version = Version(spec.version) - return Requirement(f"{req_name}=={min_version}.*") + return Requirement(f"{req_name}~={min_version}.0") def extract_min_deps( @@ -64,6 +64,7 @@ def extract_min_deps( dependencies = deque(dependencies) # We'll be mutating this project_name = pyproject["project"]["name"] + deps = {} while len(dependencies) > 0: req = dependencies.pop() @@ -76,7 +77,10 @@ def extract_min_deps( extra_deps = pyproject["project"]["optional-dependencies"][extra] dependencies += map(Requirement, extra_deps) else: - yield min_dep(req) + if req.name in deps: + req.specifier &= deps[req.name].specifier + deps[req.name] = min_dep(req) + yield from deps.values() class Args(argparse.Namespace): From 8e5d1223ca0a8f2a96d33157edd236eaeb75698f Mon Sep 17 00:00:00 2001 From: Ilia Kats Date: Mon, 14 Apr 2025 10:22:28 +0200 Subject: [PATCH 19/43] set the true index to the column specified in the _index attr --- src/anndata/_io/specs/lazy_methods.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/anndata/_io/specs/lazy_methods.py b/src/anndata/_io/specs/lazy_methods.py index 192bea0f2..63681eeef 100644 --- a/src/anndata/_io/specs/lazy_methods.py +++ b/src/anndata/_io/specs/lazy_methods.py @@ -288,7 +288,7 @@ def read_dataframe( ds = Dataset2D(elem_xarray_dict) # We ensure the indexing_key attr always points to the true index # so that the roundtrip works even for the `use_range_index` `True` case - ds.true_index_dim = "_index" + ds.true_index_dim = elem.attrs["_index"] return ds From d00b7a04bf1674d946bba193d9a77822dea2e182 Mon Sep 17 00:00:00 2001 From: Ilia Kats Date: Mon, 14 Apr 2025 10:23:45 +0200 Subject: [PATCH 20/43] fix min-deps.py docstring --- ci/scripts/min-deps.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ci/scripts/min-deps.py b/ci/scripts/min-deps.py index 40fd81736..ae0c4e886 100755 --- a/ci/scripts/min-deps.py +++ b/ci/scripts/min-deps.py @@ -34,7 +34,7 @@ def min_dep(req: Requirement) -> Requirement: ------- >>> min_dep(Requirement("numpy>=1.0")) - + >>> min_dep(Requirement("numpy<3.0")) """ From bb41f0af24d9aeee161e64a8c9fd2fde1c44fa22 Mon Sep 17 00:00:00 2001 From: Ilia Kats Date: Mon, 14 Apr 2025 10:51:27 +0200 Subject: [PATCH 21/43] attempt to fix docs build --- docs/release-notes/0.12.0rc1.md | 2 +- src/anndata/_core/xarray.py | 5 +---- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/docs/release-notes/0.12.0rc1.md b/docs/release-notes/0.12.0rc1.md index cc9b3b831..eb79d51e8 100644 --- a/docs/release-notes/0.12.0rc1.md +++ b/docs/release-notes/0.12.0rc1.md @@ -10,7 +10,7 @@ #### Bug fixes -- Disallow writing of {class}`~anndata.experimental.backed._xarray.Dataset2D` objects {user}`ilan-gold` ({pr}`1887`) +- Disallow writing of {class}`~anndata._core.xarray.Dataset2D` objects {user}`ilan-gold` ({pr}`1887`) - Upgrade old deprecation warning to a `FutureWarning` on `BaseCompressedSparseDataset.__setitem__`, showing our intent to remove the feature in the next release. {user}`ilan-gold` ({pr}`1928`) - Don't use {func}`asyncio.run` internally for any operations {user}`ilan-gold` ({pr}`1933`) - Disallow forward slashes in keys for writing {user}`ilan-gold` ({pr}`1940`) diff --git a/src/anndata/_core/xarray.py b/src/anndata/_core/xarray.py index b9fbf2e16..f82fc2214 100644 --- a/src/anndata/_core/xarray.py +++ b/src/anndata/_core/xarray.py @@ -4,10 +4,7 @@ import pandas as pd -from ..compat import XDataset - -if TYPE_CHECKING: - from ..compat import XArray +from ..compat import XArray, XDataset class Dataset2D(XDataset): From d7f3b404ffaefe75756290e6b620b15e8267cabb Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 14 Apr 2025 08:52:48 +0000 Subject: [PATCH 22/43] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/anndata/_core/xarray.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/anndata/_core/xarray.py b/src/anndata/_core/xarray.py index f82fc2214..5caa484ca 100644 --- a/src/anndata/_core/xarray.py +++ b/src/anndata/_core/xarray.py @@ -1,7 +1,5 @@ from __future__ import annotations -from typing import TYPE_CHECKING - import pandas as pd from ..compat import XArray, XDataset From 8e501fb9d1ab896037f4a1a21a6d2b077df136d2 Mon Sep 17 00:00:00 2001 From: Ilia Kats Date: Mon, 14 Apr 2025 10:59:48 +0200 Subject: [PATCH 23/43] type_checking fixes --- src/anndata/_core/xarray.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/anndata/_core/xarray.py b/src/anndata/_core/xarray.py index 5caa484ca..b9fbf2e16 100644 --- a/src/anndata/_core/xarray.py +++ b/src/anndata/_core/xarray.py @@ -1,8 +1,13 @@ from __future__ import annotations +from typing import TYPE_CHECKING + import pandas as pd -from ..compat import XArray, XDataset +from ..compat import XDataset + +if TYPE_CHECKING: + from ..compat import XArray class Dataset2D(XDataset): From 07b63d0776e5c4a0110b6c66614b3bf14de194bf Mon Sep 17 00:00:00 2001 From: Ilia Kats Date: Mon, 14 Apr 2025 11:05:52 +0200 Subject: [PATCH 24/43] more docs fixes --- src/anndata/_io/specs/registry.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/anndata/_io/specs/registry.py b/src/anndata/_io/specs/registry.py index 85ff9bd13..95942706b 100644 --- a/src/anndata/_io/specs/registry.py +++ b/src/anndata/_io/specs/registry.py @@ -27,7 +27,7 @@ from anndata.experimental.backed._lazy_arrays import CategoricalArray, MaskedArray from anndata.typing import RWAble - from ...compat import Dataset2D + from ..._core.xarray import Dataset2D T = TypeVar("T") W = TypeVar("W", bound=_WriteInternal) From 6cbbaf82ac082a430b4e4e06ad89ef2a13ec5d4d Mon Sep 17 00:00:00 2001 From: Ilia Kats Date: Thu, 24 Apr 2025 11:33:11 +0200 Subject: [PATCH 25/43] s/XArray/XDataArray/ --- src/anndata/_core/index.py | 4 ++-- src/anndata/_core/merge.py | 8 ++++---- src/anndata/_core/xarray.py | 6 +++--- src/anndata/_io/specs/lazy_methods.py | 14 +++++++------- src/anndata/compat/__init__.py | 4 ++-- src/anndata/experimental/backed/_lazy_arrays.py | 10 +++++----- src/anndata/tests/helpers.py | 6 +++--- src/anndata/typing.py | 4 ++-- 8 files changed, 28 insertions(+), 28 deletions(-) diff --git a/src/anndata/_core/index.py b/src/anndata/_core/index.py index 12c266e24..1128bc665 100644 --- a/src/anndata/_core/index.py +++ b/src/anndata/_core/index.py @@ -10,7 +10,7 @@ import pandas as pd from scipy.sparse import issparse -from ..compat import AwkArray, CSArray, CSMatrix, DaskArray, XArray +from ..compat import AwkArray, CSArray, CSMatrix, DaskArray, XDataArray from .xarray import Dataset2D if TYPE_CHECKING: @@ -111,7 +111,7 @@ def name_idx(i): ) raise KeyError(msg) return positions # np.ndarray[int] - elif isinstance(indexer, XArray): + elif isinstance(indexer, XDataArray): if isinstance(indexer.data, DaskArray): return indexer.data.compute() return indexer.data diff --git a/src/anndata/_core/merge.py b/src/anndata/_core/merge.py index ec44d9e4e..ffd7ca1b6 100644 --- a/src/anndata/_core/merge.py +++ b/src/anndata/_core/merge.py @@ -46,7 +46,7 @@ from anndata._types import Join_T - from ..compat import XArray + from ..compat import XDataArray T = TypeVar("T") @@ -1168,7 +1168,7 @@ def concat_Xs(adatas, reindexers, axis, fill_value): def make_dask_col_from_extension_dtype( - col: XArray, *, use_only_object_dtype: bool = False + col: XDataArray, *, use_only_object_dtype: bool = False ) -> DaskArray: """ Creates dask arrays from :class:`pandas.api.extensions.ExtensionArray` dtype :class:`xarray.DataArray`s. @@ -1191,7 +1191,7 @@ def make_dask_col_from_extension_dtype( get_chunksize, maybe_open_h5, ) - from anndata.compat import XArray + from anndata.compat import XDataArray from anndata.compat import xarray as xr from anndata.experimental import read_elem_lazy @@ -1218,7 +1218,7 @@ def get_chunk(block_info=None): variable = xr.Variable( data=xr.core.indexing.LazilyIndexedArray(v), dims=dims ) - data_array = XArray( + data_array = XDataArray( variable, coords=coords, dims=dims, diff --git a/src/anndata/_core/xarray.py b/src/anndata/_core/xarray.py index b9fbf2e16..d8b05f19e 100644 --- a/src/anndata/_core/xarray.py +++ b/src/anndata/_core/xarray.py @@ -7,7 +7,7 @@ from ..compat import XDataset if TYPE_CHECKING: - from ..compat import XArray + from ..compat import XDataArray class Dataset2D(XDataset): @@ -42,7 +42,7 @@ def true_index_dim(self, val: str): self.attrs["indexing_key"] = val @property - def xr_index(self) -> XArray: + def xr_index(self) -> XDataArray: return self[self.index_dim] @property @@ -66,7 +66,7 @@ def index(self, val) -> None: del self.attrs["indexing_key"] @property - def true_xr_index(self) -> XArray: + def true_xr_index(self) -> XDataArray: return self[self.true_index_dim] @property diff --git a/src/anndata/_io/specs/lazy_methods.py b/src/anndata/_io/specs/lazy_methods.py index 63681eeef..de9cc6f8b 100644 --- a/src/anndata/_io/specs/lazy_methods.py +++ b/src/anndata/_io/specs/lazy_methods.py @@ -14,7 +14,7 @@ from anndata._core.file_backing import filename, get_elem_name from anndata._core.xarray import Dataset2D from anndata.abc import CSCDataset, CSRDataset -from anndata.compat import DaskArray, H5Array, H5Group, XArray, ZarrArray, ZarrGroup +from anndata.compat import DaskArray, H5Array, H5Group, XDataArray, ZarrArray, ZarrGroup from .registry import _LAZY_REGISTRY, IOSpec @@ -220,20 +220,20 @@ def _gen_xarray_dict_iterator_from_elems( elem_dict: dict[str, LazyDataStructures], dim_name: str, index: np.NDArray, -) -> Generator[tuple[str, XArray], None, None]: +) -> Generator[tuple[str, XDataArray], None, None]: from anndata.experimental.backed._lazy_arrays import CategoricalArray, MaskedArray - from ...compat import XArray + from ...compat import XDataArray from ...compat import xarray as xr for k, v in elem_dict.items(): if isinstance(v, DaskArray) and k != dim_name: - data_array = XArray(v, coords=[index], dims=[dim_name], name=k) + data_array = XDataArray(v, coords=[index], dims=[dim_name], name=k) elif isinstance(v, CategoricalArray | MaskedArray) and k != dim_name: variable = xr.Variable( data=xr.core.indexing.LazilyIndexedArray(v), dims=[dim_name] ) - data_array = XArray( + data_array = XDataArray( variable, coords=[index], dims=[dim_name], @@ -244,7 +244,7 @@ def _gen_xarray_dict_iterator_from_elems( }, ) elif k == dim_name: - data_array = XArray(index, coords=[index], dims=[dim_name], name=dim_name) + data_array = XDataArray(index, coords=[index], dims=[dim_name], name=dim_name) else: msg = f"Could not read {k}: {v} from into xarray Dataset2D" raise ValueError(msg) @@ -279,7 +279,7 @@ def read_dataframe( _gen_xarray_dict_iterator_from_elems(elem_dict, dim_name, index) ) if use_range_index: - elem_xarray_dict[DUMMY_RANGE_INDEX_KEY] = XArray( + elem_xarray_dict[DUMMY_RANGE_INDEX_KEY] = XDataArray( index, coords=[index], dims=[DUMMY_RANGE_INDEX_KEY], diff --git a/src/anndata/compat/__init__.py b/src/anndata/compat/__init__.py index 8e9281fcf..4ea0baaff 100644 --- a/src/anndata/compat/__init__.py +++ b/src/anndata/compat/__init__.py @@ -102,14 +102,14 @@ def __repr__(): if find_spec("xarray") or TYPE_CHECKING: import xarray - from xarray import DataArray as XArray + from xarray import DataArray as XDataArray from xarray import Dataset as XDataset from xarray.backends import BackendArray as XBackendArray from xarray.backends.zarr import ZarrArrayWrapper as XZarrArrayWrapper else: xarray = None - class XArray: + class XDataArray: def __repr__(self) -> str: return "mock DataArray" diff --git a/src/anndata/experimental/backed/_lazy_arrays.py b/src/anndata/experimental/backed/_lazy_arrays.py index 8356d14f0..80e9a0d8a 100644 --- a/src/anndata/experimental/backed/_lazy_arrays.py +++ b/src/anndata/experimental/backed/_lazy_arrays.py @@ -11,7 +11,7 @@ from anndata.compat import H5Array, ZarrArray from ..._settings import settings -from ...compat import XArray, XBackendArray, XZarrArrayWrapper +from ...compat import XDataArray, XBackendArray, XZarrArrayWrapper from ...compat import xarray as xr if TYPE_CHECKING: @@ -168,13 +168,13 @@ def dtype(self): raise RuntimeError(msg) -@_subset.register(XArray) -def _subset_masked(a: XArray, subset_idx: Index): +@_subset.register(XDataArray) +def _subset_masked(a: XDataArray, subset_idx: Index): return a[subset_idx] -@as_view.register(XArray) -def _view_pd_boolean_array(a: XArray, view_args): +@as_view.register(XDataArray) +def _view_pd_boolean_array(a: XDataArray, view_args): return a diff --git a/src/anndata/tests/helpers.py b/src/anndata/tests/helpers.py index c07742e9f..b58e97549 100644 --- a/src/anndata/tests/helpers.py +++ b/src/anndata/tests/helpers.py @@ -32,7 +32,7 @@ CupyCSRMatrix, CupySparseMatrix, DaskArray, - XArray, + XDataArray, XDataset, ZarrArray, is_zarr_v2, @@ -744,9 +744,9 @@ def assert_equal_extension_array( ) -@assert_equal.register(XArray) +@assert_equal.register(XDataArray) def assert_equal_xarray( - a: XArray, b: object, *, exact: bool = False, elem_name: str | None = None + a: XDataArray, b: object, *, exact: bool = False, elem_name: str | None = None ): report_name(a.equals)(b, _elem_name=elem_name) diff --git a/src/anndata/typing.py b/src/anndata/typing.py index fd11cc2d1..25e279248 100644 --- a/src/anndata/typing.py +++ b/src/anndata/typing.py @@ -16,7 +16,7 @@ CupySparseMatrix, DaskArray, H5Array, - XArray, + XDataArray, ZappyArray, ZarrArray, ) @@ -46,7 +46,7 @@ | CupyArray | CupySparseMatrix ) -ArrayDataStructureTypes: TypeAlias = XDataType | AwkArray | XArray +ArrayDataStructureTypes: TypeAlias = XDataType | AwkArray | XDataArray InMemoryArrayOrScalarType: TypeAlias = ( From e63a17f27cba99fc5ec5cb1479b74eca44b8d121 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 24 Apr 2025 09:33:48 +0000 Subject: [PATCH 26/43] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/anndata/_io/specs/lazy_methods.py | 4 +++- src/anndata/experimental/backed/_lazy_arrays.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/anndata/_io/specs/lazy_methods.py b/src/anndata/_io/specs/lazy_methods.py index de9cc6f8b..585bbbb60 100644 --- a/src/anndata/_io/specs/lazy_methods.py +++ b/src/anndata/_io/specs/lazy_methods.py @@ -244,7 +244,9 @@ def _gen_xarray_dict_iterator_from_elems( }, ) elif k == dim_name: - data_array = XDataArray(index, coords=[index], dims=[dim_name], name=dim_name) + data_array = XDataArray( + index, coords=[index], dims=[dim_name], name=dim_name + ) else: msg = f"Could not read {k}: {v} from into xarray Dataset2D" raise ValueError(msg) diff --git a/src/anndata/experimental/backed/_lazy_arrays.py b/src/anndata/experimental/backed/_lazy_arrays.py index 80e9a0d8a..68685ae36 100644 --- a/src/anndata/experimental/backed/_lazy_arrays.py +++ b/src/anndata/experimental/backed/_lazy_arrays.py @@ -11,7 +11,7 @@ from anndata.compat import H5Array, ZarrArray from ..._settings import settings -from ...compat import XDataArray, XBackendArray, XZarrArrayWrapper +from ...compat import XBackendArray, XDataArray, XZarrArrayWrapper from ...compat import xarray as xr if TYPE_CHECKING: From efba50287628bc38527e622b721b40c223044404 Mon Sep 17 00:00:00 2001 From: Ilia Kats Date: Thu, 24 Apr 2025 18:14:22 +0200 Subject: [PATCH 27/43] add force_lazy argument to concat this controls whether concatenation of in-memory xarray Datasets is lazy (using dask) or not --- src/anndata/_core/merge.py | 49 ++++++++++++----- src/anndata/_core/xarray.py | 11 ++++ src/anndata/_io/specs/lazy_methods.py | 5 +- tests/test_concatenate.py | 78 +++++++++++++-------------- 4 files changed, 86 insertions(+), 57 deletions(-) diff --git a/src/anndata/_core/merge.py b/src/anndata/_core/merge.py index 8dadd5d78..4d920d00c 100644 --- a/src/anndata/_core/merge.py +++ b/src/anndata/_core/merge.py @@ -304,13 +304,25 @@ def try_unifying_dtype( ordered = ordered | dtype.ordered elif not pd.isnull(dtype): return None - if len(dtypes) > 0 and not ordered: + if len(dtypes) > 0: categories = reduce( lambda x, y: x.union(y), - [dtype.categories for dtype in dtypes if not pd.isnull(dtype)], + (dtype.categories for dtype in dtypes if not pd.isnull(dtype)), ) - return pd.CategoricalDtype(natsorted(categories), ordered=False) + if not ordered: + return pd.CategoricalDtype(natsorted(categories), ordered=False) + else: # for xarray Datasets, see https://github.com/pydata/xarray/issues/10247 + categories_intersection = reduce(lambda x, y: x.intersection(y), (dtype.categories for dtype in dtypes if not pd.isnull(dtype) and len(dtype.categories) > 0)) + if len(categories_intersection) < len(categories): + return object + else: + same_orders = all(dtype.ordered for dtype in dtypes if not pd.isnull(dtype) and len(dtype.categories) > 0) + same_orders &= all(np.all(categories == dtype.categories) for dtype in dtypes if not pd.isnull(dtype) and len(dtype.categories) > 0) + if same_orders: + return next(iter(dtypes)) + else: + return object # Boolean elif all(pd.api.types.is_bool_dtype(dtype) or dtype is None for dtype in col): if any(dtype is None for dtype in col): @@ -816,7 +828,7 @@ def np_bool_to_pd_bool_array(df: pd.DataFrame): return df -def concat_arrays(arrays, reindexers, axis=0, index=None, fill_value=None): # noqa: PLR0911, PLR0912 +def concat_arrays(arrays, reindexers, axis=0, index=None, fill_value=None, force_lazy: bool = False): # noqa: PLR0911, PLR0912 from anndata.experimental.backed._compat import Dataset2D arrays = list(arrays) @@ -830,7 +842,7 @@ def concat_arrays(arrays, reindexers, axis=0, index=None, fill_value=None): # n msg = f"Cannot concatenate a Dataset2D with other array types {[type(a) for a in arrays if not isinstance(a, Dataset2D)]}." raise ValueError(msg) else: - return concat_dataset2d_on_annot_axis(arrays, join="outer") + return concat_dataset2d_on_annot_axis(arrays, join="outer", force_lazy=force_lazy) if any(isinstance(a, pd.DataFrame) for a in arrays): # TODO: This is hacky, 0 is a sentinel for outer_concat_aligned_mapping if not all( @@ -920,7 +932,7 @@ def concat_arrays(arrays, reindexers, axis=0, index=None, fill_value=None): # n def inner_concat_aligned_mapping( - mappings, *, reindexers=None, index=None, axis=0, concat_axis=None + mappings, *, reindexers=None, index=None, axis=0, concat_axis=None, force_lazy: bool = False ): if concat_axis is None: concat_axis = axis @@ -935,7 +947,7 @@ def inner_concat_aligned_mapping( else: cur_reindexers = reindexers - result[k] = concat_arrays(els, cur_reindexers, index=index, axis=concat_axis) + result[k] = concat_arrays(els, cur_reindexers, index=index, axis=concat_axis, force_lazy=force_lazy) return result @@ -1031,7 +1043,7 @@ def missing_element( def outer_concat_aligned_mapping( - mappings, *, reindexers=None, index=None, axis=0, concat_axis=None, fill_value=None + mappings, *, reindexers=None, index=None, axis=0, concat_axis=None, fill_value=None, force_lazy:bool=False ): if concat_axis is None: concat_axis = axis @@ -1073,6 +1085,7 @@ def outer_concat_aligned_mapping( axis=concat_axis, index=index, fill_value=fill_value, + force_lazy=force_lazy ) return result @@ -1243,8 +1256,8 @@ def get_chunk(block_info=None): meta=np.array([], dtype=dtype), dtype=dtype, ) - else: # in-memory - return da.from_array(col.values, chunks=-1) + + return da.from_array(col.values, chunks=-1) # in-memory def make_xarray_extension_dtypes_dask( @@ -1289,6 +1302,7 @@ def make_xarray_extension_dtypes_dask( def concat_dataset2d_on_annot_axis( annotations: Iterable[Dataset2D], join: Join_T, + force_lazy: bool ) -> Dataset2D: """Create a concatenate dataset from a list of :class:`~anndata.experimental.backed._xarray.Dataset2D` objects. The goal of this function is to mimic `pd.concat(..., ignore_index=True)` so has some complicated logic @@ -1310,7 +1324,12 @@ def concat_dataset2d_on_annot_axis( from anndata.compat import xarray as xr annotations_re_indexed = [] - for a in make_xarray_extension_dtypes_dask(annotations): + have_backed = any(a.is_backed for a in annotations) + if have_backed or force_lazy: + annotations = make_xarray_extension_dtypes_dask(annotations) + else: + annotations = unify_dtypes(annotations) + for a in annotations: old_key = a.index_dim is_fake_index = old_key != a.true_index_dim # First create a dummy index @@ -1331,6 +1350,7 @@ def concat_dataset2d_on_annot_axis( ds = Dataset2D( xr.concat(annotations_re_indexed, join=join, dim=DS_CONCAT_DUMMY_INDEX_NAME), ) + ds.is_backed = have_backed ds.coords[DS_CONCAT_DUMMY_INDEX_NAME] = pd.RangeIndex( ds.coords[DS_CONCAT_DUMMY_INDEX_NAME].shape[0] ) @@ -1368,6 +1388,7 @@ def concat( # noqa: PLR0912, PLR0913, PLR0915 index_unique: str | None = None, fill_value: Any | None = None, pairwise: bool = False, + force_lazy: bool = False, ) -> AnnData: """Concatenates AnnData objects along an axis. @@ -1416,6 +1437,9 @@ def concat( # noqa: PLR0912, PLR0913, PLR0915 pairwise Whether pairwise elements along the concatenated dimension should be included. This is False by default, since the resulting arrays are often not meaningful. + force_lazy + Whether to lazily concatenate elements using dask even when eager concatenation is possible. + At the moment, this only affects obs/var and elements of obsm/varm that are xarray Datasets. Notes ----- @@ -1624,7 +1648,7 @@ def concat( # noqa: PLR0912, PLR0913, PLR0915 ) concat_annot.index = concat_indices else: - concat_annot = concat_dataset2d_on_annot_axis(annotations, join) + concat_annot = concat_dataset2d_on_annot_axis(annotations, join, force_lazy) concat_indices.name = DS_CONCAT_DUMMY_INDEX_NAME concat_annot.index = concat_indices if label is not None: @@ -1684,6 +1708,7 @@ def concat( # noqa: PLR0912, PLR0913, PLR0915 axis=axis, concat_axis=0, index=concat_indices, + force_lazy=force_lazy ) if pairwise: concat_pairwise = concat_pairwise_mapping( diff --git a/src/anndata/_core/xarray.py b/src/anndata/_core/xarray.py index d8b05f19e..a5ee566c7 100644 --- a/src/anndata/_core/xarray.py +++ b/src/anndata/_core/xarray.py @@ -21,6 +21,17 @@ class Dataset2D(XDataset): __slots__ = () + @property + def is_backed(self) -> bool: + return self.attrs.get("is_backed", False) + + @is_backed.setter + def is_backed(self, isbacked: bool): + if not isbacked and "is_backed" in self.attrs: + del self.attrs["is_backed"] + else: + self.attrs["is_backed"] = isbacked + @property def index_dim(self) -> str: if len(self.sizes) != 1: diff --git a/src/anndata/_io/specs/lazy_methods.py b/src/anndata/_io/specs/lazy_methods.py index e14744924..ee1f74118 100644 --- a/src/anndata/_io/specs/lazy_methods.py +++ b/src/anndata/_io/specs/lazy_methods.py @@ -244,9 +244,7 @@ def _gen_xarray_dict_iterator_from_elems( }, ) elif k == dim_name: - data_array = XDataArray( - index, coords=[index], dims=[dim_name], name=dim_name - ) + data_array = XDataArray(index, coords=[index], dims=[dim_name], name=dim_name) else: msg = f"Could not read {k}: {v} from into xarray Dataset2D" raise ValueError(msg) @@ -288,6 +286,7 @@ def read_dataframe( name=DUMMY_RANGE_INDEX_KEY, ) ds = Dataset2D(elem_xarray_dict) + ds.is_backed = True # We ensure the indexing_key attr always points to the true index # so that the roundtrip works even for the `use_range_index` `True` case ds.true_index_dim = elem.attrs["_index"] diff --git a/tests/test_concatenate.py b/tests/test_concatenate.py index f488e24ae..a87c4216b 100644 --- a/tests/test_concatenate.py +++ b/tests/test_concatenate.py @@ -213,11 +213,13 @@ def test_concat_interface_errors(obs_xdataset, var_xdataset): ], ) @pytest.mark.parametrize( - ("obs_xdataset", "var_xdataset"), [(False, False), (True, True)] + ("obs_xdataset", "var_xdataset", "force_lazy"), [(False, False, False), (True, True, False), (True, True, True)] ) def test_concatenate_roundtrip( - join_type, array_type, concat_func, backwards_compat, obs_xdataset, var_xdataset + join_type, array_type, concat_func, backwards_compat, obs_xdataset, var_xdataset, force_lazy ): + if backwards_compat and force_lazy: + pytest.skip("unsupported") adata = gen_adata( (100, 10), X_type=array_type, @@ -1207,10 +1209,10 @@ def test_concatenate_uns(unss, merge_strategy, result, value_gen): @pytest.mark.parametrize( - ("obs_xdataset", "var_xdataset"), [(False, False), (True, True)] + ("obs_xdataset", "var_xdataset", "force_lazy"), [(False, False, False), (True, True, False), (True, True, True)] ) def test_transposed_concat( - array_type, axis_name, join_type, merge_strategy, obs_xdataset, var_xdataset + array_type, axis_name, join_type, merge_strategy, obs_xdataset, var_xdataset, force_lazy ): axis, axis_name = merge._resolve_axis(axis_name) alt_axis = 1 - axis @@ -1229,16 +1231,16 @@ def test_transposed_concat( **GEN_ADATA_DASK_ARGS, ) - a = concat([lhs, rhs], axis=axis, join=join_type, merge=merge_strategy) - b = concat([lhs.T, rhs.T], axis=alt_axis, join=join_type, merge=merge_strategy).T + a = concat([lhs, rhs], axis=axis, join=join_type, merge=merge_strategy, force_lazy=force_lazy) + b = concat([lhs.T, rhs.T], axis=alt_axis, join=join_type, merge=merge_strategy, force_lazy=force_lazy).T assert_equal(a, b) @pytest.mark.parametrize( - ("obs_xdataset", "var_xdataset"), [(False, False), (True, True)] + ("obs_xdataset", "var_xdataset", "force_lazy"), [(False, False, False), (True, True, False), (True, True, True)] ) -def test_batch_key(axis_name, obs_xdataset, var_xdataset): +def test_batch_key(axis_name, obs_xdataset, var_xdataset, force_lazy): """Test that concat only adds a label if the key is provided""" get_annot = attrgetter(axis_name) @@ -1257,7 +1259,7 @@ def test_batch_key(axis_name, obs_xdataset, var_xdataset): ) # There is probably a prettier way to do this - annot = get_annot(concat([lhs, rhs], axis=axis_name)) + annot = get_annot(concat([lhs, rhs], axis=axis_name, force_lazy=force_lazy)) assert ( list( annot.columns.difference( @@ -1267,7 +1269,7 @@ def test_batch_key(axis_name, obs_xdataset, var_xdataset): == [] ) - batch_annot = get_annot(concat([lhs, rhs], axis=axis_name, label="batch")) + batch_annot = get_annot(concat([lhs, rhs], axis=axis_name, label="batch", force_lazy=force_lazy)) assert list( batch_annot.columns.difference( get_annot(lhs).columns.union(get_annot(rhs).columns) @@ -1276,9 +1278,9 @@ def test_batch_key(axis_name, obs_xdataset, var_xdataset): @pytest.mark.parametrize( - ("obs_xdataset", "var_xdataset"), [(False, False), (True, True)] + ("obs_xdataset", "var_xdataset", "force_lazy"), [(False, False, False), (True, True, False), (True, True, True)] ) -def test_concat_categories_from_mapping(obs_xdataset, var_xdataset): +def test_concat_categories_from_mapping(obs_xdataset, var_xdataset, force_lazy): mapping = { "a": gen_adata((10, 10), obs_xdataset=obs_xdataset, var_xdataset=var_xdataset), "b": gen_adata((10, 10), obs_xdataset=obs_xdataset, var_xdataset=var_xdataset), @@ -1286,8 +1288,8 @@ def test_concat_categories_from_mapping(obs_xdataset, var_xdataset): keys = list(mapping.keys()) adatas = list(mapping.values()) - mapping_call = partial(concat, mapping) - iter_call = partial(concat, adatas, keys=keys) + mapping_call = partial(concat, mapping, force_lazy=force_lazy) + iter_call = partial(concat, adatas, keys=keys, force_lazy=force_lazy) assert_equal(mapping_call(), iter_call()) assert_equal(mapping_call(label="batch"), iter_call(label="batch")) @@ -1418,17 +1420,17 @@ def test_bool_promotion(): @pytest.mark.parametrize( - ("obs_xdataset", "var_xdataset"), [(False, False), (True, True)] + ("obs_xdataset", "var_xdataset", "force_lazy"), [(False, False, False), (True, True, False), (True, True, True)] ) -def test_concat_names(axis_name, obs_xdataset, var_xdataset): +def test_concat_names(axis_name, obs_xdataset, var_xdataset, force_lazy): get_annot = attrgetter(axis_name) lhs = gen_adata((10, 10), obs_xdataset=obs_xdataset, var_xdataset=var_xdataset) rhs = gen_adata((10, 10), obs_xdataset=obs_xdataset, var_xdataset=var_xdataset) - assert not get_annot(concat([lhs, rhs], axis=axis_name)).index.is_unique + assert not get_annot(concat([lhs, rhs], axis=axis_name, force_lazy=force_lazy)).index.is_unique assert get_annot( - concat([lhs, rhs], axis=axis_name, index_unique="-") + concat([lhs, rhs], axis=axis_name, index_unique="-", force_lazy=force_lazy) ).index.is_unique @@ -1458,10 +1460,10 @@ def expected_shape( "shape", [pytest.param((8, 0), id="no_var"), pytest.param((0, 10), id="no_obs")] ) @pytest.mark.parametrize( - ("obs_xdataset", "var_xdataset"), [(False, False), (True, True)] + ("obs_xdataset", "var_xdataset", "force_lazy"), [(False, False, False), (True, True, False), (True, True, True)] ) def test_concat_size_0_axis( - axis_name, join_type, merge_strategy, shape, obs_xdataset, var_xdataset + axis_name, join_type, merge_strategy, shape, obs_xdataset, var_xdataset, force_lazy ): """Regression test for https://github.com/scverse/anndata/issues/526""" axis, axis_name = merge._resolve_axis(axis_name) @@ -1484,23 +1486,15 @@ def test_concat_size_0_axis( expected_size = expected_shape(a, b, axis=axis, join=join_type) - ctx_concat_empty = ( - pytest.warns( - FutureWarning, - match=r"The behavior of DataFrame concatenation with empty or all-NA entries is deprecated", - ) - if shape[axis] == 0 and Version(pd.__version__) >= Version("2.1") - else nullcontext() - ) - with ctx_concat_empty: - result = concat( - {"a": a, "b": b}, - axis=axis, - join=join_type, - merge=merge_strategy, - pairwise=True, - index_unique="-", - ) + result = concat( + {"a": a, "b": b}, + axis=axis, + join=join_type, + merge=merge_strategy, + pairwise=True, + index_unique="-", + force_lazy=force_lazy + ) assert result.shape == expected_size if join_type == "outer": @@ -1540,9 +1534,9 @@ def test_concat_size_0_axis( @pytest.mark.parametrize("elem", ["sparse", "array", "df", "da"]) @pytest.mark.parametrize("axis", ["obs", "var"]) @pytest.mark.parametrize( - ("obs_xdataset", "var_xdataset"), [(False, False), (True, True)] + ("obs_xdataset", "var_xdataset", "force_lazy"), [(False, False, False), (True, True, False), (True, True, True)] ) -def test_concat_outer_aligned_mapping(elem, axis, obs_xdataset, var_xdataset): +def test_concat_outer_aligned_mapping(elem, axis, obs_xdataset, var_xdataset, force_lazy): a = gen_adata( (5, 5), obs_xdataset=obs_xdataset, @@ -1557,7 +1551,7 @@ def test_concat_outer_aligned_mapping(elem, axis, obs_xdataset, var_xdataset): ) del getattr(b, f"{axis}m")[elem] - concated = concat({"a": a, "b": b}, join="outer", label="group", axis=axis) + concated = concat({"a": a, "b": b}, join="outer", label="group", axis=axis, force_lazy=force_lazy) mask = getattr(concated, axis)["group"] == "b" result = getattr( @@ -1581,9 +1575,9 @@ def test_concatenate_size_0_axis(): @pytest.mark.parametrize( - ("obs_xdataset", "var_xdataset"), [(False, False), (True, True)] + ("obs_xdataset", "var_xdataset", "force_lazy"), [(False, False, False), (True, True, False), (True, True, True)] ) -def test_concat_null_X(obs_xdataset, var_xdataset): +def test_concat_null_X(obs_xdataset, var_xdataset, force_lazy): adatas_orig = { k: gen_adata((20, 10), obs_xdataset=obs_xdataset, var_xdataset=var_xdataset) for k in list("abc") From 0f506f2173684fe785ea1aa8d90223205809ed62 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 24 Apr 2025 16:15:49 +0000 Subject: [PATCH 28/43] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/anndata/_core/merge.py | 62 ++++++++++++++++------ src/anndata/_io/specs/lazy_methods.py | 4 +- tests/test_concatenate.py | 75 ++++++++++++++++++++------- 3 files changed, 107 insertions(+), 34 deletions(-) diff --git a/src/anndata/_core/merge.py b/src/anndata/_core/merge.py index 4d920d00c..8eb9a0f50 100644 --- a/src/anndata/_core/merge.py +++ b/src/anndata/_core/merge.py @@ -312,13 +312,28 @@ def try_unifying_dtype( if not ordered: return pd.CategoricalDtype(natsorted(categories), ordered=False) - else: # for xarray Datasets, see https://github.com/pydata/xarray/issues/10247 - categories_intersection = reduce(lambda x, y: x.intersection(y), (dtype.categories for dtype in dtypes if not pd.isnull(dtype) and len(dtype.categories) > 0)) + else: # for xarray Datasets, see https://github.com/pydata/xarray/issues/10247 + categories_intersection = reduce( + lambda x, y: x.intersection(y), + ( + dtype.categories + for dtype in dtypes + if not pd.isnull(dtype) and len(dtype.categories) > 0 + ), + ) if len(categories_intersection) < len(categories): return object else: - same_orders = all(dtype.ordered for dtype in dtypes if not pd.isnull(dtype) and len(dtype.categories) > 0) - same_orders &= all(np.all(categories == dtype.categories) for dtype in dtypes if not pd.isnull(dtype) and len(dtype.categories) > 0) + same_orders = all( + dtype.ordered + for dtype in dtypes + if not pd.isnull(dtype) and len(dtype.categories) > 0 + ) + same_orders &= all( + np.all(categories == dtype.categories) + for dtype in dtypes + if not pd.isnull(dtype) and len(dtype.categories) > 0 + ) if same_orders: return next(iter(dtypes)) else: @@ -828,7 +843,9 @@ def np_bool_to_pd_bool_array(df: pd.DataFrame): return df -def concat_arrays(arrays, reindexers, axis=0, index=None, fill_value=None, force_lazy: bool = False): # noqa: PLR0911, PLR0912 +def concat_arrays( + arrays, reindexers, axis=0, index=None, fill_value=None, force_lazy: bool = False +): # noqa: PLR0911, PLR0912 from anndata.experimental.backed._compat import Dataset2D arrays = list(arrays) @@ -842,7 +859,9 @@ def concat_arrays(arrays, reindexers, axis=0, index=None, fill_value=None, force msg = f"Cannot concatenate a Dataset2D with other array types {[type(a) for a in arrays if not isinstance(a, Dataset2D)]}." raise ValueError(msg) else: - return concat_dataset2d_on_annot_axis(arrays, join="outer", force_lazy=force_lazy) + return concat_dataset2d_on_annot_axis( + arrays, join="outer", force_lazy=force_lazy + ) if any(isinstance(a, pd.DataFrame) for a in arrays): # TODO: This is hacky, 0 is a sentinel for outer_concat_aligned_mapping if not all( @@ -932,7 +951,13 @@ def concat_arrays(arrays, reindexers, axis=0, index=None, fill_value=None, force def inner_concat_aligned_mapping( - mappings, *, reindexers=None, index=None, axis=0, concat_axis=None, force_lazy: bool = False + mappings, + *, + reindexers=None, + index=None, + axis=0, + concat_axis=None, + force_lazy: bool = False, ): if concat_axis is None: concat_axis = axis @@ -947,7 +972,9 @@ def inner_concat_aligned_mapping( else: cur_reindexers = reindexers - result[k] = concat_arrays(els, cur_reindexers, index=index, axis=concat_axis, force_lazy=force_lazy) + result[k] = concat_arrays( + els, cur_reindexers, index=index, axis=concat_axis, force_lazy=force_lazy + ) return result @@ -1043,7 +1070,14 @@ def missing_element( def outer_concat_aligned_mapping( - mappings, *, reindexers=None, index=None, axis=0, concat_axis=None, fill_value=None, force_lazy:bool=False + mappings, + *, + reindexers=None, + index=None, + axis=0, + concat_axis=None, + fill_value=None, + force_lazy: bool = False, ): if concat_axis is None: concat_axis = axis @@ -1085,7 +1119,7 @@ def outer_concat_aligned_mapping( axis=concat_axis, index=index, fill_value=fill_value, - force_lazy=force_lazy + force_lazy=force_lazy, ) return result @@ -1257,7 +1291,7 @@ def get_chunk(block_info=None): dtype=dtype, ) - return da.from_array(col.values, chunks=-1) # in-memory + return da.from_array(col.values, chunks=-1) # in-memory def make_xarray_extension_dtypes_dask( @@ -1300,9 +1334,7 @@ def make_xarray_extension_dtypes_dask( def concat_dataset2d_on_annot_axis( - annotations: Iterable[Dataset2D], - join: Join_T, - force_lazy: bool + annotations: Iterable[Dataset2D], join: Join_T, force_lazy: bool ) -> Dataset2D: """Create a concatenate dataset from a list of :class:`~anndata.experimental.backed._xarray.Dataset2D` objects. The goal of this function is to mimic `pd.concat(..., ignore_index=True)` so has some complicated logic @@ -1708,7 +1740,7 @@ def concat( # noqa: PLR0912, PLR0913, PLR0915 axis=axis, concat_axis=0, index=concat_indices, - force_lazy=force_lazy + force_lazy=force_lazy, ) if pairwise: concat_pairwise = concat_pairwise_mapping( diff --git a/src/anndata/_io/specs/lazy_methods.py b/src/anndata/_io/specs/lazy_methods.py index ee1f74118..681716daf 100644 --- a/src/anndata/_io/specs/lazy_methods.py +++ b/src/anndata/_io/specs/lazy_methods.py @@ -244,7 +244,9 @@ def _gen_xarray_dict_iterator_from_elems( }, ) elif k == dim_name: - data_array = XDataArray(index, coords=[index], dims=[dim_name], name=dim_name) + data_array = XDataArray( + index, coords=[index], dims=[dim_name], name=dim_name + ) else: msg = f"Could not read {k}: {v} from into xarray Dataset2D" raise ValueError(msg) diff --git a/tests/test_concatenate.py b/tests/test_concatenate.py index a87c4216b..daf8fe06d 100644 --- a/tests/test_concatenate.py +++ b/tests/test_concatenate.py @@ -2,7 +2,6 @@ import warnings from collections.abc import Hashable -from contextlib import nullcontext from copy import deepcopy from functools import partial, singledispatch from itertools import chain, permutations, product @@ -213,10 +212,17 @@ def test_concat_interface_errors(obs_xdataset, var_xdataset): ], ) @pytest.mark.parametrize( - ("obs_xdataset", "var_xdataset", "force_lazy"), [(False, False, False), (True, True, False), (True, True, True)] + ("obs_xdataset", "var_xdataset", "force_lazy"), + [(False, False, False), (True, True, False), (True, True, True)], ) def test_concatenate_roundtrip( - join_type, array_type, concat_func, backwards_compat, obs_xdataset, var_xdataset, force_lazy + join_type, + array_type, + concat_func, + backwards_compat, + obs_xdataset, + var_xdataset, + force_lazy, ): if backwards_compat and force_lazy: pytest.skip("unsupported") @@ -1209,10 +1215,17 @@ def test_concatenate_uns(unss, merge_strategy, result, value_gen): @pytest.mark.parametrize( - ("obs_xdataset", "var_xdataset", "force_lazy"), [(False, False, False), (True, True, False), (True, True, True)] + ("obs_xdataset", "var_xdataset", "force_lazy"), + [(False, False, False), (True, True, False), (True, True, True)], ) def test_transposed_concat( - array_type, axis_name, join_type, merge_strategy, obs_xdataset, var_xdataset, force_lazy + array_type, + axis_name, + join_type, + merge_strategy, + obs_xdataset, + var_xdataset, + force_lazy, ): axis, axis_name = merge._resolve_axis(axis_name) alt_axis = 1 - axis @@ -1231,14 +1244,27 @@ def test_transposed_concat( **GEN_ADATA_DASK_ARGS, ) - a = concat([lhs, rhs], axis=axis, join=join_type, merge=merge_strategy, force_lazy=force_lazy) - b = concat([lhs.T, rhs.T], axis=alt_axis, join=join_type, merge=merge_strategy, force_lazy=force_lazy).T + a = concat( + [lhs, rhs], + axis=axis, + join=join_type, + merge=merge_strategy, + force_lazy=force_lazy, + ) + b = concat( + [lhs.T, rhs.T], + axis=alt_axis, + join=join_type, + merge=merge_strategy, + force_lazy=force_lazy, + ).T assert_equal(a, b) @pytest.mark.parametrize( - ("obs_xdataset", "var_xdataset", "force_lazy"), [(False, False, False), (True, True, False), (True, True, True)] + ("obs_xdataset", "var_xdataset", "force_lazy"), + [(False, False, False), (True, True, False), (True, True, True)], ) def test_batch_key(axis_name, obs_xdataset, var_xdataset, force_lazy): """Test that concat only adds a label if the key is provided""" @@ -1269,7 +1295,9 @@ def test_batch_key(axis_name, obs_xdataset, var_xdataset, force_lazy): == [] ) - batch_annot = get_annot(concat([lhs, rhs], axis=axis_name, label="batch", force_lazy=force_lazy)) + batch_annot = get_annot( + concat([lhs, rhs], axis=axis_name, label="batch", force_lazy=force_lazy) + ) assert list( batch_annot.columns.difference( get_annot(lhs).columns.union(get_annot(rhs).columns) @@ -1278,7 +1306,8 @@ def test_batch_key(axis_name, obs_xdataset, var_xdataset, force_lazy): @pytest.mark.parametrize( - ("obs_xdataset", "var_xdataset", "force_lazy"), [(False, False, False), (True, True, False), (True, True, True)] + ("obs_xdataset", "var_xdataset", "force_lazy"), + [(False, False, False), (True, True, False), (True, True, True)], ) def test_concat_categories_from_mapping(obs_xdataset, var_xdataset, force_lazy): mapping = { @@ -1420,7 +1449,8 @@ def test_bool_promotion(): @pytest.mark.parametrize( - ("obs_xdataset", "var_xdataset", "force_lazy"), [(False, False, False), (True, True, False), (True, True, True)] + ("obs_xdataset", "var_xdataset", "force_lazy"), + [(False, False, False), (True, True, False), (True, True, True)], ) def test_concat_names(axis_name, obs_xdataset, var_xdataset, force_lazy): get_annot = attrgetter(axis_name) @@ -1428,7 +1458,9 @@ def test_concat_names(axis_name, obs_xdataset, var_xdataset, force_lazy): lhs = gen_adata((10, 10), obs_xdataset=obs_xdataset, var_xdataset=var_xdataset) rhs = gen_adata((10, 10), obs_xdataset=obs_xdataset, var_xdataset=var_xdataset) - assert not get_annot(concat([lhs, rhs], axis=axis_name, force_lazy=force_lazy)).index.is_unique + assert not get_annot( + concat([lhs, rhs], axis=axis_name, force_lazy=force_lazy) + ).index.is_unique assert get_annot( concat([lhs, rhs], axis=axis_name, index_unique="-", force_lazy=force_lazy) ).index.is_unique @@ -1460,7 +1492,8 @@ def expected_shape( "shape", [pytest.param((8, 0), id="no_var"), pytest.param((0, 10), id="no_obs")] ) @pytest.mark.parametrize( - ("obs_xdataset", "var_xdataset", "force_lazy"), [(False, False, False), (True, True, False), (True, True, True)] + ("obs_xdataset", "var_xdataset", "force_lazy"), + [(False, False, False), (True, True, False), (True, True, True)], ) def test_concat_size_0_axis( axis_name, join_type, merge_strategy, shape, obs_xdataset, var_xdataset, force_lazy @@ -1493,7 +1526,7 @@ def test_concat_size_0_axis( merge=merge_strategy, pairwise=True, index_unique="-", - force_lazy=force_lazy + force_lazy=force_lazy, ) assert result.shape == expected_size @@ -1534,9 +1567,12 @@ def test_concat_size_0_axis( @pytest.mark.parametrize("elem", ["sparse", "array", "df", "da"]) @pytest.mark.parametrize("axis", ["obs", "var"]) @pytest.mark.parametrize( - ("obs_xdataset", "var_xdataset", "force_lazy"), [(False, False, False), (True, True, False), (True, True, True)] + ("obs_xdataset", "var_xdataset", "force_lazy"), + [(False, False, False), (True, True, False), (True, True, True)], ) -def test_concat_outer_aligned_mapping(elem, axis, obs_xdataset, var_xdataset, force_lazy): +def test_concat_outer_aligned_mapping( + elem, axis, obs_xdataset, var_xdataset, force_lazy +): a = gen_adata( (5, 5), obs_xdataset=obs_xdataset, @@ -1551,7 +1587,9 @@ def test_concat_outer_aligned_mapping(elem, axis, obs_xdataset, var_xdataset, fo ) del getattr(b, f"{axis}m")[elem] - concated = concat({"a": a, "b": b}, join="outer", label="group", axis=axis, force_lazy=force_lazy) + concated = concat( + {"a": a, "b": b}, join="outer", label="group", axis=axis, force_lazy=force_lazy + ) mask = getattr(concated, axis)["group"] == "b" result = getattr( @@ -1575,7 +1613,8 @@ def test_concatenate_size_0_axis(): @pytest.mark.parametrize( - ("obs_xdataset", "var_xdataset", "force_lazy"), [(False, False, False), (True, True, False), (True, True, True)] + ("obs_xdataset", "var_xdataset", "force_lazy"), + [(False, False, False), (True, True, False), (True, True, True)], ) def test_concat_null_X(obs_xdataset, var_xdataset, force_lazy): adatas_orig = { From 6998f2310f6cadc954cc9595769ea85add26f2db Mon Sep 17 00:00:00 2001 From: Ilia Kats Date: Fri, 25 Apr 2025 10:47:09 +0200 Subject: [PATCH 29/43] fix linter warnings --- src/anndata/_core/merge.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/anndata/_core/merge.py b/src/anndata/_core/merge.py index 8eb9a0f50..2cd35b152 100644 --- a/src/anndata/_core/merge.py +++ b/src/anndata/_core/merge.py @@ -279,7 +279,7 @@ def unify_dtypes( return dfs -def try_unifying_dtype( +def try_unifying_dtype( # noqa PLR0911, PLR0912 col: Sequence[np.dtype | ExtensionDtype], ) -> pd.core.dtypes.base.ExtensionDtype | None: """ @@ -843,9 +843,9 @@ def np_bool_to_pd_bool_array(df: pd.DataFrame): return df -def concat_arrays( - arrays, reindexers, axis=0, index=None, fill_value=None, force_lazy: bool = False -): # noqa: PLR0911, PLR0912 +def concat_arrays( # noqa: PLR0911, PLR0912 + arrays, reindexers, axis=0, index=None, fill_value=None, *, force_lazy: bool = False +): from anndata.experimental.backed._compat import Dataset2D arrays = list(arrays) @@ -1334,7 +1334,7 @@ def make_xarray_extension_dtypes_dask( def concat_dataset2d_on_annot_axis( - annotations: Iterable[Dataset2D], join: Join_T, force_lazy: bool + annotations: Iterable[Dataset2D], join: Join_T, *, force_lazy: bool ) -> Dataset2D: """Create a concatenate dataset from a list of :class:`~anndata.experimental.backed._xarray.Dataset2D` objects. The goal of this function is to mimic `pd.concat(..., ignore_index=True)` so has some complicated logic @@ -1680,7 +1680,9 @@ def concat( # noqa: PLR0912, PLR0913, PLR0915 ) concat_annot.index = concat_indices else: - concat_annot = concat_dataset2d_on_annot_axis(annotations, join, force_lazy) + concat_annot = concat_dataset2d_on_annot_axis( + annotations, join, force_lazy=force_lazy + ) concat_indices.name = DS_CONCAT_DUMMY_INDEX_NAME concat_annot.index = concat_indices if label is not None: From d6b8f7f0384b9822b7b0208ab205338889efb212 Mon Sep 17 00:00:00 2001 From: Ilia Kats Date: Fri, 25 Apr 2025 13:43:14 +0200 Subject: [PATCH 30/43] add tests for Dataset2D --- src/anndata/_core/xarray.py | 6 ++- tests/test_xarray.py | 91 +++++++++++++++++++++++++++++++++++++ 2 files changed, 95 insertions(+), 2 deletions(-) create mode 100644 tests/test_xarray.py diff --git a/src/anndata/_core/xarray.py b/src/anndata/_core/xarray.py index a5ee566c7..52fd33081 100644 --- a/src/anndata/_core/xarray.py +++ b/src/anndata/_core/xarray.py @@ -46,7 +46,9 @@ def true_index_dim(self) -> str: @true_index_dim.setter def true_index_dim(self, val: str): - if val not in self.dims: + if val is None or (val == self.index_dim and "indexing_key" in self.attrs): + del self.attrs["indexing_key"] + elif val not in self.dims: if val not in self.data_vars: msg = f"Unknown variable `{val}`." raise ValueError(msg) @@ -115,7 +117,7 @@ def __getitem__(self, idx): def __getitem__(self, idx) -> Dataset2D: ret = super().__getitem__(idx) - if idx == []: # empty XDataset + if len(idx) == 0 and not isinstance(idx, tuple): # empty XDataset ret.coords[self.index_dim] = self.xr_index return ret diff --git a/tests/test_xarray.py b/tests/test_xarray.py new file mode 100644 index 000000000..51e44ca07 --- /dev/null +++ b/tests/test_xarray.py @@ -0,0 +1,91 @@ +from __future__ import annotations + +import string + +import numpy as np +import pandas as pd +import pytest + +from anndata._core.xarray import Dataset2D +from anndata.tests.helpers import gen_typed_df + + +@pytest.fixture +def df(): + return gen_typed_df(10) + + +@pytest.fixture +def dataset2d(df): + return Dataset2D.from_dataframe(df) + + +def test_shape(df, dataset2d): + assert dataset2d.shape == df.shape + + +def test_columns(df, dataset2d): + assert np.all(dataset2d.columns.sort_values() == df.columns.sort_values()) + + +def test_to_memory(df, dataset2d): + memory_df = dataset2d.to_memory() + assert np.all(df == memory_df) + assert np.all(df.index == memory_df.index) + assert np.all(df.columns.sort_values() == memory_df.columns.sort_values()) + + +def test_getitem(df, dataset2d): + col = df.columns[0] + assert np.all(dataset2d[col] == df[col]) + + empty_dset = dataset2d[[]] + assert empty_dset.shape == (df.shape[0], 0) + assert np.all(empty_dset.index == dataset2d.index) + + +def test_backed_property(dataset2d): + assert not dataset2d.is_backed + + dataset2d.is_backed = True + assert dataset2d.is_backed + + dataset2d.is_backed = False + assert not dataset2d.is_backed + + +def test_index_dim(dataset2d): + assert dataset2d.index_dim == "index" + assert dataset2d.true_index_dim == dataset2d.index_dim + + col = next(iter(dataset2d.keys())) + dataset2d.true_index_dim = col + assert dataset2d.index_dim == "index" + assert dataset2d.true_index_dim == col + + with pytest.raises(ValueError, match=r"Unknown variable `test`\."): + dataset2d.true_index_dim = "test" + + dataset2d.true_index_dim = None + assert dataset2d.true_index_dim == dataset2d.index_dim + + +def test_index(dataset2d): + alphabet = np.asarray( + list(string.ascii_letters + string.digits + string.punctuation) + ) + new_idx = pd.Index( + [ + "".join(np.random.choice(alphabet, size=10)) + for _ in range(dataset2d.shape[0]) + ], + name="test_index", + ) + + col = next(iter(dataset2d.keys())) + dataset2d.true_index_dim = col + + dataset2d.index = new_idx + assert np.all(dataset2d.index == new_idx) + assert dataset2d.true_index_dim == dataset2d.index_dim == new_idx.name + assert list(dataset2d.coords.keys()) == [new_idx.name] From 4a8c4fce396ec99052c4b613171a0fbfab337ec5 Mon Sep 17 00:00:00 2001 From: Ilia Kats Date: Fri, 25 Apr 2025 14:40:45 +0200 Subject: [PATCH 31/43] fix mininum awkward version --- pyproject.toml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 6c43c98a0..17a505ef7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -79,7 +79,6 @@ doc = [ "sphinx_design>=0.5.0", # for unreleased changes "anndata[dev-doc,dask]", - "awkward>=2.3", ] dev-doc = [ "towncrier>=24.8.0" ] # release notes tool test-full = [ "anndata[test,lazy]" ] @@ -102,7 +101,7 @@ test = [ "xarray>=2024.10.0", "pandas>=2.1.0", "dask[distributed]", - "awkward>=2.3", + "awkward>=2.3.2", "pyarrow", "anndata[dask]", ] From 923562501749ab3cf7860bea2e0788ce12c60372 Mon Sep 17 00:00:00 2001 From: Ilia Kats Date: Tue, 6 May 2025 11:34:15 +0200 Subject: [PATCH 32/43] remove unreachable code --- src/anndata/_core/merge.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/src/anndata/_core/merge.py b/src/anndata/_core/merge.py index 2cd35b152..9ea17f67f 100644 --- a/src/anndata/_core/merge.py +++ b/src/anndata/_core/merge.py @@ -777,13 +777,8 @@ def _apply_to_dataset2d(self, el: Dataset2D, *, axis, fill_value=None): ) return el else: - cols = el.columns - tokeep = cols[cols.isin(self.new_idx)] - el = el[tokeep.to_list()] - newcols = self.new_idx[~self.new_idx.isin(cols)] - for col in newcols: - el[col] = (el.index_dim, np.broadcast_to(fill_value, el.shape[0])) - return el + msg = "This should be unreachable, please open an issue." + raise Exception(msg) @property def idx(self): From f1f0d6e31ea4e907c79ccb886cd704c61b8ebba9 Mon Sep 17 00:00:00 2001 From: Ilia Kats Date: Tue, 6 May 2025 14:08:28 +0200 Subject: [PATCH 33/43] properly version-gate xfailing test --- src/anndata/_core/merge.py | 6 +++++- tests/test_concatenate.py | 27 +++++++++++++++++++-------- 2 files changed, 24 insertions(+), 9 deletions(-) diff --git a/src/anndata/_core/merge.py b/src/anndata/_core/merge.py index 9ea17f67f..5046b974d 100644 --- a/src/anndata/_core/merge.py +++ b/src/anndata/_core/merge.py @@ -122,7 +122,11 @@ def equal(a, b) -> bool: b = asarray(b) if a.ndim == b.ndim == 0: return bool(a == b) - return np.array_equal(a, b) + a_na = ( + pd.isna(a) if a.dtype.names is None else np.False_ + ) # pd.isna doesn't work for record arrays + b_na = pd.isna(b) if b.dtype.names is None else np.False_ + return np.array_equal(a_na, b_na) and np.array_equal(a[~a_na], b[~b_na]) @equal.register(pd.DataFrame) diff --git a/tests/test_concatenate.py b/tests/test_concatenate.py index daf8fe06d..7b0cac8f8 100644 --- a/tests/test_concatenate.py +++ b/tests/test_concatenate.py @@ -12,6 +12,7 @@ import pandas as pd import pytest import scipy +import xarray as xr from boltons.iterutils import default_exit, remap, research from numpy import ma from packaging.version import Version @@ -146,6 +147,14 @@ def fix_known_differences( orig = orig.copy() result = result.copy() + if backwards_compat: + del orig.varm + del orig.varp + if isinstance(result.obs, XDataset): + result.obs = result.obs.drop_vars(["batch"]) + else: + result.obs.drop(columns=["batch"], inplace=True) + for attrname in ("obs", "var"): if isinstance(getattr(result, attrname), XDataset): for adata in (orig, result): @@ -171,11 +180,6 @@ def fix_known_differences( # * merge obsp, but some information should be lost del orig.obsp # TODO - if backwards_compat: - del orig.varm - del orig.varp - result.obs.drop(columns=["batch"], inplace=True) - # Possibly need to fix this, ordered categoricals lose orderedness for get_df in [lambda k: k.obs, lambda k: k.obsm["df"]]: str_to_df_converted = get_df(result) @@ -234,9 +238,6 @@ def test_concatenate_roundtrip( **GEN_ADATA_DASK_ARGS, ) - if backwards_compat and (obs_xdataset or var_xdataset): - pytest.xfail("https://github.com/pydata/xarray/issues/10218") - remaining = adata.obs_names subsets = [] while len(remaining) > 0: @@ -245,7 +246,17 @@ def test_concatenate_roundtrip( subsets.append(adata[subset_idx]) remaining = remaining.difference(subset_idx) + if ( + backwards_compat + and (obs_xdataset or var_xdataset) + and Version(xr.__version__) < Version("2025.4.0") + ): + pytest.xfail("https://github.com/pydata/xarray/issues/10218") result = concat_func(subsets, join=join_type, uns_merge="same", index_unique=None) + if backwards_compat and var_xdataset: + result.var = xr.Dataset.from_dataframe( + result.var + ) # backwards compat always returns a dataframe # Correcting for known differences orig, result = fix_known_differences( From 8c349992a2a1cf286ed048780b510bba7f409a68 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Tue, 6 May 2025 14:45:00 +0200 Subject: [PATCH 34/43] (fix): small fix for new xarray + test deps --- pyproject.toml | 8 +++----- src/anndata/_core/merge.py | 2 +- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 17a505ef7..1a61b50fd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -81,8 +81,7 @@ doc = [ "anndata[dev-doc,dask]", ] dev-doc = [ "towncrier>=24.8.0" ] # release notes tool -test-full = [ "anndata[test,lazy]" ] -test = [ +test-min = [ "loompy>=3.0.5", "pytest>=8.2,<8.3.4", "pytest-cov", @@ -98,18 +97,17 @@ test = [ "boltons", "scanpy>=1.10", "httpx", # For data downloading - "xarray>=2024.10.0", - "pandas>=2.1.0", "dask[distributed]", "awkward>=2.3.2", "pyarrow", "anndata[dask]", ] +test = [ "anndata[test-min,lazy]" ] gpu = [ "cupy" ] cu12 = [ "cupy-cuda12x" ] cu11 = [ "cupy-cuda11x" ] # requests and aiohttp needed for zarr remote data -lazy = [ "xarray>=2024.10.0", "aiohttp", "requests", "anndata[dask]" ] +lazy = [ "xarray>=2025.04.0", "aiohttp", "requests", "anndata[dask]" ] # https://github.com/dask/dask/issues/11290 # https://github.com/dask/dask/issues/11752 dask = [ "dask[array]>=2023.5.1,!=2024.8.*,!=2024.9.*,<2025.2.0" ] diff --git a/src/anndata/_core/merge.py b/src/anndata/_core/merge.py index 5046b974d..8a05a36fc 100644 --- a/src/anndata/_core/merge.py +++ b/src/anndata/_core/merge.py @@ -1276,7 +1276,7 @@ def get_chunk(block_info=None): slice(start, stop) for start, stop in block_info[None]["array-location"] ) - chunk = np.array(data_array.data[idx].array) + chunk = np.array(data_array.data[idx]) return chunk if col.dtype == "category" or col.dtype == "string" or use_only_object_dtype: # noqa PLR1714 From e8631a19f0e2beca5d896a36d977d4d332ee396d Mon Sep 17 00:00:00 2001 From: Ilia Kats Date: Tue, 6 May 2025 14:55:59 +0200 Subject: [PATCH 35/43] add back min. pandas version for tests --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 1a61b50fd..2081a5a60 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -100,6 +100,7 @@ test-min = [ "dask[distributed]", "awkward>=2.3.2", "pyarrow", + "pandas>=2.1", # for xarray "anndata[dask]", ] test = [ "anndata[test-min,lazy]" ] From dc1805d3e0266314802271cd0b24e6a1dca47460 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Tue, 6 May 2025 15:01:08 +0200 Subject: [PATCH 36/43] (chore): add `obsm` access test --- tests/lazy/test_read.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/tests/lazy/test_read.py b/tests/lazy/test_read.py index f7b2c8b23..ad112b178 100644 --- a/tests/lazy/test_read.py +++ b/tests/lazy/test_read.py @@ -3,8 +3,11 @@ from importlib.util import find_spec from typing import TYPE_CHECKING +import numpy as np +import pandas as pd import pytest +from anndata import AnnData from anndata.compat import DaskArray from anndata.experimental import read_lazy from anndata.tests.helpers import ( @@ -20,7 +23,6 @@ from collections.abc import Callable from pathlib import Path - from anndata import AnnData from anndata._types import AnnDataElem pytestmark = pytest.mark.skipif(not find_spec("xarray"), reason="xarray not installed") @@ -112,6 +114,21 @@ def test_to_memory(adata_remote: AnnData, adata_orig: AnnData): assert_equal(remote_to_memory, adata_orig) +def test_access_counts_obsm_df(tmp_path: Path): + adata = AnnData( + X=np.array(np.random.rand(100, 20)), + ) + adata.obsm["df"] = pd.DataFrame( + {"col1": np.random.rand(100), "col2": np.random.rand(100)}, + index=adata.obs_names, + ) + adata.write_zarr(tmp_path) + store = AccessTrackingStore(tmp_path) + store.initialize_key_trackers(["obsm/df"]) + read_lazy(store, load_annotation_index=False) + store.assert_access_count("obsm/df", 0) + + def test_view_to_memory(adata_remote: AnnData, adata_orig: AnnData): obs_cats = adata_orig.obs["obs_cat"].cat.categories subset_obs = adata_orig.obs["obs_cat"] == obs_cats[0] From cce12753bb3eb8542ad1dc6f73fb185963188060 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Wed, 7 May 2025 16:10:54 +0200 Subject: [PATCH 37/43] (fix): only use minimal dependcies for min --- .github/workflows/test-cpu.yml | 2 +- pyproject.toml | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/workflows/test-cpu.yml b/.github/workflows/test-cpu.yml index 0233f5876..3dd4ee4b4 100644 --- a/.github/workflows/test-cpu.yml +++ b/.github/workflows/test-cpu.yml @@ -62,7 +62,7 @@ jobs: if: matrix.dependencies-version == 'minimum' run: | uv pip install --system --compile tomli packaging - deps=$(python3 ci/scripts/min-deps.py pyproject.toml --extra dev test) + deps=$(python3 ci/scripts/min-deps.py pyproject.toml --extra dev test-min) uv pip install --system --compile $deps "anndata @ ." - name: Install dependencies release candidates diff --git a/pyproject.toml b/pyproject.toml index 2081a5a60..1a61b50fd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -100,7 +100,6 @@ test-min = [ "dask[distributed]", "awkward>=2.3.2", "pyarrow", - "pandas>=2.1", # for xarray "anndata[dask]", ] test = [ "anndata[test-min,lazy]" ] From 7384e22b419a8f8885afa022c2f8659f4bc71094 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Wed, 7 May 2025 16:17:06 +0200 Subject: [PATCH 38/43] (fix): add pandas as `test` min dep --- pyproject.toml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 1a61b50fd..d9d96d18a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -102,7 +102,10 @@ test-min = [ "pyarrow", "anndata[dask]", ] -test = [ "anndata[test-min,lazy]" ] +test = [ + "anndata[test-min,lazy]", + "pandas>=2.1.0", +] # pandas 2.1.0 needs to be specified for xarray to work with min-deps script gpu = [ "cupy" ] cu12 = [ "cupy-cuda12x" ] cu11 = [ "cupy-cuda11x" ] From c248e5c2b23e26f098f0e4c5f5bbc8e96779d733 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Fri, 9 May 2025 16:04:21 +0200 Subject: [PATCH 39/43] (fix): docs --- docs/conf.py | 1 + src/anndata/_core/merge.py | 6 +++--- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index 3ce0d4899..74151eab6 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -142,6 +142,7 @@ def setup(app: Sphinx): "anndata.compat.DaskArray": "dask.array.Array", "anndata.compat.CupyArray": "cupy.ndarray", "anndata.compat.CupySparseMatrix": "cupyx.scipy.sparse.spmatrix", + "anndata.compat.XDataArray": "xarray.DataArray", "awkward.highlevel.Array": "ak.Array", "numpy.int64": ("py:attr", "numpy.int64"), "pandas.DataFrame.iloc": ("py:attr", "pandas.DataFrame.iloc"), diff --git a/src/anndata/_core/merge.py b/src/anndata/_core/merge.py index 8a05a36fc..58457e23d 100644 --- a/src/anndata/_core/merge.py +++ b/src/anndata/_core/merge.py @@ -1335,20 +1335,20 @@ def make_xarray_extension_dtypes_dask( def concat_dataset2d_on_annot_axis( annotations: Iterable[Dataset2D], join: Join_T, *, force_lazy: bool ) -> Dataset2D: - """Create a concatenate dataset from a list of :class:`~anndata.experimental.backed._xarray.Dataset2D` objects. + """Create a concatenate dataset from a list of :class:`~anndata._core.xarray.Dataset2D` objects. The goal of this function is to mimic `pd.concat(..., ignore_index=True)` so has some complicated logic for handling the "index" to ensure (a) nothing is loaded into memory and (b) the true index is always tracked. Parameters ---------- annotations - The :class:`~anndata.experimental.backed._xarray.Dataset2D` objects to be concatenated. + The :class:`~anndata._core.xarray.Dataset2D` objects to be concatenated. join Type of join operation Returns ------- - Concatenated :class:`~anndata.experimental.backed._xarray.Dataset2D` + Concatenated :class:`~anndata._core.xarray.Dataset2D` """ from anndata._core.xarray import Dataset2D from anndata._io.specs.lazy_methods import DUMMY_RANGE_INDEX_KEY From 97cbacc4a551b9006a3f328fb8ec2c0dde971531 Mon Sep 17 00:00:00 2001 From: Ilia Kats Date: Fri, 9 May 2025 16:46:10 +0200 Subject: [PATCH 40/43] ci/min-deps.py: correctly handle same dependency with different extras --- ci/scripts/min-deps.py | 1 + 1 file changed, 1 insertion(+) diff --git a/ci/scripts/min-deps.py b/ci/scripts/min-deps.py index ae0c4e886..f04ef62eb 100755 --- a/ci/scripts/min-deps.py +++ b/ci/scripts/min-deps.py @@ -79,6 +79,7 @@ def extract_min_deps( else: if req.name in deps: req.specifier &= deps[req.name].specifier + req.extras |= deps[req.name].extras deps[req.name] = min_dep(req) yield from deps.values() From 01194d9e6a753dfbc409f49b752b3d12a79de6a7 Mon Sep 17 00:00:00 2001 From: Ilia Kats Date: Mon, 12 May 2025 10:36:51 +0200 Subject: [PATCH 41/43] move reorganize concat label handling --- src/anndata/_core/merge.py | 26 ++++++++++++++++++-------- 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/src/anndata/_core/merge.py b/src/anndata/_core/merge.py index 58457e23d..5e9aa8481 100644 --- a/src/anndata/_core/merge.py +++ b/src/anndata/_core/merge.py @@ -1333,7 +1333,12 @@ def make_xarray_extension_dtypes_dask( def concat_dataset2d_on_annot_axis( - annotations: Iterable[Dataset2D], join: Join_T, *, force_lazy: bool + annotations: Iterable[Dataset2D], + join: Join_T, + *, + force_lazy: bool, + label: str | None = None, + label_col: pd.Categorical | None = None, ) -> Dataset2D: """Create a concatenate dataset from a list of :class:`~anndata._core.xarray.Dataset2D` objects. The goal of this function is to mimic `pd.concat(..., ignore_index=True)` so has some complicated logic @@ -1345,6 +1350,13 @@ def concat_dataset2d_on_annot_axis( The :class:`~anndata._core.xarray.Dataset2D` objects to be concatenated. join Type of join operation + force_lazy + Whether to lazily concatenate elements using dask even when eager concatenation is possible. + label + Column in axis annotation (i.e. `.obs` or `.var`) to place batch information in. + If it's None, no column is added. + label_col + The bath information annotation. Returns ------- @@ -1404,6 +1416,8 @@ def concat_dataset2d_on_annot_axis( del ds[key] if DUMMY_RANGE_INDEX_KEY in ds: del ds[DUMMY_RANGE_INDEX_KEY] + if label is not None and label_col is not None: + ds[label] = (DS_CONCAT_DUMMY_INDEX_NAME, label_col) return ds @@ -1678,18 +1692,14 @@ def concat( # noqa: PLR0912, PLR0913, PLR0915 ignore_index=True, ) concat_annot.index = concat_indices + if label is not None: + concat_annot[label] = label_col else: concat_annot = concat_dataset2d_on_annot_axis( - annotations, join, force_lazy=force_lazy + annotations, join, force_lazy=force_lazy, label=label, label_col=label_col ) concat_indices.name = DS_CONCAT_DUMMY_INDEX_NAME concat_annot.index = concat_indices - if label is not None: - concat_annot[label] = ( - label_col - if not isinstance(concat_annot, Dataset2D) - else (DS_CONCAT_DUMMY_INDEX_NAME, label_col) - ) # Annotation for other axis alt_annotations = [getattr(a, alt_axis_name) for a in adatas] From 3a12723f800a1b6366a6a6079b74fbcc516a6d9a Mon Sep 17 00:00:00 2001 From: Ilan Gold Date: Mon, 12 May 2025 12:37:23 +0200 Subject: [PATCH 42/43] Update src/anndata/_core/merge.py --- src/anndata/_core/merge.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/anndata/_core/merge.py b/src/anndata/_core/merge.py index 5e9aa8481..276db242f 100644 --- a/src/anndata/_core/merge.py +++ b/src/anndata/_core/merge.py @@ -340,8 +340,7 @@ def try_unifying_dtype( # noqa PLR0911, PLR0912 ) if same_orders: return next(iter(dtypes)) - else: - return object + return object # Boolean elif all(pd.api.types.is_bool_dtype(dtype) or dtype is None for dtype in col): if any(dtype is None for dtype in col): From 40567aa12ae6a87e16e613519fdfcb857168b270 Mon Sep 17 00:00:00 2001 From: ilia-kats Date: Mon, 12 May 2025 13:36:42 +0200 Subject: [PATCH 43/43] Apply suggestions from code review Co-authored-by: Ilan Gold --- src/anndata/_core/xarray.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/anndata/_core/xarray.py b/src/anndata/_core/xarray.py index 52fd33081..82ee71b40 100644 --- a/src/anndata/_core/xarray.py +++ b/src/anndata/_core/xarray.py @@ -41,8 +41,7 @@ def index_dim(self) -> str: @property def true_index_dim(self) -> str: - index_dim = self.attrs.get("indexing_key", None) - return index_dim if index_dim is not None else self.index_dim + return self.attrs.get("indexing_key", self.index_dim) @true_index_dim.setter def true_index_dim(self, val: str): @@ -75,6 +74,7 @@ def index(self, val) -> None: if isinstance(val, pd.Index) and val.name is not None and val.name != index_dim: self.update(self.rename({self.index_dim: val.name})) del self.coords[index_dim] + # without `indexing_key` explicitly set on `self.attrs`, `self.true_index_dim` will use the `self.index_dim` if "indexing_key" in self.attrs: del self.attrs["indexing_key"]