11from __future__ import annotations
22
33import logging
4- import os
54
6- import joblib
75import numpy as np
86import scanpy as sc
9- from scipy .stats import mode
107from sklearn .neighbors import KNeighborsClassifier
118
129from popv import settings
13-
14- if settings .cuml :
15- import rapids_singlecell as rsc
1610from popv .algorithms ._base_algorithm import BaseAlgorithm
1711
1812
@@ -95,48 +89,22 @@ def compute_integration(self, adata):
9589 AnnData object. Modified inplace.
9690 """
9791 logging .info ("Integrating data with bbknn" )
98- if (
99- adata .uns ["_prediction_mode" ] == "inference"
100- and "X_umap_bbknn" in adata .obsm
101- and not settings .recompute_embeddings
102- ):
103- index = joblib .load (os .path .join (adata .uns ["_save_path_trained_models" ], "pynndescent_index.joblib" ))
104- query_features = adata .obsm ["X_pca" ][adata .obs ["_dataset" ] == "query" , :]
105- indices , _ = index .query (query_features .astype (np .float32 ), k = 5 )
106-
107- neighbor_embedding = adata .obsm ["X_umap_bbknn" ][adata .obs ["_dataset" ] == "ref" , :][indices ].astype (
108- np .float32
109- )
110- adata .obsm [self .umap_key ][adata .obs ["_dataset" ] == "query" , :] = np .mean (neighbor_embedding , axis = 1 )
111- adata .obsm [self .umap_key ] = adata .obsm [self .umap_key ].astype (np .float32 )
112-
113- neighbor_probabilities = adata .obs [f"{ self .result_key } _probabilities" ][adata .obs ["_dataset" ] == "ref" , :][
114- indices
115- ].astype (np .float32 )
116- adata .obs .loc [adata .obs ["_dataset" ] == "query" , f"{ self .result_key } _probabilities" ] = np .mean (
117- neighbor_probabilities , axis = 1
118- )
119-
120- neighbor_prediction = adata .obs [f"{ self .result_key } " ][adata .obs ["_dataset" ] == "ref" , :][indices ].astype (
121- np .float32
92+ if len (adata .obs [self .batch_key ].unique ()) > 100 :
93+ self .method_kwargs ["neighbors_within_batch" ] = 1
94+ if settings .cuml :
95+ import rapids_singlecell as rsc
96+
97+ self .method_kwargs .pop ("approx" , None ) # approx not supported in rsc
98+ self .method_kwargs .pop ("use_annoy" , None ) # use_annoy not supported in rsc
99+ rsc .pp .bbknn (
100+ adata , batch_key = self .batch_key , use_rep = "X_pca" , algorithm = "ivfflat" , ** self .method_kwargs , trim = 0
122101 )
123- adata .obs .loc [adata .obs ["_dataset" ] == "query" , f"{ self .result_key } " ] = mode (neighbor_prediction , axis = 1 )
124102 else :
125- if len (adata .obs [self .batch_key ].unique ()) > 100 :
126- logging .warning ("Using PyNNDescent instead of FAISS as high number of batches leads to OOM." )
127- self .method_kwargs ["neighbors_within_batch" ] = 1 # Reduce memory usage.
128- self .method_kwargs ["pynndescent_n_neighbors" ] = 10 # Reduce memory usage.
129- sc .external .pp .bbknn (
130- adata , batch_key = self .batch_key , use_faiss = False , use_rep = "X_pca" , ** self .method_kwargs
131- )
132- else :
133- sc .external .pp .bbknn (
134- adata , batch_key = self .batch_key , use_faiss = True , use_rep = "X_pca" , ** self .method_kwargs
135- )
103+ sc .external .pp .bbknn (adata , batch_key = self .batch_key , use_rep = "X_pca" , ** self .method_kwargs )
136104
137105 def predict (self , adata ):
138106 """
139- Predict celltypes using Celltypist .
107+ Predict celltypes using BBKNN kNN .
140108
141109 Parameters
142110 ----------
@@ -168,7 +136,9 @@ def predict(self, adata):
168136 adata .obs [self .result_key ] = adata .uns ["label_categories" ][knn .predict (test_distances )]
169137
170138 if self .return_probabilities :
171- adata .obs [f"{ self .result_key } _probabilities" ] = np .max (knn .predict_proba (test_distances ), axis = 1 )
139+ probabilities = knn .predict_proba (test_distances )
140+ adata .obs [f"{ self .result_key } _probabilities" ] = np .max (probabilities , axis = 1 )
141+ adata .obsm [f"{ self .result_key } _probabilities" ] = probabilities
172142
173143 def compute_umap (self , adata ):
174144 """
@@ -180,8 +150,10 @@ def compute_umap(self, adata):
180150 AnnData object. Results are stored in adata.obsm[self.umap_key].
181151 """
182152 if self .compute_umap_embedding :
183- logging .info (f'Saving UMAP of bbknn results to adata.obs ["{ self .embedding_key } "]' )
153+ logging .info (f'Saving UMAP of BBKNN results to adata.obsm ["{ self .umap_key } "]' )
184154 if settings .cuml :
155+ import rapids_singlecell as rsc
156+
185157 rsc .pp .neighbors (adata , use_rep = self .embedding_key )
186158 adata .obsm [self .umap_key ] = rsc .tl .umap (adata , copy = True , ** self .embedding_kwargs ).obsm ["X_umap" ]
187159 else :
0 commit comments