@@ -123,51 +123,47 @@ def __init__(
123
123
unit_colors = get_unit_colors (sorting_analyzer_or_templates )
124
124
125
125
channel_locations = sorting_analyzer_or_templates .get_channel_locations ()
126
- extra_sparsity = False
126
+ extra_sparsity = None
127
127
# handle sparsity
128
128
sparsity_mismatch_warning = (
129
129
"The provided 'sparsity' includes additional channels not in the analyzer sparsity. "
130
130
"These extra channels will be plotted as flat lines."
131
131
)
132
132
analyzer_sparsity = sorting_analyzer_or_templates .sparsity
133
133
if channel_ids is not None :
134
+ assert sparsity is None , "If 'channel_ids' is provided, 'sparsity' should be None!"
134
135
channel_mask = np .tile (
135
136
np .isin (sorting_analyzer_or_templates .channel_ids , channel_ids ),
136
137
(len (sorting_analyzer_or_templates .unit_ids ), 1 ),
137
138
)
138
- sparsity = ChannelSparsity (
139
+ extra_sparsity = ChannelSparsity (
139
140
mask = channel_mask ,
140
141
channel_ids = sorting_analyzer_or_templates .channel_ids ,
141
142
unit_ids = sorting_analyzer_or_templates .unit_ids ,
142
143
)
143
- extra_sparsity = True
144
- elif analyzer_sparsity is not None :
145
- if sparsity is None :
146
- sparsity = analyzer_sparsity
147
- else :
148
- extra_sparsity = True
149
- else :
150
- if sparsity is None :
151
- unit_id_to_channel_ids = {
152
- u : sorting_analyzer_or_templates .channel_ids for u in sorting_analyzer_or_templates .unit_ids
153
- }
154
- sparsity = ChannelSparsity .from_unit_id_to_channel_ids (
155
- unit_id_to_channel_ids = unit_id_to_channel_ids ,
156
- unit_ids = sorting_analyzer_or_templates .unit_ids ,
157
- channel_ids = sorting_analyzer_or_templates .channel_ids ,
158
- )
159
- else :
160
- assert isinstance (sparsity , ChannelSparsity ), "'sparsity' should be a ChannelSparsity object!"
144
+ elif sparsity is not None :
145
+ extra_sparsity = sparsity
161
146
162
147
if channel_ids is None :
163
148
channel_ids = sorting_analyzer_or_templates .channel_ids
164
149
165
150
# assert provided sparsity is a subset of waveform sparsity
166
- if extra_sparsity :
167
- combined_mask = np .logical_or (analyzer_sparsity .mask , sparsity .mask )
168
- if not np .all (np .sum (combined_mask , 1 ) - np .sum (sorting_analyzer_or_templates . sparsity .mask , 1 ) == 0 ):
151
+ if extra_sparsity is not None and analyzer_sparsity is not None :
152
+ combined_mask = np .logical_or (analyzer_sparsity .mask , extra_sparsity .mask )
153
+ if not np .all (np .sum (combined_mask , 1 ) - np .sum (analyzer_sparsity .mask , 1 ) == 0 ):
169
154
warn (sparsity_mismatch_warning )
170
155
156
+ final_sparsity = extra_sparsity if extra_sparsity is not None else analyzer_sparsity
157
+ if final_sparsity is None :
158
+ final_sparsity = ChannelSparsity (
159
+ mask = np .ones (
160
+ (len (sorting_analyzer_or_templates .unit_ids ), len (sorting_analyzer_or_templates .channel_ids )),
161
+ dtype = bool ,
162
+ ),
163
+ unit_ids = sorting_analyzer_or_templates .unit_ids ,
164
+ channel_ids = sorting_analyzer_or_templates .channel_ids ,
165
+ )
166
+
171
167
# get templates
172
168
if isinstance (sorting_analyzer_or_templates , Templates ):
173
169
templates = sorting_analyzer_or_templates .templates_array
@@ -195,9 +191,7 @@ def __init__(
195
191
wf_ext = sorting_analyzer_or_templates .get_extension ("waveforms" )
196
192
if wf_ext is None :
197
193
raise ValueError ("plot_waveforms() needs the extension 'waveforms'" )
198
- wfs_by_ids = self ._get_wfs_by_ids (
199
- sorting_analyzer_or_templates , unit_ids , sparsity , extra_sparsity = extra_sparsity
200
- )
194
+ wfs_by_ids = self ._get_wfs_by_ids (sorting_analyzer_or_templates , unit_ids , extra_sparsity = extra_sparsity )
201
195
else :
202
196
wfs_by_ids = None
203
197
@@ -207,7 +201,8 @@ def __init__(
207
201
nbefore = nbefore ,
208
202
unit_ids = unit_ids ,
209
203
channel_ids = channel_ids ,
210
- sparsity = sparsity ,
204
+ final_sparsity = final_sparsity ,
205
+ extra_sparsity = extra_sparsity ,
211
206
unit_colors = unit_colors ,
212
207
channel_locations = channel_locations ,
213
208
scale = scale ,
@@ -234,7 +229,6 @@ def __init__(
234
229
alpha_templates = alpha_templates ,
235
230
hide_unit_selector = hide_unit_selector ,
236
231
plot_legend = plot_legend ,
237
- extra_sparsity = extra_sparsity ,
238
232
)
239
233
BaseWidget .__init__ (self , plot_data , backend = backend , ** backend_kwargs )
240
234
@@ -269,7 +263,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
269
263
ax = self .axes .flatten ()[i ]
270
264
color = dp .unit_colors [unit_id ]
271
265
272
- chan_inds = dp .sparsity .unit_id_to_channel_indices [unit_id ]
266
+ chan_inds = dp .final_sparsity .unit_id_to_channel_indices [unit_id ]
273
267
xvectors_flat = xvectors [:, chan_inds ].T .flatten ()
274
268
275
269
# plot waveforms
@@ -501,28 +495,27 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs):
501
495
if backend_kwargs ["display" ]:
502
496
display (self .widget )
503
497
504
- def _get_wfs_by_ids (self , sorting_analyzer , unit_ids , sparsity , extra_sparsity = False ):
498
+ def _get_wfs_by_ids (self , sorting_analyzer , unit_ids , extra_sparsity ):
505
499
wfs_by_ids = {}
506
500
wf_ext = sorting_analyzer .get_extension ("waveforms" )
507
501
for unit_id in unit_ids :
508
502
unit_index = list (sorting_analyzer .unit_ids ).index (unit_id )
509
- if not extra_sparsity :
510
- # get waveforms with default sparsity
511
- if sorting_analyzer .is_sparse ():
512
- wfs = wf_ext .get_waveforms_one_unit (unit_id , force_dense = False )
513
- else :
514
- wfs = wf_ext .get_waveforms_one_unit (unit_id )
515
- wfs = wfs [:, :, sparsity .mask [unit_index ]]
503
+ if extra_sparsity is None :
504
+ wfs = wf_ext .get_waveforms_one_unit (unit_id , force_dense = False )
516
505
else :
517
506
# in this case we have to construct waveforms based on the extra sparsity and add the
518
507
# sparse waveforms on the valid channels
508
+ if sorting_analyzer .is_sparse ():
509
+ original_mask = sorting_analyzer .sparsity .mask [unit_index ]
510
+ else :
511
+ original_mask = np .ones (len (sorting_analyzer .channel_ids ), dtype = bool )
519
512
wfs_orig = wf_ext .get_waveforms_one_unit (unit_id , force_dense = False )
520
513
wfs = np .zeros (
521
- (wfs_orig .shape [0 ], wfs_orig .shape [1 ], sparsity .mask [unit_index ].sum ()), dtype = wfs_orig .dtype
514
+ (wfs_orig .shape [0 ], wfs_orig .shape [1 ], extra_sparsity .mask [unit_index ].sum ()), dtype = wfs_orig .dtype
522
515
)
523
516
# fill in the existing waveforms channels
524
- valid_wfs_indices = sparsity .mask [unit_index ][sorting_analyzer . sparsity . mask [ unit_index ] ]
525
- valid_extra_indices = sorting_analyzer . sparsity . mask [ unit_index ][ sparsity .mask [unit_index ]]
517
+ valid_wfs_indices = extra_sparsity .mask [unit_index ][original_mask ]
518
+ valid_extra_indices = original_mask [ extra_sparsity .mask [unit_index ]]
526
519
wfs [:, :, valid_extra_indices ] = wfs_orig [:, :, valid_wfs_indices ]
527
520
528
521
wfs_by_ids [unit_id ] = wfs
@@ -592,7 +585,7 @@ def _update_plot(self, change):
592
585
593
586
if data_plot ["plot_waveforms" ]:
594
587
wfs_by_ids = self ._get_wfs_by_ids (
595
- self .sorting_analyzer , unit_ids , data_plot [ "sparsity" ], extra_sparsity = data_plot ["extra_sparsity" ]
588
+ self .sorting_analyzer , unit_ids , extra_sparsity = data_plot ["extra_sparsity" ]
596
589
)
597
590
data_plot ["wfs_by_ids" ] = wfs_by_ids
598
591
@@ -638,7 +631,7 @@ def _plot_probe(self, ax, channel_locations, unit_ids):
638
631
639
632
# TODO this could be done with probeinterface plotting plotting tools!!
640
633
for unit in unit_ids :
641
- channel_inds = self .data_plot ["sparsity " ].unit_id_to_channel_indices [unit ]
634
+ channel_inds = self .data_plot ["final_sparsity " ].unit_id_to_channel_indices [unit ]
642
635
ax .plot (
643
636
channel_locations [channel_inds , 0 ],
644
637
channel_locations [channel_inds , 1 ],
0 commit comments