Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/release-notes/1880.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add support for Dask DataFrames in `.obsm` and `.varm` ({user}`ilia-kats`)
11 changes: 10 additions & 1 deletion src/anndata/_core/aligned_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,16 @@ def to_df(self) -> pd.DataFrame:
def _validate_value(self, val: Value, key: str) -> Value:
if isinstance(val, pd.DataFrame):
raise_value_error_if_multiindex_columns(val, f"{self.attrname}[{key!r}]")
if not val.index.equals(self.dim_names):
if (
not val.index.equals(self.dim_names)
and (
val.index.dtype != "string"
and self.dim_names.dtype != "O"
or val.index.dtype != "O"
and self.dim_names.dtype != "string"
)
and (val.index != self.dim_names).any()
):
# Could probably also re-order index if it’s contained
try:
pd.testing.assert_index_equal(val.index, self.dim_names)
Expand Down
5 changes: 3 additions & 2 deletions src/anndata/_core/file_backing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import h5py

from ..compat import AwkArray, DaskArray, ZarrArray, ZarrGroup
from ..compat import AwkArray, DaskArray, DaskDataFrame, ZarrArray, ZarrGroup
from .sparse_dataset import BaseCompressedSparseDataset

if TYPE_CHECKING:
Expand Down Expand Up @@ -143,7 +143,8 @@ def _(x: BaseCompressedSparseDataset, *, copy: bool = False):


@to_memory.register(DaskArray)
def _(x: DaskArray, *, copy: bool = False):
@to_memory.register(DaskDataFrame)
def _(x: DaskArray | DaskDataFrame, *, copy: bool = False):
return x.compute()


Expand Down
4 changes: 2 additions & 2 deletions src/anndata/_core/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pandas as pd
from scipy import sparse

from anndata.compat import CSArray
from anndata.compat import CSArray, DaskDataFrame

