From d23b8e037a64f2959f399a98801f45b5b0c09133 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Tue, 5 Nov 2024 12:00:46 +0100 Subject: [PATCH 1/4] (feat): seurat v3 with dask --- .../preprocessing/_highly_variable_genes.py | 89 ++++++++++++++----- tests/test_highly_variable_genes.py | 15 +++- 2 files changed, 77 insertions(+), 27 deletions(-) diff --git a/src/scanpy/preprocessing/_highly_variable_genes.py b/src/scanpy/preprocessing/_highly_variable_genes.py index fa7971d21e..fe9fab5115 100644 --- a/src/scanpy/preprocessing/_highly_variable_genes.py +++ b/src/scanpy/preprocessing/_highly_variable_genes.py @@ -2,6 +2,7 @@ import warnings from dataclasses import dataclass +from functools import singledispatch from inspect import signature from typing import TYPE_CHECKING, cast @@ -26,6 +27,62 @@ from numpy.typing import NDArray +@singledispatch +def clip_square_sum( + data_batch: np.ndarray, clip_val: np.ndarray +) -> tuple[np.ndarray, np.ndarray]: + batch_counts = data_batch.astype(np.float64).copy() + clip_val_broad = np.broadcast_to(clip_val, batch_counts.shape) + np.putmask( + batch_counts, + batch_counts > clip_val_broad, + clip_val_broad, + ) + + squared_batch_counts_sum = np.square(batch_counts).sum(axis=0) + batch_counts_sum = batch_counts.sum(axis=0) + return squared_batch_counts_sum, batch_counts_sum + + +@clip_square_sum.register(DaskArray) +def _(data_batch: DaskArray, clip_val: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + n_blocks = data_batch.blocks.size + + def sum_and_sum_squares_clipped_from_block(block): + return np.vstack(clip_square_sum(block, clip_val))[None, ...] + + squared_batch_counts_sum, batch_counts_sum = ( + data_batch.map_blocks( + sum_and_sum_squares_clipped_from_block, + new_axis=(1,), + chunks=((1,) * n_blocks, (2,), (data_batch.shape[1],)), + meta=np.array([]), + dtype=np.float64, + ) + .sum(axis=0) + .compute() + ) + return squared_batch_counts_sum, batch_counts_sum + + +@clip_square_sum.register(sp_sparse.spmatrix) +def _( + data_batch: sp_sparse.spmatrix, clip_val: np.ndarray +) -> tuple[np.ndarray, np.ndarray]: + if sp_sparse.isspmatrix_csr(data_batch): + batch_counts = data_batch + else: + batch_counts = sp_sparse.csr_matrix(data_batch) + + return _sum_and_sum_squares_clipped( + batch_counts.indices, + batch_counts.data, + n_cols=batch_counts.shape[1], + clip_val=clip_val, + nnz=batch_counts.nnz, + ) + + def _highly_variable_genes_seurat_v3( adata: AnnData, *, @@ -89,6 +146,10 @@ def _highly_variable_genes_seurat_v3( data_batch = data[batch_info == b] mean, var = _get_mean_var(data_batch) + if isinstance(mean, DaskArray) and isinstance(var, DaskArray): + import dask.array as da + + mean, var = da.compute(mean, var) not_const = var > 0 estimat_var = np.zeros(data.shape[1], dtype=np.float64) @@ -103,30 +164,9 @@ def _highly_variable_genes_seurat_v3( N = data_batch.shape[0] vmax = np.sqrt(N) clip_val = reg_std * vmax + mean - if sp_sparse.issparse(data_batch): - if sp_sparse.isspmatrix_csr(data_batch): - batch_counts = data_batch - else: - batch_counts = sp_sparse.csr_matrix(data_batch) - - squared_batch_counts_sum, batch_counts_sum = _sum_and_sum_squares_clipped( - batch_counts.indices, - batch_counts.data, - n_cols=batch_counts.shape[1], - clip_val=clip_val, - nnz=batch_counts.nnz, - ) - else: - batch_counts = data_batch.astype(np.float64).copy() - clip_val_broad = np.broadcast_to(clip_val, batch_counts.shape) - np.putmask( - batch_counts, - batch_counts > clip_val_broad, - clip_val_broad, - ) - - squared_batch_counts_sum = np.square(batch_counts).sum(axis=0) - batch_counts_sum = batch_counts.sum(axis=0) + squared_batch_counts_sum, batch_counts_sum = clip_square_sum( + data_batch, clip_val + ) norm_gene_var = (1 / ((N - 1) * np.square(reg_std))) * ( (N * np.square(mean)) @@ -198,6 +238,7 @@ def _highly_variable_genes_seurat_v3( df = df.iloc[df["highly_variable"].to_numpy(), :] return df + return None @numba.njit(cache=True) diff --git a/tests/test_highly_variable_genes.py b/tests/test_highly_variable_genes.py index 7d9fdac9fa..5f3ec82734 100644 --- a/tests/test_highly_variable_genes.py +++ b/tests/test_highly_variable_genes.py @@ -676,23 +676,32 @@ def test_subset_inplace_consistency(flavor, array_type, batch_key): assert adatas[True].var_names.equals(dfs[True].index) -@pytest.mark.parametrize("flavor", ["seurat", "cell_ranger"]) +@needs.skmisc +@needs.dask +@pytest.mark.parametrize( + "flavor", ["seurat", "cell_ranger", "seurat_v3", "seurat_v3_paper"] +) @pytest.mark.parametrize("batch_key", [None, "batch"], ids=["single", "batched"]) @pytest.mark.parametrize( "to_dask", [p for p in ARRAY_TYPES if "dask" in p.values[0].__name__] ) def test_dask_consistency(adata: AnnData, flavor, batch_key, to_dask): + # current blob produces singularities in loess....maybe a bad sign of the data? + if "seurat_v3" in flavor: + adata = pbmc3k() adata.X = np.abs(adata.X).astype(int) if batch_key is not None: adata.obs[batch_key] = np.tile(["a", "b"], adata.shape[0] // 2) sc.pp.normalize_total(adata, target_sum=1e4) sc.pp.log1p(adata) - adata_dask = adata.copy() adata_dask.X = to_dask(adata_dask.X) + adata_dask.X = adata_dask.X.rechunk((adata_dask.X.chunksize[0], -1)) output_mem, output_dask = ( - sc.pp.highly_variable_genes(ad, flavor=flavor, n_top_genes=15, inplace=False) + sc.pp.highly_variable_genes( + ad, flavor=flavor, n_top_genes=15, inplace=False, batch_key=batch_key + ) for ad in [adata, adata_dask] ) From 11a711e97b84c01dcd73d5b17b8e68c6b6d047c6 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Tue, 5 Nov 2024 12:04:14 +0100 Subject: [PATCH 2/4] (chore): add release note --- docs/release-notes/3340.feature.md | 1 + 1 file changed, 1 insertion(+) create mode 100644 docs/release-notes/3340.feature.md diff --git a/docs/release-notes/3340.feature.md b/docs/release-notes/3340.feature.md new file mode 100644 index 0000000000..08a78f2ec0 --- /dev/null +++ b/docs/release-notes/3340.feature.md @@ -0,0 +1 @@ +{func}`scanpy.pp.highly_variable_genes` flavors `seurat_v3` and `seurat_v3_paper` are now `dask`-compatible {smaller}`I Gold` From 78fc8ee8ce14760aefc6914bfac8612e996d55d2 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Mon, 11 Nov 2024 12:39:58 +0100 Subject: [PATCH 3/4] (chore): try aborting early --- src/scanpy/_utils/__init__.py | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/src/scanpy/_utils/__init__.py b/src/scanpy/_utils/__init__.py index d97b23f7ae..7011ef5c1a 100644 --- a/src/scanpy/_utils/__init__.py +++ b/src/scanpy/_utils/__init__.py @@ -845,8 +845,26 @@ def _check_nonnegative_integers_in_mem(X: _MemoryArray) -> bool: @check_nonnegative_integers.register(DaskArray) -def _check_nonnegative_integers_dask(X: DaskArray) -> DaskArray: - return X.map_blocks(check_nonnegative_integers, dtype=bool, drop_axis=(0, 1)) +def _check_nonnegative_integers_dask(X: DaskArray) -> bool: + import dask.distributed as dd + + X_nonnegative: DaskArray = X.map_blocks( + check_nonnegative_integers, dtype=bool, drop_axis=(0, 1) + ) + try: + client = dd.default_client() + has_client = True + except ValueError: + has_client = False + if has_client: + blocks = X_nonnegative.to_delayed().ravel() + return any( + not block.result() + for block in dd.as_completed( + client.submit(lambda block: block.compute(), block) for block in blocks + ) + ) + return X_nonnegative.compute() def select_groups( From f83e3f38711347c8775ed4d9b662902ee07427b5 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Wed, 13 Nov 2024 20:41:07 +0100 Subject: [PATCH 4/4] (fix): use `.any()` to prevent erroneous computation --- src/scanpy/_utils/__init__.py | 17 +---------------- 1 file changed, 1 insertion(+), 16 deletions(-) diff --git a/src/scanpy/_utils/__init__.py b/src/scanpy/_utils/__init__.py index 7011ef5c1a..abc72c55fc 100644 --- a/src/scanpy/_utils/__init__.py +++ b/src/scanpy/_utils/__init__.py @@ -846,25 +846,10 @@ def _check_nonnegative_integers_in_mem(X: _MemoryArray) -> bool: @check_nonnegative_integers.register(DaskArray) def _check_nonnegative_integers_dask(X: DaskArray) -> bool: - import dask.distributed as dd - X_nonnegative: DaskArray = X.map_blocks( check_nonnegative_integers, dtype=bool, drop_axis=(0, 1) ) - try: - client = dd.default_client() - has_client = True - except ValueError: - has_client = False - if has_client: - blocks = X_nonnegative.to_delayed().ravel() - return any( - not block.result() - for block in dd.as_completed( - client.submit(lambda block: block.compute(), block) for block in blocks - ) - ) - return X_nonnegative.compute() + return X_nonnegative.any().compute() def select_groups(