Skip to content

Commit 9adbeb5

Browse files
authored
Merge pull request #101 from YosefLab/faiss_nn
Faiss NN
2 parents 598da03 + 4403e75 commit 9adbeb5

22 files changed

+408
-265
lines changed

.github/workflows/build.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@ jobs:
1515
runs-on: ubuntu-latest
1616
steps:
1717
- uses: actions/checkout@v3
18-
- name: Set up Python 3.10
18+
- name: Set up Python 3.12
1919
uses: actions/setup-python@v4
2020
with:
21-
python-version: "3.10"
21+
python-version: "3.12"
2222
cache: "pip"
2323
cache-dependency-path: "**/pyproject.toml"
2424
- name: Install build dependencies

.github/workflows/release.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ jobs:
112112

113113
- uses: actions/setup-python@v4
114114
with:
115-
python-version: "3.11"
115+
python-version: "3.12"
116116

117117
- run: pip install build
118118

@@ -158,6 +158,6 @@ jobs:
158158
cache-from: type=registry,ref=ghcr.io/yoseflab/popv:buildcache
159159
cache-to: type=inline,ref=ghcr.io/yoseflab/popv:buildcache
160160
target: build
161-
tags: ghcr.io/yoseflab/popv:py3.11-cu12-${{ inputs.tag }}-${{ matrix.dependencies }}
161+
tags: ghcr.io/yoseflab/popv:py3.12-cu12-${{ inputs.tag }}-${{ matrix.dependencies }}
162162
build-args: |
163163
DEPENDENCIES=${{ matrix.dependencies }}

.github/workflows/test_linux_cuda.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ jobs:
3030
shell: bash -e {0} # -e to fail on error
3131

3232
container:
33-
image: ghcr.io/yoseflab/popv:py3.11-cu12-0.5.2.post1-
33+
image: ghcr.io/yoseflab/popv:py3.12-cu12-0.6.0-
3434
options: --user root --gpus all --pull always
3535

3636
name: integration
@@ -54,6 +54,7 @@ jobs:
5454
python -m uv pip install --system "PopV[tests] @ ."
5555
python -m pip install jax[cuda]
5656
python -m pip install nvidia-nccl-cu12
57+
python -m pip install faiss-gpu-cu12
5758
5859
- name: Run pytest
5960
env:

docs/tutorials/notebooks/tabula_sapiens_tutorial.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1829,4 +1829,4 @@
18291829
},
18301830
"nbformat": 4,
18311831
"nbformat_minor": 5
1832-
}
1832+
}

