Skip to content

Commit d35d714

Browse files
committed
cleaned up data_graph
1 parent 8241507 commit d35d714

File tree

9 files changed

+202
-154
lines changed

9 files changed

+202
-154
lines changed

scanpy/data_structs/data_graph.py

Lines changed: 98 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,20 @@
66
import scipy as sp
77
import scipy.spatial
88
import scipy.sparse
9+
from scipy.sparse import issparse
910
from joblib import Parallel, delayed
1011
from ..cython import utils_cy
1112
from .. import settings as sett
1213
from .. import logging as logg
1314
from .. import utils
14-
from .ann_data import AnnData
1515

1616

1717
def get_neighbors(X, Y, k):
1818
Dsq = utils.comp_sqeuclidean_distance_using_matrix_mult(X, Y)
1919
chunk_range = np.arange(Dsq.shape[0])[:, None]
2020
indices_chunk = np.argpartition(Dsq, k-1, axis=1)[:, :k]
2121
indices_chunk = indices_chunk[chunk_range,
22-
np.argsort(Dsq[chunk_range, indices_chunk])]
22+
np.argsort(Dsq[chunk_range, indices_chunk])]
2323
indices_chunk = indices_chunk[:, 1:] # exclude first data point (point itself)
2424
distances_chunk = Dsq[chunk_range, indices_chunk]
2525
return indices_chunk, distances_chunk
@@ -59,7 +59,7 @@ def get_distance_matrix_and_neighbors(X, k, sparse=True, n_jobs=1):
5959
result_lst = Parallel(n_jobs=n_jobs, backend='threading')(
6060
delayed(get_neighbors)(X[chunk], X, k) for chunk in chunks)
6161
else:
62-
logg.m('--> can be sped up by setting `n_jobs` > 1')
62+
logg.info('--> can be sped up by setting `n_jobs` > 1')
6363
for i_chunk, chunk in enumerate(chunks):
6464
if n_jobs > 1:
6565
indices_chunk, distances_chunk = result_lst[i_chunk]
@@ -129,44 +129,89 @@ def __getitem__(self, index):
129129
return self.rows[glob_index_0][glob_index_1]
130130

131131
def restrict(self, index_array):
132-
"""Generate a 1d view of the data.
132+
"""Generate a view restricted to a subset of indices.
133133
"""
134134
new_shape = index_array.shape[0], index_array.shape[0]
135135
return OnFlySymMatrix(self.get_row, new_shape, DC_start=self.DC_start,
136136
DC_end=self.DC_end,
137137
rows=self.rows, restrict_array=index_array)
138138

139139

140-
class DataGraph(object):
141-
"""Represent data matrix as graph of closeby data points.
140+
class DataGraph():
141+
"""Represent data matrix as graph of neighborhood relations among data points.
142142
"""
143143