from .._warnings import ImplicitModificationWarning
from ..utils import (
Expand Down Expand Up @@ -51,7 +51,7 @@ def coerce_array(
if any(is_non_csc_r_array_or_matrix):
msg = f"Only CSR and CSC {'matrices' if isinstance(value, sparse.spmatrix) else 'arrays'} are supported."
raise ValueError(msg)
if isinstance(value, pd.DataFrame):
if isinstance(value, pd.DataFrame | DaskDataFrame):
if allow_df:
raise_value_error_if_multiindex_columns(value, name)
return value if allow_df else ensure_df_homogeneous(value, name)
Expand Down
22 changes: 21 additions & 1 deletion src/anndata/_io/specs/methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
CupyCSCMatrix,
CupyCSRMatrix,
DaskArray,
DaskDataFrame,
H5Array,
H5File,
H5Group,
Expand Down Expand Up @@ -896,6 +897,19 @@ def write_dataframe(
)


@_REGISTRY.register_write(H5Group, DaskDataFrame, IOSpec("dataframe", "0.2.0"))
@_REGISTRY.register_write(ZarrGroup, DaskDataFrame, IOSpec("dataframe", "0.2.0"))
def write_dask_dataframe(
f: GroupStorageType,
key: str,
df: DaskDataFrame,
*,
_writer: Writer,
dataset_kwargs: Mapping[str, Any] = MappingProxyType({}),
):
_writer.write_elem(f, key, df.compute(), dataset_kwargs=dataset_kwargs)


@_REGISTRY.register_read(H5Group, IOSpec("dataframe", "0.2.0"))
@_REGISTRY.register_read(ZarrGroup, IOSpec("dataframe", "0.2.0"))
def read_dataframe(elem: GroupStorageType, *, _reader: Reader) -> pd.DataFrame:
Expand Down Expand Up @@ -1051,6 +1065,12 @@ def read_partial_categorical(elem, *, items=None, indices=(slice(None),)):
@_REGISTRY.register_write(
ZarrGroup, pd.arrays.StringArray, IOSpec("nullable-string-array", "0.1.0")
)
@_REGISTRY.register_write(
H5Group, pd.arrays.ArrowStringArray, IOSpec("nullable-string-array", "0.1.0")
)
@_REGISTRY.register_write(
ZarrGroup, pd.arrays.ArrowStringArray, IOSpec("nullable-string-array", "0.1.0")
)
def write_nullable(
f: GroupStorageType,
k: str,
Expand All @@ -1073,7 +1093,7 @@ def write_nullable(
g = f.require_group(k)
values = (
v.to_numpy(na_value="")
if isinstance(v, pd.arrays.StringArray)
if isinstance(v, pd.arrays.StringArray | pd.arrays.ArrowStringArray)
else v.to_numpy(na_value=0, dtype=v.dtype.numpy_dtype)
)
_writer.write_elem(g, "values", values, dataset_kwargs=dataset_kwargs)
Expand Down
6 changes: 6 additions & 0 deletions src/anndata/compat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,19 @@
from dask.array.core import Array as DaskArray
elif find_spec("dask"):
from dask.array import Array as DaskArray
from dask.dataframe import DataFrame as DaskDataFrame
else:

class DaskArray:
@staticmethod
def __repr__():
return "mock dask.array.core.Array"

class DaskDataFrame:
@staticmethod
def __repr__():
return "mock dask.dataframe.dask_expr._collection.DataFrame"

Check warning on line 114 in src/anndata/compat/__init__.py

View check run for this annotation

Codecov / codecov/patch

src/anndata/compat/__init__.py#L111-L114

Added lines #L111 - L114 were not covered by tests


# https://github.com/scverse/anndata/issues/1749
def is_cupy_importable() -> bool:
Expand Down
8 changes: 7 additions & 1 deletion src/anndata/tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
CupyCSRMatrix,
CupySparseMatrix,
DaskArray,
DaskDataFrame,
ZarrArray,
)
from anndata.utils import asarray
Expand Down Expand Up @@ -644,8 +645,13 @@ def assert_equal_h5py_dataset(


@assert_equal.register(DaskArray)
@assert_equal.register(DaskDataFrame)
def assert_equal_dask_array(
a: DaskArray, b: object, *, exact: bool = False, elem_name: str | None = None
a: DaskArray | DaskDataFrame,
b: object,
*,
exact: bool = False,
elem_name: str | None = None,
):
assert_equal(b, a.compute(), exact=exact, elem_name=elem_name)

Expand Down
7 changes: 6 additions & 1 deletion src/anndata/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import anndata

from ._core.sparse_dataset import BaseCompressedSparseDataset
from .compat import CSArray, CupyArray, CupySparseMatrix, DaskArray
from .compat import CSArray, CupyArray, CupySparseMatrix, DaskArray, DaskDataFrame
from .logging import get_logger

if TYPE_CHECKING:
Expand Down Expand Up @@ -115,6 +115,11 @@ def axis_len(x, axis: Literal[0, 1]) -> int | None:
return x.shape[axis]


@axis_len.register(DaskDataFrame)
def axis_len_dask_df(df, axis: Literal[0, 1]) -> int | None:
return df.shape[axis].compute()


try:
from .compat import awkward as ak

Expand Down
51 changes: 50 additions & 1 deletion tests/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import anndata as ad
from anndata._core.anndata import AnnData
from anndata.compat import CupyArray, DaskArray
from anndata.compat import CupyArray, DaskArray, DaskDataFrame
from anndata.experimental.merge import as_group
from anndata.tests.helpers import (
GEN_ADATA_DASK_ARGS,
Expand Down Expand Up @@ -74,6 +74,7 @@ def test_dask_X_view():

def test_dask_write(adata, tmp_path, diskfmt):
import dask.array as da
import dask.dataframe as ddf
import numpy as np

pth = tmp_path / f"test_write.{diskfmt}"
Expand All @@ -84,6 +85,12 @@ def test_dask_write(adata, tmp_path, diskfmt):
adata.obsm["a"] = da.random.random((M, 10))
adata.obsm["b"] = da.random.random((M, 10))
adata.varm["a"] = da.random.random((N, 10))
adata.varm["b"] = ddf.from_pandas(
pd.DataFrame(
{"A": np.arange(N), "B": np.random.randint(1e6, size=N)},
index=adata.var_names,
)
)

orig = adata
write(orig, pth)
Expand All @@ -93,6 +100,7 @@ def test_dask_write(adata, tmp_path, diskfmt):
assert_equal(curr.obsm["a"], curr.obsm["b"])

assert_equal(curr.varm["a"], orig.varm["a"])
assert_equal(orig.varm["b"], orig.varm["b"])
assert_equal(curr.obsm["a"], orig.obsm["a"])

assert isinstance(curr.X, np.ndarray)
Expand All @@ -105,6 +113,7 @@ def test_dask_write(adata, tmp_path, diskfmt):

def test_dask_distributed_write(adata, tmp_path, diskfmt):
import dask.array as da
import dask.dataframe as ddf
import dask.distributed as dd
import numpy as np

Expand All @@ -119,6 +128,12 @@ def test_dask_distributed_write(adata, tmp_path, diskfmt):
adata.obsm["a"] = da.random.random((M, 10))
adata.obsm["b"] = da.random.random((M, 10))
adata.varm["a"] = da.random.random((N, 10))
adata.varm["b"] = ddf.from_pandas(
pd.DataFrame(
{"A": np.arange(N), "B": np.random.randint(1e6, size=N)},
index=adata.var_names,
)
)
orig = adata
if diskfmt == "h5ad":
with pytest.raises(ValueError, match=r"Cannot write dask arrays to hdf5"):
Expand All @@ -131,6 +146,7 @@ def test_dask_distributed_write(adata, tmp_path, diskfmt):
assert_equal(curr.obsm["a"], curr.obsm["b"])

assert_equal(curr.varm["a"], orig.varm["a"])
assert_equal(orig.varm["b"], curr.varm["b"])
assert_equal(curr.obsm["a"], orig.obsm["a"])

assert isinstance(curr.X, np.ndarray)
Expand All @@ -143,6 +159,7 @@ def test_dask_distributed_write(adata, tmp_path, diskfmt):

def test_dask_to_memory_check_array_types(adata, tmp_path, diskfmt):
import dask.array as da
import dask.dataframe as ddf
import numpy as np

pth = tmp_path / f"test_write.{diskfmt}"
Expand All @@ -153,6 +170,12 @@ def test_dask_to_memory_check_array_types(adata, tmp_path, diskfmt):
adata.obsm["a"] = da.random.random((M, 10))
adata.obsm["b"] = da.random.random((M, 10))
adata.varm["a"] = da.random.random((N, 10))
adata.varm["b"] = ddf.from_pandas(
pd.DataFrame(
{"A": np.arange(N), "B": np.random.randint(1e6, size=N)},
index=adata.var_names,
)
)

orig = adata
write(orig, pth)
Expand All @@ -161,6 +184,7 @@ def test_dask_to_memory_check_array_types(adata, tmp_path, diskfmt):
assert isinstance(orig.X, DaskArray)
assert isinstance(orig.obsm["a"], DaskArray)
assert isinstance(orig.varm["a"], DaskArray)
assert isinstance(orig.varm["b"], DaskDataFrame)

mem = orig.to_memory()

Expand All @@ -171,20 +195,25 @@ def test_dask_to_memory_check_array_types(adata, tmp_path, diskfmt):
assert_equal(curr.obsm["a"], orig.obsm["a"])
assert_equal(mem.obsm["a"], orig.obsm["a"])
assert_equal(mem.varm["a"], orig.varm["a"])
assert_equal(orig.varm["b"], mem.varm["b"])

assert isinstance(curr.X, np.ndarray)
assert isinstance(curr.obsm["a"], np.ndarray)
assert isinstance(curr.varm["a"], np.ndarray)
assert isinstance(curr.varm["b"], pd.DataFrame)
assert isinstance(mem.X, np.ndarray)
assert isinstance(mem.obsm["a"], np.ndarray)
assert isinstance(mem.varm["a"], np.ndarray)
assert isinstance(mem.varm["b"], pd.DataFrame)
assert isinstance(orig.X, DaskArray)
assert isinstance(orig.obsm["a"], DaskArray)
assert isinstance(orig.varm["a"], DaskArray)
assert isinstance(orig.varm["b"], DaskDataFrame)


def test_dask_to_memory_copy_check_array_types(adata, tmp_path, diskfmt):
import dask.array as da
import dask.dataframe as ddf
import numpy as np

pth = tmp_path / f"test_write.{diskfmt}"
Expand All @@ -195,6 +224,12 @@ def test_dask_to_memory_copy_check_array_types(adata, tmp_path, diskfmt):
adata.obsm["a"] = da.random.random((M, 10))
adata.obsm["b"] = da.random.random((M, 10))
adata.varm["a"] = da.random.random((N, 10))
adata.varm["b"] = ddf.from_pandas(
pd.DataFrame(
{"A": np.arange(N), "B": np.random.randint(1e6, size=N)},
index=adata.var_names,
)
)

orig = adata
write(orig, pth)
Expand All @@ -209,25 +244,36 @@ def test_dask_to_memory_copy_check_array_types(adata, tmp_path, diskfmt):
assert_equal(curr.obsm["a"], orig.obsm["a"])
assert_equal(mem.obsm["a"], orig.obsm["a"])
assert_equal(mem.varm["a"], orig.varm["a"])
assert_equal(orig.varm["b"], mem.varm["b"])

assert isinstance(curr.X, np.ndarray)
assert isinstance(curr.obsm["a"], np.ndarray)
assert isinstance(curr.varm["a"], np.ndarray)
assert isinstance(curr.varm["b"], pd.DataFrame)
assert isinstance(mem.X, np.ndarray)
assert isinstance(mem.obsm["a"], np.ndarray)
assert isinstance(mem.varm["a"], np.ndarray)
assert isinstance(mem.varm["b"], pd.DataFrame)
assert isinstance(orig.X, DaskArray)
assert isinstance(orig.obsm["a"], DaskArray)
assert isinstance(orig.varm["a"], DaskArray)
assert isinstance(orig.varm["b"], DaskDataFrame)


def test_dask_copy_check_array_types(adata):
import dask.array as da
import dask.dataframe as ddf

M, N = adata.X.shape
adata.obsm["a"] = da.random.random((M, 10))
adata.obsm["b"] = da.random.random((M, 10))
adata.varm["a"] = da.random.random((N, 10))
adata.varm["b"] = ddf.from_pandas(
pd.DataFrame(
{"A": np.arange(N), "B": np.random.randint(1e6, size=N)},
index=adata.var_names,
)
)

orig = adata
curr = adata.copy()
Expand All @@ -236,14 +282,17 @@ def test_dask_copy_check_array_types(adata):
assert_equal(curr.obsm["a"], curr.obsm["b"])

assert_equal(curr.varm["a"], orig.varm["a"])
assert_equal(orig.varm["b"], curr.varm["b"])
assert_equal(curr.obsm["a"], orig.obsm["a"])

assert isinstance(curr.X, DaskArray)
assert isinstance(curr.obsm["a"], DaskArray)
assert isinstance(curr.varm["a"], DaskArray)
assert isinstance(curr.varm["b"], DaskDataFrame)
assert isinstance(orig.X, DaskArray)
assert isinstance(orig.obsm["a"], DaskArray)
assert isinstance(orig.varm["a"], DaskArray)
assert isinstance(orig.varm["b"], DaskDataFrame)


def test_assign_X(adata):
Expand Down
Loading