popv/_faiss_knn_classifier.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
import os
2+
3+
import faiss
4+
import numpy as np
5+
import pandas as pd
6+
7+
from popv import settings
8+
9+
10+
class FAISSKNNProba:
11+
def __init__(self, n_neighbors=5):
12+
self.n_neighbors = n_neighbors
13+
self.index = None
14+
if settings.cuml and faiss.get_num_gpus() > 0:
15+
self.res = faiss.StandardGpuResources()
16+
self.use_gpu = True
17+
else:
18+
self.res = None
19+
self.use_gpu = False
20+
21+
def fit(self, X, labels):
22+
X = X.astype("float32")
23+
self.labels = labels
24+
d = X.shape[1]
25+
26+
cpu_index = faiss.IndexFlatL2(d)
27+
28+
if self.use_gpu:
29+
gpu_index = faiss.index_cpu_to_gpu(self.res, settings.device, cpu_index)
30+
gpu_index.add(X)
31+
self.index = faiss.index_gpu_to_cpu(gpu_index)
32+
else:
33+
cpu_index.add(X)
34+
self.index = cpu_index
35+
36+
return self
37+
38+
def query(self, X, n_neighbors):
39+
X = X.astype("float32")
40+
if self.use_gpu:
41+
index = faiss.index_cpu_to_gpu(self.res, settings.device, self.index)
42+
else:
43+
index = self.index
44+
_, I = index.search(X, n_neighbors)
45+
return I
46+
47+
def predict(self, X, classes):
48+
X = X.astype("float32")
49+
if self.use_gpu:
50+
index = faiss.index_cpu_to_gpu(self.res, settings.device, self.index)
51+
else:
52+
index = self.index
53+
_, I = index.search(X, self.n_neighbors)
54+
preds = classes[np.array([np.bincount(self.labels[i], minlength=len(classes)).argmax() for i in I])]
55+
return preds
56+
57+
def predict_proba(self, X, classes):
58+
X = X.astype("float32")
59+
if self.use_gpu:
60+
index = faiss.index_cpu_to_gpu(self.res, settings.device, self.index)
61+
else:
62+
index = self.index
63+
_, I = index.search(X, self.n_neighbors)
64+
probas = []
65+
for neighbors in I:
66+
counts = np.bincount(self.labels[neighbors], minlength=len(classes))
67+
probas.append(counts / counts.sum())
68+
return np.array(probas)
69+
70+
def save(self, path_prefix):
71+
"""
72+
Save FAISS index and metadata (labels + classes) to disk.
73+
74+
Parameters
75+
----------
76+
path_prefix : str
77+
Path prefix, e.g. "models/faiss_knn"
78+
"""
79+
faiss.write_index(self.index, f"{path_prefix}.index")
80+
81+
@classmethod
82+
def load(cls, path_prefix, index, n_neighbors=5):
83+
"""
84+
Load FAISS index and metadata from disk.
85+
86+
Parameters
87+
----------
88+
path_prefix : str
89+
Path prefix used in save()
90+
n_neighbors : int
91+
Number of neighbors to use
92+
"""
93+
obj = cls(n_neighbors=n_neighbors)
94+
obj.index = faiss.read_index(os.path.join(path_prefix, f"{index}.index"))
95+
labels = pd.read_csv(os.path.join(path_prefix, "ref_labels.csv"), index_col=0)
96+
obj.labels = labels.iloc[:, 0].to_numpy()
97+
return obj