144144
def __init__(self,
145-
adata_or_X,
146-
k=30,
145+
adata,
146+
k=None,
147147
knn=True,
148148
n_jobs=None,
149149
n_pcs=30,
150150
n_dcs=10,
151151
recompute_pca=None,
152-
recompute_diffmap=None,
152+
recompute_distances=False,
153+
recompute_graph=None,
153154
flavor='haghverdi16'):
154-
logg.info('initializing data graph with `n_neighbors={}`'
155-
.format(k))
156-
self.k = k if k is not None else 30
157-
self.knn = knn
158-
self.n_jobs = sett.n_jobs if n_jobs is None else n_jobs
159-
self.n_pcs = n_pcs
160-
self.n_dcs = n_dcs
161-
self.flavor = flavor # this is to experiment around
162155
self.sym = True # we do not allow asymetric cases
163-
self.iroot = None
164-
isadata = isinstance(adata_or_X, AnnData)
165-
if isadata:
166-
adata = adata_or_X
167-
X = adata_or_X.X
156+
# use the graph in adata
157+
if (not recompute_graph
158+
and 'X_diffmap' in adata.smp
159+
and adata.smp['X_diffmap'].shape[1] >= n_dcs-1):
160+
self.n_pcs = n_pcs
161+
self.n_dcs = n_dcs
162+
self.iroot = None if 'iroot' not in adata.add else adata.add['iroot']
163+
self.X = adata.X # this is a hack, PCA?
164+
self.knn = issparse(adata.add['Ktilde'])
165+
self.Ktilde = adata.add['Ktilde']
166+
self.Dsq = adata.add['distance']
167+
if self.knn:
168+
self.k = adata.add['distance'][0].nonzero()[0].size + 1
169+
else:
170+
self.k = adata.X.shape[0]
171+
# for output of spectrum
172+
self.X_diffmap = adata.smp['X_diffmap'][:, :n_dcs-1]
173+
self.evals = np.r_[1, adata.add['diffmap_evals'][:n_dcs-1]]
174+
self.rbasis = np.c_[adata.smp['X_diffmap0'][:, None],
175+
adata.smp['X_diffmap'][:, :n_dcs-1]]
176+
self.lbasis = self.rbasis
177+
self.Dchosen = OnFlySymMatrix(self.get_Ddiff_row,
178+
shape=(self.X.shape[0], self.X.shape[0]))
179+
np.set_printoptions(precision=3)
180+
logg.info('use stored data graph with `n_neighbors = {}` and '
181+
'spectrum\n {}'
182+
.format(self.k,
183+
str(self.evals).replace('\n', '\n ')))
184+
# recompute the graph
168185
else:
169-
X = adata_or_X
186+
self.k = k if k is not None else 30
187+
logg.info('compute data graph with `n_neighbors={}`'
188+
.format(self.k))
189+
self.evals = None
190+
self.rbasis = None
191+
self.lbasis = None
192+
self.X_diffmap = None
193+
self.Dsq = None
194+
self.knn = knn
195+
self.n_jobs = sett.n_jobs if n_jobs is None else n_jobs
196+
self.n_pcs = n_pcs
197+
self.n_dcs = n_dcs
198+
self.flavor = flavor # this is to experiment around
199+
self.iroot = None
200+
self.X = adata.X # might be overwritten with X_pca below
201+
self.Dchosen = None
202+
self.M = None
203+
self.init_iroot_and_X_from_PCA(adata, recompute_pca, n_pcs)
204+
if False: # TODO
205+
# in case we already computed distance relations
206+
if not recompute_distances and 'distance' in adata.add:
207+
n_neighbors = adata.add['distance'][0].nonzero()[0].size + 1
208+
if (knn and issparse(adata.add['distance'])
209+
and n_neighbors == self.k):
210+
logg.info(' using stored distances with `n_neighbors={}`'
211+
.format(self.k))
212+
self.Dsq = adata.add['distance']
213+
214+
def init_iroot_and_X_from_PCA(self, adata, recompute_pca, n_pcs):
170215
# retrieve xroot
171216
xroot = None
172217
if 'xroot' in adata.add: xroot = adata.add['xroot']
@@ -179,60 +224,29 @@ def __init__(self,
179224
.format(adata.add['iroot'], adata.n_smps))
180225
else:
181226
self.iroot = adata.add['iroot']
182-
# use the fulll data matrix X
183-
if (self.n_pcs == 0 # use the full X as n_pcs == 0
184-
or X.shape[1] <= self.n_pcs):
185-
self.X = X
186-
logg.m(' using data matrix X directly for building graph (no PCA)')
187-
if xroot is not None: self.set_root(xroot)
188-
# use the precomupted X_pca
189-
elif (isadata
190-
and not recompute_pca
191-
and 'X_pca' in adata.smp
192-
and adata.smp['X_pca'].shape[1] >= self.n_pcs):
193-
logg.info(' using X_pca for building graph')
194-
if xroot is not None and xroot.size == adata.X.shape[1]:
195-
self.X = adata.X
196-
self.set_root(xroot)
227+
# see whether we can set self.iroot using the full data matrix
228+
if xroot is not None and xroot.size == self.X.shape[1]:
229+
self.set_root(xroot)
230+
# use the fulll data matrix X, nothing to be done
231+
if self.n_pcs == 0 or self.X.shape[1] <= self.n_pcs:
232+
logg.info(' using data matrix X directly for building graph (no PCA)')
233+
# use X_pca
234+
else:
235+
# use a precomputed X_pca
236+
if (not recompute_pca
237+
and 'X_pca' in adata.smp
238+
and adata.smp['X_pca'].shape[1] >= self.n_pcs):
239+
logg.info(' using "X_pca" for building graph')
240+
# compute X_pca
241+
else:
242+
logg.info(' compute "X_pca" for building graph')
243+
from ..preprocessing import pca
244+
pca(adata, n_comps=self.n_pcs)
245+
# set the data matrix
197246
self.X = adata.smp['X_pca'][:, :n_pcs]
247+
# see whether we can find xroot using X_pca
198248
if xroot is not None and xroot.size == adata.smp['X_pca'].shape[1]:
199249
self.set_root(xroot[:n_pcs])
200-
# compute X_pca
201-
else:
202-
self.X = X
203-
if (isadata
204-
and xroot is not None
205-
and xroot.size == adata.X.shape[1]):
206-
self.set_root(xroot)
207-
logg.m(' compute `X_pca` for building graph')
208-
from ..preprocessing import pca
209-
pca(adata, n_comps=self.n_pcs)
210-
self.X = adata.smp['X_pca']
211-
if xroot is not None and xroot.size == adata.smp['X_pca'].shape[1]:
212-
self.set_root(xroot)
213-
self.Dchosen = None
214-
self.Dsq = adata.add['distance'] if knn and 'distance' in adata.add else None
215-
# use diffmap from previous calculation
216-
if (isadata and 'X_diffmap' in adata.smp and not recompute_diffmap
217-
and adata.smp['X_diffmap'].shape[1] >= n_dcs-1):
218-
self.X_diffmap = adata.smp['X_diffmap'][:, :n_dcs-1]
219-
self.evals = np.r_[1, adata.add['diffmap_evals'][:n_dcs-1]]
220-
np.set_printoptions(precision=3)
221-
logg.info(' using stored "X_diffmap" with spectrum\n {}'
222-
.format(str(self.evals).replace('\n', '\n ')))
223-
self.rbasis = np.c_[adata.smp['X_diffmap0'][:, None],
224-
adata.smp['X_diffmap'][:, :n_dcs-1]]
225-
self.lbasis = self.rbasis
226-
if knn: self.Ktilde = adata.add['Ktilde']
227-
self.Dchosen = OnFlySymMatrix(self.get_Ddiff_row,
228-
shape=(self.X.shape[0], self.X.shape[0]))
229-
else:
230-
self.evals = None
231-
self.rbasis = None
232-
self.lbasis = None
233-
self.Dsq = None
234-
# further attributes that might be written during the computation
235-
self.M = None
236250

