diff --git a/docs/release-notes/1880.feature.md b/docs/release-notes/1880.feature.md new file mode 100755 index 000000000..405d8e980 --- /dev/null +++ b/docs/release-notes/1880.feature.md @@ -0,0 +1 @@ +Add support for Dask DataFrames in `.obsm` and `.varm` ({user}`ilia-kats`) diff --git a/src/anndata/_core/aligned_mapping.py b/src/anndata/_core/aligned_mapping.py index 115e8d065..1f91c6b72 100644 --- a/src/anndata/_core/aligned_mapping.py +++ b/src/anndata/_core/aligned_mapping.py @@ -264,7 +264,16 @@ def to_df(self) -> pd.DataFrame: def _validate_value(self, val: Value, key: str) -> Value: if isinstance(val, pd.DataFrame): raise_value_error_if_multiindex_columns(val, f"{self.attrname}[{key!r}]") - if not val.index.equals(self.dim_names): + if ( + not val.index.equals(self.dim_names) + and ( + val.index.dtype != "string" + and self.dim_names.dtype != "O" + or val.index.dtype != "O" + and self.dim_names.dtype != "string" + ) + and (val.index != self.dim_names).any() + ): # Could probably also re-order index if it’s contained try: pd.testing.assert_index_equal(val.index, self.dim_names) diff --git a/src/anndata/_core/file_backing.py b/src/anndata/_core/file_backing.py index 005a47b97..a671553a4 100644 --- a/src/anndata/_core/file_backing.py +++ b/src/anndata/_core/file_backing.py @@ -8,7 +8,7 @@ import h5py -from ..compat import AwkArray, DaskArray, ZarrArray, ZarrGroup +from ..compat import AwkArray, DaskArray, DaskDataFrame, ZarrArray, ZarrGroup from .sparse_dataset import BaseCompressedSparseDataset if TYPE_CHECKING: @@ -143,7 +143,8 @@ def _(x: BaseCompressedSparseDataset, *, copy: bool = False): @to_memory.register(DaskArray) -def _(x: DaskArray, *, copy: bool = False): +@to_memory.register(DaskDataFrame) +def _(x: DaskArray | DaskDataFrame, *, copy: bool = False): return x.compute() diff --git a/src/anndata/_core/storage.py b/src/anndata/_core/storage.py index 63a7de76d..6f8022ce0 100644 --- a/src/anndata/_core/storage.py +++ b/src/anndata/_core/storage.py @@ -7,7 +7,7 @@ import pandas as pd from scipy import sparse -from anndata.compat import CSArray +from anndata.compat import CSArray, DaskDataFrame from .._warnings import ImplicitModificationWarning from ..utils import ( @@ -51,7 +51,7 @@ def coerce_array( if any(is_non_csc_r_array_or_matrix): msg = f"Only CSR and CSC {'matrices' if isinstance(value, sparse.spmatrix) else 'arrays'} are supported." raise ValueError(msg) - if isinstance(value, pd.DataFrame): + if isinstance(value, pd.DataFrame | DaskDataFrame): if allow_df: raise_value_error_if_multiindex_columns(value, name) return value if allow_df else ensure_df_homogeneous(value, name) diff --git a/src/anndata/_io/specs/methods.py b/src/anndata/_io/specs/methods.py index f4ef6d44f..da296f513 100644 --- a/src/anndata/_io/specs/methods.py +++ b/src/anndata/_io/specs/methods.py @@ -29,6 +29,7 @@ CupyCSCMatrix, CupyCSRMatrix, DaskArray, + DaskDataFrame, H5Array, H5File, H5Group, @@ -896,6 +897,19 @@ def write_dataframe( ) +@_REGISTRY.register_write(H5Group, DaskDataFrame, IOSpec("dataframe", "0.2.0")) +@_REGISTRY.register_write(ZarrGroup, DaskDataFrame, IOSpec("dataframe", "0.2.0")) +def write_dask_dataframe( + f: GroupStorageType, + key: str, + df: DaskDataFrame, + *, + _writer: Writer, + dataset_kwargs: Mapping[str, Any] = MappingProxyType({}), +): + _writer.write_elem(f, key, df.compute(), dataset_kwargs=dataset_kwargs) + + @_REGISTRY.register_read(H5Group, IOSpec("dataframe", "0.2.0")) @_REGISTRY.register_read(ZarrGroup, IOSpec("dataframe", "0.2.0")) def read_dataframe(elem: GroupStorageType, *, _reader: Reader) -> pd.DataFrame: @@ -1051,6 +1065,12 @@ def read_partial_categorical(elem, *, items=None, indices=(slice(None),)): @_REGISTRY.register_write( ZarrGroup, pd.arrays.StringArray, IOSpec("nullable-string-array", "0.1.0") ) +@_REGISTRY.register_write( + H5Group, pd.arrays.ArrowStringArray, IOSpec("nullable-string-array", "0.1.0") +) +@_REGISTRY.register_write( + ZarrGroup, pd.arrays.ArrowStringArray, IOSpec("nullable-string-array", "0.1.0") +) def write_nullable( f: GroupStorageType, k: str, @@ -1073,7 +1093,7 @@ def write_nullable( g = f.require_group(k) values = ( v.to_numpy(na_value="") - if isinstance(v, pd.arrays.StringArray) + if isinstance(v, pd.arrays.StringArray | pd.arrays.ArrowStringArray) else v.to_numpy(na_value=0, dtype=v.dtype.numpy_dtype) ) _writer.write_elem(g, "values", values, dataset_kwargs=dataset_kwargs) diff --git a/src/anndata/compat/__init__.py b/src/anndata/compat/__init__.py index b3d313f17..bff94e9d1 100644 --- a/src/anndata/compat/__init__.py +++ b/src/anndata/compat/__init__.py @@ -100,6 +100,7 @@ def __repr__(): from dask.array.core import Array as DaskArray elif find_spec("dask"): from dask.array import Array as DaskArray + from dask.dataframe import DataFrame as DaskDataFrame else: class DaskArray: @@ -107,6 +108,11 @@ class DaskArray: def __repr__(): return "mock dask.array.core.Array" + class DaskDataFrame: + @staticmethod + def __repr__(): + return "mock dask.dataframe.dask_expr._collection.DataFrame" + # https://github.com/scverse/anndata/issues/1749 def is_cupy_importable() -> bool: diff --git a/src/anndata/tests/helpers.py b/src/anndata/tests/helpers.py index f6bbfd5a5..aa3feb01f 100644 --- a/src/anndata/tests/helpers.py +++ b/src/anndata/tests/helpers.py @@ -32,6 +32,7 @@ CupyCSRMatrix, CupySparseMatrix, DaskArray, + DaskDataFrame, ZarrArray, ) from anndata.utils import asarray @@ -644,8 +645,13 @@ def assert_equal_h5py_dataset( @assert_equal.register(DaskArray) +@assert_equal.register(DaskDataFrame) def assert_equal_dask_array( - a: DaskArray, b: object, *, exact: bool = False, elem_name: str | None = None + a: DaskArray | DaskDataFrame, + b: object, + *, + exact: bool = False, + elem_name: str | None = None, ): assert_equal(b, a.compute(), exact=exact, elem_name=elem_name) diff --git a/src/anndata/utils.py b/src/anndata/utils.py index ba57cc3a2..82b36d376 100644 --- a/src/anndata/utils.py +++ b/src/anndata/utils.py @@ -13,7 +13,7 @@ import anndata from ._core.sparse_dataset import BaseCompressedSparseDataset -from .compat import CSArray, CupyArray, CupySparseMatrix, DaskArray +from .compat import CSArray, CupyArray, CupySparseMatrix, DaskArray, DaskDataFrame from .logging import get_logger if TYPE_CHECKING: @@ -115,6 +115,11 @@ def axis_len(x, axis: Literal[0, 1]) -> int | None: return x.shape[axis] +@axis_len.register(DaskDataFrame) +def axis_len_dask_df(df, axis: Literal[0, 1]) -> int | None: + return df.shape[axis].compute() + + try: from .compat import awkward as ak diff --git a/tests/test_dask.py b/tests/test_dask.py index 21db60e7e..0a969af0c 100644 --- a/tests/test_dask.py +++ b/tests/test_dask.py @@ -11,7 +11,7 @@ import anndata as ad from anndata._core.anndata import AnnData -from anndata.compat import CupyArray, DaskArray +from anndata.compat import CupyArray, DaskArray, DaskDataFrame from anndata.experimental.merge import as_group from anndata.tests.helpers import ( GEN_ADATA_DASK_ARGS, @@ -74,6 +74,7 @@ def test_dask_X_view(): def test_dask_write(adata, tmp_path, diskfmt): import dask.array as da + import dask.dataframe as ddf import numpy as np pth = tmp_path / f"test_write.{diskfmt}" @@ -84,6 +85,12 @@ def test_dask_write(adata, tmp_path, diskfmt): adata.obsm["a"] = da.random.random((M, 10)) adata.obsm["b"] = da.random.random((M, 10)) adata.varm["a"] = da.random.random((N, 10)) + adata.varm["b"] = ddf.from_pandas( + pd.DataFrame( + {"A": np.arange(N), "B": np.random.randint(1e6, size=N)}, + index=adata.var_names, + ) + ) orig = adata write(orig, pth) @@ -93,6 +100,7 @@ def test_dask_write(adata, tmp_path, diskfmt): assert_equal(curr.obsm["a"], curr.obsm["b"]) assert_equal(curr.varm["a"], orig.varm["a"]) + assert_equal(orig.varm["b"], orig.varm["b"]) assert_equal(curr.obsm["a"], orig.obsm["a"]) assert isinstance(curr.X, np.ndarray) @@ -105,6 +113,7 @@ def test_dask_write(adata, tmp_path, diskfmt): def test_dask_distributed_write(adata, tmp_path, diskfmt): import dask.array as da + import dask.dataframe as ddf import dask.distributed as dd import numpy as np @@ -119,6 +128,12 @@ def test_dask_distributed_write(adata, tmp_path, diskfmt): adata.obsm["a"] = da.random.random((M, 10)) adata.obsm["b"] = da.random.random((M, 10)) adata.varm["a"] = da.random.random((N, 10)) + adata.varm["b"] = ddf.from_pandas( + pd.DataFrame( + {"A": np.arange(N), "B": np.random.randint(1e6, size=N)}, + index=adata.var_names, + ) + ) orig = adata if diskfmt == "h5ad": with pytest.raises(ValueError, match=r"Cannot write dask arrays to hdf5"): @@ -131,6 +146,7 @@ def test_dask_distributed_write(adata, tmp_path, diskfmt): assert_equal(curr.obsm["a"], curr.obsm["b"]) assert_equal(curr.varm["a"], orig.varm["a"]) + assert_equal(orig.varm["b"], curr.varm["b"]) assert_equal(curr.obsm["a"], orig.obsm["a"]) assert isinstance(curr.X, np.ndarray) @@ -143,6 +159,7 @@ def test_dask_distributed_write(adata, tmp_path, diskfmt): def test_dask_to_memory_check_array_types(adata, tmp_path, diskfmt): import dask.array as da + import dask.dataframe as ddf import numpy as np pth = tmp_path / f"test_write.{diskfmt}" @@ -153,6 +170,12 @@ def test_dask_to_memory_check_array_types(adata, tmp_path, diskfmt): adata.obsm["a"] = da.random.random((M, 10)) adata.obsm["b"] = da.random.random((M, 10)) adata.varm["a"] = da.random.random((N, 10)) + adata.varm["b"] = ddf.from_pandas( + pd.DataFrame( + {"A": np.arange(N), "B": np.random.randint(1e6, size=N)}, + index=adata.var_names, + ) + ) orig = adata write(orig, pth) @@ -161,6 +184,7 @@ def test_dask_to_memory_check_array_types(adata, tmp_path, diskfmt): assert isinstance(orig.X, DaskArray) assert isinstance(orig.obsm["a"], DaskArray) assert isinstance(orig.varm["a"], DaskArray) + assert isinstance(orig.varm["b"], DaskDataFrame) mem = orig.to_memory() @@ -171,20 +195,25 @@ def test_dask_to_memory_check_array_types(adata, tmp_path, diskfmt): assert_equal(curr.obsm["a"], orig.obsm["a"]) assert_equal(mem.obsm["a"], orig.obsm["a"]) assert_equal(mem.varm["a"], orig.varm["a"]) + assert_equal(orig.varm["b"], mem.varm["b"]) assert isinstance(curr.X, np.ndarray) assert isinstance(curr.obsm["a"], np.ndarray) assert isinstance(curr.varm["a"], np.ndarray) + assert isinstance(curr.varm["b"], pd.DataFrame) assert isinstance(mem.X, np.ndarray) assert isinstance(mem.obsm["a"], np.ndarray) assert isinstance(mem.varm["a"], np.ndarray) + assert isinstance(mem.varm["b"], pd.DataFrame) assert isinstance(orig.X, DaskArray) assert isinstance(orig.obsm["a"], DaskArray) assert isinstance(orig.varm["a"], DaskArray) + assert isinstance(orig.varm["b"], DaskDataFrame) def test_dask_to_memory_copy_check_array_types(adata, tmp_path, diskfmt): import dask.array as da + import dask.dataframe as ddf import numpy as np pth = tmp_path / f"test_write.{diskfmt}" @@ -195,6 +224,12 @@ def test_dask_to_memory_copy_check_array_types(adata, tmp_path, diskfmt): adata.obsm["a"] = da.random.random((M, 10)) adata.obsm["b"] = da.random.random((M, 10)) adata.varm["a"] = da.random.random((N, 10)) + adata.varm["b"] = ddf.from_pandas( + pd.DataFrame( + {"A": np.arange(N), "B": np.random.randint(1e6, size=N)}, + index=adata.var_names, + ) + ) orig = adata write(orig, pth) @@ -209,25 +244,36 @@ def test_dask_to_memory_copy_check_array_types(adata, tmp_path, diskfmt): assert_equal(curr.obsm["a"], orig.obsm["a"]) assert_equal(mem.obsm["a"], orig.obsm["a"]) assert_equal(mem.varm["a"], orig.varm["a"]) + assert_equal(orig.varm["b"], mem.varm["b"]) assert isinstance(curr.X, np.ndarray) assert isinstance(curr.obsm["a"], np.ndarray) assert isinstance(curr.varm["a"], np.ndarray) + assert isinstance(curr.varm["b"], pd.DataFrame) assert isinstance(mem.X, np.ndarray) assert isinstance(mem.obsm["a"], np.ndarray) assert isinstance(mem.varm["a"], np.ndarray) + assert isinstance(mem.varm["b"], pd.DataFrame) assert isinstance(orig.X, DaskArray) assert isinstance(orig.obsm["a"], DaskArray) assert isinstance(orig.varm["a"], DaskArray) + assert isinstance(orig.varm["b"], DaskDataFrame) def test_dask_copy_check_array_types(adata): import dask.array as da + import dask.dataframe as ddf M, N = adata.X.shape adata.obsm["a"] = da.random.random((M, 10)) adata.obsm["b"] = da.random.random((M, 10)) adata.varm["a"] = da.random.random((N, 10)) + adata.varm["b"] = ddf.from_pandas( + pd.DataFrame( + {"A": np.arange(N), "B": np.random.randint(1e6, size=N)}, + index=adata.var_names, + ) + ) orig = adata curr = adata.copy() @@ -236,14 +282,17 @@ def test_dask_copy_check_array_types(adata): assert_equal(curr.obsm["a"], curr.obsm["b"]) assert_equal(curr.varm["a"], orig.varm["a"]) + assert_equal(orig.varm["b"], curr.varm["b"]) assert_equal(curr.obsm["a"], orig.obsm["a"]) assert isinstance(curr.X, DaskArray) assert isinstance(curr.obsm["a"], DaskArray) assert isinstance(curr.varm["a"], DaskArray) + assert isinstance(curr.varm["b"], DaskDataFrame) assert isinstance(orig.X, DaskArray) assert isinstance(orig.obsm["a"], DaskArray) assert isinstance(orig.varm["a"], DaskArray) + assert isinstance(orig.varm["b"], DaskDataFrame) def test_assign_X(adata):