Skip to content

add score_genes support for Dask #408

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/release-notes/0.13.0.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
```{rubric} Features
```
* Add support for aggregate operations on CSC matrices, Fortran-ordered arrays, and Dask with sparse CSR and dense matrices {pr}`395` {smaller}`S Dicks`
* Adds dask support for `tl.score_genes` & `tl.score_genes_cell_cycle` {pr}`408` {smaller}`S Dicks`

```{rubric} Performance
```

```{rubric} Bug fixes
```

* Fixes a bug for `_get_mean_var` with dask chunk sizes {pr}`408` {smaller}`S Dicks`

```{rubric} Misc
```
7 changes: 3 additions & 4 deletions src/rapids_singlecell/get/_anndata.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,9 @@ def X_to_GPU(
if isinstance(X, GPU_ARRAY_TYPE):
pass
elif isinstance(X, DaskArray):
if isinstance(X._meta, csc_matrix_cpu):
pass
meta = _meta_sparse if isinstance(X._meta, csr_matrix_cpu) else _meta_dense
X = X.map_blocks(X_to_GPU, meta=meta(X.dtype))
if isinstance(X._meta, csr_matrix_cpu | np.ndarray):
meta = _meta_sparse if isinstance(X._meta, csr_matrix_cpu) else _meta_dense
X = X.map_blocks(X_to_GPU, meta=meta(X.dtype))
elif isspmatrix_csr_cpu(X):
X = csr_matrix_gpu(X)
elif isspmatrix_csc_cpu(X):
Expand Down
28 changes: 12 additions & 16 deletions src/rapids_singlecell/preprocessing/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import cupy as cp
import numpy as np
import pandas as pd
from cuml.internals.memory_utils import with_cupy_rmm
from cupyx.scipy.sparse import issparse, isspmatrix_csc, isspmatrix_csr, spmatrix
from natsort import natsorted
from pandas.api.types import infer_dtype
Expand Down Expand Up @@ -98,7 +97,6 @@ def _mean_var_minor(X, major, minor):
return mean, var


@with_cupy_rmm
def _mean_var_minor_dask(X, major, minor):
"""
Implements sum operation for dask array when the backend is cupy sparse csr matrix
Expand Down Expand Up @@ -134,7 +132,6 @@ def __mean_var(X_part):


# todo: Implement this dynamically for csc matrix as well
@with_cupy_rmm
def _mean_var_major_dask(X, major, minor):
"""
Implements sum operation for dask array when the backend is cupy sparse csr matrix
Expand Down Expand Up @@ -165,23 +162,23 @@ def __mean_var(X_part):
minor,
),
)
return cp.vstack([mean, var])
return cp.stack([mean, var], axis=1)
Comment on lines -168 to +165
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I’m curious: why the axis switch? is this more local?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this works the other for some reason doesn't