237251
def update_diffmap(self, n_comps=None):
238252
"""Diffusion Map as of Coifman et al. (2005) and Haghverdi et al. (2016).
@@ -241,7 +255,8 @@ def update_diffmap(self, n_comps=None):
241255
self.n_dcs = n_comps
242256
logg.info(' updating number of DCs to', self.n_dcs)
243257
if self.evals is None or self.evals.size < self.n_dcs:
244-
logg.info('computing Diffusion Map with', self.n_dcs, 'components', r=True)
258+
logg.info('computing spectral decomposition ("diffmap") with',
259+
self.n_dcs, 'components', r=True)
245260
self.compute_transition_matrix()
246261
self.embed(n_evals=self.n_dcs)
247262
return True
@@ -272,7 +287,7 @@ def spec_layout(self):
272287
return ddmap
273288

274289
def compute_distance_matrix(self):
275-
logg.m('... computing distance matrix with n_neighbors={}'
290+
logg.m('computing distance matrix with n_neighbors = {}'
276291
.format(self.k), v=4)
277292
Dsq, indices, distances_sq = get_distance_matrix_and_neighbors(
278293
X=self.X,
@@ -294,10 +309,9 @@ def compute_transition_matrix(self, alpha=1, recompute_distance=False):
294309
neglect_selfloops : bool
295310
Discard selfloops.
296311
297-
See also
298-
--------
299-
Also Haghverdi et al. (2016, 2015) and Coifman and Lafon (2006) and
300-
Coifman et al. (2005).
312+
References
313+
----------
314+
Haghverdi et al. (2016), Coifman and Lafon (2006), Coifman et al. (2005).
301315
"""
302316
if self.Dsq is None or recompute_distance:
303317
Dsq, indices, distances_sq = self.compute_distance_matrix()
@@ -317,7 +331,8 @@ def compute_transition_matrix(self, alpha=1, recompute_distance=False):
317331
# zero - in its sorted position
318332
sigmas_sq = distances_sq[:, -1]/4
319333
sigmas = np.sqrt(sigmas_sq)
320-
logg.m('determined k =', self.k, 'nearest neighbors of each point', t=True, v=4)
334+
logg.m('determined n_neighbors =',
335+
self.k, 'nearest neighbors of each point', t=True, v=4)
321336

322337
if self.flavor == 'unweighted':
323338
if not self.knn:
@@ -365,10 +380,6 @@ def compute_transition_matrix(self, alpha=1, recompute_distance=False):
365380
W = W.tocsr()
366381
logg.m('computed W (weight matrix) with "knn" =', self.knn, t=True, v=4)
367382

368-
# if sp.sparse.issparse(W): W = W.toarray()
369-
# print(W)
370-
# quit()
371-
372383
if False:
373384
pl.matshow(W)
374385
pl.title('$ W$')
@@ -422,13 +433,13 @@ def compute_transition_matrix(self, alpha=1, recompute_distance=False):
422433
row = self.K.indices[self.K.indptr[i]: self.K.indptr[i+1]]
423434
num = self.sqrtz[i] * self.sqrtz[row]
424435
self.Ktilde.data[self.K.indptr[i]: self.K.indptr[i+1]] = self.K.data[self.K.indptr[i]: self.K.indptr[i+1]] / num
425-
logg.m(' computed Ktilde (normalized anistropic kernel)', v=4)
436+
logg.m('computed Ktilde (normalized anistropic kernel)', v=4)
426437

