diff --git a/docs/release-notes/3340.feature.md b/docs/release-notes/3340.feature.md new file mode 100644 index 0000000000..08a78f2ec0 --- /dev/null +++ b/docs/release-notes/3340.feature.md @@ -0,0 +1 @@ +{func}`scanpy.pp.highly_variable_genes` flavors `seurat_v3` and `seurat_v3_paper` are now `dask`-compatible {smaller}`I Gold` diff --git a/src/scanpy/_utils/__init__.py b/src/scanpy/_utils/__init__.py index d97b23f7ae..abc72c55fc 100644 --- a/src/scanpy/_utils/__init__.py +++ b/src/scanpy/_utils/__init__.py @@ -845,8 +845,11 @@ def _check_nonnegative_integers_in_mem(X: _MemoryArray) -> bool: @check_nonnegative_integers.register(DaskArray) -def _check_nonnegative_integers_dask(X: DaskArray) -> DaskArray: - return X.map_blocks(check_nonnegative_integers, dtype=bool, drop_axis=(0, 1)) +def _check_nonnegative_integers_dask(X: DaskArray) -> bool: + X_nonnegative: DaskArray = X.map_blocks( + check_nonnegative_integers, dtype=bool, drop_axis=(0, 1) + ) + return X_nonnegative.any().compute() def select_groups( diff --git a/src/scanpy/preprocessing/_highly_variable_genes.py b/src/scanpy/preprocessing/_highly_variable_genes.py index e34340b256..aafc7e8557 100644 --- a/src/scanpy/preprocessing/_highly_variable_genes.py +++ b/src/scanpy/preprocessing/_highly_variable_genes.py @@ -2,6 +2,7 @@ import warnings from dataclasses import dataclass +from functools import singledispatch from inspect import signature from typing import TYPE_CHECKING, cast @@ -26,6 +27,62 @@ from numpy.typing import NDArray +@singledispatch +def clip_square_sum( + data_batch: np.ndarray, clip_val: np.ndarray +) -> tuple[np.ndarray, np.ndarray]: + batch_counts = data_batch.astype(np.float64).copy() + clip_val_broad = np.broadcast_to(clip_val, batch_counts.shape) + np.putmask( + batch_counts, + batch_counts > clip_val_broad, + clip_val_broad, + ) + + squared_batch_counts_sum = np.square(batch_counts).sum(axis=0) + batch_counts_sum = batch_counts.sum(axis=0) + return squared_batch_counts_sum, batch_counts_sum + + +@clip_square_sum.register(DaskArray) +def _(data_batch: DaskArray, clip_val: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + n_blocks = data_batch.blocks.size + + def sum_and_sum_squares_clipped_from_block(block): + return np.vstack(clip_square_sum(block, clip_val))[None, ...] + + squared_batch_counts_sum, batch_counts_sum = ( + data_batch.map_blocks( + sum_and_sum_squares_clipped_from_block, + new_axis=(1,), + chunks=((1,) * n_blocks, (2,), (data_batch.shape[1],)), + meta=np.array([]), + dtype=np.float64, + ) + .sum(axis=0) + .compute() + ) + return squared_batch_counts_sum, batch_counts_sum + + +@clip_square_sum.register(sp_sparse.spmatrix) +def _( + data_batch: sp_sparse.spmatrix, clip_val: np.ndarray +) -> tuple[np.ndarray, np.ndarray]: + if sp_sparse.isspmatrix_csr(data_batch): + batch_counts = data_batch + else: + batch_counts = sp_sparse.csr_matrix(data_batch) + + return _sum_and_sum_squares_clipped( + batch_counts.indices, + batch_counts.data, + n_cols=batch_counts.shape[1], + clip_val=clip_val, + nnz=batch_counts.nnz, + ) + + def _highly_variable_genes_seurat_v3( adata: AnnData, *, @@ -89,6 +146,10 @@ def _highly_variable_genes_seurat_v3( data_batch = data[batch_info == b] mean, var = _get_mean_var(data_batch) + if isinstance(mean, DaskArray) and isinstance(var, DaskArray): + import dask.array as da + + mean, var = da.compute(mean, var) not_const = var > 0 estimat_var = np.zeros(data.shape[1], dtype=np.float64) @@ -103,30 +164,9 @@ def _highly_variable_genes_seurat_v3( N = data_batch.shape[0] vmax = np.sqrt(N) clip_val = reg_std * vmax + mean - if sp_sparse.issparse(data_batch): - if sp_sparse.isspmatrix_csr(data_batch): - batch_counts = data_batch - else: - batch_counts = sp_sparse.csr_matrix(data_batch) - - squared_batch_counts_sum, batch_counts_sum = _sum_and_sum_squares_clipped( - batch_counts.indices, - batch_counts.data, - n_cols=batch_counts.shape[1], - clip_val=clip_val, - nnz=batch_counts.nnz, - ) - else: - batch_counts = data_batch.astype(np.float64).copy() - clip_val_broad = np.broadcast_to(clip_val, batch_counts.shape) - np.putmask( - batch_counts, - batch_counts > clip_val_broad, - clip_val_broad, - ) - - squared_batch_counts_sum = np.square(batch_counts).sum(axis=0) - batch_counts_sum = batch_counts.sum(axis=0) + squared_batch_counts_sum, batch_counts_sum = clip_square_sum( + data_batch, clip_val + ) norm_gene_var = (1 / ((N - 1) * np.square(reg_std))) * ( (N * np.square(mean)) @@ -198,6 +238,7 @@ def _highly_variable_genes_seurat_v3( df = df.iloc[df["highly_variable"].to_numpy(), :] return df + return None # parallel=False needed for accuracy diff --git a/tests/test_highly_variable_genes.py b/tests/test_highly_variable_genes.py index 7d9fdac9fa..5f3ec82734 100644 --- a/tests/test_highly_variable_genes.py +++ b/tests/test_highly_variable_genes.py @@ -676,23 +676,32 @@ def test_subset_inplace_consistency(flavor, array_type, batch_key): assert adatas[True].var_names.equals(dfs[True].index) -@pytest.mark.parametrize("flavor", ["seurat", "cell_ranger"]) +@needs.skmisc +@needs.dask +@pytest.mark.parametrize( + "flavor", ["seurat", "cell_ranger", "seurat_v3", "seurat_v3_paper"] +) @pytest.mark.parametrize("batch_key", [None, "batch"], ids=["single", "batched"]) @pytest.mark.parametrize( "to_dask", [p for p in ARRAY_TYPES if "dask" in p.values[0].__name__] ) def test_dask_consistency(adata: AnnData, flavor, batch_key, to_dask): + # current blob produces singularities in loess....maybe a bad sign of the data? + if "seurat_v3" in flavor: + adata = pbmc3k() adata.X = np.abs(adata.X).astype(int) if batch_key is not None: adata.obs[batch_key] = np.tile(["a", "b"], adata.shape[0] // 2) sc.pp.normalize_total(adata, target_sum=1e4) sc.pp.log1p(adata) - adata_dask = adata.copy() adata_dask.X = to_dask(adata_dask.X) + adata_dask.X = adata_dask.X.rechunk((adata_dask.X.chunksize[0], -1)) output_mem, output_dask = ( - sc.pp.highly_variable_genes(ad, flavor=flavor, n_top_genes=15, inplace=False) + sc.pp.highly_variable_genes( + ad, flavor=flavor, n_top_genes=15, inplace=False, batch_key=batch_key + ) for ad in [adata, adata_dask] )