mean, var = X.map_blocks(
output = X.map_blocks(
__mean_var,
chunks=((2,), X.chunks[0]),
chunks=(X.chunks[0], (2,)),
dtype=cp.float64,
meta=cp.array([]),
)

mean = output[:, 0]
var = output[:, 1]
mean = mean / minor
var = var / minor
var -= cp.power(mean, 2)
var -= mean**2
var *= minor / (minor - 1)
return mean, var


@with_cupy_rmm
def _mean_var_dense_dask(X, axis):
"""
Implements sum operation for dask array when the backend is cupy dense matrix
Expand All @@ -192,25 +189,24 @@ def __mean_var(X_part):
var = sq_sum(X_part, axis=axis)
mean = mean_sum(X_part, axis=axis)
if axis == 0:
mean = mean.reshape(-1, 1)
var = var.reshape(-1, 1)
return cp.vstack([mean.ravel(), var.ravel()])[
None if 1 - axis else slice(None, None), ...
]
return cp.vstack([mean, var])[None, ...]
else:
return cp.stack([mean, var], axis=1)

n_blocks = X.blocks.size
mean_var = X.map_blocks(
__mean_var,
new_axis=(1,) if axis - 1 else None,
chunks=((2,), X.chunks[0]) if axis else ((1,) * n_blocks, (2,), (X.shape[1],)),
chunks=(X.chunks[0], (2,)) if axis else ((1,) * n_blocks, (2,), (X.shape[1],)),
dtype=cp.float64,
meta=cp.array([]),
)

if axis == 0:
mean, var = mean_var.sum(axis=0)
else:
mean, var = mean_var
mean = mean_var[:, 0]
var = mean_var[:, 1]

mean = mean / X.shape[axis]
var = var / X.shape[axis]
Expand Down
12 changes: 9 additions & 3 deletions src/rapids_singlecell/tools/_score_genes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
from typing import TYPE_CHECKING

import cupy as cp
import dask
import numpy as np
import pandas as pd

from rapids_singlecell._compat import DaskArray
from rapids_singlecell.get import X_to_GPU, _get_obs_rep
from rapids_singlecell.preprocessing._utils import _check_gpu_X, _check_use_raw

Expand Down Expand Up @@ -77,8 +79,7 @@ def score_genes(
use_raw = _check_use_raw(adata, use_raw, layer=layer)
X = _get_obs_rep(adata, layer=layer, use_raw=use_raw)
X = X_to_GPU(X)
_check_gpu_X(X)

_check_gpu_X(X, allow_dask=True)
if random_state is not None:
np.random.seed(random_state)

Expand Down Expand Up @@ -108,6 +109,8 @@ def score_genes(
means_control = _nan_mean(
X, axis=1, mask=control_array, n_features=len(control_genes)
)
if isinstance(X, DaskArray):
means_list, means_control = dask.compute(means_list, means_control)

score = means_list - means_control

Expand Down Expand Up @@ -157,7 +160,10 @@ def _score_genes_bins(
) -> Generator[pd.Index[str], None, None]:
# average expression of genes
idx = cp.array(var_names.isin(gene_pool), dtype=cp.bool_)
nanmeans = _nan_mean(X, axis=0, mask=idx, n_features=len(gene_pool)).get()
nanmeans = _nan_mean(X, axis=0, mask=idx, n_features=len(gene_pool))
if isinstance(X, DaskArray):
nanmeans = nanmeans.compute()
nanmeans = nanmeans.get()
obs_avg = pd.Series(nanmeans, index=gene_pool)
# Sometimes (and I don’t know how) missing data may be there, with NaNs for missing entries
obs_avg = obs_avg[np.isfinite(obs_avg)]
Expand Down
124 changes: 124 additions & 0 deletions src/rapids_singlecell/tools/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import cupy as cp
from cupyx.scipy.sparse import issparse, isspmatrix_csc, isspmatrix_csr

from rapids_singlecell._compat import DaskArray

from . import pca


Expand Down Expand Up @@ -47,6 +49,100 @@ def _choose_representation(adata, use_rep=None, n_pcs=None):
return X


def _nan_mean_minor_dask_sparse(X, major, minor, *, mask=None, n_features=None):
from ._kernels._nan_mean_kernels import _get_nan_mean_minor

kernel = _get_nan_mean_minor(X.dtype)
kernel.compile()

def __nan_mean_minor(X_part):
mean = cp.zeros(minor, dtype=cp.float64)
nans = cp.zeros(minor, dtype=cp.int32)
tpb = (32,)
bpg_x = math.ceil(X_part.nnz / 32)
bpg = (bpg_x,)
kernel(bpg, tpb, (X_part.indices, X_part.data, mean, nans, mask, X_part.nnz))
return cp.vstack([mean, nans.astype(cp.float64)])[None, ...]

n_blocks = X.blocks.size
mean, nans = X.map_blocks(
__nan_mean_minor,
new_axis=(1,),
chunks=((1,) * n_blocks, (2,), (minor,)),
dtype=cp.float64,
meta=cp.array([]),
).sum(axis=0)
mean /= n_features - nans
return mean


def _nan_mean_major_dask_sparse(X, major, minor, *, mask=None, n_features=None):
from ._kernels._nan_mean_kernels import _get_nan_mean_major

kernel = _get_nan_mean_major(X.dtype)
kernel.compile()

def __nan_mean_major(X_part):
major_part = X_part.shape[0]
mean = cp.zeros(major_part, dtype=cp.float64)
nans = cp.zeros(major_part, dtype=cp.int32)
block = (64,)
grid = (major_part,)
kernel(
grid,
block,
(
X_part.indptr,
X_part.indices,
X_part.data,
mean,
nans,
mask,
major_part,
minor,
),
)
return cp.stack([mean, nans.astype(cp.float64)], axis=1)

output = X.map_blocks(
__nan_mean_major,
chunks=(X.chunks[0], (2,)),
dtype=cp.float64,
meta=cp.array([]),
)
mean = output[:, 0]
nans = output[:, 1]
mean /= n_features - nans
return mean


def _nan_mean_dense_dask(X, axis, *, mask, n_features):
def __nan_mean_dense(X_part):
X_to_use = X_part[:, mask].astype(cp.float64)
sum = cp.nansum(X_to_use, axis=axis).ravel()
nans = cp.sum(cp.isnan(X_to_use), axis=axis).ravel()
if axis == 1:
return cp.stack([sum, nans.astype(cp.float64)], axis=1)
else:
return cp.vstack([sum, nans.astype(cp.float64)])[None, ...]

n_blocks = X.blocks.size
output = X.map_blocks(
__nan_mean_dense,
new_axis=(1,) if axis - 1 else None,
chunks=(X.chunks[0], (2,)) if axis else ((1,) * n_blocks, (2,), (X.shape[1],)),
dtype=cp.float64,
meta=cp.array([]),
)
if axis == 0:
mean, nans = output.sum(axis=0)
else:
mean = output[:, 0]
nans = output[:, 1]
mean /= n_features - nans
return mean


def _nan_mean_minor(X, major, minor, *, mask=None, n_features=None):
from ._kernels._nan_mean_kernels import _get_nan_mean_minor

Expand Down Expand Up @@ -120,6 +216,34 @@ def _nan_mean(X, axis=0, *, mask=None, n_features=None):
mean = _nan_mean_minor(
X, major, minor, mask=mask, n_features=n_features
)
elif isinstance(X, DaskArray):
if isspmatrix_csr(X._meta):
major, minor = X.shape
if mask is None:
mask = cp.ones(X.shape[1], dtype=cp.bool_)
if axis == 0:
n_features = major
mean = _nan_mean_minor_dask_sparse(
X, major, minor, mask=mask, n_features=n_features
)
elif axis == 1:
n_features = minor if n_features is None else n_features
mean = _nan_mean_major_dask_sparse(
X, major, minor, mask=mask, n_features=n_features
)
else:
raise ValueError("axis must be either 0 or 1")
elif isinstance(X._meta, cp.ndarray):
if mask is None:
mask = cp.ones(X.shape[1], dtype=cp.bool_)
if n_features is None:
n_features = X.shape[axis]
mean = _nan_mean_dense_dask(X, axis, mask=mask, n_features=n_features)
# raise NotImplementedError("Dask dense arrays are not supported yet")
else:
raise ValueError(
"Type not supported. Please provide a CuPy ndarray or a CuPy sparse matrix. Or a Dask array with a CuPy ndarray or a CuPy sparse matrix as meta."
)
else:
if mask is None:
mask = cp.ones(X.shape[1], dtype=cp.bool_)
Expand Down
66 changes: 66 additions & 0 deletions tests/dask/test_dask_mean_var.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from __future__ import annotations

import cupy as cp
import pytest
from scanpy.datasets import pbmc3k, pbmc68k_reduced

import rapids_singlecell as rsc
from rapids_singlecell._testing import (
as_dense_cupy_dask_array,
as_sparse_cupy_dask_array,
)
from rapids_singlecell.preprocessing._utils import _get_mean_var

from ..test_score_genes import _create_sparse_nan_matrix # noqa: TID252


@pytest.mark.parametrize("data_kind", ["sparse", "dense"])
@pytest.mark.parametrize("axis", [0, 1])
@pytest.mark.parametrize("dtype", [cp.float32, cp.float64])
def test_mean_var(client, data_kind, axis, dtype):
if data_kind == "dense":
adata = pbmc68k_reduced()
adata.X = adata.X.astype(dtype)
dask_data = adata.copy()
dask_data.X = as_dense_cupy_dask_array(dask_data.X).persist()
rsc.get.anndata_to_GPU(adata)
elif data_kind == "sparse":
adata = pbmc3k()
adata.X = adata.X.astype(dtype)
dask_data = adata.copy()
dask_data.X = as_sparse_cupy_dask_array(dask_data.X).persist()
rsc.get.anndata_to_GPU(adata)

mean, var = _get_mean_var(adata.X, axis=axis)
dask_mean, dask_var = _get_mean_var(dask_data.X, axis=axis)
dask_mean, dask_var = dask_mean.compute(), dask_var.compute()

cp.testing.assert_allclose(mean, dask_mean)
cp.testing.assert_allclose(var, dask_var)


@pytest.mark.parametrize("array_type", ["csr", "dense"])
@pytest.mark.parametrize("percent_nan", [0, 0.3])
def test_sparse_nanmean(client, array_type, percent_nan):
"""Needs to be fixed"""
from rapids_singlecell.tools._utils import _nan_mean

R, C = 100, 50

# sparse matrix with nan
S = _create_sparse_nan_matrix(R, C, percent_zero=0.3, percent_nan=percent_nan)
S = S.astype(cp.float64)
A = S.toarray()
A = rsc.get.X_to_GPU(A)

if array_type == "dense":
S = as_dense_cupy_dask_array(A).persist()
else:
S = as_sparse_cupy_dask_array(S).persist()

cp.testing.assert_allclose(
_nan_mean(A, 1).ravel(), (_nan_mean(S, 1)).ravel().compute()
)
cp.testing.assert_allclose(
_nan_mean(A, 0).ravel(), (_nan_mean(S, 0)).ravel().compute()
)
Loading
Loading