popv/_settings.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def __init__(
6868
recompute_embeddings: bool = False,
6969
return_probabilities: bool = True,
7070
compute_umap_embedding: bool = True,
71+
device: int | None = 0,
7172
):
7273
"""Set up Config manager for PopV."""
7374
self.seed = seed
@@ -80,6 +81,7 @@ def __init__(
8081
self.recompute_embeddings = recompute_embeddings
8182
self.return_probabilities = return_probabilities
8283
self.compute_umap_embedding = compute_umap_embedding
84+
self.device = device
8385

8486
@property
8587
def logging_dir(self) -> Path:
@@ -198,5 +200,14 @@ def return_probabilities(self) -> bool:
198200
def return_probabilities(self, return_probabilities: bool):
199201
self._return_probabilities = return_probabilities
200202

203+
@property
204+
def device(self) -> int | None:
205+
"""GPU device to use for acceleration."""
206+
return self._device
207+
208+
@device.setter
209+
def device(self, device: int | None):
210+
self._device = device
211+
201212

202213
settings = Config()

popv/algorithms/_base_algorithm.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ def __init__(
3535
umap_key
3636
Key in obsm in which UMAP embedding of integrated data is stored.
3737
"""
38+
if settings.cuml:
39+
import rapids_singlecell as rsc # noqa: F401
3840
self.batch_key = batch_key
3941
self.labels_key = labels_key
4042
if seen_result_key is None:

popv/algorithms/_bbknn.py

Lines changed: 17 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,12 @@
11
from __future__ import annotations
22

33
import logging
4-
import os
54

6-
import joblib
75
import numpy as np
86
import scanpy as sc
9-
from scipy.stats import mode
107
from sklearn.neighbors import KNeighborsClassifier
118

129
from popv import settings
13-
14-
if settings.cuml:
15-
import rapids_singlecell as rsc
1610
from popv.algorithms._base_algorithm import BaseAlgorithm
1711

1812

@@ -95,48 +89,22 @@ def compute_integration(self, adata):
9589
AnnData object. Modified inplace.
9690
"""
9791
logging.info("Integrating data with bbknn")
98-
if (
99-
adata.uns["_prediction_mode"] == "inference"
100-
and "X_umap_bbknn" in adata.obsm
101-
and not settings.recompute_embeddings
102-
):
103-
index = joblib.load(os.path.join(adata.uns["_save_path_trained_models"], "pynndescent_index.joblib"))
104-
query_features = adata.obsm["X_pca"][adata.obs["_dataset"] == "query", :]
105-
indices, _ = index.query(query_features.astype(np.float32), k=5)
106-
107-
neighbor_embedding = adata.obsm["X_umap_bbknn"][adata.obs["_dataset"] == "ref", :][indices].astype(
108-
np.float32
109-
)
110-
adata.obsm[self.umap_key][adata.obs["_dataset"] == "query", :] = np.mean(neighbor_embedding, axis=1)
111-
adata.obsm[self.umap_key] = adata.obsm[self.umap_key].astype(np.float32)
112-
113-
neighbor_probabilities = adata.obs[f"{self.result_key}_probabilities"][adata.obs["_dataset"] == "ref", :][
114-
indices
115-
].astype(np.float32)
116-
adata.obs.loc[adata.obs["_dataset"] == "query", f"{self.result_key}_probabilities"] = np.mean(
117-
neighbor_probabilities, axis=1
118-
)
119-
120-
neighbor_prediction = adata.obs[f"{self.result_key}"][adata.obs["_dataset"] == "ref", :][indices].astype(
121-
np.float32
92+
if len(adata.obs[self.batch_key].unique()) > 100:
93+
self.method_kwargs["neighbors_within_batch"] = 1
94+
if settings.cuml:
95+
import rapids_singlecell as rsc
96+
97+
self.method_kwargs.pop("approx", None) # approx not supported in rsc
98+
self.method_kwargs.pop("use_annoy", None) # use_annoy not supported in rsc
99+
rsc.pp.bbknn(
100+
adata, batch_key=self.batch_key, use_rep="X_pca", algorithm="ivfflat", **self.method_kwargs, trim=0
122101
)
123-
adata.obs.loc[adata.obs["_dataset"] == "query", f"{self.result_key}"] = mode(neighbor_prediction, axis=1)
124102
else:
125-
if len(adata.obs[self.batch_key].unique()) > 100:
126-
logging.warning("Using PyNNDescent instead of FAISS as high number of batches leads to OOM.")
127-
self.method_kwargs["neighbors_within_batch"] = 1 # Reduce memory usage.
128-
self.method_kwargs["pynndescent_n_neighbors"] = 10 # Reduce memory usage.
129-
sc.external.pp.bbknn(
130-
adata, batch_key=self.batch_key, use_faiss=False, use_rep="X_pca", **self.method_kwargs
131-
)
132-
else:
133-
sc.external.pp.bbknn(
134-
adata, batch_key=self.batch_key, use_faiss=True, use_rep="X_pca", **self.method_kwargs
135-
)
103+
sc.external.pp.bbknn(adata, batch_key=self.batch_key, use_rep="X_pca", **self.method_kwargs)
136104

137105
def predict(self, adata):
138106
"""
139-
Predict celltypes using Celltypist.
107+
Predict celltypes using BBKNN kNN.
140108
141109
Parameters
142110
----------
@@ -168,7 +136,9 @@ def predict(self, adata):
168136
adata.obs[self.result_key] = adata.uns["label_categories"][knn.predict(test_distances)]
169137

170138
if self.return_probabilities:
171-
adata.obs[f"{self.result_key}_probabilities"] = np.max(knn.predict_proba(test_distances), axis=1)
139+
probabilities = knn.predict_proba(test_distances)
140+
adata.obs[f"{self.result_key}_probabilities"] = np.max(probabilities, axis=1)
141+
adata.obsm[f"{self.result_key}_probabilities"] = probabilities
172142

173143
def compute_umap(self, adata):
174144
"""
@@ -180,8 +150,10 @@ def compute_umap(self, adata):
180150
AnnData object. Results are stored in adata.obsm[self.umap_key].
181151
"""
182152
if self.compute_umap_embedding:
183-
logging.info(f'Saving UMAP of bbknn results to adata.obs["{self.embedding_key}"]')
153+
logging.info(f'Saving UMAP of BBKNN results to adata.obsm["{self.umap_key}"]')
184154
if settings.cuml:
155+
import rapids_singlecell as rsc
156+
185157
rsc.pp.neighbors(adata, use_rep=self.embedding_key)
186158
adata.obsm[self.umap_key] = rsc.tl.umap(adata, copy=True, **self.embedding_kwargs).obsm["X_umap"]
187159
else:

popv/algorithms/_celltypist.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,13 @@
44
import os
55

66
import celltypist
7-
import joblib
87
import numpy as np
98
import pandas as pd
109
import scanpy as sc
1110
from scipy.stats import mode
1211

1312
from popv import settings
14-
15-
if settings.cuml:
16-
import rapids_singlecell as rsc
13+
from popv._faiss_knn_classifier import FAISSKNNProba
1714
from popv.algorithms._base_algorithm import BaseAlgorithm
1815

1916

@@ -86,16 +83,20 @@ def predict(self, adata):
8683
and "over_clustering" in adata.obs
8784
and not settings.recompute_embeddings
8885
):
89-
index = joblib.load(os.path.join(adata.uns["_save_path_trained_models"], "pynndescent_index.joblib"))
86+
knn = FAISSKNNProba(n_neighbors=5)
87+
knn = knn.load(adata.uns["_save_path_trained_models"], "faiss_index")
88+
9089
query_features = adata.obsm["X_pca"][adata.obs["_dataset"] == "query", :]
91-
indices, _ = index.query(query_features.astype(np.float32), k=5)
90+
indices = knn.query(query_features.astype(np.float32), n_neighbors=5)
9291
neighbor_values = adata.obs.loc[adata.obs["_dataset"] == "ref", "over_clustering"].cat.codes.values[indices]
9392
adata.obs.loc[adata.obs["_dataset"] == "query", "over_clustering"] = adata.obs[
9493
"over_clustering"
9594
].cat.categories[mode(neighbor_values, axis=1).mode.flatten()]
9695
over_clustering = adata.obs.loc[adata.obs["_predict_cells"] == "relabel", "over_clustering"]
9796
else:
9897
if settings.cuml:
98+
import rapids_singlecell as rsc
99+
99100
rsc.pp.neighbors(adata, n_neighbors=15, use_rep="X_pca")
100101
rsc.tl.leiden(adata, resolution=25.0, key_added="over_clustering")
101102
else:
@@ -136,7 +137,16 @@ def predict(self, adata):
136137
if self.return_probabilities:
137138
if f"{self.result_key}_probabilities" not in adata.obs.columns:
138139
adata.obs[f"{self.result_key}_probabilities"] = pd.Series(dtype="float64")
140+
if f"{self.result_key}_probabilities" not in adata.obsm:
141+
adata.obsm[f"{self.result_key}_probabilities"] = pd.DataFrame(
142+
np.nan,
143+
index=adata.obs_names,
144+
columns=adata.uns["label_categories"],
145+
)
139146
adata.obs.loc[
140147
adata.obs["_predict_cells"] == "relabel",
141148
f"{self.result_key}_probabilities",
142149
] = predictions.probability_matrix.max(axis=1).values
150+
adata.obsm[f"{self.result_key}_probabilities"].loc[adata.obs["_predict_cells"] == "relabel", :] = (
151+
predictions.probability_matrix
152+
)

0 commit comments

Comments
 (0)