Skip to content

Commit f0cfad6

Browse files
committed
cleaned up graph loading
1 parent d35d714 commit f0cfad6

File tree

7 files changed

+80
-64
lines changed

7 files changed

+80
-64
lines changed

scanpy/api/tools.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@
44
"""
55

66
# order alphabetically
7+
from ..tools.aga import aga
8+
from ..tools.aga import aga_contract_graph
79
from ..tools.dbscan import dbscan
10+
from ..tools.draw_graph import draw_graph
811
from ..tools.diffmap import diffmap
912
from ..tools.rank_genes_groups import rank_genes_groups
1013
from ..tools.dpt import dpt
@@ -13,11 +16,3 @@
1316
from ..tools.sim import sim
1417
from ..tools.spring import spring
1518
from ..tools.tsne import tsne
16-
17-
try:
18-
# development tools
19-
from ..tools.draw_graph import draw_graph
20-
from ..tools.aga import aga
21-
from ..tools.aga import aga_contract_graph
22-
except ImportError:
23-
pass

scanpy/data_structs/data_graph.py

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,27 @@
1414
from .. import utils
1515

1616

17+
def add_graph_to_adata(
18+
adata,
19+
n_neighbors=30,
20+
n_pcs=50,
21+
recompute_pca=None,
22+
recompute_graph=False,
23+
n_jobs=None):
24+
graph = DataGraph(adata,
25+
k=n_neighbors,
26+
n_pcs=n_pcs,
27+
recompute_pca=recompute_pca,
28+
recompute_graph=recompute_graph,
29+
n_jobs=n_jobs)
30+
graph.update_diffmap()
31+
adata.add['distance'] = graph.Dsq
32+
adata.add['Ktilde'] = graph.Ktilde
33+
adata.smp['X_diffmap'] = graph.rbasis[:, 1:]
34+
adata.smp['X_diffmap0'] = graph.rbasis[:, 0]
35+
adata.add['diffmap_evals'] = graph.evals[1:]
36+
37+
1738
def get_neighbors(X, Y, k):
1839
Dsq = utils.comp_sqeuclidean_distance_using_matrix_mult(X, Y)
1940
chunk_range = np.arange(Dsq.shape[0])[:, None]
@@ -159,7 +180,7 @@ def __init__(self,
159180
and adata.smp['X_diffmap'].shape[1] >= n_dcs-1):
160181
self.n_pcs = n_pcs
161182
self.n_dcs = n_dcs
162-
self.iroot = None if 'iroot' not in adata.add else adata.add['iroot']
183+
self.init_iroot_directly(adata)
163184
self.X = adata.X # this is a hack, PCA?
164185
self.knn = issparse(adata.add['Ktilde'])
165186
self.Ktilde = adata.add['Ktilde']
@@ -177,7 +198,7 @@ def __init__(self,
177198
self.Dchosen = OnFlySymMatrix(self.get_Ddiff_row,
178199
shape=(self.X.shape[0], self.X.shape[0]))
179200
np.set_printoptions(precision=3)
180-
logg.info('use stored data graph with `n_neighbors = {}` and '
201+
logg.info(' using stored data graph with n_neighbors = {} and '
181202
'spectrum\n {}'
182203
.format(self.k,
183204
str(self.evals).replace('\n', '\n ')))
@@ -211,19 +232,24 @@ def __init__(self,
211232
.format(self.k))
212233
self.Dsq = adata.add['distance']
213234

214-
def init_iroot_and_X_from_PCA(self, adata, recompute_pca, n_pcs):
215-
# retrieve xroot
216-
xroot = None
217-
if 'xroot' in adata.add: xroot = adata.add['xroot']
218-
elif 'xroot' in adata.var: xroot = adata.var['xroot']
219-
# set iroot directly
235+
def init_iroot_directly(self, adata):
220236
if 'iroot' in adata.add:
221237
if adata.add['iroot'] >= adata.n_smps:
222238
logg.warn('Root cell index {} does not exist for {} samples. '
223239
'Is ignored.'
224240
.format(adata.add['iroot'], adata.n_smps))
241+
self.iroot = None
225242
else:
226243
self.iroot = adata.add['iroot']
244+
245+
246+
def init_iroot_and_X_from_PCA(self, adata, recompute_pca, n_pcs):
247+
# retrieve xroot
248+
xroot = None
249+
if 'xroot' in adata.add: xroot = adata.add['xroot']
250+
elif 'xroot' in adata.var: xroot = adata.var['xroot']
251+
# set iroot directly
252+
self.init_iroot_directly(adata)
227253
# see whether we can set self.iroot using the full data matrix
228254
if xroot is not None and xroot.size == self.X.shape[1]:
229255
self.set_root(xroot)

scanpy/tools/aga.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -108,12 +108,14 @@ def aga(adata,
108108
root_cell_was_passed = False
109109
logg.m('... no root cell found, no computation of pseudotime')
110110
msg = \
111-
'''To enable computation of pseudotime, pass the expression "xroot" of a root cell.
112-
Either add
111+
'''To enable computation of pseudotime, pass the index or expression vector
112+
of a root cell. Either add
113+
adata.add['iroot'] = root_cell_index
114+
or (robust to subsampling)
113115
adata.var['xroot'] = adata.X[root_cell_index, :]
114-
where `root_cell_index` is the integer index of the root cell, or
116+
where "root_cell_index" is the integer index of the root cell, or
115117
adata.var['xroot'] = adata[root_cell_name, :].X
116-
where `root_cell_name` is the name (a string) of the root cell.'''
118+
where "root_cell_name" is the name (a string) of the root cell.'''
117119
logg.hint(msg)
118120
fresh_compute_louvain = False
119121
if ((node_groups == 'louvain' and 'louvain_groups' not in adata.smp_keys())
@@ -127,6 +129,7 @@ def aga(adata,
127129
fresh_compute_louvain = True
128130
clusters = node_groups
129131
if node_groups == 'louvain': clusters = 'louvain_groups'
132+
logg.info('running Approximate Graph Abstraction (AGA)', r=True)
130133
aga = AGA(adata,
131134
clusters=clusters,
132135
n_neighbors=n_neighbors,
@@ -149,7 +152,6 @@ def aga(adata,
149152
adata.add['diffmap_evals'] = aga.evals[1:]
150153
adata.add['distance'] = aga.Dsq
151154
adata.add['Ktilde'] = aga.Ktilde
152-
logg.info('running Approximate Graph Abstraction (AGA)', r=True)
153155
if aga.iroot is not None:
154156
aga.set_pseudotime() # pseudotimes are random walk distances from root point
155157
adata.add['iroot'] = aga.iroot # update iroot, might have changed when subsampling, for example

scanpy/tools/dpt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def dpt(adata, n_branchings=0, n_neighbors=30, knn=True, n_pcs=50, n_dcs=10,
104104
where "root_cell_index" is the integer index of the root cell, or
105105
adata.var['xroot'] = adata[root_cell_name, :].X
106106
where "root_cell_name" is the name (a string) of the root cell.'''
107-
logg.m(msg, v='hint')
107+
logg.hint(msg)
108108
if n_branchings == 0:
109109
logg.m('set parameter `n_branchings` > 0 to detect branchings', v='hint')
110110
dpt = DPT(adata, n_neighbors=n_neighbors, knn=knn, n_pcs=n_pcs, n_dcs=n_dcs,

scanpy/tools/draw_graph.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,21 @@
88
transcriptomics: Weinreb et al., bioRxiv doi:10.1101/090332 (2016)
99
"""
1010

