diff --git a/.github/workflows/test-cpu.yml b/.github/workflows/test-cpu.yml index 0233f5876..3dd4ee4b4 100644 --- a/.github/workflows/test-cpu.yml +++ b/.github/workflows/test-cpu.yml @@ -62,7 +62,7 @@ jobs: if: matrix.dependencies-version == 'minimum' run: | uv pip install --system --compile tomli packaging - deps=$(python3 ci/scripts/min-deps.py pyproject.toml --extra dev test) + deps=$(python3 ci/scripts/min-deps.py pyproject.toml --extra dev test-min) uv pip install --system --compile $deps "anndata @ ." - name: Install dependencies release candidates diff --git a/ci/scripts/min-deps.py b/ci/scripts/min-deps.py index b2e5fa716..f04ef62eb 100755 --- a/ci/scripts/min-deps.py +++ b/ci/scripts/min-deps.py @@ -34,7 +34,7 @@ def min_dep(req: Requirement) -> Requirement: ------- >>> min_dep(Requirement("numpy>=1.0")) - + >>> min_dep(Requirement("numpy<3.0")) """ @@ -55,7 +55,7 @@ def min_dep(req: Requirement) -> Requirement: elif spec.operator == "==": min_version = Version(spec.version) - return Requirement(f"{req_name}=={min_version}.*") + return Requirement(f"{req_name}~={min_version}.0") def extract_min_deps( @@ -64,6 +64,7 @@ def extract_min_deps( dependencies = deque(dependencies) # We'll be mutating this project_name = pyproject["project"]["name"] + deps = {} while len(dependencies) > 0: req = dependencies.pop() @@ -76,7 +77,11 @@ def extract_min_deps( extra_deps = pyproject["project"]["optional-dependencies"][extra] dependencies += map(Requirement, extra_deps) else: - yield min_dep(req) + if req.name in deps: + req.specifier &= deps[req.name].specifier + req.extras |= deps[req.name].extras + deps[req.name] = min_dep(req) + yield from deps.values() class Args(argparse.Namespace): diff --git a/docs/api.md b/docs/api.md index ba6de634d..358968b3e 100644 --- a/docs/api.md +++ b/docs/api.md @@ -180,7 +180,7 @@ Types used by the former: experimental.StorageType experimental.backed._lazy_arrays.MaskedArray experimental.backed._lazy_arrays.CategoricalArray - experimental.backed._xarray.Dataset2D + _core.xarray.Dataset2D ``` (extensions-api)= diff --git a/docs/conf.py b/docs/conf.py index 3ce0d4899..74151eab6 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -142,6 +142,7 @@ def setup(app: Sphinx): "anndata.compat.DaskArray": "dask.array.Array", "anndata.compat.CupyArray": "cupy.ndarray", "anndata.compat.CupySparseMatrix": "cupyx.scipy.sparse.spmatrix", + "anndata.compat.XDataArray": "xarray.DataArray", "awkward.highlevel.Array": "ak.Array", "numpy.int64": ("py:attr", "numpy.int64"), "pandas.DataFrame.iloc": ("py:attr", "pandas.DataFrame.iloc"), diff --git a/docs/release-notes/0.12.0rc1.md b/docs/release-notes/0.12.0rc1.md index cc9b3b831..eb79d51e8 100644 --- a/docs/release-notes/0.12.0rc1.md +++ b/docs/release-notes/0.12.0rc1.md @@ -10,7 +10,7 @@ #### Bug fixes -- Disallow writing of {class}`~anndata.experimental.backed._xarray.Dataset2D` objects {user}`ilan-gold` ({pr}`1887`) +- Disallow writing of {class}`~anndata._core.xarray.Dataset2D` objects {user}`ilan-gold` ({pr}`1887`) - Upgrade old deprecation warning to a `FutureWarning` on `BaseCompressedSparseDataset.__setitem__`, showing our intent to remove the feature in the next release. {user}`ilan-gold` ({pr}`1928`) - Don't use {func}`asyncio.run` internally for any operations {user}`ilan-gold` ({pr}`1933`) - Disallow forward slashes in keys for writing {user}`ilan-gold` ({pr}`1940`) diff --git a/docs/release-notes/1966.feature.md b/docs/release-notes/1966.feature.md new file mode 100755 index 000000000..a702d9d47 --- /dev/null +++ b/docs/release-notes/1966.feature.md @@ -0,0 +1 @@ +Allow xarray Datasets to be used for obs/var/obsm/varm. {user}`ilia-kats` diff --git a/pyproject.toml b/pyproject.toml index 9ae7e53a5..d9d96d18a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -79,11 +79,9 @@ doc = [ "sphinx_design>=0.5.0", # for unreleased changes "anndata[dev-doc,dask]", - "awkward>=2.3", ] dev-doc = [ "towncrier>=24.8.0" ] # release notes tool -test-full = [ "anndata[test,lazy]" ] -test = [ +test-min = [ "loompy>=3.0.5", "pytest>=8.2,<8.3.4", "pytest-cov", @@ -100,15 +98,19 @@ test = [ "scanpy>=1.10", "httpx", # For data downloading "dask[distributed]", - "awkward>=2.3", + "awkward>=2.3.2", "pyarrow", "anndata[dask]", ] +test = [ + "anndata[test-min,lazy]", + "pandas>=2.1.0", +] # pandas 2.1.0 needs to be specified for xarray to work with min-deps script gpu = [ "cupy" ] cu12 = [ "cupy-cuda12x" ] cu11 = [ "cupy-cuda11x" ] # requests and aiohttp needed for zarr remote data -lazy = [ "xarray>=2024.06.0", "aiohttp", "requests", "anndata[dask]" ] +lazy = [ "xarray>=2025.04.0", "aiohttp", "requests", "anndata[dask]" ] # https://github.com/dask/dask/issues/11290 # https://github.com/dask/dask/issues/11752 dask = [ "dask[array]>=2023.5.1,!=2024.8.*,!=2024.9.*,<2025.2.0" ] diff --git a/src/anndata/_core/aligned_df.py b/src/anndata/_core/aligned_df.py index 259068826..d49e72bbb 100644 --- a/src/anndata/_core/aligned_df.py +++ b/src/anndata/_core/aligned_df.py @@ -9,6 +9,8 @@ from pandas.api.types import is_string_dtype from .._warnings import ImplicitModificationWarning +from ..compat import XDataset +from .xarray import Dataset2D if TYPE_CHECKING: from collections.abc import Iterable @@ -108,8 +110,8 @@ def _mk_df_error( expected: int, actual: int, ): + what = "row" if attr == "obs" else "column" if source == "X": - what = "row" if attr == "obs" else "column" msg = ( f"Observations annot. `{attr}` must have as many rows as `X` has {what}s " f"({expected}), but has {actual} rows." @@ -117,6 +119,30 @@ def _mk_df_error( else: msg = ( f"`shape` is inconsistent with `{attr}` " - "({actual} {what}s instead of {expected})" + f"({actual} {what}s instead of {expected})" ) return ValueError(msg) + + +@_gen_dataframe.register(Dataset2D) +def _gen_dataframe_xr( + anno: Dataset2D, + index_names: Iterable[str], + *, + source: Literal["X", "shape"], + attr: Literal["obs", "var"], + length: int | None = None, +): + return anno + + +@_gen_dataframe.register(XDataset) +def _gen_dataframe_xdataset( + anno: XDataset, + index_names: Iterable[str], + *, + source: Literal["X", "shape"], + attr: Literal["obs", "var"], + length: int | None = None, +): + return Dataset2D(anno) diff --git a/src/anndata/_core/aligned_mapping.py b/src/anndata/_core/aligned_mapping.py index b5649d0f4..e883032a1 100644 --- a/src/anndata/_core/aligned_mapping.py +++ b/src/anndata/_core/aligned_mapping.py @@ -11,7 +11,7 @@ import pandas as pd from .._warnings import ExperimentalFeatureWarning, ImplicitModificationWarning -from ..compat import AwkArray, CSArray, CSMatrix, CupyArray +from ..compat import AwkArray, CSArray, CSMatrix, CupyArray, XDataset from ..utils import ( axis_len, convert_to_dict, @@ -23,6 +23,7 @@ from .index import _subset from .storage import coerce_array from .views import as_view, view_update +from .xarray import Dataset2D if TYPE_CHECKING: from collections.abc import Callable, Iterable, Iterator, Mapping @@ -75,8 +76,10 @@ def _validate_value(self, val: Value, key: str) -> Value: ExperimentalFeatureWarning, # stacklevel=3, ) - if isinstance(val, np.ndarray | CupyArray) and len(val.shape) == 1: + elif isinstance(val, np.ndarray | CupyArray) and len(val.shape) == 1: val = val.reshape((val.shape[0], 1)) + elif isinstance(val, XDataset): + val = Dataset2D(data_vars=val.data_vars, coords=val.coords, attrs=val.attrs) for i, axis in enumerate(self.axes): if self.parent.shape[axis] == axis_len(val, i): continue @@ -275,6 +278,9 @@ def _validate_value(self, val: Value, key: str) -> Value: else: msg = "Index.equals and pd.testing.assert_index_equal disagree" raise AssertionError(msg) + val.index.name = ( + self.dim_names.name + ) # this is consistent with AnnData.obsm.setter and AnnData.varm.setter return super()._validate_value(val, key) @property diff --git a/src/anndata/_core/anndata.py b/src/anndata/_core/anndata.py index 66505b74d..5f9d777da 100644 --- a/src/anndata/_core/anndata.py +++ b/src/anndata/_core/anndata.py @@ -47,6 +47,7 @@ _resolve_idxs, as_view, ) +from .xarray import Dataset2D if TYPE_CHECKING: from collections.abc import Iterable @@ -55,7 +56,7 @@ from zarr.storage import StoreLike - from ..compat import Index1D + from ..compat import Index1D, XDataset from ..typing import XDataType from .aligned_mapping import AxisArraysView, LayersView, PairwiseArraysView from .index import Index @@ -746,10 +747,14 @@ def n_vars(self) -> int: """Number of variables/features.""" return len(self.var_names) - def _set_dim_df(self, value: pd.DataFrame, attr: Literal["obs", "var"]): - if not isinstance(value, pd.DataFrame): - msg = f"Can only assign pd.DataFrame to {attr}." - raise ValueError(msg) + def _set_dim_df(self, value: pd.DataFrame | XDataset, attr: Literal["obs", "var"]): + value = _gen_dataframe( + value, + [f"{attr}_names", f"{'row' if attr == 'obs' else 'col'}_names"], + source="shape", + attr=attr, + length=self.n_obs if attr == "obs" else self.n_vars, + ) raise_value_error_if_multiindex_columns(value, attr) value_idx = self._prep_dim_index(value.index, attr) if self.is_view: @@ -804,12 +809,12 @@ def _set_dim_index(self, value: pd.Index, attr: str): v.index = value @property - def obs(self) -> pd.DataFrame: + def obs(self) -> pd.DataFrame | Dataset2D: """One-dimensional annotation of observations (`pd.DataFrame`).""" return self._obs @obs.setter - def obs(self, value: pd.DataFrame): + def obs(self, value: pd.DataFrame | XDataset): self._set_dim_df(value, "obs") @obs.deleter @@ -827,12 +832,12 @@ def obs_names(self, names: Sequence[str]): self._set_dim_index(names, "obs") @property - def var(self) -> pd.DataFrame: + def var(self) -> pd.DataFrame | Dataset2D: """One-dimensional annotation of variables/ features (`pd.DataFrame`).""" return self._var @var.setter - def var(self, value: pd.DataFrame): + def var(self, value: pd.DataFrame | XDataset): self._set_dim_df(value, "var") @var.deleter @@ -2079,6 +2084,14 @@ def _get_and_delete_multicol_field(self, a, key_multicol): return values +@AnnData._remove_unused_categories.register(Dataset2D) +@staticmethod +def _remove_unused_categories_xr( + df_full: Dataset2D, df_sub: Dataset2D, uns: dict[str, Any] +): + pass # this is handled automatically by the categorical arrays themselves i.e., they dedup upon access. + + def _check_2d_shape(X): """\ Check shape of array or sparse matrix. diff --git a/src/anndata/_core/file_backing.py b/src/anndata/_core/file_backing.py index 45275e651..0e1dbf336 100644 --- a/src/anndata/_core/file_backing.py +++ b/src/anndata/_core/file_backing.py @@ -10,6 +10,7 @@ from ..compat import AwkArray, DaskArray, ZarrArray, ZarrGroup from .sparse_dataset import BaseCompressedSparseDataset +from .xarray import Dataset2D if TYPE_CHECKING: from collections.abc import Iterator @@ -162,6 +163,11 @@ def _(x: AwkArray, *, copy: bool = False): return x +@to_memory.register(Dataset2D) +def _(x: Dataset2D, *, copy: bool = False): + return x.to_memory(copy=copy) + + @singledispatch def filename(x): msg = f"Not implemented for {type(x)}" diff --git a/src/anndata/_core/index.py b/src/anndata/_core/index.py index f6435b64a..1128bc665 100644 --- a/src/anndata/_core/index.py +++ b/src/anndata/_core/index.py @@ -10,7 +10,8 @@ import pandas as pd from scipy.sparse import issparse -from ..compat import AwkArray, CSArray, CSMatrix, DaskArray +from ..compat import AwkArray, CSArray, CSMatrix, DaskArray, XDataArray +from .xarray import Dataset2D if TYPE_CHECKING: from ..compat import Index, Index1D @@ -44,8 +45,6 @@ def _normalize_index( # noqa: PLR0911, PLR0912 | pd.Index, index: pd.Index, ) -> slice | int | np.ndarray: # ndarray of int or bool - from ..experimental.backed._compat import DataArray - # TODO: why is this here? All tests pass without it and it seems at the minimum not strict enough. if not isinstance(index, pd.RangeIndex) and index.dtype in (np.float64, np.int64): msg = f"Don’t call _normalize_index with non-categorical/string names and non-range index {index}" @@ -112,7 +111,7 @@ def name_idx(i): ) raise KeyError(msg) return positions # np.ndarray[int] - elif isinstance(indexer, DataArray): + elif isinstance(indexer, XDataArray): if isinstance(indexer.data, DaskArray): return indexer.data.compute() return indexer.data @@ -210,6 +209,15 @@ def _subset_awkarray(a: AwkArray, subset_idx: Index): return a[subset_idx] +@_subset.register(Dataset2D) +def _(a: Dataset2D, subset_idx: Index): + key = a.index_dim + # xarray seems to have some code looking for a second entry in tuples + if isinstance(subset_idx, tuple) and len(subset_idx) == 1: + subset_idx = subset_idx[0] + return a.isel(**{key: subset_idx}) + + # Registration for SparseDataset occurs in sparse_dataset.py @_subset.register(h5py.Dataset) def _subset_dataset(d, subset_idx): diff --git a/src/anndata/_core/merge.py b/src/anndata/_core/merge.py index 05737b4c4..276db242f 100644 --- a/src/anndata/_core/merge.py +++ b/src/anndata/_core/merge.py @@ -17,6 +17,7 @@ import scipy from natsort import natsorted from packaging.version import Version +from pandas.api.types import is_extension_array_dtype from scipy import sparse from anndata._core.file_backing import to_memory @@ -35,6 +36,7 @@ from ..utils import asarray, axis_len, warn_once from .anndata import AnnData from .index import _subset, make_slice +from .xarray import Dataset2D if TYPE_CHECKING: from collections.abc import Collection, Generator, Iterable, Sequence @@ -43,7 +45,8 @@ from pandas.api.extensions import ExtensionDtype from anndata._types import Join_T - from anndata.experimental.backed._compat import DataArray, Dataset2D + + from ..compat import XDataArray T = TypeVar("T") @@ -119,7 +122,11 @@ def equal(a, b) -> bool: b = asarray(b) if a.ndim == b.ndim == 0: return bool(a == b) - return np.array_equal(a, b) + a_na = ( + pd.isna(a) if a.dtype.names is None else np.False_ + ) # pd.isna doesn't work for record arrays + b_na = pd.isna(b) if b.dtype.names is None else np.False_ + return np.array_equal(a_na, b_na) and np.array_equal(a[~a_na], b[~b_na]) @equal.register(pd.DataFrame) @@ -208,6 +215,11 @@ def equal_awkward(a, b) -> bool: return ak.almost_equal(a, b) +@equal.register(Dataset2D) +def equal_dataset2d(a, b) -> bool: + return a.equals(b) + + def as_sparse(x, *, use_sparse_array: bool = False) -> CSMatrix | CSArray: if not isinstance(x, CSMatrix | CSArray): in_memory_array_class = ( @@ -271,7 +283,7 @@ def unify_dtypes( return dfs -def try_unifying_dtype( +def try_unifying_dtype( # noqa PLR0911, PLR0912 col: Sequence[np.dtype | ExtensionDtype], ) -> pd.core.dtypes.base.ExtensionDtype | None: """ @@ -296,13 +308,39 @@ def try_unifying_dtype( ordered = ordered | dtype.ordered elif not pd.isnull(dtype): return None - if len(dtypes) > 0 and not ordered: + if len(dtypes) > 0: categories = reduce( lambda x, y: x.union(y), - [dtype.categories for dtype in dtypes if not pd.isnull(dtype)], + (dtype.categories for dtype in dtypes if not pd.isnull(dtype)), ) - return pd.CategoricalDtype(natsorted(categories), ordered=False) + if not ordered: + return pd.CategoricalDtype(natsorted(categories), ordered=False) + else: # for xarray Datasets, see https://github.com/pydata/xarray/issues/10247 + categories_intersection = reduce( + lambda x, y: x.intersection(y), + ( + dtype.categories + for dtype in dtypes + if not pd.isnull(dtype) and len(dtype.categories) > 0 + ), + ) + if len(categories_intersection) < len(categories): + return object + else: + same_orders = all( + dtype.ordered + for dtype in dtypes + if not pd.isnull(dtype) and len(dtype.categories) > 0 + ) + same_orders &= all( + np.all(categories == dtype.categories) + for dtype in dtypes + if not pd.isnull(dtype) and len(dtype.categories) > 0 + ) + if same_orders: + return next(iter(dtypes)) + return object # Boolean elif all(pd.api.types.is_bool_dtype(dtype) or dtype is None for dtype in col): if any(dtype is None for dtype in col): @@ -448,7 +486,7 @@ def _merge_nested( vals = [d[k] for d in ds if k in d] if len(vals) == 0: return MissingVal - elif all(isinstance(v, Mapping) for v in vals): + elif all(isinstance(v, Mapping) and not isinstance(v, Dataset2D) for v in vals): new_map = merge_nested(vals, keys_join, value_join) if len(new_map) == 0: return MissingVal @@ -556,6 +594,8 @@ def apply(self, el, *, axis, fill_value=None): # noqa: PLR0911 return self._apply_to_dask_array(el, axis=axis, fill_value=fill_value) elif isinstance(el, CupyArray): return self._apply_to_cupy_array(el, axis=axis, fill_value=fill_value) + elif isinstance(el, Dataset2D): + return self._apply_to_dataset2d(el, axis=axis, fill_value=fill_value) else: return self._apply_to_array(el, axis=axis, fill_value=fill_value) @@ -718,6 +758,31 @@ def _apply_to_awkward(self, el: AwkArray, *, axis, fill_value=None): el = ak.pad_none(el, 1, axis=axis) # axis == 0 return el[self.idx] + def _apply_to_dataset2d(self, el: Dataset2D, *, axis, fill_value=None): + if fill_value is None: + fill_value = np.nan + index_dim = el.index_dim + if axis == 0: + # Dataset.reindex() can't handle ExtensionArrays + extension_arrays = { + col: arr for col, arr in el.items() if is_extension_array_dtype(arr) + } + el = el.drop_vars(extension_arrays.keys()) + el = el.reindex( + {index_dim: self.new_idx}, method=None, fill_value=fill_value + ) + for col, arr in extension_arrays.items(): + el[col] = ( + index_dim, + pd.Series(arr, index=self.old_idx).reindex( + self.new_idx, fill_value=fill_value + ), + ) + return el + else: + msg = "This should be unreachable, please open an issue." + raise Exception(msg) + @property def idx(self): return self.old_idx.get_indexer(self.new_idx) @@ -776,7 +841,9 @@ def np_bool_to_pd_bool_array(df: pd.DataFrame): return df -def concat_arrays(arrays, reindexers, axis=0, index=None, fill_value=None): # noqa: PLR0911, PLR0912 +def concat_arrays( # noqa: PLR0911, PLR0912 + arrays, reindexers, axis=0, index=None, fill_value=None, *, force_lazy: bool = False +): from anndata.experimental.backed._compat import Dataset2D arrays = list(arrays) @@ -790,7 +857,9 @@ def concat_arrays(arrays, reindexers, axis=0, index=None, fill_value=None): # n msg = f"Cannot concatenate a Dataset2D with other array types {[type(a) for a in arrays if not isinstance(a, Dataset2D)]}." raise ValueError(msg) else: - return concat_dataset2d_on_annot_axis(arrays, join="outer") + return concat_dataset2d_on_annot_axis( + arrays, join="outer", force_lazy=force_lazy + ) if any(isinstance(a, pd.DataFrame) for a in arrays): # TODO: This is hacky, 0 is a sentinel for outer_concat_aligned_mapping if not all( @@ -880,7 +949,13 @@ def concat_arrays(arrays, reindexers, axis=0, index=None, fill_value=None): # n def inner_concat_aligned_mapping( - mappings, *, reindexers=None, index=None, axis=0, concat_axis=None + mappings, + *, + reindexers=None, + index=None, + axis=0, + concat_axis=None, + force_lazy: bool = False, ): if concat_axis is None: concat_axis = axis @@ -895,7 +970,9 @@ def inner_concat_aligned_mapping( else: cur_reindexers = reindexers - result[k] = concat_arrays(els, cur_reindexers, index=index, axis=concat_axis) + result[k] = concat_arrays( + els, cur_reindexers, index=index, axis=concat_axis, force_lazy=force_lazy + ) return result @@ -991,7 +1068,14 @@ def missing_element( def outer_concat_aligned_mapping( - mappings, *, reindexers=None, index=None, axis=0, concat_axis=None, fill_value=None + mappings, + *, + reindexers=None, + index=None, + axis=0, + concat_axis=None, + fill_value=None, + force_lazy: bool = False, ): if concat_axis is None: concat_axis = axis @@ -1033,6 +1117,7 @@ def outer_concat_aligned_mapping( axis=concat_axis, index=index, fill_value=fill_value, + force_lazy=force_lazy, ) return result @@ -1099,7 +1184,11 @@ def _resolve_axis( def axis_indices(adata: AnnData, axis: Literal["obs", 0, "var", 1]) -> pd.Index: """Helper function to get adata.{dim}_names.""" _, axis_name = _resolve_axis(axis) - return getattr(adata, f"{axis_name}_names") + attr = getattr(adata, axis_name) + if isinstance(attr, Dataset2D): + return attr.true_index + else: + return attr.index # TODO: Resolve https://github.com/scverse/anndata/issues/678 and remove this function @@ -1127,7 +1216,7 @@ def concat_Xs(adatas, reindexers, axis, fill_value): def make_dask_col_from_extension_dtype( - col: DataArray, *, use_only_object_dtype: bool = False + col: XDataArray, *, use_only_object_dtype: bool = False ) -> DaskArray: """ Creates dask arrays from :class:`pandas.api.extensions.ExtensionArray` dtype :class:`xarray.DataArray`s. @@ -1150,50 +1239,57 @@ def make_dask_col_from_extension_dtype( get_chunksize, maybe_open_h5, ) + from anndata.compat import XDataArray + from anndata.compat import xarray as xr from anndata.experimental import read_elem_lazy - from anndata.experimental.backed._compat import DataArray - from anndata.experimental.backed._compat import xarray as xr base_path_or_zarr_group = col.attrs.get("base_path_or_zarr_group") elem_name = col.attrs.get("elem_name") - dims = col.dims - coords = col.coords.copy() - with maybe_open_h5(base_path_or_zarr_group, elem_name) as f: - maybe_chunk_size = get_chunksize(read_elem_lazy(f)) - chunk_size = ( - compute_chunk_layout_for_axis_size( - 1000 if maybe_chunk_size is None else maybe_chunk_size[0], col.shape[0] - ), - ) - - def get_chunk(block_info=None): - # reopening is important to get around h5py's unserializable lock in processes + if ( + base_path_or_zarr_group is not None and elem_name is not None + ): # lazy, backed by store + dims = col.dims + coords = col.coords.copy() with maybe_open_h5(base_path_or_zarr_group, elem_name) as f: - v = read_elem_lazy(f) - variable = xr.Variable( - data=xr.core.indexing.LazilyIndexedArray(v), dims=dims - ) - data_array = DataArray( - variable, - coords=coords, - dims=dims, - ) - idx = tuple( - slice(start, stop) for start, stop in block_info[None]["array-location"] + maybe_chunk_size = get_chunksize(read_elem_lazy(f)) + chunk_size = ( + compute_chunk_layout_for_axis_size( + 1000 if maybe_chunk_size is None else maybe_chunk_size[0], + col.shape[0], + ), ) - chunk = np.array(data_array.data[idx].array) - return chunk - if col.dtype in ("category", "string") or use_only_object_dtype: - dtype = "object" - else: - dtype = col.dtype.numpy_dtype - return da.map_blocks( - get_chunk, - chunks=chunk_size, - meta=np.array([], dtype=dtype), - dtype=dtype, - ) + def get_chunk(block_info=None): + # reopening is important to get around h5py's unserializable lock in processes + with maybe_open_h5(base_path_or_zarr_group, elem_name) as f: + v = read_elem_lazy(f) + variable = xr.Variable( + data=xr.core.indexing.LazilyIndexedArray(v), dims=dims + ) + data_array = XDataArray( + variable, + coords=coords, + dims=dims, + ) + idx = tuple( + slice(start, stop) + for start, stop in block_info[None]["array-location"] + ) + chunk = np.array(data_array.data[idx]) + return chunk + + if col.dtype == "category" or col.dtype == "string" or use_only_object_dtype: # noqa PLR1714 + dtype = "object" + else: + dtype = col.dtype.numpy_dtype + return da.map_blocks( + get_chunk, + chunks=chunk_size, + meta=np.array([], dtype=dtype), + dtype=dtype, + ) + + return da.from_array(col.values, chunks=-1) # in-memory def make_xarray_extension_dtypes_dask( @@ -1238,33 +1334,50 @@ def make_xarray_extension_dtypes_dask( def concat_dataset2d_on_annot_axis( annotations: Iterable[Dataset2D], join: Join_T, + *, + force_lazy: bool, + label: str | None = None, + label_col: pd.Categorical | None = None, ) -> Dataset2D: - """Create a concatenate dataset from a list of :class:`~anndata.experimental.backed._xarray.Dataset2D` objects. + """Create a concatenate dataset from a list of :class:`~anndata._core.xarray.Dataset2D` objects. The goal of this function is to mimic `pd.concat(..., ignore_index=True)` so has some complicated logic for handling the "index" to ensure (a) nothing is loaded into memory and (b) the true index is always tracked. Parameters ---------- annotations - The :class:`~anndata.experimental.backed._xarray.Dataset2D` objects to be concatenated. + The :class:`~anndata._core.xarray.Dataset2D` objects to be concatenated. join Type of join operation + force_lazy + Whether to lazily concatenate elements using dask even when eager concatenation is possible. + label + Column in axis annotation (i.e. `.obs` or `.var`) to place batch information in. + If it's None, no column is added. + label_col + The bath information annotation. Returns ------- - Concatenated :class:`~anndata.experimental.backed._xarray.Dataset2D` + Concatenated :class:`~anndata._core.xarray.Dataset2D` """ + from anndata._core.xarray import Dataset2D from anndata._io.specs.lazy_methods import DUMMY_RANGE_INDEX_KEY - from anndata.experimental.backed._compat import Dataset2D - from anndata.experimental.backed._compat import xarray as xr + from anndata.compat import xarray as xr annotations_re_indexed = [] - for a in make_xarray_extension_dtypes_dask(annotations): - old_key = next(iter(a.coords.keys())) + have_backed = any(a.is_backed for a in annotations) + if have_backed or force_lazy: + annotations = make_xarray_extension_dtypes_dask(annotations) + else: + annotations = unify_dtypes(annotations) + for a in annotations: + old_key = a.index_dim + is_fake_index = old_key != a.true_index_dim # First create a dummy index a.coords[DS_CONCAT_DUMMY_INDEX_NAME] = ( old_key, - pd.RangeIndex(a[a.attrs["indexing_key"]].shape[0]).astype("str"), + pd.RangeIndex(a.shape[0]), ) # Set all the dimensions to this new dummy index a = a.swap_dims({old_key: DS_CONCAT_DUMMY_INDEX_NAME}) @@ -1272,30 +1385,38 @@ def concat_dataset2d_on_annot_axis( old_coord = a.coords[old_key] del a.coords[old_key] a[old_key] = old_coord + if not is_fake_index: + a.true_index_dim = old_key annotations_re_indexed.append(a) # Concat along the dummy index ds = Dataset2D( xr.concat(annotations_re_indexed, join=join, dim=DS_CONCAT_DUMMY_INDEX_NAME), - attrs={"indexing_key": f"true_{DS_CONCAT_DUMMY_INDEX_NAME}"}, ) + ds.is_backed = have_backed ds.coords[DS_CONCAT_DUMMY_INDEX_NAME] = pd.RangeIndex( ds.coords[DS_CONCAT_DUMMY_INDEX_NAME].shape[0] - ).astype("str") + ) # Drop any lingering dimensions (swap doesn't delete) ds = ds.drop_dims(d for d in ds.dims if d != DS_CONCAT_DUMMY_INDEX_NAME) # Create a new true index and then delete the columns resulting from the concatenation for each index. # This includes the dummy column (which is neither a dimension nor a true indexing column) index = xr.concat( - [a[a.attrs["indexing_key"]] for a in annotations_re_indexed], + [a.true_xr_index for a in annotations_re_indexed], dim=DS_CONCAT_DUMMY_INDEX_NAME, ) # prevent duplicate values index.coords[DS_CONCAT_DUMMY_INDEX_NAME] = ds.coords[DS_CONCAT_DUMMY_INDEX_NAME] - ds[f"true_{DS_CONCAT_DUMMY_INDEX_NAME}"] = index - for key in {a.attrs["indexing_key"] for a in annotations_re_indexed}: + ds.coords[DS_CONCAT_DUMMY_INDEX_NAME] = index + for key in { + true_index + for a in annotations_re_indexed + if (true_index := a.true_index_dim) != a.index_dim + }: del ds[key] if DUMMY_RANGE_INDEX_KEY in ds: del ds[DUMMY_RANGE_INDEX_KEY] + if label is not None and label_col is not None: + ds[label] = (DS_CONCAT_DUMMY_INDEX_NAME, label_col) return ds @@ -1311,6 +1432,7 @@ def concat( # noqa: PLR0912, PLR0913, PLR0915 index_unique: str | None = None, fill_value: Any | None = None, pairwise: bool = False, + force_lazy: bool = False, ) -> AnnData: """Concatenates AnnData objects along an axis. @@ -1359,6 +1481,9 @@ def concat( # noqa: PLR0912, PLR0913, PLR0915 pairwise Whether pairwise elements along the concatenated dimension should be included. This is False by default, since the resulting arrays are often not meaningful. + force_lazy + Whether to lazily concatenate elements using dask even when eager concatenation is possible. + At the moment, this only affects obs/var and elements of obsm/varm that are xarray Datasets. Notes ----- @@ -1503,8 +1628,8 @@ def concat( # noqa: PLR0912, PLR0913, PLR0915 {'a': 1, 'b': 2, 'c': {'c.a': 3, 'c.b': 4, 'c.c': 5}} """ - from anndata.experimental.backed._compat import Dataset2D - from anndata.experimental.backed._compat import xarray as xr + from anndata._core.xarray import Dataset2D + from anndata.compat import xarray as xr # Argument normalization merge = resolve_merge_strategy(merge) @@ -1566,11 +1691,14 @@ def concat( # noqa: PLR0912, PLR0913, PLR0915 ignore_index=True, ) concat_annot.index = concat_indices + if label is not None: + concat_annot[label] = label_col else: - concat_annot = concat_dataset2d_on_annot_axis(annotations, join) + concat_annot = concat_dataset2d_on_annot_axis( + annotations, join, force_lazy=force_lazy, label=label, label_col=label_col + ) concat_indices.name = DS_CONCAT_DUMMY_INDEX_NAME - if label is not None: - concat_annot[label] = label_col + concat_annot.index = concat_indices # Annotation for other axis alt_annotations = [getattr(a, alt_axis_name) for a in adatas] @@ -1592,13 +1720,13 @@ def concat( # noqa: PLR0912, PLR0913, PLR0915 ) ) annotations_with_only_dask = [ - a.rename({a.attrs["indexing_key"]: "merge_index"}) + a.rename({a.true_index_dim: "merge_index"}) for a in annotations_with_only_dask ] alt_annot = Dataset2D( - xr.merge(annotations_with_only_dask, join=join, compat="override"), - attrs={"indexing_key": "merge_index"}, + xr.merge(annotations_with_only_dask, join=join, compat="override") ) + alt_annot.true_index_dim = "merge_index" X = concat_Xs(adatas, reindexers, axis=axis, fill_value=fill_value) @@ -1622,6 +1750,7 @@ def concat( # noqa: PLR0912, PLR0913, PLR0915 axis=axis, concat_axis=0, index=concat_indices, + force_lazy=force_lazy, ) if pairwise: concat_pairwise = concat_pairwise_mapping( diff --git a/src/anndata/_core/storage.py b/src/anndata/_core/storage.py index 0dc59d06d..a986d6ce7 100644 --- a/src/anndata/_core/storage.py +++ b/src/anndata/_core/storage.py @@ -10,11 +10,13 @@ from anndata.compat import CSArray, CSMatrix from .._warnings import ImplicitModificationWarning +from ..compat import XDataset from ..utils import ( ensure_df_homogeneous, join_english, raise_value_error_if_multiindex_columns, ) +from .xarray import Dataset2D if TYPE_CHECKING: from typing import Any @@ -27,15 +29,6 @@ def coerce_array( allow_df: bool = False, allow_array_like: bool = False, ): - try: - from anndata.experimental.backed._compat import Dataset2D - except ImportError: - - class Dataset2D: - @staticmethod - def __repr__(): - return "mock anndata.experimental.backed._xarray." - """Coerce arrays stored in layers/X, and aligned arrays ({obs,var}{m,p}).""" from ..typing import ArrayDataStructureTypes @@ -44,6 +37,8 @@ def __repr__(): return value # If value is one of the allowed types, return it array_data_structure_types = get_args(ArrayDataStructureTypes) + if isinstance(value, XDataset) and not isinstance(value, Dataset2D): + value = Dataset2D(value.data_vars, value.coords, value.attrs) if isinstance(value, (*array_data_structure_types, Dataset2D)): if isinstance(value, np.matrix): msg = f"{name} should not be a np.matrix, use np.ndarray instead." diff --git a/src/anndata/_core/views.py b/src/anndata/_core/views.py index 87d8724dd..ac9a0dd0f 100644 --- a/src/anndata/_core/views.py +++ b/src/anndata/_core/views.py @@ -23,6 +23,7 @@ ZappyArray, ) from .access import ElementRef +from .xarray import Dataset2D if TYPE_CHECKING: from collections.abc import Callable, Iterable, KeysView, Sequence @@ -362,6 +363,11 @@ def as_view_cupy_csc(mtx, view_args): return CupySparseCSCView(mtx, view_args=view_args) +@as_view.register(Dataset2D) +def _(a: Dataset2D, view_args): + return a + + try: import weakref diff --git a/src/anndata/_core/xarray.py b/src/anndata/_core/xarray.py new file mode 100644 index 000000000..82ee71b40 --- /dev/null +++ b/src/anndata/_core/xarray.py @@ -0,0 +1,145 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pandas as pd + +from ..compat import XDataset + +if TYPE_CHECKING: + from ..compat import XDataArray + + +class Dataset2D(XDataset): + """ + A wrapper class meant to enable working with lazy dataframe data. + We do not guarantee the stability of this API beyond that guaranteed + by :class:`xarray.Dataset` and the `to_memory` function, a thin wrapper + around :meth:`xarray.Dataset.to_dataframe` to ensure roundtrip + compatibility here. + """ + + __slots__ = () + + @property + def is_backed(self) -> bool: + return self.attrs.get("is_backed", False) + + @is_backed.setter + def is_backed(self, isbacked: bool): + if not isbacked and "is_backed" in self.attrs: + del self.attrs["is_backed"] + else: + self.attrs["is_backed"] = isbacked + + @property + def index_dim(self) -> str: + if len(self.sizes) != 1: + msg = f"xarray Dataset should not have more than 1 dims, found {len(self.sizes)} {self.sizes}, {self}" + raise ValueError(msg) + return next(iter(self.coords.keys())) + + @property + def true_index_dim(self) -> str: + return self.attrs.get("indexing_key", self.index_dim) + + @true_index_dim.setter + def true_index_dim(self, val: str): + if val is None or (val == self.index_dim and "indexing_key" in self.attrs): + del self.attrs["indexing_key"] + elif val not in self.dims: + if val not in self.data_vars: + msg = f"Unknown variable `{val}`." + raise ValueError(msg) + self.attrs["indexing_key"] = val + + @property + def xr_index(self) -> XDataArray: + return self[self.index_dim] + + @property + def index(self) -> pd.Index: + """:attr:`~anndata.AnnData` internally looks for :attr:`~pandas.DataFrame.index` so this ensures usability + + Returns + ------- + The index of the of the dataframe as resolved from :attr:`~xarray.Dataset.coords`. + """ + return self.indexes[self.index_dim] + + @index.setter + def index(self, val) -> None: + index_dim = self.index_dim + self.coords[index_dim] = (index_dim, val) + if isinstance(val, pd.Index) and val.name is not None and val.name != index_dim: + self.update(self.rename({self.index_dim: val.name})) + del self.coords[index_dim] + # without `indexing_key` explicitly set on `self.attrs`, `self.true_index_dim` will use the `self.index_dim` + if "indexing_key" in self.attrs: + del self.attrs["indexing_key"] + + @property + def true_xr_index(self) -> XDataArray: + return self[self.true_index_dim] + + @property + def true_index(self) -> pd.Index: + return self.true_xr_index.to_index() + + @property + def shape(self) -> tuple[int, int]: + """:attr:`~anndata.AnnData` internally looks for :attr:`~pandas.DataFrame.shape` so this ensures usability + + Returns + ------- + The (2D) shape of the dataframe resolved from :attr:`~xarray.Dataset.sizes`. + """ + return (self.sizes[self.index_dim], len(self)) + + @property + def iloc(self): + """:attr:`~anndata.AnnData` internally looks for :attr:`~pandas.DataFrame.iloc` so this ensures usability + + Returns + ------- + Handler class for doing the iloc-style indexing using :meth:`~xarray.Dataset.isel`. + """ + + class IlocGetter: + def __init__(self, ds): + self._ds = ds + + def __getitem__(self, idx): + coord = self._ds.index_dim + return self._ds.isel(**{coord: idx}) + + return IlocGetter(self) + + def __getitem__(self, idx) -> Dataset2D: + ret = super().__getitem__(idx) + if len(idx) == 0 and not isinstance(idx, tuple): # empty XDataset + ret.coords[self.index_dim] = self.xr_index + return ret + + def to_memory(self, *, copy=False) -> pd.DataFrame: + df = self.to_dataframe() + index_key = self.attrs.get("indexing_key", None) + if df.index.name != index_key and index_key is not None: + df = df.set_index(index_key) + df.index.name = None # matches old AnnData object + return df + + @property + def columns(self) -> pd.Index: + """ + :class:`~anndata.AnnData` internally looks for :attr:`~pandas.DataFrame.columns` so this ensures usability + + Returns + ------- + :class:`pandas.Index` that represents the "columns." + """ + columns = set(self.keys()) + index_key = self.attrs.get("indexing_key", None) + if index_key is not None: + columns.discard(index_key) + return pd.Index(columns) diff --git a/src/anndata/_io/specs/lazy_methods.py b/src/anndata/_io/specs/lazy_methods.py index 6cfbc7ff2..681716daf 100644 --- a/src/anndata/_io/specs/lazy_methods.py +++ b/src/anndata/_io/specs/lazy_methods.py @@ -12,8 +12,9 @@ import anndata as ad from anndata._core.file_backing import filename, get_elem_name +from anndata._core.xarray import Dataset2D from anndata.abc import CSCDataset, CSRDataset -from anndata.compat import DaskArray, H5Array, H5Group, ZarrArray, ZarrGroup +from anndata.compat import DaskArray, H5Array, H5Group, XDataArray, ZarrArray, ZarrGroup from .registry import _LAZY_REGISTRY, IOSpec @@ -21,7 +22,6 @@ from collections.abc import Generator, Mapping, Sequence from typing import Literal, ParamSpec, TypeVar - from anndata.experimental.backed._compat import DataArray, Dataset2D from anndata.experimental.backed._lazy_arrays import CategoricalArray, MaskedArray from ...compat import CSArray, CSMatrix, H5File @@ -220,19 +220,20 @@ def _gen_xarray_dict_iterator_from_elems( elem_dict: dict[str, LazyDataStructures], dim_name: str, index: np.NDArray, -) -> Generator[tuple[str, DataArray], None, None]: - from anndata.experimental.backed._compat import DataArray - from anndata.experimental.backed._compat import xarray as xr +) -> Generator[tuple[str, XDataArray], None, None]: from anndata.experimental.backed._lazy_arrays import CategoricalArray, MaskedArray + from ...compat import XDataArray + from ...compat import xarray as xr + for k, v in elem_dict.items(): if isinstance(v, DaskArray) and k != dim_name: - data_array = DataArray(v, coords=[index], dims=[dim_name], name=k) + data_array = XDataArray(v, coords=[index], dims=[dim_name], name=k) elif isinstance(v, CategoricalArray | MaskedArray) and k != dim_name: variable = xr.Variable( data=xr.core.indexing.LazilyIndexedArray(v), dims=[dim_name] ) - data_array = DataArray( + data_array = XDataArray( variable, coords=[index], dims=[dim_name], @@ -243,7 +244,7 @@ def _gen_xarray_dict_iterator_from_elems( }, ) elif k == dim_name: - data_array = DataArray( + data_array = XDataArray( index, coords=[index], dims=[dim_name], name=dim_name ) else: @@ -263,8 +264,6 @@ def read_dataframe( _reader: LazyReader, use_range_index: bool = False, ) -> Dataset2D: - from anndata.experimental.backed._compat import DataArray, Dataset2D - elem_dict = { k: _reader.read_elem(elem[k]) for k in [*elem.attrs["column-order"], elem.attrs["_index"]] @@ -282,15 +281,17 @@ def read_dataframe( _gen_xarray_dict_iterator_from_elems(elem_dict, dim_name, index) ) if use_range_index: - elem_xarray_dict[DUMMY_RANGE_INDEX_KEY] = DataArray( + elem_xarray_dict[DUMMY_RANGE_INDEX_KEY] = XDataArray( index, coords=[index], dims=[DUMMY_RANGE_INDEX_KEY], name=DUMMY_RANGE_INDEX_KEY, ) + ds = Dataset2D(elem_xarray_dict) + ds.is_backed = True # We ensure the indexing_key attr always points to the true index # so that the roundtrip works even for the `use_range_index` `True` case - ds = Dataset2D(elem_xarray_dict, attrs={"indexing_key": elem.attrs["_index"]}) + ds.true_index_dim = elem.attrs["_index"] return ds diff --git a/src/anndata/_io/specs/registry.py b/src/anndata/_io/specs/registry.py index de158f2b1..95942706b 100644 --- a/src/anndata/_io/specs/registry.py +++ b/src/anndata/_io/specs/registry.py @@ -24,10 +24,11 @@ WriteCallback, _WriteInternal, ) - from anndata.experimental.backed._compat import Dataset2D from anndata.experimental.backed._lazy_arrays import CategoricalArray, MaskedArray from anndata.typing import RWAble + from ..._core.xarray import Dataset2D + T = TypeVar("T") W = TypeVar("W", bound=_WriteInternal) LazyDataStructures = DaskArray | Dataset2D | CategoricalArray | MaskedArray diff --git a/src/anndata/compat/__init__.py b/src/anndata/compat/__init__.py index aa89c0bbe..a81843f35 100644 --- a/src/anndata/compat/__init__.py +++ b/src/anndata/compat/__init__.py @@ -99,6 +99,32 @@ def __repr__(): return "mock dask.array.core.Array" +if find_spec("xarray") or TYPE_CHECKING: + import xarray + from xarray import DataArray as XDataArray + from xarray import Dataset as XDataset + from xarray.backends import BackendArray as XBackendArray + from xarray.backends.zarr import ZarrArrayWrapper as XZarrArrayWrapper +else: + xarray = None + + class XDataArray: + def __repr__(self) -> str: + return "mock DataArray" + + class XDataset: + def __repr__(self) -> str: + return "mock Dataset" + + class XZarrArrayWrapper: + def __repr__(self) -> str: + return "mock ZarrArrayWrapper" + + class XBackendArray: + def __repr__(self) -> str: + return "mock BackendArray" + + # https://github.com/scverse/anndata/issues/1749 def is_cupy_importable() -> bool: try: diff --git a/src/anndata/experimental/backed/_compat.py b/src/anndata/experimental/backed/_compat.py index 7ea06e93b..0657a4be3 100644 --- a/src/anndata/experimental/backed/_compat.py +++ b/src/anndata/experimental/backed/_compat.py @@ -1,33 +1,8 @@ from __future__ import annotations -from importlib.util import find_spec from typing import TYPE_CHECKING -if find_spec("xarray") or TYPE_CHECKING: - import xarray - from xarray import DataArray - from xarray.backends import BackendArray - from xarray.backends.zarr import ZarrArrayWrapper - - -else: - - class DataArray: - def __repr__(self) -> str: - return "mock DataArray" - - xarray = None - - class ZarrArrayWrapper: - def __repr__(self) -> str: - return "mock ZarrArrayWrapper" - - class BackendArray: - def __repr__(self) -> str: - return "mock BackendArray" - - -from ._xarray import Dataset, Dataset2D # noqa: F401 +from ..._core.xarray import Dataset2D if TYPE_CHECKING: from anndata import AnnData diff --git a/src/anndata/experimental/backed/_io.py b/src/anndata/experimental/backed/_io.py index 14237bc21..5dbdd9950 100644 --- a/src/anndata/experimental/backed/_io.py +++ b/src/anndata/experimental/backed/_io.py @@ -141,7 +141,10 @@ def callback(func: Read, /, elem_name: str, elem: StorageType, *, iospec: IOSpec } or "nullable" in iospec.encoding_type ): - if iospec.encoding_type == "dataframe" and elem_name in {"/obs", "/var"}: + if iospec.encoding_type == "dataframe" and ( + elem_name[:4] in {"/obs", "/var"} + or elem_name[:8] in {"/raw/obs", "/raw/var"} + ): return read_elem_lazy(elem, use_range_index=not load_annotation_index) return read_elem_lazy(elem) elif iospec.encoding_type in {"awkward-array"}: diff --git a/src/anndata/experimental/backed/_lazy_arrays.py b/src/anndata/experimental/backed/_lazy_arrays.py index 70b5ac9b6..68685ae36 100644 --- a/src/anndata/experimental/backed/_lazy_arrays.py +++ b/src/anndata/experimental/backed/_lazy_arrays.py @@ -11,8 +11,8 @@ from anndata.compat import H5Array, ZarrArray from ..._settings import settings -from ._compat import BackendArray, DataArray, ZarrArrayWrapper -from ._compat import xarray as xr +from ...compat import XBackendArray, XDataArray, XZarrArrayWrapper +from ...compat import xarray as xr if TYPE_CHECKING: from pathlib import Path @@ -27,7 +27,7 @@ K = TypeVar("K", H5Array, ZarrArray) -class ZarrOrHDF5Wrapper(ZarrArrayWrapper, Generic[K]): +class ZarrOrHDF5Wrapper(XZarrArrayWrapper, Generic[K]): def __init__(self, array: K): self.chunks = array.chunks if isinstance(array, ZarrArray): @@ -48,7 +48,7 @@ def __getitem__(self, key: xr.core.indexing.ExplicitIndexer): ) -class CategoricalArray(BackendArray, Generic[K]): +class CategoricalArray(XBackendArray, Generic[K]): """ A wrapper class meant to enable working with lazy categorical data. We do not guarantee the stability of this API beyond that guaranteed @@ -103,7 +103,7 @@ def dtype(self): return pd.CategoricalDtype(categories=self.categories, ordered=self._ordered) -class MaskedArray(BackendArray, Generic[K]): +class MaskedArray(XBackendArray, Generic[K]): """ A wrapper class meant to enable working with lazy masked data. We do not guarantee the stability of this API beyond that guaranteed @@ -168,13 +168,13 @@ def dtype(self): raise RuntimeError(msg) -@_subset.register(DataArray) -def _subset_masked(a: DataArray, subset_idx: Index): +@_subset.register(XDataArray) +def _subset_masked(a: XDataArray, subset_idx: Index): return a[subset_idx] -@as_view.register(DataArray) -def _view_pd_boolean_array(a: DataArray, view_args): +@as_view.register(XDataArray) +def _view_pd_boolean_array(a: XDataArray, view_args): return a diff --git a/src/anndata/experimental/backed/_xarray.py b/src/anndata/experimental/backed/_xarray.py deleted file mode 100644 index e5420a45b..000000000 --- a/src/anndata/experimental/backed/_xarray.py +++ /dev/null @@ -1,147 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -import pandas as pd - -from ..._core.anndata import AnnData, _gen_dataframe -from ..._core.file_backing import to_memory -from ..._core.index import _subset -from ..._core.views import as_view - -try: - from xarray import Dataset -except ImportError: - - class Dataset: - def __repr__(self) -> str: - return "mock Dataset" - - -if TYPE_CHECKING: - from collections.abc import Hashable, Iterable - from typing import Any, Literal - - from ..._core.index import Index - from ._compat import xarray as xr - - -def get_index_dim(ds: xr.DataArray) -> Hashable: - if len(ds.sizes) != 1: - msg = f"xarray Dataset should not have more than 1 dims, found {len(ds.sizes)} {ds.sizes}, {ds}" - raise ValueError(msg) - return next(iter(ds.indexes.keys())) - - -class Dataset2D(Dataset): - """ - A wrapper class meant to enable working with lazy dataframe data. - We do not guarantee the stability of this API beyond that guaranteed - by :class:`xarray.Dataset` and the `to_memory` function, a thin wrapper - around :meth:`xarray.Dataset.to_dataframe` to ensure roundtrip - compatibility here. - """ - - __slots__ = () - - @property - def index(self) -> pd.Index: - """:attr:`~anndata.AnnData` internally looks for :attr:`~pandas.DataFrame.index` so this ensures usability - - Returns - ------- - The index of the of the dataframe as resolved from :attr:`~xarray.Dataset.coords`. - """ - coord = get_index_dim(self) - return self.indexes[coord] - - @index.setter - def index(self, val) -> None: - coord = get_index_dim(self) - self.coords[coord] = val - - @property - def shape(self) -> tuple[int, int]: - """:attr:`~anndata.AnnData` internally looks for :attr:`~pandas.DataFrame.shape` so this ensures usability - - Returns - ------- - The (2D) shape of the dataframe resolved from :attr:`~xarray.Dataset.sizes`. - """ - return (self.sizes[get_index_dim(self)], len(self)) - - @property - def iloc(self): - """:attr:`~anndata.AnnData` internally looks for :attr:`~pandas.DataFrame.iloc` so this ensures usability - - Returns - ------- - Handler class for doing the iloc-style indexing using :meth:`~xarray.Dataset.isel`. - """ - - class IlocGetter: - def __init__(self, ds): - self._ds = ds - - def __getitem__(self, idx): - coord = get_index_dim(self._ds) - return self._ds.isel(**{coord: idx}) - - return IlocGetter(self) - - def to_memory(self, *, copy=False) -> pd.DataFrame: - df = self.to_dataframe() - index_key = self.attrs.get("indexing_key", None) - if df.index.name != index_key and index_key is not None: - df = df.set_index(index_key) - df.index.name = None # matches old AnnData object - return df - - @property - def columns(self) -> pd.Index: - """ - :class:`~anndata.AnnData` internally looks for :attr:`~pandas.DataFrame.columns` so this ensures usability - - Returns - ------- - :class:`pandas.Index` that represents the "columns." - """ - columns_list = list(self.keys()) - return pd.Index(columns_list) - - -@_subset.register(Dataset2D) -def _(a: Dataset2D, subset_idx: Index): - key = get_index_dim(a) - # xarray seems to have some code looking for a second entry in tuples - if isinstance(subset_idx, tuple) and len(subset_idx) == 1: - subset_idx = subset_idx[0] - return a.isel(**{key: subset_idx}) - - -@as_view.register(Dataset2D) -def _(a: Dataset2D, view_args): - return a - - -@_gen_dataframe.register(Dataset2D) -def _gen_dataframe_xr( - anno: Dataset2D, - index_names: Iterable[str], - *, - source: Literal["X", "shape"], - attr: Literal["obs", "var"], - length: int | None = None, -): - return anno - - -@AnnData._remove_unused_categories.register(Dataset2D) -@staticmethod -def _remove_unused_categories_xr( - df_full: Dataset2D, df_sub: Dataset2D, uns: dict[str, Any] -): - pass # this is handled automatically by the categorical arrays themselves i.e., they dedup upon access. - - -to_memory.register(Dataset2D, Dataset2D.to_memory) diff --git a/src/anndata/tests/helpers.py b/src/anndata/tests/helpers.py index ecbbda7ba..674fa4d96 100644 --- a/src/anndata/tests/helpers.py +++ b/src/anndata/tests/helpers.py @@ -30,6 +30,8 @@ CupyCSRMatrix, CupySparseMatrix, DaskArray, + XDataArray, + XDataset, ZarrArray, is_zarr_v2, ) @@ -56,32 +58,6 @@ ) -# Give this to gen_adata when dask array support is expected. -GEN_ADATA_DASK_ARGS = dict( - obsm_types=( - sparse.csr_matrix, - np.ndarray, - pd.DataFrame, - DaskArray, - sparse.csr_array, - ), - varm_types=( - sparse.csr_matrix, - np.ndarray, - pd.DataFrame, - DaskArray, - sparse.csr_array, - ), - layers_types=( - sparse.csr_matrix, - np.ndarray, - pd.DataFrame, - DaskArray, - sparse.csr_array, - ), -) - - DEFAULT_KEY_TYPES = ( sparse.csr_matrix, np.ndarray, @@ -102,6 +78,18 @@ ) +# Give this to gen_adata when dask array support is expected. +GEN_ADATA_DASK_ARGS = dict( + obsm_types=(*DEFAULT_KEY_TYPES, DaskArray), + varm_types=(*DEFAULT_KEY_TYPES, DaskArray), + layers_types=(*DEFAULT_KEY_TYPES, DaskArray), +) + +GEN_ADATA_NO_XARRAY_ARGS = dict( + obsm_types=(*DEFAULT_KEY_TYPES, AwkArray), varm_types=(*DEFAULT_KEY_TYPES, AwkArray) +) + + def gen_vstr_recarray(m, n, dtype=None): size = m * n lengths = np.random.randint(3, 5, size) @@ -288,8 +276,10 @@ def gen_adata( # noqa: PLR0913 var_dtypes: Collection[ np.dtype | pd.api.extensions.ExtensionDtype ] = DEFAULT_COL_TYPES, - obsm_types: Collection[type] = (*DEFAULT_KEY_TYPES, AwkArray), - varm_types: Collection[type] = (*DEFAULT_KEY_TYPES, AwkArray), + obs_xdataset: bool = False, + var_xdataset: bool = False, + obsm_types: Collection[type] = (*DEFAULT_KEY_TYPES, AwkArray, XDataset), + varm_types: Collection[type] = (*DEFAULT_KEY_TYPES, AwkArray, XDataset), layers_types: Collection[type] = DEFAULT_KEY_TYPES, random_state: np.random.Generator | None = None, sparse_fmt: Literal["csr", "csc"] = "csr", @@ -321,6 +311,7 @@ def gen_adata( # noqa: PLR0913 (csr, csc) """ import dask.array as da + import xarray as xr if random_state is None: random_state = np.random.default_rng() @@ -334,6 +325,11 @@ def gen_adata( # noqa: PLR0913 obs.rename(columns=dict(cat="obs_cat"), inplace=True) var.rename(columns=dict(cat="var_cat"), inplace=True) + if obs_xdataset: + obs = XDataset.from_dataframe(obs) + if var_xdataset: + var = XDataset.from_dataframe(var) + if X_type is None: X = None else: @@ -345,6 +341,9 @@ def gen_adata( # noqa: PLR0913 df=gen_typed_df(M, obs_names, dtypes=obs_dtypes), awk_2d_ragged=gen_awkward((M, None)), da=da.random.random((M, 50)), + xdataset=xr.Dataset.from_dataframe( + gen_typed_df(M, obs_names, dtypes=obs_dtypes) + ), ) obsm = {k: v for k, v in obsm.items() if type(v) in obsm_types} obsm = maybe_add_sparse_array( @@ -360,6 +359,9 @@ def gen_adata( # noqa: PLR0913 df=gen_typed_df(N, var_names, dtypes=var_dtypes), awk_2d_ragged=gen_awkward((N, None)), da=da.random.random((N, 50)), + xdataset=xr.Dataset.from_dataframe( + gen_typed_df(N, var_names, dtypes=var_dtypes) + ), ) varm = {k: v for k, v in varm.items() if type(v) in varm_types} varm = maybe_add_sparse_array( @@ -740,6 +742,13 @@ def assert_equal_extension_array( ) +@assert_equal.register(XDataArray) +def assert_equal_xarray( + a: XDataArray, b: object, *, exact: bool = False, elem_name: str | None = None +): + report_name(a.equals)(b, _elem_name=elem_name) + + @assert_equal.register(Raw) def assert_equal_raw( a: Raw, b: object, *, exact: bool = False, elem_name: str | None = None diff --git a/src/anndata/typing.py b/src/anndata/typing.py index f0cf974b4..25e279248 100644 --- a/src/anndata/typing.py +++ b/src/anndata/typing.py @@ -16,6 +16,7 @@ CupySparseMatrix, DaskArray, H5Array, + XDataArray, ZappyArray, ZarrArray, ) @@ -45,7 +46,7 @@ | CupyArray | CupySparseMatrix ) -ArrayDataStructureTypes: TypeAlias = XDataType | AwkArray +ArrayDataStructureTypes: TypeAlias = XDataType | AwkArray | XDataArray InMemoryArrayOrScalarType: TypeAlias = ( diff --git a/tests/lazy/conftest.py b/tests/lazy/conftest.py index 4a153c25f..6e181c70b 100644 --- a/tests/lazy/conftest.py +++ b/tests/lazy/conftest.py @@ -14,7 +14,9 @@ from anndata.experimental import read_lazy from anndata.tests.helpers import ( DEFAULT_COL_TYPES, + DEFAULT_KEY_TYPES, AccessTrackingStore, + AwkArray, as_dense_dask_array, gen_adata, gen_typed_df, @@ -92,6 +94,8 @@ def adata_remote_orig_with_path( mtx_format, obs_dtypes=(*DEFAULT_COL_TYPES, pd.StringDtype), var_dtypes=(*DEFAULT_COL_TYPES, pd.StringDtype), + obsm_types=(*DEFAULT_KEY_TYPES, AwkArray), + varm_types=(*DEFAULT_KEY_TYPES, AwkArray), ) orig.raw = orig.copy() with ad.settings.override(allow_write_nullable_strings=True): diff --git a/tests/lazy/test_concat.py b/tests/lazy/test_concat.py index f04db4046..152d7e50c 100644 --- a/tests/lazy/test_concat.py +++ b/tests/lazy/test_concat.py @@ -11,7 +11,7 @@ import anndata as ad from anndata._core.file_backing import to_memory from anndata.experimental import read_lazy -from anndata.tests.helpers import assert_equal, gen_adata +from anndata.tests.helpers import GEN_ADATA_NO_XARRAY_ARGS, assert_equal, gen_adata from .conftest import ANNDATA_ELEMS, get_key_trackers_for_columns_on_axis @@ -254,7 +254,7 @@ def test_concat_data_subsetting( join: Join_T, index: slice | NDArray | Literal["a"] | None, ): - from anndata.experimental.backed._compat import Dataset2D + from anndata._core.xarray import Dataset2D remote_concatenated = ad.concat([adata_remote, adata_remote], join=join) if index is not None: @@ -312,7 +312,7 @@ def with_elem_in_memory(adata: AnnData, attr: str, key: str | None) -> AnnData: def test_concat_bad_mixed_types(tmp_path: Path): - orig = gen_adata((100, 200), np.array) + orig = gen_adata((100, 200), np.array, **GEN_ADATA_NO_XARRAY_ARGS) orig.write_zarr(tmp_path) remote = read_lazy(tmp_path) orig.obsm["df"] = orig.obsm["array"] diff --git a/tests/lazy/test_read.py b/tests/lazy/test_read.py index cef2d97e5..ad112b178 100644 --- a/tests/lazy/test_read.py +++ b/tests/lazy/test_read.py @@ -3,11 +3,19 @@ from importlib.util import find_spec from typing import TYPE_CHECKING +import numpy as np +import pandas as pd import pytest +from anndata import AnnData from anndata.compat import DaskArray from anndata.experimental import read_lazy -from anndata.tests.helpers import AccessTrackingStore, assert_equal, gen_adata +from anndata.tests.helpers import ( + GEN_ADATA_NO_XARRAY_ARGS, + AccessTrackingStore, + assert_equal, + gen_adata, +) from .conftest import ANNDATA_ELEMS @@ -15,7 +23,6 @@ from collections.abc import Callable from pathlib import Path - from anndata import AnnData from anndata._types import AnnDataElem pytestmark = pytest.mark.skipif(not find_spec("xarray"), reason="xarray not installed") @@ -107,6 +114,21 @@ def test_to_memory(adata_remote: AnnData, adata_orig: AnnData): assert_equal(remote_to_memory, adata_orig) +def test_access_counts_obsm_df(tmp_path: Path): + adata = AnnData( + X=np.array(np.random.rand(100, 20)), + ) + adata.obsm["df"] = pd.DataFrame( + {"col1": np.random.rand(100), "col2": np.random.rand(100)}, + index=adata.obs_names, + ) + adata.write_zarr(tmp_path) + store = AccessTrackingStore(tmp_path) + store.initialize_key_trackers(["obsm/df"]) + read_lazy(store, load_annotation_index=False) + store.assert_access_count("obsm/df", 0) + + def test_view_to_memory(adata_remote: AnnData, adata_orig: AnnData): obs_cats = adata_orig.obs["obs_cat"].cat.categories subset_obs = adata_orig.obs["obs_cat"] == obs_cats[0] @@ -144,7 +166,7 @@ def test_view_of_view_to_memory(adata_remote: AnnData, adata_orig: AnnData): def test_unconsolidated(tmp_path: Path, mtx_format): - adata = gen_adata((1000, 1000), mtx_format) + adata = gen_adata((1000, 1000), mtx_format, **GEN_ADATA_NO_XARRAY_ARGS) orig_pth = tmp_path / "orig.zarr" adata.write_zarr(orig_pth) (orig_pth / ".zmetadata").unlink() diff --git a/tests/test_backed_hdf5.py b/tests/test_backed_hdf5.py index ecf2ef03e..75239adee 100644 --- a/tests/test_backed_hdf5.py +++ b/tests/test_backed_hdf5.py @@ -13,6 +13,7 @@ from anndata.compat import CSArray, CSMatrix from anndata.tests.helpers import ( GEN_ADATA_DASK_ARGS, + GEN_ADATA_NO_XARRAY_ARGS, as_dense_dask_array, assert_equal, gen_adata, @@ -196,7 +197,7 @@ def test_backed_raw(tmp_path): def test_backed_raw_subset(tmp_path, array_type, subset_func, subset_func2): backed_pth = tmp_path / "backed.h5ad" final_pth = tmp_path / "final.h5ad" - mem_adata = gen_adata((10, 10), X_type=array_type) + mem_adata = gen_adata((10, 10), X_type=array_type, **GEN_ADATA_NO_XARRAY_ARGS) mem_adata.raw = mem_adata obs_idx = subset_func(mem_adata.obs_names) var_idx = subset_func2(mem_adata.var_names) diff --git a/tests/test_base.py b/tests/test_base.py index 18f9f52ab..5d326c88e 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -17,7 +17,12 @@ from anndata import AnnData, ImplicitModificationWarning from anndata._core.raw import Raw from anndata._settings import settings -from anndata.tests.helpers import assert_equal, gen_adata, get_multiindex_columns_df +from anndata.tests.helpers import ( + GEN_ADATA_NO_XARRAY_ARGS, + assert_equal, + gen_adata, + get_multiindex_columns_df, +) if TYPE_CHECKING: from pathlib import Path @@ -278,11 +283,14 @@ def test_setting_index_names_error(attr): @pytest.mark.parametrize("dim", ["obs", "var"]) -def test_setting_dim_index(dim): +@pytest.mark.parametrize( + ("obs_xdataset", "var_xdataset"), [(False, False), (True, True)] +) +def test_setting_dim_index(dim, obs_xdataset, var_xdataset): index_attr = f"{dim}_names" mapping_attr = f"{dim}m" - orig = gen_adata((5, 5)) + orig = gen_adata((5, 5), obs_xdataset=obs_xdataset, var_xdataset=var_xdataset) orig.raw = orig.copy() curr = orig.copy() view = orig[:, :] @@ -516,12 +524,10 @@ def test_set_obs(): adata = AnnData(np.array([[1, 2, 3], [4, 5, 6]])) adata.obs = pd.DataFrame(dict(a=[3, 4])) - assert adata.obs_names.tolist() == [0, 1] + assert adata.obs_names.tolist() == ["0", "1"] - with pytest.raises(ValueError, match="but this AnnData has shape"): + with pytest.raises(ValueError, match="`shape` is inconsistent with `obs`"): adata.obs = pd.DataFrame(dict(a=[3, 4, 5])) - with pytest.raises(ValueError, match="Can only assign pd.DataFrame"): - adata.obs = dict(a=[1, 2]) def test_multicol(): @@ -730,7 +736,7 @@ def assert_eq_not_id(a, b): def test_to_memory_no_copy(): - adata = gen_adata((3, 5)) + adata = gen_adata((3, 5), **GEN_ADATA_NO_XARRAY_ARGS) mem = adata.to_memory() assert mem.X is adata.X diff --git a/tests/test_concatenate.py b/tests/test_concatenate.py index e8f01d8bb..7b0cac8f8 100644 --- a/tests/test_concatenate.py +++ b/tests/test_concatenate.py @@ -2,7 +2,6 @@ import warnings from collections.abc import Hashable -from contextlib import nullcontext from copy import deepcopy from functools import partial, singledispatch from itertools import chain, permutations, product @@ -13,6 +12,7 @@ import pandas as pd import pytest import scipy +import xarray as xr from boltons.iterutils import default_exit, remap, research from numpy import ma from packaging.version import Version @@ -21,7 +21,14 @@ from anndata import AnnData, Raw, concat from anndata._core import merge from anndata._core.index import _subset -from anndata.compat import AwkArray, CSArray, CSMatrix, CupySparseMatrix, DaskArray +from anndata.compat import ( + AwkArray, + CSArray, + CSMatrix, + CupySparseMatrix, + DaskArray, + XDataset, +) from anndata.tests import helpers from anndata.tests.helpers import ( BASE_MATRIX_PARAMS, @@ -140,6 +147,32 @@ def fix_known_differences( orig = orig.copy() result = result.copy() + if backwards_compat: + del orig.varm + del orig.varp + if isinstance(result.obs, XDataset): + result.obs = result.obs.drop_vars(["batch"]) + else: + result.obs.drop(columns=["batch"], inplace=True) + + for attrname in ("obs", "var"): + if isinstance(getattr(result, attrname), XDataset): + for adata in (orig, result): + df = getattr(adata, attrname).to_dataframe() + df.index.name = "index" + setattr(adata, attrname, df) + resattr = getattr(result, attrname) + origattr = getattr(orig, attrname) + for colname, col in resattr.items(): + # concatenation of XDatasets happens via Dask arrays and those don't know about Pandas Extension arrays + # so categoricals and nullable arrays are all converted to other dtypes + if col.dtype != origattr[ + colname + ].dtype and pd.api.types.is_extension_array_dtype( + origattr[colname].dtype + ): + resattr[colname] = col.astype(origattr[colname].dtype) + result.strings_to_categoricals() # Should this be implicit in concatenation? # TODO @@ -147,11 +180,6 @@ def fix_known_differences( # * merge obsp, but some information should be lost del orig.obsp # TODO - if backwards_compat: - del orig.varm - del orig.varp - result.obs.drop(columns=["batch"], inplace=True) - # Possibly need to fix this, ordered categoricals lose orderedness for get_df in [lambda k: k.obs, lambda k: k.obsm["df"]]: str_to_df_converted = get_df(result) @@ -162,8 +190,14 @@ def fix_known_differences( return orig, result -def test_concat_interface_errors(): - adatas = [gen_adata((5, 10)), gen_adata((5, 10))] +@pytest.mark.parametrize( + ("obs_xdataset", "var_xdataset"), [(False, False), (True, True)] +) +def test_concat_interface_errors(obs_xdataset, var_xdataset): + adatas = [ + gen_adata((5, 10), obs_xdataset=obs_xdataset, var_xdataset=var_xdataset), + gen_adata((5, 10), obs_xdataset=obs_xdataset, var_xdataset=var_xdataset), + ] with pytest.raises(ValueError, match="`axis` must be.*0, 1, 'obs', or 'var'"): concat(adatas, axis=3) @@ -181,8 +215,28 @@ def test_concat_interface_errors(): (lambda x, **kwargs: x[0].concatenate(x[1:], **kwargs), True), ], ) -def test_concatenate_roundtrip(join_type, array_type, concat_func, backwards_compat): - adata = gen_adata((100, 10), X_type=array_type, **GEN_ADATA_DASK_ARGS) +@pytest.mark.parametrize( + ("obs_xdataset", "var_xdataset", "force_lazy"), + [(False, False, False), (True, True, False), (True, True, True)], +) +def test_concatenate_roundtrip( + join_type, + array_type, + concat_func, + backwards_compat, + obs_xdataset, + var_xdataset, + force_lazy, +): + if backwards_compat and force_lazy: + pytest.skip("unsupported") + adata = gen_adata( + (100, 10), + X_type=array_type, + obs_xdataset=obs_xdataset, + var_xdataset=var_xdataset, + **GEN_ADATA_DASK_ARGS, + ) remaining = adata.obs_names subsets = [] @@ -192,7 +246,17 @@ def test_concatenate_roundtrip(join_type, array_type, concat_func, backwards_com subsets.append(adata[subset_idx]) remaining = remaining.difference(subset_idx) + if ( + backwards_compat + and (obs_xdataset or var_xdataset) + and Version(xr.__version__) < Version("2025.4.0") + ): + pytest.xfail("https://github.com/pydata/xarray/issues/10218") result = concat_func(subsets, join=join_type, uns_merge="same", index_unique=None) + if backwards_compat and var_xdataset: + result.var = xr.Dataset.from_dataframe( + result.var + ) # backwards compat always returns a dataframe # Correcting for known differences orig, result = fix_known_differences( @@ -497,19 +561,19 @@ def get_obs_els(adata): adata1.obsm = { k: v for k, v in adata1.obsm.items() - if not isinstance(v, pd.DataFrame | AwkArray) + if not isinstance(v, pd.DataFrame | AwkArray | XDataset) } adata2 = gen_adata((10, 5)) adata2.obsm = { k: v[:, : v.shape[1] // 2] for k, v in adata2.obsm.items() - if not isinstance(v, pd.DataFrame | AwkArray) + if not isinstance(v, pd.DataFrame | AwkArray | XDataset) } adata3 = gen_adata((7, 3)) adata3.obsm = { k: v[:, : v.shape[1] // 3] for k, v in adata3.obsm.items() - if not isinstance(v, pd.DataFrame | AwkArray) + if not isinstance(v, pd.DataFrame | AwkArray | XDataset) } # remove AwkArrays from adata.var, as outer joins are not yet implemented for them for tmp_ad in [adata1, adata2, adata3]: @@ -1161,28 +1225,78 @@ def test_concatenate_uns(unss, merge_strategy, result, value_gen): assert_equal(merged, result, elem_name="uns") -def test_transposed_concat(array_type, axis_name, join_type, merge_strategy): +@pytest.mark.parametrize( + ("obs_xdataset", "var_xdataset", "force_lazy"), + [(False, False, False), (True, True, False), (True, True, True)], +) +def test_transposed_concat( + array_type, + axis_name, + join_type, + merge_strategy, + obs_xdataset, + var_xdataset, + force_lazy, +): axis, axis_name = merge._resolve_axis(axis_name) alt_axis = 1 - axis - lhs = gen_adata((10, 10), X_type=array_type, **GEN_ADATA_DASK_ARGS) - rhs = gen_adata((10, 12), X_type=array_type, **GEN_ADATA_DASK_ARGS) - - a = concat([lhs, rhs], axis=axis, join=join_type, merge=merge_strategy) - b = concat([lhs.T, rhs.T], axis=alt_axis, join=join_type, merge=merge_strategy).T + lhs = gen_adata( + (10, 10), + X_type=array_type, + obs_xdataset=obs_xdataset, + var_xdataset=var_xdataset, + **GEN_ADATA_DASK_ARGS, + ) + rhs = gen_adata( + (10, 12), + X_type=array_type, + obs_xdataset=obs_xdataset, + var_xdataset=var_xdataset, + **GEN_ADATA_DASK_ARGS, + ) + + a = concat( + [lhs, rhs], + axis=axis, + join=join_type, + merge=merge_strategy, + force_lazy=force_lazy, + ) + b = concat( + [lhs.T, rhs.T], + axis=alt_axis, + join=join_type, + merge=merge_strategy, + force_lazy=force_lazy, + ).T assert_equal(a, b) -def test_batch_key(axis_name): +@pytest.mark.parametrize( + ("obs_xdataset", "var_xdataset", "force_lazy"), + [(False, False, False), (True, True, False), (True, True, True)], +) +def test_batch_key(axis_name, obs_xdataset, var_xdataset, force_lazy): """Test that concat only adds a label if the key is provided""" get_annot = attrgetter(axis_name) - lhs = gen_adata((10, 10), **GEN_ADATA_DASK_ARGS) - rhs = gen_adata((10, 12), **GEN_ADATA_DASK_ARGS) + lhs = gen_adata( + (10, 10), + obs_xdataset=obs_xdataset, + var_xdataset=var_xdataset, + **GEN_ADATA_DASK_ARGS, + ) + rhs = gen_adata( + (10, 12), + obs_xdataset=obs_xdataset, + var_xdataset=var_xdataset, + **GEN_ADATA_DASK_ARGS, + ) # There is probably a prettier way to do this - annot = get_annot(concat([lhs, rhs], axis=axis_name)) + annot = get_annot(concat([lhs, rhs], axis=axis_name, force_lazy=force_lazy)) assert ( list( annot.columns.difference( @@ -1192,7 +1306,9 @@ def test_batch_key(axis_name): == [] ) - batch_annot = get_annot(concat([lhs, rhs], axis=axis_name, label="batch")) + batch_annot = get_annot( + concat([lhs, rhs], axis=axis_name, label="batch", force_lazy=force_lazy) + ) assert list( batch_annot.columns.difference( get_annot(lhs).columns.union(get_annot(rhs).columns) @@ -1200,16 +1316,20 @@ def test_batch_key(axis_name): ) == ["batch"] -def test_concat_categories_from_mapping(): +@pytest.mark.parametrize( + ("obs_xdataset", "var_xdataset", "force_lazy"), + [(False, False, False), (True, True, False), (True, True, True)], +) +def test_concat_categories_from_mapping(obs_xdataset, var_xdataset, force_lazy): mapping = { - "a": gen_adata((10, 10)), - "b": gen_adata((10, 10)), + "a": gen_adata((10, 10), obs_xdataset=obs_xdataset, var_xdataset=var_xdataset), + "b": gen_adata((10, 10), obs_xdataset=obs_xdataset, var_xdataset=var_xdataset), } keys = list(mapping.keys()) adatas = list(mapping.values()) - mapping_call = partial(concat, mapping) - iter_call = partial(concat, adatas, keys=keys) + mapping_call = partial(concat, mapping, force_lazy=force_lazy) + iter_call = partial(concat, adatas, keys=keys, force_lazy=force_lazy) assert_equal(mapping_call(), iter_call()) assert_equal(mapping_call(label="batch"), iter_call(label="batch")) @@ -1339,15 +1459,21 @@ def test_bool_promotion(): assert result.obs["bool"].dtype == np.dtype(bool) -def test_concat_names(axis_name): +@pytest.mark.parametrize( + ("obs_xdataset", "var_xdataset", "force_lazy"), + [(False, False, False), (True, True, False), (True, True, True)], +) +def test_concat_names(axis_name, obs_xdataset, var_xdataset, force_lazy): get_annot = attrgetter(axis_name) - lhs = gen_adata((10, 10)) - rhs = gen_adata((10, 10)) + lhs = gen_adata((10, 10), obs_xdataset=obs_xdataset, var_xdataset=var_xdataset) + rhs = gen_adata((10, 10), obs_xdataset=obs_xdataset, var_xdataset=var_xdataset) - assert not get_annot(concat([lhs, rhs], axis=axis_name)).index.is_unique + assert not get_annot( + concat([lhs, rhs], axis=axis_name, force_lazy=force_lazy) + ).index.is_unique assert get_annot( - concat([lhs, rhs], axis=axis_name, index_unique="-") + concat([lhs, rhs], axis=axis_name, index_unique="-", force_lazy=force_lazy) ).index.is_unique @@ -1376,33 +1502,43 @@ def expected_shape( @pytest.mark.parametrize( "shape", [pytest.param((8, 0), id="no_var"), pytest.param((0, 10), id="no_obs")] ) -def test_concat_size_0_axis(axis_name, join_type, merge_strategy, shape): +@pytest.mark.parametrize( + ("obs_xdataset", "var_xdataset", "force_lazy"), + [(False, False, False), (True, True, False), (True, True, True)], +) +def test_concat_size_0_axis( + axis_name, join_type, merge_strategy, shape, obs_xdataset, var_xdataset, force_lazy +): """Regression test for https://github.com/scverse/anndata/issues/526""" axis, axis_name = merge._resolve_axis(axis_name) alt_axis = 1 - axis col_dtypes = (*DEFAULT_COL_TYPES, pd.StringDtype) - a = gen_adata((5, 7), obs_dtypes=col_dtypes, var_dtypes=col_dtypes) - b = gen_adata(shape, obs_dtypes=col_dtypes, var_dtypes=col_dtypes) + a = gen_adata( + (5, 7), + obs_dtypes=col_dtypes, + var_dtypes=col_dtypes, + obs_xdataset=obs_xdataset, + var_xdataset=var_xdataset, + ) + b = gen_adata( + shape, + obs_dtypes=col_dtypes, + var_dtypes=col_dtypes, + obs_xdataset=obs_xdataset, + var_xdataset=var_xdataset, + ) expected_size = expected_shape(a, b, axis=axis, join=join_type) - ctx_concat_empty = ( - pytest.warns( - FutureWarning, - match=r"The behavior of DataFrame concatenation with empty or all-NA entries is deprecated", - ) - if shape[axis] == 0 and Version(pd.__version__) >= Version("2.1") - else nullcontext() - ) - with ctx_concat_empty: - result = concat( - {"a": a, "b": b}, - axis=axis, - join=join_type, - merge=merge_strategy, - pairwise=True, - index_unique="-", - ) + result = concat( + {"a": a, "b": b}, + axis=axis, + join=join_type, + merge=merge_strategy, + pairwise=True, + index_unique="-", + force_lazy=force_lazy, + ) assert result.shape == expected_size if join_type == "outer": @@ -1441,12 +1577,30 @@ def test_concat_size_0_axis(axis_name, join_type, merge_strategy, shape): @pytest.mark.parametrize("elem", ["sparse", "array", "df", "da"]) @pytest.mark.parametrize("axis", ["obs", "var"]) -def test_concat_outer_aligned_mapping(elem, axis): - a = gen_adata((5, 5), **GEN_ADATA_DASK_ARGS) - b = gen_adata((3, 5), **GEN_ADATA_DASK_ARGS) +@pytest.mark.parametrize( + ("obs_xdataset", "var_xdataset", "force_lazy"), + [(False, False, False), (True, True, False), (True, True, True)], +) +def test_concat_outer_aligned_mapping( + elem, axis, obs_xdataset, var_xdataset, force_lazy +): + a = gen_adata( + (5, 5), + obs_xdataset=obs_xdataset, + var_xdataset=var_xdataset, + **GEN_ADATA_DASK_ARGS, + ) + b = gen_adata( + (3, 5), + obs_xdataset=obs_xdataset, + var_xdataset=var_xdataset, + **GEN_ADATA_DASK_ARGS, + ) del getattr(b, f"{axis}m")[elem] - concated = concat({"a": a, "b": b}, join="outer", label="group", axis=axis) + concated = concat( + {"a": a, "b": b}, join="outer", label="group", axis=axis, force_lazy=force_lazy + ) mask = getattr(concated, axis)["group"] == "b" result = getattr( @@ -1469,8 +1623,15 @@ def test_concatenate_size_0_axis(): assert b.concatenate([a]).shape == (10, 0) -def test_concat_null_X(): - adatas_orig = {k: gen_adata((20, 10)) for k in list("abc")} +@pytest.mark.parametrize( + ("obs_xdataset", "var_xdataset", "force_lazy"), + [(False, False, False), (True, True, False), (True, True, True)], +) +def test_concat_null_X(obs_xdataset, var_xdataset, force_lazy): + adatas_orig = { + k: gen_adata((20, 10), obs_xdataset=obs_xdataset, var_xdataset=var_xdataset) + for k in list("abc") + } adatas_no_X = {} for k, v in adatas_orig.items(): v = v.copy() diff --git a/tests/test_io_conversion.py b/tests/test_io_conversion.py index 763c89233..a1a778f62 100644 --- a/tests/test_io_conversion.py +++ b/tests/test_io_conversion.py @@ -11,7 +11,7 @@ import anndata as ad from anndata.compat import CSMatrix -from anndata.tests.helpers import assert_equal, gen_adata +from anndata.tests.helpers import GEN_ADATA_NO_XARRAY_ARGS, assert_equal, gen_adata @pytest.fixture( @@ -39,7 +39,7 @@ def test_sparse_to_dense_disk(tmp_path, mtx_format, to_convert): mem_pth = tmp_path / "orig.h5ad" dense_from_mem_pth = tmp_path / "dense_mem.h5ad" dense_from_disk_pth = tmp_path / "dense_disk.h5ad" - mem = gen_adata((50, 50), mtx_format) + mem = gen_adata((50, 50), mtx_format, **GEN_ADATA_NO_XARRAY_ARGS) mem.raw = mem.copy() mem.write_h5ad(mem_pth) @@ -66,7 +66,7 @@ def test_sparse_to_dense_disk(tmp_path, mtx_format, to_convert): def test_sparse_to_dense_inplace(tmp_path, spmtx_format): pth = tmp_path / "adata.h5ad" - orig = gen_adata((50, 50), spmtx_format) + orig = gen_adata((50, 50), spmtx_format, **GEN_ADATA_NO_XARRAY_ARGS) orig.raw = orig.copy() orig.write(pth) backed = ad.read_h5ad(pth, backed="r+") @@ -97,7 +97,7 @@ def test_sparse_to_dense_errors(tmp_path): def test_dense_to_sparse_memory(tmp_path, spmtx_format, to_convert): dense_path = tmp_path / "dense.h5ad" - orig = gen_adata((50, 50), np.array) + orig = gen_adata((50, 50), np.array, **GEN_ADATA_NO_XARRAY_ARGS) orig.raw = orig.copy() orig.write_h5ad(dense_path) assert not isinstance(orig.X, CSMatrix) diff --git a/tests/test_io_dispatched.py b/tests/test_io_dispatched.py index d7dc8354a..2cd48db44 100644 --- a/tests/test_io_dispatched.py +++ b/tests/test_io_dispatched.py @@ -10,7 +10,7 @@ from anndata._io.zarr import open_write_group from anndata.compat import CSArray, CSMatrix, ZarrGroup, is_zarr_v2 from anndata.experimental import read_dispatched, write_dispatched -from anndata.tests.helpers import assert_equal, gen_adata +from anndata.tests.helpers import GEN_ADATA_NO_XARRAY_ARGS, assert_equal, gen_adata if TYPE_CHECKING: from collections.abc import Callable @@ -26,7 +26,7 @@ def read_only_axis_dfs(func, elem_name: str, elem, iospec): else: return None - adata = gen_adata((1000, 100)) + adata = gen_adata((1000, 100), **GEN_ADATA_NO_XARRAY_ARGS) z = open_write_group(tmp_path) ad.io.write_elem(z, "/", adata) @@ -57,7 +57,7 @@ def read_as_dask_array(func, elem_name: str, elem, iospec): else: return func(elem) - adata = gen_adata((1000, 100)) + adata = gen_adata((1000, 100), **GEN_ADATA_NO_XARRAY_ARGS) z = open_write_group(tmp_path) ad.io.write_elem(z, "/", adata) # TODO: see https://github.com/zarr-developers/zarr-python/issues/2716 @@ -77,7 +77,7 @@ def read_as_dask_array(func, elem_name: str, elem, iospec): def test_read_dispatched_null_case(tmp_path: Path): - adata = gen_adata((100, 100)) + adata = gen_adata((100, 100), **GEN_ADATA_NO_XARRAY_ARGS) z = open_write_group(tmp_path) ad.io.write_elem(z, "/", adata) # TODO: see https://github.com/zarr-developers/zarr-python/issues/2716 @@ -99,7 +99,7 @@ def determine_chunks(elem_shape, specified_chunks): for e, c in zip(elem_shape, chunk_iterator, strict=False) ) - adata = gen_adata((1000, 100)) + adata = gen_adata((1000, 100), **GEN_ADATA_NO_XARRAY_ARGS) def write_chunked(func, store, k, elem, dataset_kwargs, iospec): M, N = 13, 42 @@ -197,7 +197,7 @@ def zarr_reader(func, elem_name: str, elem, iospec): zarr_read_keys.append(elem_name if is_zarr_v2() else elem_name.strip("/")) return func(elem) - adata = gen_adata((50, 100)) + adata = gen_adata((50, 100), **GEN_ADATA_NO_XARRAY_ARGS) with h5py.File(h5ad_path, "w") as f: write_dispatched(f, "/", adata, callback=h5ad_writer) diff --git a/tests/test_io_elementwise.py b/tests/test_io_elementwise.py index bb6fae527..b568fc0e0 100644 --- a/tests/test_io_elementwise.py +++ b/tests/test_io_elementwise.py @@ -24,6 +24,7 @@ from anndata.experimental import read_elem_lazy from anndata.io import read_elem, write_elem from anndata.tests.helpers import ( + GEN_ADATA_NO_XARRAY_ARGS, as_cupy, as_cupy_sparse_dask_array, as_dense_cupy_dask_array, @@ -123,7 +124,9 @@ def create_sparse_store( pytest.param(True, "numeric-scalar", id="py_bool"), pytest.param(1.0, "numeric-scalar", id="py_float"), pytest.param({"a": 1}, "dict", id="py_dict"), - pytest.param(gen_adata((3, 2)), "anndata", id="anndata"), + pytest.param( + gen_adata((3, 2), **GEN_ADATA_NO_XARRAY_ARGS), "anndata", id="anndata" + ), pytest.param( sparse.random(5, 3, format="csr", density=0.5), "csr_matrix", @@ -428,7 +431,7 @@ def test_write_indptr_dtype_override(store, sparse_format): def test_io_spec_raw(store): - adata = gen_adata((3, 2)) + adata = gen_adata((3, 2), **GEN_ADATA_NO_XARRAY_ARGS) adata.raw = adata.copy() write_elem(store, "adata", adata) @@ -440,7 +443,7 @@ def test_io_spec_raw(store): def test_write_anndata_to_root(store): - adata = gen_adata((3, 2)) + adata = gen_adata((3, 2), **GEN_ADATA_NO_XARRAY_ARGS) write_elem(store, "/", adata) # TODO: see https://github.com/zarr-developers/zarr-python/issues/2716 @@ -460,7 +463,7 @@ def test_write_anndata_to_root(store): ], ) def test_read_iospec_not_found(store, attribute, value): - adata = gen_adata((3, 2)) + adata = gen_adata((3, 2), **GEN_ADATA_NO_XARRAY_ARGS) write_elem(store, "/", adata) store["obs"].attrs.update({attribute: value}) @@ -527,7 +530,7 @@ def _(store, key, adata): "value", [ pytest.param({"a": 1}, id="dict"), - pytest.param(gen_adata((3, 2)), id="anndata"), + pytest.param(gen_adata((3, 2), **GEN_ADATA_NO_XARRAY_ARGS), id="anndata"), pytest.param(sparse.random(5, 3, format="csr", density=0.5), id="csr_matrix"), pytest.param(sparse.random(5, 3, format="csc", density=0.5), id="csc_matrix"), pytest.param(pd.DataFrame({"a": [1, 2, 3]}), id="dataframe"), @@ -578,7 +581,7 @@ def test_write_to_root(store, value): def test_read_zarr_from_group(tmp_path, consolidated): # https://github.com/scverse/anndata/issues/1056 pth = tmp_path / "test.zarr" - adata = gen_adata((3, 2)) + adata = gen_adata((3, 2), **GEN_ADATA_NO_XARRAY_ARGS) z = open_write_group(pth) write_elem(z, "table/table", adata) @@ -628,7 +631,7 @@ def test_io_pd_cow(store, copy_on_write): pytest.xfail("copy_on_write option is not available in pandas < 2") # https://github.com/zarr-developers/numcodecs/issues/514 with pd.option_context("mode.copy_on_write", copy_on_write): - orig = gen_adata((3, 2)) + orig = gen_adata((3, 2), **GEN_ADATA_NO_XARRAY_ARGS) write_elem(store, "adata", orig) from_store = read_elem(store["adata"]) assert_equal(orig, from_store) diff --git a/tests/test_io_warnings.py b/tests/test_io_warnings.py index 2cdc99775..e8b069a16 100644 --- a/tests/test_io_warnings.py +++ b/tests/test_io_warnings.py @@ -10,7 +10,7 @@ from packaging.version import Version import anndata as ad -from anndata.tests.helpers import gen_adata +from anndata.tests.helpers import GEN_ADATA_NO_XARRAY_ARGS, gen_adata @pytest.mark.skipif(not find_spec("scanpy"), reason="Scanpy is not installed") @@ -38,7 +38,7 @@ def msg_re(entry: str) -> str: def test_old_format_warning_not_thrown(tmp_path): pth = tmp_path / "current.h5ad" - adata = gen_adata((20, 10)) + adata = gen_adata((20, 10), **GEN_ADATA_NO_XARRAY_ARGS) adata.write_h5ad(pth) with warnings.catch_warnings(record=True) as record: diff --git a/tests/test_readwrite.py b/tests/test_readwrite.py index 4e9bdfdbc..81c76c76b 100644 --- a/tests/test_readwrite.py +++ b/tests/test_readwrite.py @@ -29,7 +29,12 @@ _read_attr, is_zarr_v2, ) -from anndata.tests.helpers import as_dense_dask_array, assert_equal, gen_adata +from anndata.tests.helpers import ( + GEN_ADATA_NO_XARRAY_ARGS, + as_dense_dask_array, + assert_equal, + gen_adata, +) if TYPE_CHECKING: from typing import Literal @@ -88,7 +93,7 @@ def dataset_kwargs(request): @pytest.fixture def rw(backing_h5ad): M, N = 100, 101 - orig = gen_adata((M, N)) + orig = gen_adata((M, N), **GEN_ADATA_NO_XARRAY_ARGS) orig.write(backing_h5ad) curr = ad.read_h5ad(backing_h5ad) return curr, orig @@ -256,7 +261,7 @@ def test_readwrite_equivalent_h5ad_zarr(tmp_path, typ): zarr_pth = tmp_path / "adata.zarr" M, N = 100, 101 - adata = gen_adata((M, N), X_type=typ) + adata = gen_adata((M, N), X_type=typ, **GEN_ADATA_NO_XARRAY_ARGS) adata.raw = adata.copy() adata.write_h5ad(h5ad_pth) @@ -287,7 +292,7 @@ def store_context(path: Path): ], ) def test_read_full_io_error(tmp_path, name, read, write): - adata = gen_adata((4, 3)) + adata = gen_adata((4, 3), **GEN_ADATA_NO_XARRAY_ARGS) path = tmp_path / name write(adata, path) with store_context(path) as store: @@ -326,7 +331,7 @@ def test_read_full_io_error(tmp_path, name, read, write): def test_hdf5_compression_opts(tmp_path, compression, compression_opts): # https://github.com/scverse/anndata/issues/497 pth = Path(tmp_path) / "adata.h5ad" - adata = gen_adata((10, 8)) + adata = gen_adata((10, 8), **GEN_ADATA_NO_XARRAY_ARGS) kwargs = {} if compression is not None: kwargs["compression"] = compression @@ -363,7 +368,7 @@ def check_compressed(key, value): def test_zarr_compression(tmp_path, zarr_write_format): ad.settings.zarr_write_format = zarr_write_format pth = str(Path(tmp_path) / "adata.zarr") - adata = gen_adata((10, 8)) + adata = gen_adata((10, 8), **GEN_ADATA_NO_XARRAY_ARGS) if zarr_write_format == 2 or is_zarr_v2(): from numcodecs import Blosc @@ -416,7 +421,7 @@ def check_compressed(value, key): def test_changed_obs_var_names(tmp_path, diskfmt): filepth = tmp_path / f"test.{diskfmt}" - orig = gen_adata((10, 10)) + orig = gen_adata((10, 10), **GEN_ADATA_NO_XARRAY_ARGS) orig.obs_names.name = "obs" orig.var_names.name = "var" modified = orig.copy() @@ -752,7 +757,7 @@ def test_zarr_chunk_X(tmp_path): import zarr zarr_pth = Path(tmp_path) / "test.zarr" - adata = gen_adata((100, 100), X_type=np.array) + adata = gen_adata((100, 100), X_type=np.array, **GEN_ADATA_NO_XARRAY_ARGS) adata.write_zarr(zarr_pth, chunks=(10, 10)) z = zarr.open(str(zarr_pth)) # As of v2.3.2 zarr won’t take a Path @@ -880,13 +885,13 @@ def test_backwards_compat_zarr(): def test_adata_in_uns(tmp_path, diskfmt, roundtrip): pth = tmp_path / f"adatas_in_uns.{diskfmt}" - orig = gen_adata((4, 5)) + orig = gen_adata((4, 5), **GEN_ADATA_NO_XARRAY_ARGS) orig.uns["adatas"] = { - "a": gen_adata((1, 2)), - "b": gen_adata((12, 8)), + "a": gen_adata((1, 2), **GEN_ADATA_NO_XARRAY_ARGS), + "b": gen_adata((12, 8), **GEN_ADATA_NO_XARRAY_ARGS), } - another_one = gen_adata((2, 5)) - another_one.raw = gen_adata((2, 7)) + another_one = gen_adata((2, 5), **GEN_ADATA_NO_XARRAY_ARGS) + another_one.raw = gen_adata((2, 7), **GEN_ADATA_NO_XARRAY_ARGS) orig.uns["adatas"]["b"].uns["another_one"] = another_one curr = roundtrip(orig, pth) diff --git a/tests/test_x.py b/tests/test_x.py index 4c0b62516..d7da59a0c 100644 --- a/tests/test_x.py +++ b/tests/test_x.py @@ -10,7 +10,7 @@ import anndata as ad from anndata import AnnData from anndata._warnings import ImplicitModificationWarning -from anndata.tests.helpers import assert_equal, gen_adata +from anndata.tests.helpers import GEN_ADATA_NO_XARRAY_ARGS, assert_equal, gen_adata from anndata.utils import asarray UNLABELLED_ARRAY_TYPES = [ @@ -156,7 +156,7 @@ def test_io_missing_X(tmp_path, diskfmt): write = lambda obj, pth: getattr(obj, f"write_{diskfmt}")(pth) read = lambda pth: getattr(ad, f"read_{diskfmt}")(pth) - adata = gen_adata((20, 30)) + adata = gen_adata((20, 30), **GEN_ADATA_NO_XARRAY_ARGS) del adata.X write(adata, file_pth) diff --git a/tests/test_xarray.py b/tests/test_xarray.py new file mode 100644 index 000000000..51e44ca07 --- /dev/null +++ b/tests/test_xarray.py @@ -0,0 +1,91 @@ +from __future__ import annotations + +import string + +import numpy as np +import pandas as pd +import pytest + +from anndata._core.xarray import Dataset2D +from anndata.tests.helpers import gen_typed_df + + +@pytest.fixture +def df(): + return gen_typed_df(10) + + +@pytest.fixture +def dataset2d(df): + return Dataset2D.from_dataframe(df) + + +def test_shape(df, dataset2d): + assert dataset2d.shape == df.shape + + +def test_columns(df, dataset2d): + assert np.all(dataset2d.columns.sort_values() == df.columns.sort_values()) + + +def test_to_memory(df, dataset2d): + memory_df = dataset2d.to_memory() + assert np.all(df == memory_df) + assert np.all(df.index == memory_df.index) + assert np.all(df.columns.sort_values() == memory_df.columns.sort_values()) + + +def test_getitem(df, dataset2d): + col = df.columns[0] + assert np.all(dataset2d[col] == df[col]) + + empty_dset = dataset2d[[]] + assert empty_dset.shape == (df.shape[0], 0) + assert np.all(empty_dset.index == dataset2d.index) + + +def test_backed_property(dataset2d): + assert not dataset2d.is_backed + + dataset2d.is_backed = True + assert dataset2d.is_backed + + dataset2d.is_backed = False + assert not dataset2d.is_backed + + +def test_index_dim(dataset2d): + assert dataset2d.index_dim == "index" + assert dataset2d.true_index_dim == dataset2d.index_dim + + col = next(iter(dataset2d.keys())) + dataset2d.true_index_dim = col + assert dataset2d.index_dim == "index" + assert dataset2d.true_index_dim == col + + with pytest.raises(ValueError, match=r"Unknown variable `test`\."): + dataset2d.true_index_dim = "test" + + dataset2d.true_index_dim = None + assert dataset2d.true_index_dim == dataset2d.index_dim + + +def test_index(dataset2d): + alphabet = np.asarray( + list(string.ascii_letters + string.digits + string.punctuation) + ) + new_idx = pd.Index( + [ + "".join(np.random.choice(alphabet, size=10)) + for _ in range(dataset2d.shape[0]) + ], + name="test_index", + ) + + col = next(iter(dataset2d.keys())) + dataset2d.true_index_dim = col + + dataset2d.index = new_idx + assert np.all(dataset2d.index == new_idx) + assert dataset2d.true_index_dim == dataset2d.index_dim == new_idx.name + assert list(dataset2d.coords.keys()) == [new_idx.name]