Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/release-notes/3340.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{func}`scanpy.pp.highly_variable_genes` flavors `seurat_v3` and `seurat_v3_paper` are now `dask`-compatible {smaller}`I Gold`
7 changes: 5 additions & 2 deletions src/scanpy/_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -845,8 +845,11 @@ def _check_nonnegative_integers_in_mem(X: _MemoryArray) -> bool:


@check_nonnegative_integers.register(DaskArray)
def _check_nonnegative_integers_dask(X: DaskArray) -> DaskArray:
return X.map_blocks(check_nonnegative_integers, dtype=bool, drop_axis=(0, 1))
def _check_nonnegative_integers_dask(X: DaskArray) -> bool:
X_nonnegative: DaskArray = X.map_blocks(
check_nonnegative_integers, dtype=bool, drop_axis=(0, 1)
)
return X_nonnegative.any().compute()


def select_groups(
Expand Down
89 changes: 65 additions & 24 deletions src/scanpy/preprocessing/_highly_variable_genes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import warnings
from dataclasses import dataclass
from functools import singledispatch
from inspect import signature
from typing import TYPE_CHECKING, cast

Expand All @@ -26,6 +27,62 @@
from numpy.typing import NDArray


@singledispatch
def clip_square_sum(
data_batch: np.ndarray, clip_val: np.ndarray
) -> tuple[np.ndarray, np.ndarray]:
batch_counts = data_batch.astype(np.float64).copy()
clip_val_broad = np.broadcast_to(clip_val, batch_counts.shape)
np.putmask(
batch_counts,
batch_counts > clip_val_broad,
clip_val_broad,
)

squared_batch_counts_sum = np.square(batch_counts).sum(axis=0)
batch_counts_sum = batch_counts.sum(axis=0)
return squared_batch_counts_sum, batch_counts_sum


@clip_square_sum.register(DaskArray)
def _(data_batch: DaskArray, clip_val: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
n_blocks = data_batch.blocks.size

def sum_and_sum_squares_clipped_from_block(block):
return np.vstack(clip_square_sum(block, clip_val))[None, ...]

squared_batch_counts_sum, batch_counts_sum = (
data_batch.map_blocks(
sum_and_sum_squares_clipped_from_block,
new_axis=(1,),
chunks=((1,) * n_blocks, (2,), (data_batch.shape[1],)),
meta=np.array([]),
dtype=np.float64,
)
.sum(axis=0)
.compute()
)
return squared_batch_counts_sum, batch_counts_sum


@clip_square_sum.register(sp_sparse.spmatrix)
def _(
data_batch: sp_sparse.spmatrix, clip_val: np.ndarray
) -> tuple[np.ndarray, np.ndarray]:
if sp_sparse.isspmatrix_csr(data_batch):
batch_counts = data_batch
else:
batch_counts = sp_sparse.csr_matrix(data_batch)

Check warning on line 75 in src/scanpy/preprocessing/_highly_variable_genes.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/preprocessing/_highly_variable_genes.py#L75

Added line #L75 was not covered by tests

return _sum_and_sum_squares_clipped(
batch_counts.indices,
batch_counts.data,
n_cols=batch_counts.shape[1],
clip_val=clip_val,
nnz=batch_counts.nnz,
)


def _highly_variable_genes_seurat_v3(
adata: AnnData,
*,
Expand Down Expand Up @@ -89,6 +146,10 @@
data_batch = data[batch_info == b]

mean, var = _get_mean_var(data_batch)
if isinstance(mean, DaskArray) and isinstance(var, DaskArray):
import dask.array as da

mean, var = da.compute(mean, var)
not_const = var > 0
estimat_var = np.zeros(data.shape[1], dtype=np.float64)

Expand All @@ -103,30 +164,9 @@
N = data_batch.shape[0]
vmax = np.sqrt(N)
clip_val = reg_std * vmax + mean
if sp_sparse.issparse(data_batch):
if sp_sparse.isspmatrix_csr(data_batch):
batch_counts = data_batch
else:
batch_counts = sp_sparse.csr_matrix(data_batch)

squared_batch_counts_sum, batch_counts_sum = _sum_and_sum_squares_clipped(
batch_counts.indices,
batch_counts.data,
n_cols=batch_counts.shape[1],
clip_val=clip_val,
nnz=batch_counts.nnz,
)
else:
batch_counts = data_batch.astype(np.float64).copy()
clip_val_broad = np.broadcast_to(clip_val, batch_counts.shape)
np.putmask(
batch_counts,
batch_counts > clip_val_broad,
clip_val_broad,
)

squared_batch_counts_sum = np.square(batch_counts).sum(axis=0)
batch_counts_sum = batch_counts.sum(axis=0)
squared_batch_counts_sum, batch_counts_sum = clip_square_sum(
data_batch, clip_val
)

norm_gene_var = (1 / ((N - 1) * np.square(reg_std))) * (
(N * np.square(mean))
Expand Down Expand Up @@ -198,6 +238,7 @@
df = df.iloc[df["highly_variable"].to_numpy(), :]

return df
return None


# parallel=False needed for accuracy
Expand Down
15 changes: 12 additions & 3 deletions tests/test_highly_variable_genes.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,23 +676,32 @@ def test_subset_inplace_consistency(flavor, array_type, batch_key):
assert adatas[True].var_names.equals(dfs[True].index)


@pytest.mark.parametrize("flavor", ["seurat", "cell_ranger"])
@needs.skmisc
@needs.dask
@pytest.mark.parametrize(
"flavor", ["seurat", "cell_ranger", "seurat_v3", "seurat_v3_paper"]
)
@pytest.mark.parametrize("batch_key", [None, "batch"], ids=["single", "batched"])
@pytest.mark.parametrize(
"to_dask", [p for p in ARRAY_TYPES if "dask" in p.values[0].__name__]
)
def test_dask_consistency(adata: AnnData, flavor, batch_key, to_dask):
# current blob produces singularities in loess....maybe a bad sign of the data?
if "seurat_v3" in flavor:
adata = pbmc3k()
adata.X = np.abs(adata.X).astype(int)
if batch_key is not None:
adata.obs[batch_key] = np.tile(["a", "b"], adata.shape[0] // 2)
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)

adata_dask = adata.copy()
adata_dask.X = to_dask(adata_dask.X)
adata_dask.X = adata_dask.X.rechunk((adata_dask.X.chunksize[0], -1))

output_mem, output_dask = (
sc.pp.highly_variable_genes(ad, flavor=flavor, n_top_genes=15, inplace=False)
sc.pp.highly_variable_genes(
ad, flavor=flavor, n_top_genes=15, inplace=False, batch_key=batch_key
)
for ad in [adata, adata_dask]
)

Expand Down
Loading