diff --git a/spikeinterface_gui/basescatterview.py b/spikeinterface_gui/basescatterview.py index cd9cc19..7879a70 100644 --- a/spikeinterface_gui/basescatterview.py +++ b/spikeinterface_gui/basescatterview.py @@ -8,7 +8,7 @@ class BaseScatterView(ViewBase): _supported_backend = ['qt', 'panel'] _depend_on = None _settings = [ - {'name': 'auto_decimate', 'type': 'bool', 'value' : True }, + {'name': "auto_decimate", 'type': 'bool', 'value' : True }, {'name': 'max_spikes_per_unit', 'type': 'int', 'value' : 10_000 }, {'name': 'alpha', 'type': 'float', 'value' : 0.7, 'limits':(0, 1.), 'step':0.05 }, {'name': 'scatter_size', 'type': 'float', 'value' : 2., 'step':0.5 }, @@ -29,6 +29,9 @@ def __init__(self, spike_data, y_label, controller=None, parent=None, backend="q eps = (self._data_max - self._data_min) / 100.0 self._data_max += eps self._max_count = None + self._lasso_vertices = {segment_index: [] for segment_index in range(controller.num_segments)} + # this is used in panel + self._current_selected = 0 ViewBase.__init__(self, controller=controller, parent=parent, backend=backend) @@ -42,7 +45,7 @@ def get_unit_data(self, unit_id, seg_index=0): hist_count, hist_bins = np.histogram(spike_data, bins=np.linspace(hist_min, hist_max, self.settings['num_bins'])) - if self.settings['auto_decimate'] and spike_times.size > self.settings['max_spikes_per_unit']: + if self.settings["auto_decimate"] and spike_times.size > self.settings['max_spikes_per_unit']: step = spike_times.size // self.settings['max_spikes_per_unit'] spike_times = spike_times[::step] spike_data = spike_data[::step] @@ -61,6 +64,62 @@ def get_selected_spikes_data(self, seg_index=0): return (spike_times, spike_data) + def split(self): + """ + Add a split to the curation data based on the lasso vertices. + """ + if self.controller.num_segments > 1: + # check that lasso vertices are defined for all segments + if not all(len(self._lasso_vertices[seg_index]) > 0 for seg_index in range(self.controller.num_segments)): + self.warning("Select areas for all segments.") + return + + # split is only possible if one unit is visible + visible_unit_ids = self.controller.get_visible_unit_ids() + if len(visible_unit_ids) != 1: + self.warning("Split is only possible if one unit is visible.") + return + + visible_unit_id = visible_unit_ids[0] + + fs = self.controller.sampling_frequency + indices = [] + offset = 0 + for segment_index, vertices in self._lasso_vertices.items(): + spike_inds = self.controller.get_spike_indices(visible_unit_id, seg_index=segment_index) + spike_times = self.controller.spikes["sample_index"][spike_inds] / fs + spike_data = self.spike_data[spike_inds] + + # spike inds within spike train + inds = np.arange(offset, offset + len(spike_inds), dtype=int) + + points = np.column_stack((spike_times, spike_data)) + indices_in_segment = [] + for polygon in vertices: + # Check if points are inside the polygon + inside = mpl_path(polygon).contains_points(points) + if np.any(inside): + # If any point is inside, we can proceed with the split + indices_in_segment.extend(inds[inside]) + indices.extend(indices_in_segment) + offset += len(spike_inds) + + self.controller.make_manual_split_if_possible( + unit_id=visible_unit_id, + indices=indices, + ) + + # Clear the lasso vertices after splitting + self._lasso_vertices = {segment_index: [] for segment_index in range(self.controller.num_segments)} + self.refresh() + self.notify_manual_curation_updated() + + + def on_unit_visibility_changed(self): + self._lasso_vertices = {segment_index: [] for segment_index in range(self.controller.num_segments)} + self._current_selected = 0 + self.refresh() + ## QT zone ## def _qt_make_layout(self): from .myqt import QT @@ -76,16 +135,20 @@ def _qt_make_layout(self): self.combo_seg.currentIndexChanged.connect(self.refresh) add_stretch_to_qtoolbar(tb) self.lasso_but = QT.QPushButton("select", checkable = True) - tb.addWidget(self.lasso_but) self.lasso_but.clicked.connect(self.enable_disable_lasso) - + if self.controller.curation: + self.split_but = QT.QPushButton("split") + tb.addWidget(self.split_but) + self.split_but.clicked.connect(self.split) + shortcut_split = QT.QShortcut(self.qt_widget) + shortcut_split.setKey(QT.QKeySequence("ctrl+s")) + shortcut_split.activated.connect(self.split) h = QT.QHBoxLayout() self.layout.addLayout(h) self.graphicsview = pg.GraphicsView() h.addWidget(self.graphicsview, 3) - self.graphicsview2 = pg.GraphicsView() h.addWidget(self.graphicsview2, 1) @@ -103,7 +166,6 @@ def _qt_make_layout(self): self.scatter_select.setZValue(1000) - def initialize_plot(self): import pyqtgraph as pg from .utils_qt import ViewBoxHandlingLasso @@ -148,7 +210,10 @@ def _qt_refresh(self): max_count = 1 for unit_id in self.controller.get_visible_unit_ids(): - spike_times, spike_data, hist_count, hist_bins, _ = self.get_unit_data(unit_id) + spike_times, spike_data, hist_count, hist_bins, _ = self.get_unit_data( + unit_id, + seg_index=self.combo_seg.currentIndex() + ) # make a copy of the color color = QT.QColor(self.get_unit_color(unit_id)) @@ -162,13 +227,13 @@ def _qt_refresh(self): max_count = max(max_count, np.max(hist_count)) self._max_count = max_count - seg_index = self.combo_seg.currentIndex() + seg_index = self.combo_seg.currentIndex() time_max = self.controller.get_num_samples(seg_index) / self.controller.sampling_frequency self.plot.setXRange( 0., time_max, padding = 0.0) self.plot2.setXRange(0, self._max_count, padding = 0.0) - spike_times, spike_data = self.get_selected_spikes_data() + spike_times, spike_data = self.get_selected_spikes_data(seg_index=self.combo_seg.currentIndex()) self.scatter_select.setData(spike_times, spike_data) def enable_disable_lasso(self, checked): @@ -183,7 +248,7 @@ def on_lasso_drawing(self, points): points = np.array(points) self.lasso.setData(points[:, 0], points[:, 1]) - def on_lasso_finished(self, points): + def on_lasso_finished(self, points, shift_held=False): self.lasso.setData([], []) vertices = np.array(points) @@ -200,40 +265,53 @@ def on_lasso_finished(self, points): # Only consider spikes from visible units visible_spikes = spikes_in_seg[visible_mask] if len(visible_spikes) == 0: - # Clear selection if no visible spikes - self.controller.set_indices_spike_selected([]) - self.refresh() - self.notify_spike_selection_changed() + # Clear selection if no visible spikes and shift not held + if not shift_held: + self.controller.set_indices_spike_selected([]) + self.refresh() + self.notify_spike_selection_changed() return - + spike_times = visible_spikes['sample_index'] / fs spike_data = self.spike_data[sl][visible_mask] - points = np.column_stack((spike_times, spike_data)) - inside = mpl_path(vertices).contains_points(points) - - # Clear selection if no spikes inside lasso - if not np.any(inside): - self.controller.set_indices_spike_selected([]) - self.refresh() - self.notify_spike_selection_changed() - return + scatter_data = np.column_stack((spike_times, spike_data)) + inside = mpl_path(vertices).contains_points(scatter_data) + + if shift_held: + # If shift is held, append the vertices to the current lasso vertices + self._lasso_vertices[seg_index].append(vertices) + else: + # If shift is not held, clear the existing lasso vertices for this segment + self._lasso_vertices[seg_index] = [vertices] - # Map back to original indices - visible_indices = np.nonzero(visible_mask)[0] - selected_indices = sl.start + visible_indices[inside] - self.controller.set_indices_spike_selected(selected_indices) + # print(f"Lasso selection for segment {seg_index} has {len(self._lasso_vertices[seg_index])} polygons.") + + # Handle selection based on whether shift is held + if np.any(inside): + # Map back to original indices + visible_indices = np.nonzero(visible_mask)[0] + new_selected_indices = sl.start + visible_indices[inside] + + if shift_held: + # Extend existing selection + current_selection = self.controller.get_indices_spike_selected() + extended_selection = np.unique(np.concatenate([current_selection, new_selected_indices])) + self.controller.set_indices_spike_selected(extended_selection) + else: + # Replace selection + self.controller.set_indices_spike_selected(new_selected_indices) + self.refresh() self.notify_spike_selection_changed() - ## Panel zone ## def _panel_make_layout(self): import panel as pn import bokeh.plotting as bpl from bokeh.models import ColumnDataSource, LassoSelectTool, Range1d - from .utils_panel import _bg_color, slow_lasso + from .utils_panel import _bg_color #, slow_lasso self.lasso_tool = LassoSelectTool() @@ -246,7 +324,11 @@ def _panel_make_layout(self): self.segment_selector.param.watch(self._panel_change_segment, 'value') self.select_toggle_button = pn.widgets.Toggle(name="Select") - self.select_toggle_button.param.watch(self._panel_on_select_button, 'value') + self.select_toggle_button.param.watch(self._panel_on_select_button, 'value') + + if self.controller.curation: + self.split_button = pn.widgets.Button(name="Split", button_type="primary") + self.split_button.on_click(self._panel_split) self.y_range = Range1d(self._data_min, self._data_max) self.scatter_source = ColumnDataSource(data={"x": [], "y": [], "color": []}) @@ -276,7 +358,8 @@ def _panel_make_layout(self): time_max = self.controller.get_num_samples(self.segment_index) / self.controller.sampling_frequency self.scatter_fig.x_range = Range1d(0., time_max) - slow_lasso(self.scatter_source, self._on_panel_lasso_selected) + # Add SelectionGeometry event handler to capture lasso vertices + self.scatter_fig.on_event('selectiongeometry', self._on_panel_selection_geometry) self.hist_fig = bpl.figure( tools="reset,wheel_zoom", @@ -292,8 +375,19 @@ def _panel_make_layout(self): self.hist_fig.xaxis.axis_label = "Count" self.hist_fig.x_range = Range1d(0, 1000) # Initial x range for histogram + toolbar_elements = [self.segment_selector, self.select_toggle_button] + if self.controller.curation: + toolbar_elements.append(self.split_button) + + if self.controller.curation: + from .utils_panel import KeyboardShortcut, KeyboardShortcuts + shortcuts = [KeyboardShortcut(key="s", name="split", ctrlKey=True)] + shortcuts_component = KeyboardShortcuts(shortcuts=shortcuts) + shortcuts_component.on_msg(self._panel_handle_shortcut) + toolbar_elements.append(shortcuts_component) + self.layout = pn.Column( - pn.Row(self.segment_selector, self.select_toggle_button, sizing_mode="stretch_width"), + pn.Row(*toolbar_elements, sizing_mode="stretch_width"), pn.Row( pn.Column( self.scatter_fig, @@ -374,30 +468,46 @@ def _panel_refresh(self): self.hist_fig.x_range.end = max_count def _panel_on_select_button(self, event): - if self.select_toggle_button.value and len(self.controller.get_visible_unit_ids()) == 1: + if self.select_toggle_button.value: self.scatter_fig.toolbar.active_drag = self.lasso_tool else: self.scatter_fig.toolbar.active_drag = None self.scatter_source.selected.indices = [] - self._on_panel_lasso_selected(None, None, None) + def _panel_change_segment(self, event): + self._current_selected = 0 self.segment_index = int(self.segment_selector.value.split()[-1]) time_max = self.controller.get_num_samples(self.segment_index) / self.controller.sampling_frequency self.scatter_fig.x_range.end = time_max self.refresh() - def _on_panel_lasso_selected(self, attr, old, new): + def _on_panel_selection_geometry(self, event): """ - Handle selection changes in the scatter plot. + Handle SelectionGeometry event to capture lasso polygon vertices. """ - if self.select_toggle_button.value: + if event.final: + xs = np.array(event.geometry["x"]) + ys = np.array(event.geometry["y"]) + polygon = np.column_stack((xs, ys)) + selected = self.scatter_source.selected.indices if len(selected) == 0: self.controller.set_indices_spike_selected([]) self.notify_spike_selection_changed() return + # Append the current polygon to the lasso vertices if shift is held + seg_index = self.segment_index + if len(selected) > self._current_selected: + self._current_selected = len(selected) + # Store the current polygon for the current segment + self._lasso_vertices[seg_index].append(polygon) + else: + self._lasso_vertices[seg_index] = [polygon] + + # print(f"Lasso selection for segment {self.segment_index} has {len(self._lasso_vertices[self.segment_index])} polygons.") + # Map back to original indices sl = self.controller.segment_slices[self.segment_index] spikes_in_seg = self.controller.spikes[sl] @@ -409,9 +519,16 @@ def _on_panel_lasso_selected(self, attr, old, new): # Map back to original indices visible_indices = np.nonzero(visible_mask)[0] selected_indices = sl.start + visible_indices[selected] - self.controller.set_indices_spike_selected(selected_indices) + already_selected = self.controller.get_indices_spike_selected() + all_selected = np.concatenate([already_selected, selected_indices]) + self.controller.set_indices_spike_selected(all_selected) self.notify_spike_selection_changed() + def _panel_split(self, event): + """ + Handle split button click in panel mode. + """ + self.split() def _panel_update_selected_spikes(self): # handle selected spikes @@ -426,7 +543,7 @@ def _panel_update_selected_spikes(self): visible_indices = sl.start + np.nonzero(visible_mask)[0] selected_indices = np.nonzero(np.isin(visible_indices, selected_spike_indices))[0] # set selected spikes in scatter plot - if self.settings["auto_decimate"]: + if self.settings["auto_decimate"] and len(selected_indices) > 0: selected_indices, = np.nonzero(np.isin(self.plotted_inds, selected_spike_indices)) self.scatter_source.selected.indices = list(selected_indices) else: @@ -447,3 +564,7 @@ def _panel_on_spike_selection_changed(self): # update selected spikes self._panel_update_selected_spikes() + def _panel_handle_shortcut(self, event): + if event.data == "split": + if len(self.controller.get_visible_unit_ids()) == 1: + self.split() \ No newline at end of file diff --git a/spikeinterface_gui/controller.py b/spikeinterface_gui/controller.py index ad653e6..b0f3978 100644 --- a/spikeinterface_gui/controller.py +++ b/spikeinterface_gui/controller.py @@ -12,9 +12,10 @@ import spikeinterface.qualitymetrics from spikeinterface.core.sorting_tools import spike_vector_to_indices from spikeinterface.core.core_tools import check_json +from spikeinterface.curation import validate_curation_dict from spikeinterface.widgets.utils import make_units_table_from_analyzer -from .curation_tools import adding_group, default_label_definitions, empty_curation_data +from .curation_tools import add_merge, default_label_definitions, empty_curation_data spike_dtype =[('sample_index', 'int64'), ('unit_index', 'int64'), ('channel_index', 'int64'), ('segment_index', 'int64'), @@ -260,13 +261,14 @@ def __init__(self, analyzer=None, backend="qt", parent=None, verbose=False, save spike_vector2 = self.analyzer.sorting.to_spike_vector(concatenated=False) # this is dict of list because per segment spike_indices[segment_index][unit_id] + spike_indices_abs = spike_vector_to_indices(spike_vector2, unit_ids, absolute_index=True) spike_indices = spike_vector_to_indices(spike_vector2, unit_ids) # this is flatten spike_per_seg = [s.size for s in spike_vector2] # dict[unit_id] -> all indices for this unit across segments self._spike_index_by_units = {} # dict[seg_index][unit_id] -> all indices for this unit for one segment - self._spike_index_by_segment_and_units = spike_indices + self._spike_index_by_segment_and_units = spike_indices_abs for unit_id in unit_ids: inds = [] for seg_ind in range(num_seg): @@ -319,7 +321,23 @@ def __init__(self, analyzer=None, backend="qt", parent=None, verbose=False, save if curation_data is None: self.curation_data = empty_curation_data.copy() else: - self.curation_data = curation_data + # validate the curation data + format_version = curation_data.get("format_version", None) + # assume version 2 if not present + if format_version is None: + raise ValueError("Curation data format version is missing and is required in the curation data.") + try: + validate_curation_dict(curation_data) + self.curation_data = curation_data + except Exception as e: + print(f"Invalid curation data. Initializing with empty curation data.\nError: {e}") + self.curation_data = empty_curation_data.copy() + if curation_data.get("merges") is None: + curation_data["merges"] = [] + if curation_data.get("splits") is None: + curation_data["splits"] = [] + if curation_data.get("removed") is None: + curation_data["removed"] = [] self.has_default_quality_labels = False if "label_definitions" not in self.curation_data: @@ -337,6 +355,8 @@ def __init__(self, analyzer=None, backend="qt", parent=None, verbose=False, save print('Curation quality labels are the default ones') self.has_default_quality_labels = True + # this is used to store the active split unit + self.active_split = None def check_is_view_possible(self, view_name): from .viewlist import possible_class_views @@ -442,6 +462,13 @@ def set_visible_unit_ids(self, visible_unit_ids): if len(visible_unit_ids) > lim: visible_unit_ids = visible_unit_ids[:lim] self._visible_unit_ids = list(visible_unit_ids) + self.active_split = None + if len(visible_unit_ids) == 1 and self.curation: + # check if unit is split + for split in self.curation_data['splits']: + if visible_unit_ids[0] == split['unit_id']: + self.active_split = split + break def get_visible_unit_ids(self): """Get list of visible unit_ids""" @@ -506,10 +533,21 @@ def get_indices_spike_visible(self): return self._spike_visible_indices def get_indices_spike_selected(self): + if self.active_split is not None: + # select the splitted spikes in the active split + split_unit_id = self.active_split['unit_id'] + spike_inds = self.get_spike_indices(split_unit_id, seg_index=None) + split_indices = self.active_split['indices'] + self._spike_selected_indices = np.array(spike_inds[split_indices], dtype='int64') return self._spike_selected_indices - + def set_indices_spike_selected(self, inds): self._spike_selected_indices = np.array(inds) + if len(self._spike_selected_indices) == 1: + # set time info + segment_index = self.spikes['segment_index'][self._spike_selected_indices[0]] + sample_index = self.spikes['sample_index'][self._spike_selected_indices[0]] + self.set_time(time=sample_index / self.sampling_frequency, segment_index=segment_index) def get_spike_indices(self, unit_id, seg_index=None): if seg_index is None: @@ -668,7 +706,7 @@ def curation_can_be_saved(self): def construct_final_curation(self): d = dict() - d["format_version"] = "1" + d["format_version"] = "2" d["unit_ids"] = self.unit_ids.tolist() d.update(self.curation_data.copy()) return d @@ -699,14 +737,14 @@ def make_manual_delete_if_possible(self, removed_unit_ids): if not self.curation: return - all_merged_units = sum(self.curation_data["merge_unit_groups"], []) + all_merged_units = sum([m["unit_ids"] for m in self.curation_data["merges"]], []) for unit_id in removed_unit_ids: - if unit_id in self.curation_data["removed_units"]: + if unit_id in self.curation_data["removed"]: continue # TODO: check if unit is already in a merge group if unit_id in all_merged_units: continue - self.curation_data["removed_units"].append(unit_id) + self.curation_data["removed"].append(unit_id) if self.verbose: print(f"Unit {unit_id} is removed from the curation data") @@ -718,10 +756,10 @@ def make_manual_restore(self, restore_unit_ids): return for unit_id in restore_unit_ids: - if unit_id in self.curation_data["removed_units"]: + if unit_id in self.curation_data["removed"]: if self.verbose: print(f"Unit {unit_id} is restored from the curation data") - self.curation_data["removed_units"].remove(unit_id) + self.curation_data["removed"].remove(unit_id) def make_manual_merge_if_possible(self, merge_unit_ids): """ @@ -740,22 +778,75 @@ def make_manual_merge_if_possible(self, merge_unit_ids): return False for unit_id in merge_unit_ids: - if unit_id in self.curation_data["removed_units"]: + if unit_id in self.curation_data["removed"]: + return False + + new_merges = add_merge(self.curation_data["merges"], merge_unit_ids) + self.curation_data["merges"] = new_merges + if self.verbose: + print(f"Merged unit group: {[str(u) for u in merge_unit_ids]}") + return True + + def make_manual_split_if_possible(self, unit_id, indices): + """ + Check if the a unit_id can be split into a new split in the curation_data. + + If unit_id is already in the removed list then the split is skipped. + If unit_id is already in some other split then the split is skipped. + """ + if not self.curation: + return False + + if unit_id in self.curation_data["removed"]: + return False + + # check if unit_id is already in a split + for split in self.curation_data["splits"]: + if split["unit_id"] == unit_id: return False - merged_groups = adding_group(self.curation_data["merge_unit_groups"], merge_unit_ids) - self.curation_data["merge_unit_groups"] = merged_groups + + new_split = { + "unit_id": unit_id, + "mode": "indices", + "indices": indices + } + self.curation_data["splits"].append(new_split) if self.verbose: - print(f"Merged unit group: {merge_unit_ids}") + print(f"Split unit {unit_id} with {len(indices)} spikes") return True - def make_manual_restore_merge(self, merge_group_indices): + def make_manual_restore_merge(self, merge_indices): + if not self.curation: + return + for merge_index in merge_indices: + if self.verbose: + print(f"Unmerged {self.curation_data['merges'][merge_index]['unit_ids']}") + self.curation_data["merges"].pop(merge_index) + + def make_manual_restore_split(self, split_indices): if not self.curation: return - merge_groups_to_remove = [self.curation_data["merge_unit_groups"][merge_group_index] for merge_group_index in merge_group_indices] - for merge_group in merge_groups_to_remove: + for split_index in split_indices: if self.verbose: - print(f"Unmerged merge group {merge_group}") - self.curation_data["merge_unit_groups"].remove(merge_group) + print(f"Unsplitting {self.curation_data['splits'][split_index]['unit_id']}") + self.curation_data["splits"].pop(split_index) + + def set_active_split_unit(self, unit_id): + """ + Set the active split unit_id. + This is used to set the label for the split unit. + """ + if not self.curation: + return + if unit_id is None: + self.active_split = None + else: + if unit_id in self.curation_data["removed"]: + print(f"Unit {unit_id} is removed, cannot set as active split unit") + return + active_split = [s for s in self.curation_data["splits"] if s["unit_id"] == unit_id] + if len(active_split) == 1: + self.active_split = active_split[0] def get_curation_label_definitions(self): # give only label definition with exclusive diff --git a/spikeinterface_gui/crosscorrelogramview.py b/spikeinterface_gui/crosscorrelogramview.py index 205c566..fcf5321 100644 --- a/spikeinterface_gui/crosscorrelogramview.py +++ b/spikeinterface_gui/crosscorrelogramview.py @@ -6,11 +6,10 @@ class CrossCorrelogramView(ViewBase): _supported_backend = ['qt', 'panel'] _depend_on = ["correlograms"] _settings = [ - {'name': 'window_ms', 'type': 'float', 'value' : 50. }, - {'name': 'bin_ms', 'type': 'float', 'value' : 1.0 }, - {'name': 'display_axis', 'type': 'bool', 'value' : True }, - {'name': 'max_visible', 'type': 'int', 'value' : 8 }, - ] + {'name': 'window_ms', 'type': 'float', 'value' : 50. }, + {'name': 'bin_ms', 'type': 'float', 'value' : 1.0 }, + {'name': 'display_axis', 'type': 'bool', 'value' : True }, + ] _need_compute = True def __init__(self, controller=None, parent=None, backend="qt"): @@ -26,6 +25,36 @@ def _on_settings_changed(self): def _compute(self): self.ccg, self.bins = self.controller.compute_correlograms( self.settings['window_ms'], self.settings['bin_ms']) + + def _compute_split_ccg(self): + """ + This method is used to compute the cross-correlogram for a split unit. + It is called when the user selects a split unit in the controller. + """ + from spikeinterface import NumpySorting + from spikeinterface.postprocessing import compute_correlograms + + if self.controller.active_split is None: + raise ValueError("No active split unit selected.") + + split_unit_id = self.controller.active_split["unit_id"] + spike_inds = self.controller.get_spike_indices(split_unit_id, seg_index=None) + split_indices = self.controller.active_split['indices'] + spikes_split_unit = self.controller.spikes[spike_inds] + unit_index = spikes_split_unit[0]["unit_index"] + # change unit_index for split indices + spikes_split_unit["unit_index"][split_indices] = unit_index + 1 + split_sorting = NumpySorting( + spikes=spikes_split_unit, + sampling_frequency=self.controller.sampling_frequency, + unit_ids=[f"{split_unit_id}-0", f"{split_unit_id}-1"] + ) + ccg, bins = compute_correlograms( + split_sorting, + window_ms=self.settings['window_ms'], + bin_ms=self.settings['bin_ms'] + ) + return ccg, bins ## Qt ## @@ -51,18 +80,32 @@ def _qt_refresh(self): return visible_unit_ids = self.controller.get_visible_unit_ids() - visible_unit_ids = visible_unit_ids[:self.settings['max_visible']] - - n = len(visible_unit_ids) - - unit_ids = list(self.controller.unit_ids) + if self.controller.active_split is None: + n = len(visible_unit_ids) + unit_ids = list(self.controller.unit_ids) + colors = { + unit_id: self.get_unit_color(unit_id) for unit_id in visible_unit_ids + } + ccg = self.ccg + bins = self.bins + else: + split_unit_id = visible_unit_ids[0] + n = 2 + unit_ids = [f"{split_unit_id}-0", f"{split_unit_id}-1"] + visible_unit_ids = unit_ids + ccg, bins = self._compute_split_ccg() + split_unit_color = self.get_unit_color(split_unit_id) + colors = { + f"{split_unit_id}-0": split_unit_color, + f"{split_unit_id}-1": split_unit_color, + } for r in range(n): for c in range(r, n): i = unit_ids.index(visible_unit_ids[r]) j = unit_ids.index(visible_unit_ids[c]) - count = self.ccg[i, j, :] + count = ccg[i, j, :] plot = pg.PlotItem() if not self.settings['display_axis']: @@ -71,16 +114,15 @@ def _qt_refresh(self): if r==c: unit_id = visible_unit_ids[r] - color = self.get_unit_color(unit_id) + color = colors[unit_id] else: color = (120,120,120,120) - curve = pg.PlotCurveItem(self.bins, count, stepMode='center', fillLevel=0, brush=color, pen=color) + curve = pg.PlotCurveItem(bins, count, stepMode='center', fillLevel=0, brush=color, pen=color) plot.addItem(curve) self.grid.addItem(plot, row=r, col=c) ## panel ## - def _panel_make_layout(self): import panel as pn import bokeh.plotting as bpl @@ -115,31 +157,40 @@ def _panel_refresh(self): return visible_unit_ids = self.controller.get_visible_unit_ids() - - # Show warning above the plot if too many visible units - if len(visible_unit_ids) > self.settings['max_visible']: - warning_msg = f"Only showing first {self.settings['max_visible']} units out of {len(visible_unit_ids)} visible units" - insert_warning(self, warning_msg) - self.is_warning_active = True - return - if self.is_warning_active: - clear_warning(self) - self.is_warning_active = False - - visible_unit_ids = visible_unit_ids[:self.settings['max_visible']] - - n = len(visible_unit_ids) - unit_ids = list(self.controller.unit_ids) + if self.controller.active_split is None: + n = len(visible_unit_ids) + unit_ids = list(self.controller.unit_ids) + colors = { + unit_id: self.get_unit_color(unit_id) for unit_id in visible_unit_ids + } + ccg = self.ccg + bins = self.bins + else: + split_unit_id = visible_unit_ids[0] + n = 2 + unit_ids = [f"{split_unit_id}-0", f"{split_unit_id}-1"] + visible_unit_ids = unit_ids + ccg, bins = self._compute_split_ccg() + split_unit_color = self.get_unit_color(split_unit_id) + colors = { + f"{split_unit_id}-0": split_unit_color, + f"{split_unit_id}-1": split_unit_color, + } + + first_fig = None for r in range(n): row_plots = [] for c in range(r, n): - i = unit_ids.index(visible_unit_ids[r]) j = unit_ids.index(visible_unit_ids[c]) - count = self.ccg[i, j, :] + count = ccg[i, j, :] # Create Bokeh figure - p = bpl.figure( + if first_fig is not None: + extra_kwargs = dict(x_range=first_fig.x_range) + else: + extra_kwargs = dict() + fig = bpl.figure( width=250, height=250, tools="pan,wheel_zoom,reset", @@ -148,29 +199,32 @@ def _panel_refresh(self): background_fill_color=_bg_color, border_fill_color=_bg_color, outline_line_color="white", + **extra_kwargs, ) - p.toolbar.logo = None + fig.toolbar.logo = None # Get color from controller if r == c: unit_id = visible_unit_ids[r] - color = self.get_unit_color(unit_id) + color = colors[unit_id] fill_alpha = 0.7 else: color = "lightgray" fill_alpha = 0.4 - p.quad( + fig.quad( top=count, bottom=0, - left=self.bins[:-1], - right=self.bins[1:], + left=bins[:-1], + right=bins[1:], fill_color=color, line_color=color, alpha=fill_alpha, ) + if first_fig is None: + first_fig = fig - row_plots.append(p) + row_plots.append(fig) # Fill row with None for proper spacing full_row = [None] * r + row_plots + [None] * (n - len(row_plots)) self.plots.append(full_row) diff --git a/spikeinterface_gui/curation_tools.py b/spikeinterface_gui/curation_tools.py index 2d92b41..d96964b 100644 --- a/spikeinterface_gui/curation_tools.py +++ b/spikeinterface_gui/curation_tools.py @@ -11,27 +11,29 @@ empty_curation_data = { "manual_labels": [], - "merge_unit_groups": [], - "removed_units": [] + "merges": [], + "splits": [], + "removes": [] } -def adding_group(previous_groups, new_group): +def add_merge(previous_merges, new_merge_unit_ids): # this is to ensure that np.str_ types are rendered as str - to_merge = [np.array(new_group).tolist()] + to_merge = [np.array(new_merge_unit_ids).tolist()] unchanged = [] - for c_prev in previous_groups: + for c_prev in previous_merges: is_unaffected = True - - for c_new in new_group: - if c_new in c_prev: + c_prev_unit_ids = c_prev["unit_ids"] + for c_new in new_merge_unit_ids: + if c_new in c_prev_unit_ids: is_unaffected = False - to_merge.append(c_prev) + to_merge.append(c_prev_unit_ids) break if is_unaffected: - unchanged.append(c_prev) - new_merge_group = [sum(to_merge, [])] - new_merge_group.extend(unchanged) - # Ensure the unicity - new_merge_group = [list(set(gp)) for gp in new_merge_group] - return new_merge_group + unchanged.append(c_prev_unit_ids) + + new_merge_units = [sum(to_merge, [])] + new_merge_units.extend(unchanged) + # Ensure the uniqueness + new_merges = [{"unit_ids": list(set(gp))} for gp in new_merge_units] + return new_merges diff --git a/spikeinterface_gui/curationview.py b/spikeinterface_gui/curationview.py index b6ed141..ed496b3 100644 --- a/spikeinterface_gui/curationview.py +++ b/spikeinterface_gui/curationview.py @@ -28,7 +28,7 @@ def restore_units(self): self.notify_manual_curation_updated() self.refresh() - def unmerge_groups(self): + def unmerge(self): if self.backend == 'qt': merge_indices = self._qt_get_merge_table_row() else: @@ -38,6 +38,16 @@ def unmerge_groups(self): self.notify_manual_curation_updated() self.refresh() + def unsplit(self): + if self.backend == 'qt': + split_indices = self._qt_get_split_table_row() + else: + split_indices = self._panel_get_split_table_row() + if split_indices is not None: + self.controller.make_manual_restore_split(split_indices) + self.notify_manual_curation_updated() + self.refresh() + ## Qt def _qt_make_layout(self): from .myqt import QT @@ -60,6 +70,26 @@ def _qt_make_layout(self): h = QT.QHBoxLayout() self.layout.addLayout(h) + + v = QT.QVBoxLayout() + h.addLayout(v) + v.addWidget(QT.QLabel("Deleted")) + self.table_delete = QT.QTableWidget(selectionMode=QT.QAbstractItemView.SingleSelection, + selectionBehavior=QT.QAbstractItemView.SelectRows) + v.addWidget(self.table_delete) + self.table_delete.setContextMenuPolicy(QT.Qt.CustomContextMenu) + self.table_delete.customContextMenuRequested.connect(self._qt_open_context_menu_delete) + self.table_delete.itemSelectionChanged.connect(self._qt_on_item_selection_changed_delete) + + + + self.delete_menu = QT.QMenu() + act = self.delete_menu.addAction('Restore') + act.triggered.connect(self.restore_units) + shortcut_restore = QT.QShortcut(self.qt_widget) + shortcut_restore.setKey(QT.QKeySequence("ctrl+r")) + shortcut_restore.activated.connect(self.restore_units) + v = QT.QVBoxLayout() h.addLayout(v) v.addWidget(QT.QLabel("Merges")) @@ -73,35 +103,32 @@ def _qt_make_layout(self): self.table_merge.itemSelectionChanged.connect(self._qt_on_item_selection_changed_merge) self.merge_menu = QT.QMenu() - act = self.merge_menu.addAction('Remove merge group') - act.triggered.connect(self.unmerge_groups) + act = self.merge_menu.addAction('Remove merge') + act.triggered.connect(self.unmerge) shortcut_unmerge = QT.QShortcut(self.qt_widget) shortcut_unmerge.setKey(QT.QKeySequence("ctrl+u")) - shortcut_unmerge.activated.connect(self.unmerge_groups) - + shortcut_unmerge.activated.connect(self.unmerge) v = QT.QVBoxLayout() h.addLayout(v) - v.addWidget(QT.QLabel("Deleted")) - self.table_delete = QT.QTableWidget(selectionMode=QT.QAbstractItemView.SingleSelection, + v.addWidget(QT.QLabel("Splits")) + self.table_split = QT.QTableWidget(selectionMode=QT.QAbstractItemView.SingleSelection, selectionBehavior=QT.QAbstractItemView.SelectRows) - v.addWidget(self.table_delete) - self.table_delete.setContextMenuPolicy(QT.Qt.CustomContextMenu) - self.table_delete.customContextMenuRequested.connect(self._qt_open_context_menu_delete) - self.table_delete.itemSelectionChanged.connect(self._qt_on_item_selection_changed_delete) - - - self.delete_menu = QT.QMenu() - act = self.delete_menu.addAction('Restore') - act.triggered.connect(self.restore_units) - shortcut_restore = QT.QShortcut(self.qt_widget) - shortcut_restore.setKey(QT.QKeySequence("ctrl+r")) - shortcut_restore.activated.connect(self.restore_units) + v.addWidget(self.table_split) + self.table_split.setContextMenuPolicy(QT.Qt.CustomContextMenu) + self.table_split.customContextMenuRequested.connect(self._qt_open_context_menu_split) + self.table_split.itemSelectionChanged.connect(self._qt_on_item_selection_changed_split) + self.split_menu = QT.QMenu() + act = self.split_menu.addAction('Remove split') + act.triggered.connect(self.unsplit) + shortcut_unsplit = QT.QShortcut(self.qt_widget) + shortcut_unsplit.setKey(QT.QKeySequence("ctrl+x")) + shortcut_unsplit.activated.connect(self.unsplit) def _qt_refresh(self): from .myqt import QT # Merged - merged_units = self.controller.curation_data["merge_unit_groups"] + merged_units = [m["unit_ids"] for m in self.controller.curation_data["merges"]] self.table_merge.clear() self.table_merge.setRowCount(len(merged_units)) self.table_merge.setColumnCount(1) @@ -114,8 +141,8 @@ def _qt_refresh(self): for i in range(self.table_merge.columnCount()): self.table_merge.resizeColumnToContents(i) - ## deleted - removed_units = self.controller.curation_data["removed_units"] + # Removed + removed_units = self.controller.curation_data["removed"] self.table_delete.clear() self.table_delete.setRowCount(len(removed_units)) self.table_delete.setColumnCount(1) @@ -133,6 +160,24 @@ def _qt_refresh(self): item.unit_id = unit_id self.table_delete.resizeColumnToContents(0) + # Splits + splits = self.controller.curation_data["splits"] + self.table_split.clear() + self.table_split.setRowCount(len(splits)) + self.table_split.setColumnCount(1) + self.table_split.setHorizontalHeaderLabels(["Split units"]) + self.table_split.setSortingEnabled(False) + for i, split in enumerate(splits): + unit_id = split["unit_id"] + num_indices = len(split["indices"]) + num_spikes = self.controller.num_spikes[unit_id] + num_splits = f"({num_indices}-{num_spikes - num_indices})" + item = QT.QTableWidgetItem(f"{unit_id} {num_splits}") + item.setFlags(QT.Qt.ItemIsEnabled|QT.Qt.ItemIsSelectable) + self.table_split.setItem(i, 0, item) + item.unit_id = unit_id + self.table_split.resizeColumnToContents(0) + def _qt_get_delete_table_selection(self): @@ -148,12 +193,22 @@ def _qt_get_merge_table_row(self): return None else: return [s.row() for s in selected_items] + + def _qt_get_split_table_row(self): + selected_items = self.table_split.selectedItems() + if len(selected_items) == 0: + return None + else: + return [s.row() for s in selected_items] def _qt_open_context_menu_delete(self): self.delete_menu.popup(self.qt_widget.cursor().pos()) def _qt_open_context_menu_merge(self): self.merge_menu.popup(self.qt_widget.cursor().pos()) + + def _qt_open_context_menu_split(self): + self.split_menu.popup(self.qt_widget.cursor().pos()) def _qt_on_item_selection_changed_merge(self): if len(self.table_merge.selectedIndexes()) == 0: @@ -161,16 +216,28 @@ def _qt_on_item_selection_changed_merge(self): dtype = self.controller.unit_ids.dtype ind = self.table_merge.selectedIndexes()[0].row() - visible_unit_ids = self.controller.curation_data["merge_unit_groups"][ind] + visible_unit_ids = [m["unit_ids"] for m in self.controller.curation_data["merges"]][ind] visible_unit_ids = [dtype.type(unit_id) for unit_id in visible_unit_ids] self.controller.set_visible_unit_ids(visible_unit_ids) self.notify_unit_visibility_changed() + def _qt_on_item_selection_changed_split(self): + if len(self.table_split.selectedIndexes()) == 0: + return + + dtype = self.controller.unit_ids.dtype + ind = self.table_split.selectedIndexes()[0].row() + split_unit_str = self.table_split.item(ind, 0).text() + split_unit_id = dtype.type(split_unit_str.split(" ")[0]) + self.controller.set_visible_unit_ids([split_unit_id]) + self.controller.set_active_split_unit(split_unit_id) + self.notify_unit_visibility_changed() + def _qt_on_item_selection_changed_delete(self): if len(self.table_delete.selectedIndexes()) == 0: return ind = self.table_delete.selectedIndexes()[0].row() - unit_id = self.controller.curation_data["removed_units"][ind] + unit_id = self.controller.curation_data["removed"][ind] self.controller.set_all_unit_visibility_off() # convert to the correct type unit_id = self.controller.unit_ids.dtype.type(unit_id) @@ -216,39 +283,56 @@ def _panel_make_layout(self): pn.extension("tabulator") # Create dataframe - merge_df = pd.DataFrame({"merge_groups": []}) - delete_df = pd.DataFrame({"deleted_unit_id": []}) + delete_df = pd.DataFrame({"removed": []}) + merge_df = pd.DataFrame({"merges": []}) + split_df = pd.DataFrame({"splits": []}) # Create tables + self.table_delete = SelectableTabulator( + delete_df, + show_index=False, + disabled=True, + sortable=False, + formatters={"removed": "plaintext"}, + sizing_mode="stretch_width", + # SelectableTabulator functions + parent_view=self, + # refresh_table_function=self.refresh, + conditional_shortcut=self._conditional_refresh_delete, + column_callbacks={"removed": self._panel_on_deleted_col}, + ) self.table_merge = SelectableTabulator( merge_df, show_index=False, disabled=True, sortable=False, - formatters={"merge_groups": "plaintext"}, + selectable=1, + formatters={"merges": "plaintext"}, sizing_mode="stretch_width", # SelectableTabulator functions parent_view=self, # refresh_table_function=self.refresh, conditional_shortcut=self._conditional_refresh_merge, - column_callbacks={"merge_groups": self._panel_on_merged_col}, + column_callbacks={"merges": self._panel_on_merged_col}, ) - self.table_delete = SelectableTabulator( - delete_df, + self.table_split = SelectableTabulator( + split_df, show_index=False, disabled=True, sortable=False, - formatters={"deleted_unit_id": "plaintext"}, + selectable=1, + formatters={"splits": "plaintext"}, sizing_mode="stretch_width", # SelectableTabulator functions parent_view=self, # refresh_table_function=self.refresh, - conditional_shortcut=self._conditional_refresh_delete, - column_callbacks={"deleted_unit_id": self._panel_on_deleted_col}, + conditional_shortcut=self._conditional_refresh_split, + column_callbacks={"splits": self._panel_on_split_col}, ) self.table_delete.param.watch(self._panel_update_unit_visibility, "selection") self.table_merge.param.watch(self._panel_update_unit_visibility, "selection") + self.table_split.param.watch(self._panel_update_unit_visibility, "selection") # Create buttons save_button = pn.widgets.Button( @@ -277,7 +361,7 @@ def _panel_make_layout(self): button_type="primary", height=30 ) - remove_merge_button.on_click(self._panel_unmerge_groups) + remove_merge_button.on_click(self._panel_unmerge) submit_button = pn.widgets.Button( name="Submit to parent", @@ -306,12 +390,14 @@ def _panel_make_layout(self): shortcuts = [ KeyboardShortcut(name="restore", key="r", ctrlKey=True), KeyboardShortcut(name="unmerge", key="u", ctrlKey=True), + KeyboardShortcut(name="unsplit", key="x", ctrlKey=True), ] shortcuts_component = KeyboardShortcuts(shortcuts=shortcuts) shortcuts_component.on_msg(self._panel_handle_shortcut) # Create main layout with proper sizing - sections = pn.Row(self.table_merge, self.table_delete, sizing_mode="stretch_width") + sections = pn.Row(self.table_delete, self.table_merge, self.table_split, + sizing_mode="stretch_width") self.layout = pn.Column( save_sections, buttons_curate, @@ -331,47 +417,68 @@ def _panel_make_layout(self): def _panel_refresh(self): import pandas as pd - # Merged - merged_units = self.controller.curation_data["merge_unit_groups"] + ## deleted + removed_units = self.controller.curation_data["removed"] + removed_units = [str(unit_id) for unit_id in removed_units] + df = pd.DataFrame({"removed": removed_units}) + self.table_delete.value = df + self.table_delete.selection = [] + + # Merged + merged_units = [m["unit_ids"] for m in self.controller.curation_data["merges"]] # for visualization, we make all row entries strings merged_units_str = [] for group in merged_units: # convert to string group = [str(unit_id) for unit_id in group] merged_units_str.append(" - ".join(group)) - df = pd.DataFrame({"merge_groups": merged_units_str}) + df = pd.DataFrame({"merges": merged_units_str}) self.table_merge.value = df self.table_merge.selection = [] - ## deleted - removed_units = self.controller.curation_data["removed_units"] - removed_units = [str(unit_id) for unit_id in removed_units] - df = pd.DataFrame({"deleted_unit_id": removed_units}) - self.table_delete.value = df - self.table_delete.selection = [] + # Splits + split_units_str = [] + num_spikes = self.controller.num_spikes + for split in self.controller.curation_data["splits"]: + unit_id = split["unit_id"] + num_indices = len(split["indices"]) + num_splits = f"({num_indices}-{num_spikes[unit_id] - num_indices})" + split_units_str.append(f"{unit_id} {num_splits}") + df = pd.DataFrame({"splits": split_units_str}) + self.table_split.value = df + self.table_split.selection = [] def _panel_update_unit_visibility(self, event): unit_dtype = self.controller.unit_ids.dtype if self.active_table == "delete": - visible_unit_ids = self.table_delete.value["deleted_unit_id"].values[self.table_delete.selection].tolist() + visible_unit_ids = self.table_delete.value["removed"].values[self.table_delete.selection].tolist() visible_unit_ids = [unit_dtype.type(unit_id) for unit_id in visible_unit_ids] self.controller.set_visible_unit_ids(visible_unit_ids) elif self.active_table == "merge": - merge_groups = self.table_merge.value["merge_groups"].values[self.table_merge.selection].tolist() + merge_groups = self.table_merge.value["merges"].values[self.table_merge.selection].tolist() # self.controller.set_all_unit_visibility_off() visible_unit_ids = [] for merge_group in merge_groups: merge_unit_ids = [unit_dtype.type(unit_id) for unit_id in merge_group.split(" - ")] visible_unit_ids.extend(merge_unit_ids) self.controller.set_visible_unit_ids(visible_unit_ids) + elif self.active_table == "split": + split_unit_str = self.table_split.value["splits"].values[self.table_split.selection].tolist() + if len(split_unit_str) == 1: + split_unit_str = split_unit_str[0] + split_unit = split_unit_str.split(" ")[0] + # self.controller.set_all_unit_visibility_off() + split_unit = unit_dtype.type(split_unit) + self.controller.set_visible_unit_ids([split_unit]) + self.controller.set_active_split_unit(split_unit) self.notify_unit_visibility_changed() def _panel_restore_units(self, event): self.restore_units() - def _panel_unmerge_groups(self, event): - self.unmerge_groups() + def _panel_unmerge(self, event): + self.unmerge() def _panel_save_in_analyzer(self, event): self.save_in_analyzer() @@ -392,7 +499,7 @@ def _panel_get_delete_table_selection(self): if len(selected_items) == 0: return None else: - return self.table_delete.value["deleted_unit_id"].values[selected_items].tolist() + return self.table_delete.value["removed"].values[selected_items].tolist() def _panel_get_merge_table_row(self): selected_items = self.table_merge.selection @@ -401,11 +508,20 @@ def _panel_get_merge_table_row(self): else: return selected_items + def _panel_get_split_table_row(self): + selected_items = self.table_split.selection + if len(selected_items) == 0: + return None + else: + return selected_items + def _panel_handle_shortcut(self, event): if event.data == "restore": self.restore_units() elif event.data == "unmerge": - self.unmerge_groups() + self.unmerge() + elif event.data == "unsplit": + pass def _panel_submit_to_parent(self, event): """Send the curation data to the parent window""" @@ -441,10 +557,17 @@ def _panel_submit_to_parent(self, event): def _panel_on_deleted_col(self, row): self.active_table = "delete" self.table_merge.selection = [] + self.table_split.selection = [] def _panel_on_merged_col(self, row): self.active_table = "merge" self.table_delete.selection = [] + self.table_split.selection = [] + + def _panel_on_split_col(self, row): + self.active_table = "split" + self.table_delete.selection = [] + self.table_merge.selection = [] def _conditional_refresh_merge(self): # Check if the view is active before refreshing @@ -460,6 +583,13 @@ def _conditional_refresh_delete(self): else: return False + def _conditional_refresh_split(self): + # Check if the view is active before refreshing + if self.is_view_active() and self.active_table == "split": + return True + else: + return False + CurationView._gui_help_txt = """ ## Curation View diff --git a/spikeinterface_gui/mergeview.py b/spikeinterface_gui/mergeview.py index 8f5fa76..6578ccd 100644 --- a/spikeinterface_gui/mergeview.py +++ b/spikeinterface_gui/mergeview.py @@ -80,7 +80,7 @@ def get_table_data(self, include_deleted=False): unit_ids = list(self.controller.unit_ids) for group_ids in self.proposed_merge_unit_groups: if not include_deleted and self.controller.curation: - deleted_unit_ids = self.controller.curation_data["removed_units"] + deleted_unit_ids = self.controller.curation_data["removed"] if any(unit_id in deleted_unit_ids for unit_id in group_ids): continue diff --git a/spikeinterface_gui/ndscatterview.py b/spikeinterface_gui/ndscatterview.py index 1bafa92..ebd5102 100644 --- a/spikeinterface_gui/ndscatterview.py +++ b/spikeinterface_gui/ndscatterview.py @@ -54,6 +54,8 @@ def __init__(self, controller=None, parent=None, backend="qt"): self.tour_step = 0 self.auto_update_limit = True + self._lasso_vertices = [] + ViewBase.__init__(self, controller=controller, parent=parent, backend=backend) @@ -107,9 +109,6 @@ def random_projection(self): # here we don't want to update the components because it's been done already! self.refresh(update_components=False) - def on_spike_selection_changed(self): - self.refresh() - def on_unit_visibility_changed(self): self.random_projection() @@ -121,6 +120,19 @@ def apply_dot(self, data): return projected def get_plotting_data(self, return_spike_indices=False): + """ + Get the data to plot in the scatter plot. + + Parameters + ---------- + return_spike_indices : bool, default: False + If True, also return the indices of the spikes for each unit. + + Returns + ------- + _type_ + _description_ + """ scatter_x = {} scatter_y = {} all_limits = [] @@ -134,7 +146,7 @@ def get_plotting_data(self, return_spike_indices=False): projected_2d = projected[:, :2] all_limits.append(float(np.percentile(np.abs(projected_2d), 95) * 2.)) if return_spike_indices: - spike_indices[unit_id] = mask + spike_indices[unit_id] = self.random_spikes_indices[mask] if len(all_limits) > 0 and self.auto_update_limit: self.limit = max(all_limits) @@ -323,7 +335,9 @@ def _qt_refresh(self, update_components=True, update_colors=True): self.plot.setYRange(-self.limit, self.limit) # self.graphicsview.repaint() - + + def _qt_on_spike_selection_changed(self): + self.refresh() def _qt_start_stop_tour(self, checked): if checked: @@ -346,17 +360,31 @@ def _qt_on_lasso_drawing(self, points): points = np.array(points) self.lasso.setData(points[:, 0], points[:, 1]) - def _qt_on_lasso_finished(self, points): + def _qt_on_lasso_finished(self, points, shift_held=False): self.lasso.setData([], []) vertices = np.array(points) + self._lasso_vertices.append(vertices) # inside lasso and visibles ind_visibles, = np.nonzero(np.isin(self.random_spikes_indices, self.controller.get_indices_spike_visible())) projected = self.apply_dot(self.data[ind_visibles, :]) inside = inside_poly(projected, vertices) - inds = self.random_spikes_indices[ind_visibles[inside]] - self.controller.set_indices_spike_selected(inds) + new_selected_inds = self.random_spikes_indices[ind_visibles[inside]] + + if shift_held: + # Extend existing selection + current_selection = self.controller.get_indices_spike_selected() + if len(current_selection) > 0 and len(new_selected_inds) > 0: + extended_selection = np.unique(np.concatenate([current_selection, new_selected_inds])) + self.controller.set_indices_spike_selected(extended_selection) + elif len(new_selected_inds) > 0: + # No current selection, just use new selection + self.controller.set_indices_spike_selected(new_selected_inds) + # If no new selection and shift held, keep existing selection unchanged + else: + # Replace selection (original behavior) + self.controller.set_indices_spike_selected(new_selected_inds) self.refresh() self.notify_spike_selection_changed() @@ -369,7 +397,7 @@ def _panel_make_layout(self): from bokeh.models import ColumnDataSource, LassoSelectTool, Range1d from bokeh.events import MouseWheel - from .utils_panel import _bg_color, slow_lasso + from .utils_panel import _bg_color self.lasso_tool = LassoSelectTool() @@ -393,12 +421,9 @@ def _panel_make_layout(self): # remove the bokeh mousewheel zoom and keep only this one self.scatter_fig.on_event(MouseWheel, self._panel_gain_zoom) - self.scatter_source = ColumnDataSource({"x": [], "y": [], "color": []}) - self.scatter_select_source = ColumnDataSource({"x": [], "y": [], "color": []}) + self.scatter_source = ColumnDataSource({"x": [], "y": [], "color": [], "spike_indices": []}) self.scatter = self.scatter_fig.scatter("x", "y", source=self.scatter_source, size=3, color="color", alpha=0.7) - self.scatter_select = self.scatter_fig.scatter("x", "y", source=self.scatter_select_source, - size=11, color="white", alpha=0.8) # toolbar self.next_face_button = pn.widgets.Button(name="Next Face", button_type="default", width=100) @@ -410,14 +435,14 @@ def _panel_make_layout(self): self.random_tour_button = pn.widgets.Toggle(name="Random Tour", button_type="default", width=100) self.random_tour_button.param.watch(self._panel_start_stop_tour, "value") - # self.select_toggle_button = pn.widgets.Toggle(name="Select") - # self.select_toggle_button.param.watch(self._panel_on_select_button, 'value') + self.select_toggle_button = pn.widgets.Toggle(name="Select") + self.select_toggle_button.param.watch(self._panel_on_select_button, 'value') - # TODO: add a lasso selection - # slow_lasso(self.scatter_source, self._on_panel_lasso_selected) + self.scatter_fig.on_event('selectiongeometry', self._on_panel_selection_geometry) self.toolbar = pn.Row( - self.next_face_button, self.random_button, self.random_tour_button, sizing_mode="stretch_both", + self.next_face_button, self.random_button, self.random_tour_button, self.select_toggle_button, + sizing_mode="stretch_both", styles={"flex": "0.15"} ) @@ -433,15 +458,16 @@ def _panel_make_layout(self): def _panel_refresh(self, update_components=True, update_colors=True): if update_components: self.update_selected_components() - scatter_x, scatter_y, selected_scatter_x, selected_scatter_y = self.get_plotting_data() + scatter_x, scatter_y, _, _, spike_indices = self.get_plotting_data(return_spike_indices=True) - xs, ys, colors = [], [], [] + xs, ys, colors, plotted_spike_indices = [], [], [], [] for unit_id in scatter_x.keys(): color = self.get_unit_color(unit_id) xs.extend(scatter_x[unit_id]) ys.extend(scatter_y[unit_id]) if update_colors: colors.extend([color] * len(scatter_x[unit_id])) + plotted_spike_indices.extend(spike_indices[unit_id]) if not update_colors: colors = self.scatter_source.data.get("color") @@ -450,23 +476,20 @@ def _panel_refresh(self, update_components=True, update_colors=True): "x": xs, "y": ys, "color": colors, + "spike_indices": plotted_spike_indices } - self.scatter_select_source.data = { - "x": selected_scatter_x, - "y": selected_scatter_y, - } - - # TODO: handle selection with lasso - # mask = np.isin(self.random_spikes_indices, self.controller.get_indices_spike_selected()) - # selected_indices = np.flatnonzero(mask) - # self.scatter_source.selected.indices = selected_indices.tolist() - self.scatter_fig.x_range.start = -self.limit self.scatter_fig.x_range.end = self.limit self.scatter_fig.y_range.start = -self.limit self.scatter_fig.y_range.end = self.limit + def _panel_on_spike_selection_changed(self): + # handle selection with lasso + plotted_spike_indices = self.scatter_source.data.get("spike_indices", []) + ind_selected, = np.nonzero(np.isin(plotted_spike_indices, self.controller.get_indices_spike_selected())) + self.scatter_source.selected.indices = ind_selected + def _panel_gain_zoom(self, event): from bokeh.models import Range1d @@ -501,24 +524,37 @@ def _panel_on_select_button(self, event): else: self.scatter_fig.toolbar.active_drag = None self.scatter_source.selected.indices = [] - # self._on_panel_lasso_selected(None, None, None) - - - # TODO: Handle lasso selection and updates - # def _on_panel_lasso_selected(self, attr, old, new): - # if len(self.scatter_source.selected.indices) == 0: - # self.notify_spike_selection_changed() - # self.refresh() - # return - - # # inside lasso and visibles - # inside = self.scatter_source.selected.indices - - # inds = self.random_spikes_indices[inside] - # self.controller.set_indices_spike_selected(inds) - # self.refresh() - # self.notify_spike_selection_changed() + def _on_panel_selection_geometry(self, event): + """ + Handle SelectionGeometry event to capture lasso polygon vertices. + """ + if event.final: + xs = np.array(event.geometry["x"]) + ys = np.array(event.geometry["y"]) + polygon = np.column_stack((xs, ys)) + + selected = self.scatter_source.selected.indices + + if len(selected) == 0: + self.notify_spike_selection_changed() + self.refresh() + return + + if len(selected) > self._current_selected: + self._current_selected = len(selected) + # Store the current polygon for the current segment + self._lasso_vertices.append(polygon) + else: + self._lasso_vertices = [polygon] + + # inside lasso and visibles + ind_visibles, = np.nonzero(np.isin(self.random_spikes_indices, self.controller.get_indices_spike_visible())) + inds = self.random_spikes_indices[ind_visibles[selected]] + self.controller.set_indices_spike_selected(inds) + + self.notify_spike_selection_changed() + self.refresh() def inside_poly(data, vertices): diff --git a/spikeinterface_gui/tests/test_mainwindow_panel.py b/spikeinterface_gui/tests/test_mainwindow_panel.py index e20e231..0cc96d5 100644 --- a/spikeinterface_gui/tests/test_mainwindow_panel.py +++ b/spikeinterface_gui/tests/test_mainwindow_panel.py @@ -32,7 +32,7 @@ def teardown_module(): clean_all(test_folder) -def test_mainwindow(start_app=False, verbose=True, curation=False, only_some_extensions=False, from_si_api=False): +def test_mainwindow(start_app=False, verbose=True, curation=False, only_some_extensions=False, from_si_api=False, port=0): analyzer = load_sorting_analyzer(test_folder / "sorting_analyzer") @@ -71,7 +71,7 @@ def test_mainwindow(start_app=False, verbose=True, curation=False, only_some_ext layout_preset='default', # skip_extensions=["waveforms", "principal_components", "template_similarity", "spike_amplitudes"], # address="10.69.168.40", - # port=5000, + port=port, ) return win @@ -100,9 +100,9 @@ def test_launcher(verbose=True): if not test_folder.is_dir(): setup_module() - # win = test_mainwindow(start_app=True, verbose=True, curation=True) + win = test_mainwindow(start_app=True, verbose=True, curation=True, port=5006) - test_launcher(verbose=True) + # test_launcher(verbose=True) # TO RUN with panel serve: # win = test_mainwindow(start_app=False, verbose=True, curation=True) diff --git a/spikeinterface_gui/tests/test_mainwindow_qt.py b/spikeinterface_gui/tests/test_mainwindow_qt.py index 14e65d2..d29f8ca 100644 --- a/spikeinterface_gui/tests/test_mainwindow_qt.py +++ b/spikeinterface_gui/tests/test_mainwindow_qt.py @@ -15,7 +15,7 @@ test_folder = Path(__file__).parent / 'my_dataset_small' -test_folder = Path(__file__).parent / 'my_dataset_big' +# test_folder = Path(__file__).parent / 'my_dataset_big' # test_folder = Path(__file__).parent / 'my_dataset_multiprobe' # yep is for testing @@ -110,7 +110,7 @@ def test_launcher(verbose=True): if __name__ == '__main__': if not test_folder.is_dir(): setup_module() - # win = test_mainwindow(start_app=True, verbose=True, curation=True) + win = test_mainwindow(start_app=True, verbose=True, curation=True) # win = test_mainwindow(start_app=True, verbose=True, curation=False) - test_launcher(verbose=True) + # test_launcher(verbose=True) diff --git a/spikeinterface_gui/tests/testingtools.py b/spikeinterface_gui/tests/testingtools.py index 4f1a7a0..92e1218 100644 --- a/spikeinterface_gui/tests/testingtools.py +++ b/spikeinterface_gui/tests/testingtools.py @@ -137,6 +137,7 @@ def make_analyzer_folder(test_folder, case="small", unit_dtype="str"): def make_curation_dict(analyzer): unit_ids = analyzer.unit_ids.tolist() curation_dict = { + "format_version": "2", "unit_ids": unit_ids, "label_definitions": { "quality":{ @@ -153,8 +154,8 @@ def make_curation_dict(analyzer): {'unit_id': unit_ids[2], "putative_type": ["exitatory"]}, {'unit_id': unit_ids[3], "quality": ["noise"], "putative_type": ["inhibitory"]}, ], - "merge_unit_groups": [unit_ids[:3], unit_ids[3:5]], - "removed_units": unit_ids[5:8], + "merges": [{"unit_ids": unit_ids[:3]}, {"unit_ids": unit_ids[3:5]}], + "removed": unit_ids[5:8], } return curation_dict diff --git a/spikeinterface_gui/utils_qt.py b/spikeinterface_gui/utils_qt.py index a5b5044..a815adc 100644 --- a/spikeinterface_gui/utils_qt.py +++ b/spikeinterface_gui/utils_qt.py @@ -47,12 +47,13 @@ def add_stretch_to_qtoolbar(tb): class ViewBoxHandlingLasso(pg.ViewBox): doubleclicked = QT.pyqtSignal() lasso_drawing = QT.pyqtSignal(object) - lasso_finished = QT.pyqtSignal(object) + lasso_finished = QT.pyqtSignal(object, bool) # Added bool parameter for shift_held def __init__(self, *args, **kwds): pg.ViewBox.__init__(self, *args, **kwds) self.drag_points = [] self.lasso_active = False + self.shift_held = False def mouseDoubleClickEvent(self, ev): self.doubleclicked.emit() @@ -68,12 +69,14 @@ def mouseDragEvent(self, ev): if ev.isStart(): self.drag_points = [] + # Check if shift is held at the start of the drag + self.shift_held = ev.modifiers() == QT.Qt.ShiftModifier pos = self.mapToView(ev.pos()) self.drag_points.append([pos.x(), pos.y()]) if ev.isFinish(): - self.lasso_finished.emit(self.drag_points) + self.lasso_finished.emit(self.drag_points, self.shift_held) else: self.lasso_drawing.emit(self.drag_points) @@ -123,12 +126,13 @@ class ViewBoxHandlingLassoAndGain(pg.ViewBox): doubleclicked = QT.pyqtSignal() gain_zoom = QT.pyqtSignal(float) lasso_drawing = QT.pyqtSignal(object) - lasso_finished = QT.pyqtSignal(object) + lasso_finished = QT.pyqtSignal(object, bool) # Added bool parameter for shift_held def __init__(self, *args, **kwds): pg.ViewBox.__init__(self, *args, **kwds) self.disableAutoRange() self.drag_points = [] + self.shift_held = False def mouseClickEvent(self, ev): ev.accept() @@ -152,12 +156,14 @@ def mouseDragEvent(self, ev): if ev.isStart(): self.drag_points = [] + # Check if shift is held at the start of the drag + self.shift_held = ev.modifiers() == QT.Qt.ShiftModifier pos = self.mapToView(ev.pos()) self.drag_points.append([pos.x(), pos.y()]) if ev.isFinish(): - self.lasso_finished.emit(self.drag_points) + self.lasso_finished.emit(self.drag_points, self.shift_held) else: self.lasso_drawing.emit(self.drag_points) diff --git a/spikeinterface_gui/view_base.py b/spikeinterface_gui/view_base.py index 3efec5a..8aab966 100644 --- a/spikeinterface_gui/view_base.py +++ b/spikeinterface_gui/view_base.py @@ -113,6 +113,12 @@ def _refresh(self, **kwargs): self._qt_refresh(**kwargs) elif self.backend == "panel": self._panel_refresh(**kwargs) + + def warning(self, warning_msg): + if self.backend == "qt": + self._qt_insert_warning(warning_msg) + elif self.backend == "panel": + self._panel_insert_warning(warning_msg) def get_unit_color(self, unit_id): if self.backend == "qt": @@ -212,6 +218,13 @@ def _qt_on_time_info_updated(self): def _qt_on_unit_color_changed(self): self.refresh() + def _qt_insert_warning(self, warning_msg): + from .myqt import QT + + alert = QT.QMessageBox(QT.QMessageBox.Warning, "Warning", warning_msg, parent=self.qt_widget) + alert.setStandardButtons(QT.QMessageBox.Ok) + alert.exec_() + ## PANEL ## def _panel_make_layout(self): raise NotImplementedError @@ -237,3 +250,15 @@ def _panel_on_time_info_updated(self): def _panel_on_unit_color_changed(self): self.refresh() + + def _panel_insert_warning(self, warning_msg): + import panel as pn + + alert_markdown = pn.pane.Markdown(f"⚠️⚠️⚠️ {warning_msg} ⚠️⚠️⚠️", styles={'color': 'red', 'font-size': '16px'}) + close_button = pn.widgets.Button(name="X") + close_button.on_click(self._panel_clear_warning) + row = pn.Row(alert_markdown, close_button, sizing_mode='stretch_width') + self.layout.insert(0, row) + + def _panel_clear_warning(self, event): + self.layout.pop(0)