427438
def compute_L_matrix(self):
428439
"""Graph Laplacian for K.
429440
"""
430441
self.L = np.diag(self.z) - self.K
431-
sett.mt(0, 'compute graph Laplacian')
442+
logg.info('compute graph Laplacian')
432443

433444
def embed(self, matrix=None, n_evals=15, sym=None, sort='decrease'):
434445
"""Compute eigen decomposition of matrix.

scanpy/plotting/tools.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -579,6 +579,7 @@ def aga_graph(
579579
add_noise_to_node_positions=None,
580580
left_margin=0.01,
581581
attachedness_type='relative',
582+
force_labels_to_front=False,
582583
show=None,
583584
save=None,
584585
ax=None):
@@ -620,7 +621,6 @@ def aga_graph(
620621
axs = ax
621622
if len(colors) == 1: axs = [axs]
622623
for icolor, color in enumerate(colors):
623-
show_color = False if icolor != len(colors)-1 else show
624624
_aga_graph_single(
625625
adata,
626626
layout=layout,
@@ -635,7 +635,8 @@ def aga_graph(
635635
ext=ext,
636636
ax=axs[icolor],
637637
title=title[icolor],
638-
add_noise_to_node_positions=add_noise_to_node_positions)
638+
add_noise_to_node_positions=add_noise_to_node_positions,
639+
force_labels_to_front=force_labels_to_front)
639640
if ext == 'pdf':
640641
logg.warn('Be aware that saving as pdf exagerates thin lines.')
641642
utils.savefig_or_show('aga_graph', show=show, ext=ext, save=save)
@@ -657,7 +658,8 @@ def _aga_graph_single(
657658
layout=None,
658659
add_noise_to_node_positions=None,
659660
attachedness_type=False,
660-
draw_edge_labels=False):
661+
draw_edge_labels=False,
662+
force_labels_to_front=False):
661663
from matplotlib import rcParams
662664
if colors is None and 'aga_groups_colors_original' in adata.add:
663665
colors = adata.add['aga_groups_colors_original']
@@ -688,6 +690,9 @@ def _aga_graph_single(
688690
if layout is None: layout = 'simple'
689691
if layout == 'simple':
690692
pos = utils.hierarchy_pos(nx_g, root)
693+
if len(pos) < nx_g.number_of_nodes():
694+
raise ValueError('This is a forest and not a single tree. '
695+
'Try another `layout`, e.g., {"fr"}.')
691696
else:
692697
from .. import utils as sc_utils
693698
g = sc_utils.get_igraph_from_adjacency(adata.add['aga_adjacency'])
@@ -757,7 +762,7 @@ def _aga_graph_single(
757762
ax.set_yticks([])
758763
base_pie_size = 1/(np.sqrt(nx_g.number_of_nodes()) + 10) * node_size
759764
median_group_size = np.median(adata.add['aga_groups_sizes'])
760-
for count, n in enumerate(nx_g):
765+
for count, n in enumerate(nx_g.nodes_iter()):
761766
pie_size = base_pie_size
762767
pie_size *= np.power(adata.add['aga_groups_sizes'][count] / median_group_size,
763768
node_size_power)
@@ -776,21 +781,20 @@ def _aga_graph_single(
776781
color = list(color)
777782
color.append('grey')
778783
fracs.append(1-sum(fracs))
779-
# names[count] += '\n?'
780784
else:
781785
raise ValueError('{} is neither a dict of valid matplotlib colors '
782786
'nor a valid matplotlib color.'.format(colors[count]))
783787
a.pie(fracs, colors=color)
784-
# if names is not None:
785-
# a.text(0.5, 0.5, names[count],
786-
# verticalalignment='center',
787-
# horizontalalignment='center',
788-
# transform=a.transAxes,
789-
# size=fontsize)
788+
if not force_labels_to_front and groups is not None:
789+
a.text(0.5, 0.5, groups[count],
790+
verticalalignment='center',
791+
horizontalalignment='center',
792+
transform=a.transAxes,
793+
size=fontsize)
790794
# TODO: this is a terrible hack, but if we use the solution above, labels
791795
# get hidden behind pies
792-
if groups is not None:
793-
for count, n in enumerate(nx_g):
796+
if force_labels_to_front and groups is not None:
797+
for count, n in enumerate(nx_g.nodes_iter()):
794798
# all copy and paste from above
795799
pie_size = base_pie_size
796800
pie_size *= np.power(adata.add['aga_groups_sizes'][count] / median_group_size,

0 commit comments

Comments
 (0)