Skip to content

Commit 600f25e

Browse files
committed
Simplify sparsity handling in plot waveforms/templates and fix plot_traces sortingview
1 parent 63a31c0 commit 600f25e

File tree

3 files changed

+38
-44
lines changed

3 files changed

+38
-44
lines changed

src/spikeinterface/widgets/unit_templates.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ def plot_sortingview(self, data_plot, **backend_kwargs):
2424
assert len(dp.templates_shading) <= 4, "Only 2 ans 4 templates shading are supported in sortingview"
2525

2626
# ensure serializable for sortingview
27-
unit_id_to_channel_ids = dp.sparsity.unit_id_to_channel_ids
28-
unit_id_to_channel_indices = dp.sparsity.unit_id_to_channel_indices
27+
unit_id_to_channel_ids = dp.final_sparsity.unit_id_to_channel_ids
28+
unit_id_to_channel_indices = dp.final_sparsity.unit_id_to_channel_indices
2929

3030
unit_ids, channel_ids = make_serializable(dp.unit_ids, dp.channel_ids)
3131

src/spikeinterface/widgets/unit_waveforms.py

Lines changed: 35 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -123,51 +123,47 @@ def __init__(
123123
unit_colors = get_unit_colors(sorting_analyzer_or_templates)
124124

125125
channel_locations = sorting_analyzer_or_templates.get_channel_locations()
126-
extra_sparsity = False
126+
extra_sparsity = None
127127
# handle sparsity
128128
sparsity_mismatch_warning = (
129129
"The provided 'sparsity' includes additional channels not in the analyzer sparsity. "
130130
"These extra channels will be plotted as flat lines."
131131
)
132132
analyzer_sparsity = sorting_analyzer_or_templates.sparsity
133133
if channel_ids is not None:
134+
assert sparsity is None, "If 'channel_ids' is provided, 'sparsity' should be None!"
134135
channel_mask = np.tile(
135136
np.isin(sorting_analyzer_or_templates.channel_ids, channel_ids),
136137
(len(sorting_analyzer_or_templates.unit_ids), 1),
137138
)
138-
sparsity = ChannelSparsity(
139+
extra_sparsity = ChannelSparsity(
139140
mask=channel_mask,
140141
channel_ids=sorting_analyzer_or_templates.channel_ids,
141142
unit_ids=sorting_analyzer_or_templates.unit_ids,
142143
)
143-
extra_sparsity = True
144-
elif analyzer_sparsity is not None:
145-
if sparsity is None:
146-
sparsity = analyzer_sparsity
147-
else:
148-
extra_sparsity = True
149-
else:
150-
if sparsity is None:
151-
unit_id_to_channel_ids = {
152-
u: sorting_analyzer_or_templates.channel_ids for u in sorting_analyzer_or_templates.unit_ids
153-
}
154-
sparsity = ChannelSparsity.from_unit_id_to_channel_ids(
155-
unit_id_to_channel_ids=unit_id_to_channel_ids,
156-
unit_ids=sorting_analyzer_or_templates.unit_ids,
157-
channel_ids=sorting_analyzer_or_templates.channel_ids,
158-
)
159-
else:
160-
assert isinstance(sparsity, ChannelSparsity), "'sparsity' should be a ChannelSparsity object!"
144+
elif sparsity is not None:
145+
extra_sparsity = sparsity
161146

162147
if channel_ids is None:
163148
channel_ids = sorting_analyzer_or_templates.channel_ids
164149

165150
# assert provided sparsity is a subset of waveform sparsity
166-
if extra_sparsity:
167-
combined_mask = np.logical_or(analyzer_sparsity.mask, sparsity.mask)
168-
if not np.all(np.sum(combined_mask, 1) - np.sum(sorting_analyzer_or_templates.sparsity.mask, 1) == 0):
151+
if extra_sparsity is not None and analyzer_sparsity is not None:
152+
combined_mask = np.logical_or(analyzer_sparsity.mask, extra_sparsity.mask)
153+
if not np.all(np.sum(combined_mask, 1) - np.sum(analyzer_sparsity.mask, 1) == 0):
169154
warn(sparsity_mismatch_warning)
170155

156+
final_sparsity = extra_sparsity if extra_sparsity is not None else analyzer_sparsity
157+
if final_sparsity is None:
158+
final_sparsity = ChannelSparsity(
159+
mask=np.ones(
160+
(len(sorting_analyzer_or_templates.unit_ids), len(sorting_analyzer_or_templates.channel_ids)),
161+
dtype=bool,
162+
),
163+
unit_ids=sorting_analyzer_or_templates.unit_ids,
164+
channel_ids=sorting_analyzer_or_templates.channel_ids,
165+
)
166+
171167
# get templates
172168
if isinstance(sorting_analyzer_or_templates, Templates):
173169
templates = sorting_analyzer_or_templates.templates_array
@@ -195,9 +191,7 @@ def __init__(
195191
wf_ext = sorting_analyzer_or_templates.get_extension("waveforms")
196192
if wf_ext is None:
197193
raise ValueError("plot_waveforms() needs the extension 'waveforms'")
198-
wfs_by_ids = self._get_wfs_by_ids(
199-
sorting_analyzer_or_templates, unit_ids, sparsity, extra_sparsity=extra_sparsity
200-
)
194+
wfs_by_ids = self._get_wfs_by_ids(sorting_analyzer_or_templates, unit_ids, extra_sparsity=extra_sparsity)
201195
else:
202196
wfs_by_ids = None
203197

@@ -207,7 +201,8 @@ def __init__(
207201
nbefore=nbefore,
208202
unit_ids=unit_ids,
209203
channel_ids=channel_ids,
210-
sparsity=sparsity,
204+
final_sparsity=final_sparsity,
205+
extra_sparsity=extra_sparsity,
211206
unit_colors=unit_colors,
212207
channel_locations=channel_locations,
213208
scale=scale,
@@ -234,7 +229,6 @@ def __init__(
234229
alpha_templates=alpha_templates,
235230
hide_unit_selector=hide_unit_selector,
236231
plot_legend=plot_legend,
237-
extra_sparsity=extra_sparsity,
238232
)
239233
BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs)
240234

@@ -269,7 +263,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
269263
ax = self.axes.flatten()[i]
270264
color = dp.unit_colors[unit_id]
271265

272-
chan_inds = dp.sparsity.unit_id_to_channel_indices[unit_id]
266+
chan_inds = dp.final_sparsity.unit_id_to_channel_indices[unit_id]
273267
xvectors_flat = xvectors[:, chan_inds].T.flatten()
274268

275269
# plot waveforms
@@ -501,28 +495,27 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs):
501495
if backend_kwargs["display"]:
502496
display(self.widget)
503497

504-
def _get_wfs_by_ids(self, sorting_analyzer, unit_ids, sparsity, extra_sparsity=False):
498+
def _get_wfs_by_ids(self, sorting_analyzer, unit_ids, extra_sparsity):
505499
wfs_by_ids = {}
506500
wf_ext = sorting_analyzer.get_extension("waveforms")
507501
for unit_id in unit_ids:
508502
unit_index = list(sorting_analyzer.unit_ids).index(unit_id)
509-
if not extra_sparsity:
510-
# get waveforms with default sparsity
511-
if sorting_analyzer.is_sparse():
512-
wfs = wf_ext.get_waveforms_one_unit(unit_id, force_dense=False)
513-
else:
514-
wfs = wf_ext.get_waveforms_one_unit(unit_id)
515-
wfs = wfs[:, :, sparsity.mask[unit_index]]
503+
if extra_sparsity is None:
504+
wfs = wf_ext.get_waveforms_one_unit(unit_id, force_dense=False)
516505
else:
517506
# in this case we have to construct waveforms based on the extra sparsity and add the
518507
# sparse waveforms on the valid channels
508+
if sorting_analyzer.is_sparse():
509+
original_mask = sorting_analyzer.sparsity.mask[unit_index]
510+
else:
511+
original_mask = np.ones(len(sorting_analyzer.channel_ids), dtype=bool)
519512
wfs_orig = wf_ext.get_waveforms_one_unit(unit_id, force_dense=False)
520513
wfs = np.zeros(
521-
(wfs_orig.shape[0], wfs_orig.shape[1], sparsity.mask[unit_index].sum()), dtype=wfs_orig.dtype
514+
(wfs_orig.shape[0], wfs_orig.shape[1], extra_sparsity.mask[unit_index].sum()), dtype=wfs_orig.dtype
522515
)
523516
# fill in the existing waveforms channels
524-
valid_wfs_indices = sparsity.mask[unit_index][sorting_analyzer.sparsity.mask[unit_index]]
525-
valid_extra_indices = sorting_analyzer.sparsity.mask[unit_index][sparsity.mask[unit_index]]
517+
valid_wfs_indices = extra_sparsity.mask[unit_index][original_mask]
518+
valid_extra_indices = original_mask[extra_sparsity.mask[unit_index]]
526519
wfs[:, :, valid_extra_indices] = wfs_orig[:, :, valid_wfs_indices]
527520

528521
wfs_by_ids[unit_id] = wfs
@@ -592,7 +585,7 @@ def _update_plot(self, change):
592585

593586
if data_plot["plot_waveforms"]:
594587
wfs_by_ids = self._get_wfs_by_ids(
595-
self.sorting_analyzer, unit_ids, data_plot["sparsity"], extra_sparsity=data_plot["extra_sparsity"]
588+
self.sorting_analyzer, unit_ids, extra_sparsity=data_plot["extra_sparsity"]
596589
)
597590
data_plot["wfs_by_ids"] = wfs_by_ids
598591

@@ -638,7 +631,7 @@ def _plot_probe(self, ax, channel_locations, unit_ids):
638631

639632
# TODO this could be done with probeinterface plotting plotting tools!!
640633
for unit in unit_ids:
641-
channel_inds = self.data_plot["sparsity"].unit_id_to_channel_indices[unit]
634+
channel_inds = self.data_plot["final_sparsity"].unit_id_to_channel_indices[unit]
642635
ax.plot(
643636
channel_locations[channel_inds, 0],
644637
channel_locations[channel_inds, 1],

src/spikeinterface/widgets/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ def array_to_image(
151151
output_image : 3D numpy array
152152
153153
"""
154+
import matplotlib.pyplot as plt
154155

155156
from scipy.ndimage import zoom
156157

0 commit comments

Comments
 (0)