11+
import numpy as np
12+
from .. import utils
13+
from ..data_structs.data_graph import add_graph_to_adata
14+
1115

1216
def draw_graph(adata,
1317
layout='fr',
18+
root=None,
1419
n_neighbors=30,
1520
n_pcs=50,
16-
root=None,
17-
n_jobs=None,
1821
random_state=0,
22+
recompute_pca=None,
1923
recompute_graph=False,
2024
adjacency=None,
25+
n_jobs=None,
2126
copy=False):
2227
"""Visualize data using standard graph drawing algorithms.
2328
@@ -51,22 +56,21 @@ def draw_graph(adata,
5156
from .. import logging as logg
5257
logg.info('drawing single-cell graph using layout "{}"'.format(layout),
5358
r=True)
54-
import numpy as np
55-
from .. import data_structs
56-
from .. import utils
5759
avail_layouts = {'fr', 'drl', 'kk', 'grid_fr', 'lgl', 'rt', 'rt_circular'}
5860
if layout not in avail_layouts:
5961
raise ValueError('Provide a valid layout, one of {}.'.format(avail_layouts))
6062
adata = adata.copy() if copy else adata
6163
if 'Ktilde' not in adata.add or recompute_graph:
62-
graph = data_structs.DataGraph(adata,
63-
k=n_neighbors,
64-
n_pcs=n_pcs,
65-
n_jobs=n_jobs)
66-
graph.compute_transition_matrix(recompute_distance=True)
67-
adata.add['Ktilde'] = graph.Ktilde
68-
elif n_neighbors is not None and not recompute_graph:
69-
logg.warn('`n_neighbors={}` has no effect (set `recompute_graph=True` to enable it)'
64+
add_graph_to_adata(
65+
adata,
66+
n_neighbors=n_neighbors,
67+
n_pcs=n_pcs,
68+
recompute_pca=recompute_pca,
69+
recompute_graph=recompute_graph,
70+
n_jobs=n_jobs)
71+
else:
72+
n_neighbors = adata.add['distance'][0].nonzero()[0].size + 1
73+
logg.info(' using stored graph with n_neighbors = {}'
7074
.format(n_neighbors))
7175
adjacency = adata.add['Ktilde']
7276
g = utils.get_igraph_from_adjacency(adjacency)

scanpy/tools/louvain.py

Lines changed: 15 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
import numpy as np
88
from .. import utils
99
from .. import logging as logg
10-
from ..data_structs import DataGraph
10+
from ..data_structs.data_graph import add_graph_to_adata
11+
1112

1213
def louvain(adata,
1314
n_neighbors=30,
@@ -43,28 +44,19 @@ def louvain(adata,
4344
- basic suggestion for single-cell: Levine et al., Cell 162, 184-197 (2015)
4445
- combination with "attachedness" matrix: Wolf et al., bioRxiv (2017)
4546
"""
46-
logg.m('run Louvain clustering', r=True)
47+
logg.m('running Louvain clustering', r=True)
4748
adata = adata.copy() if copy else adata
4849
if 'Ktilde' not in adata.add or recompute_graph:
49-
graph = DataGraph(adata,
50-
k=n_neighbors,
51-
n_pcs=n_pcs,
52-
recompute_pca=recompute_pca,
53-
recompute_graph=recompute_graph,
54-
n_jobs=n_jobs)
55-
# compute diffmap for later use although it's not needed here
56-
# it does not cost much
57-
graph.update_diffmap()
58-
adata.add['distance'] = graph.Dsq
59-
adata.add['Ktilde'] = graph.Ktilde
60-
adata.smp['X_diffmap'] = graph.rbasis[:, 1:]
61-
adata.smp['X_diffmap0'] = graph.rbasis[:, 0]
62-
adata.add['diffmap_evals'] = graph.evals[1:]
50+
add_graph_to_adata(
51+
adata,
52+
n_neighbors=n_neighbors,
53+
n_pcs=n_pcs,
54+
recompute_pca=recompute_pca,
55+
recompute_graph=recompute_graph,
56+
n_jobs=n_jobs)
6357
else:
64-
# do not use the undirected kernel Ktilde here, but the
65-
# sparse distance matrix
6658
n_neighbors = adata.add['distance'][0].nonzero()[0].size + 1
67-
logg.info(' using precomputed graph with n_neighbors={}'
59+
logg.info(' using stored graph with n_neighbors = {}'
6860
.format(n_neighbors))
6961
adjacency = adata.add['Ktilde']
7062
if flavor in {'vtraag', 'igraph'}:
@@ -83,11 +75,10 @@ def louvain(adata,
8375
resolution_parameter=resolution)
8476
adata.add['louvain_quality'] = part.quality()
8577
except AttributeError:
86-
logg.warn('Did not find louvain package >= 0.6 on your system, '
87-
'the result will therefore not be 100% reproducible, but '
88-
'is influenced by randomness in the community detection '
89-
'algorithm. Still you get very meaningful results!\n'
90-
'If you want 100% reproducible results, but 0.6 is not yet '
78+
logg.warn('Did not find package louvain>=0.6, '
79+
'the clustering result will therefore not be 100% reproducible, '
80+
'but still meaningful! '
81+
'If you want 100% reproducible results, but louvain 0.6 is not yet '
9182
'available via "pip install louvain", '
9283
'either get the latest (development) version from '
9384
'https://github.com/vtraag/louvain-igraph or use the option '

scanpy/tools/tsne.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -99,17 +99,15 @@ def tsne(adata, random_state=0, n_pcs=50, perplexity=30, learning_rate=None,
9999
X_tsne = tsne.fit_transform(X.astype(np.float64))
100100
except ImportError:
101101
multicore_failed = True
102-
logg.hint('did not find package MulticoreTSNE: to speed up the computation, install it from\n'
103-
' https://github.com/DmitryUlyanov/Multicore-TSNE')
104102
if multicore_failed:
105103
from sklearn.manifold import TSNE
106104
# unfortunately, we cannot set a minimum number of iterations for barnes-hut
107105
params_sklearn['learning_rate'] = 1000 if learning_rate is None else learning_rate
108106
tsne = TSNE(**params_sklearn)
109-
logg.warn('Consider installing the package MulticoreTSNE.\n'
110-
' https://github.com/DmitryUlyanov/Multicore-TSNE\n'
111-
'Even for `n_jobs=1` this speeds up the computation considerably.')
112107
logg.info(' using sklearn.manifold.TSNE')
108+
logg.warn('Consider installing the package MulticoreTSNE '
109+
' https://github.com/DmitryUlyanov/Multicore-TSNE.'
110+
' Even for `n_jobs=1` this speeds up the computation considerably and might yield better converged results.')
113111
X_tsne = tsne.fit_transform(X)
114112
# update AnnData instance
115113
adata.smp['X_tsne'] = X_tsne

0 commit comments

Comments
 (0)