diff --git a/spikeinterface_gui/basescatterview.py b/spikeinterface_gui/basescatterview.py
index cd9cc19..7879a70 100644
--- a/spikeinterface_gui/basescatterview.py
+++ b/spikeinterface_gui/basescatterview.py
@@ -8,7 +8,7 @@ class BaseScatterView(ViewBase):
_supported_backend = ['qt', 'panel']
_depend_on = None
_settings = [
- {'name': 'auto_decimate', 'type': 'bool', 'value' : True },
+ {'name': "auto_decimate", 'type': 'bool', 'value' : True },
{'name': 'max_spikes_per_unit', 'type': 'int', 'value' : 10_000 },
{'name': 'alpha', 'type': 'float', 'value' : 0.7, 'limits':(0, 1.), 'step':0.05 },
{'name': 'scatter_size', 'type': 'float', 'value' : 2., 'step':0.5 },
@@ -29,6 +29,9 @@ def __init__(self, spike_data, y_label, controller=None, parent=None, backend="q
eps = (self._data_max - self._data_min) / 100.0
self._data_max += eps
self._max_count = None
+ self._lasso_vertices = {segment_index: [] for segment_index in range(controller.num_segments)}
+ # this is used in panel
+ self._current_selected = 0
ViewBase.__init__(self, controller=controller, parent=parent, backend=backend)
@@ -42,7 +45,7 @@ def get_unit_data(self, unit_id, seg_index=0):
hist_count, hist_bins = np.histogram(spike_data, bins=np.linspace(hist_min, hist_max, self.settings['num_bins']))
- if self.settings['auto_decimate'] and spike_times.size > self.settings['max_spikes_per_unit']:
+ if self.settings["auto_decimate"] and spike_times.size > self.settings['max_spikes_per_unit']:
step = spike_times.size // self.settings['max_spikes_per_unit']
spike_times = spike_times[::step]
spike_data = spike_data[::step]
@@ -61,6 +64,62 @@ def get_selected_spikes_data(self, seg_index=0):
return (spike_times, spike_data)
+ def split(self):
+ """
+ Add a split to the curation data based on the lasso vertices.
+ """
+ if self.controller.num_segments > 1:
+ # check that lasso vertices are defined for all segments
+ if not all(len(self._lasso_vertices[seg_index]) > 0 for seg_index in range(self.controller.num_segments)):
+ self.warning("Select areas for all segments.")
+ return
+
+ # split is only possible if one unit is visible
+ visible_unit_ids = self.controller.get_visible_unit_ids()
+ if len(visible_unit_ids) != 1:
+ self.warning("Split is only possible if one unit is visible.")
+ return
+
+ visible_unit_id = visible_unit_ids[0]
+
+ fs = self.controller.sampling_frequency
+ indices = []
+ offset = 0
+ for segment_index, vertices in self._lasso_vertices.items():
+ spike_inds = self.controller.get_spike_indices(visible_unit_id, seg_index=segment_index)
+ spike_times = self.controller.spikes["sample_index"][spike_inds] / fs
+ spike_data = self.spike_data[spike_inds]
+
+ # spike inds within spike train
+ inds = np.arange(offset, offset + len(spike_inds), dtype=int)
+
+ points = np.column_stack((spike_times, spike_data))
+ indices_in_segment = []
+ for polygon in vertices:
+ # Check if points are inside the polygon
+ inside = mpl_path(polygon).contains_points(points)
+ if np.any(inside):
+ # If any point is inside, we can proceed with the split
+ indices_in_segment.extend(inds[inside])
+ indices.extend(indices_in_segment)
+ offset += len(spike_inds)
+
+ self.controller.make_manual_split_if_possible(
+ unit_id=visible_unit_id,
+ indices=indices,
+ )
+
+ # Clear the lasso vertices after splitting
+ self._lasso_vertices = {segment_index: [] for segment_index in range(self.controller.num_segments)}
+ self.refresh()
+ self.notify_manual_curation_updated()
+
+
+ def on_unit_visibility_changed(self):
+ self._lasso_vertices = {segment_index: [] for segment_index in range(self.controller.num_segments)}
+ self._current_selected = 0
+ self.refresh()
+
## QT zone ##
def _qt_make_layout(self):
from .myqt import QT
@@ -76,16 +135,20 @@ def _qt_make_layout(self):
self.combo_seg.currentIndexChanged.connect(self.refresh)
add_stretch_to_qtoolbar(tb)
self.lasso_but = QT.QPushButton("select", checkable = True)
-
tb.addWidget(self.lasso_but)
self.lasso_but.clicked.connect(self.enable_disable_lasso)
-
+ if self.controller.curation:
+ self.split_but = QT.QPushButton("split")
+ tb.addWidget(self.split_but)
+ self.split_but.clicked.connect(self.split)
+ shortcut_split = QT.QShortcut(self.qt_widget)
+ shortcut_split.setKey(QT.QKeySequence("ctrl+s"))
+ shortcut_split.activated.connect(self.split)
h = QT.QHBoxLayout()
self.layout.addLayout(h)
self.graphicsview = pg.GraphicsView()
h.addWidget(self.graphicsview, 3)
-
self.graphicsview2 = pg.GraphicsView()
h.addWidget(self.graphicsview2, 1)
@@ -103,7 +166,6 @@ def _qt_make_layout(self):
self.scatter_select.setZValue(1000)
-
def initialize_plot(self):
import pyqtgraph as pg
from .utils_qt import ViewBoxHandlingLasso
@@ -148,7 +210,10 @@ def _qt_refresh(self):
max_count = 1
for unit_id in self.controller.get_visible_unit_ids():
- spike_times, spike_data, hist_count, hist_bins, _ = self.get_unit_data(unit_id)
+ spike_times, spike_data, hist_count, hist_bins, _ = self.get_unit_data(
+ unit_id,
+ seg_index=self.combo_seg.currentIndex()
+ )
# make a copy of the color
color = QT.QColor(self.get_unit_color(unit_id))
@@ -162,13 +227,13 @@ def _qt_refresh(self):
max_count = max(max_count, np.max(hist_count))
self._max_count = max_count
- seg_index = self.combo_seg.currentIndex()
+ seg_index = self.combo_seg.currentIndex()
time_max = self.controller.get_num_samples(seg_index) / self.controller.sampling_frequency
self.plot.setXRange( 0., time_max, padding = 0.0)
self.plot2.setXRange(0, self._max_count, padding = 0.0)
- spike_times, spike_data = self.get_selected_spikes_data()
+ spike_times, spike_data = self.get_selected_spikes_data(seg_index=self.combo_seg.currentIndex())
self.scatter_select.setData(spike_times, spike_data)
def enable_disable_lasso(self, checked):
@@ -183,7 +248,7 @@ def on_lasso_drawing(self, points):
points = np.array(points)
self.lasso.setData(points[:, 0], points[:, 1])
- def on_lasso_finished(self, points):
+ def on_lasso_finished(self, points, shift_held=False):
self.lasso.setData([], [])
vertices = np.array(points)
@@ -200,40 +265,53 @@ def on_lasso_finished(self, points):
# Only consider spikes from visible units
visible_spikes = spikes_in_seg[visible_mask]
if len(visible_spikes) == 0:
- # Clear selection if no visible spikes
- self.controller.set_indices_spike_selected([])
- self.refresh()
- self.notify_spike_selection_changed()
+ # Clear selection if no visible spikes and shift not held
+ if not shift_held:
+ self.controller.set_indices_spike_selected([])
+ self.refresh()
+ self.notify_spike_selection_changed()
return
-
+
spike_times = visible_spikes['sample_index'] / fs
spike_data = self.spike_data[sl][visible_mask]
- points = np.column_stack((spike_times, spike_data))
- inside = mpl_path(vertices).contains_points(points)
-
- # Clear selection if no spikes inside lasso
- if not np.any(inside):
- self.controller.set_indices_spike_selected([])
- self.refresh()
- self.notify_spike_selection_changed()
- return
+ scatter_data = np.column_stack((spike_times, spike_data))
+ inside = mpl_path(vertices).contains_points(scatter_data)
+
+ if shift_held:
+ # If shift is held, append the vertices to the current lasso vertices
+ self._lasso_vertices[seg_index].append(vertices)
+ else:
+ # If shift is not held, clear the existing lasso vertices for this segment
+ self._lasso_vertices[seg_index] = [vertices]
- # Map back to original indices
- visible_indices = np.nonzero(visible_mask)[0]
- selected_indices = sl.start + visible_indices[inside]
- self.controller.set_indices_spike_selected(selected_indices)
+ # print(f"Lasso selection for segment {seg_index} has {len(self._lasso_vertices[seg_index])} polygons.")
+
+ # Handle selection based on whether shift is held
+ if np.any(inside):
+ # Map back to original indices
+ visible_indices = np.nonzero(visible_mask)[0]
+ new_selected_indices = sl.start + visible_indices[inside]
+
+ if shift_held:
+ # Extend existing selection
+ current_selection = self.controller.get_indices_spike_selected()
+ extended_selection = np.unique(np.concatenate([current_selection, new_selected_indices]))
+ self.controller.set_indices_spike_selected(extended_selection)
+ else:
+ # Replace selection
+ self.controller.set_indices_spike_selected(new_selected_indices)
+
self.refresh()
self.notify_spike_selection_changed()
-
## Panel zone ##
def _panel_make_layout(self):
import panel as pn
import bokeh.plotting as bpl
from bokeh.models import ColumnDataSource, LassoSelectTool, Range1d
- from .utils_panel import _bg_color, slow_lasso
+ from .utils_panel import _bg_color #, slow_lasso
self.lasso_tool = LassoSelectTool()
@@ -246,7 +324,11 @@ def _panel_make_layout(self):
self.segment_selector.param.watch(self._panel_change_segment, 'value')
self.select_toggle_button = pn.widgets.Toggle(name="Select")
- self.select_toggle_button.param.watch(self._panel_on_select_button, 'value')
+ self.select_toggle_button.param.watch(self._panel_on_select_button, 'value')
+
+ if self.controller.curation:
+ self.split_button = pn.widgets.Button(name="Split", button_type="primary")
+ self.split_button.on_click(self._panel_split)
self.y_range = Range1d(self._data_min, self._data_max)
self.scatter_source = ColumnDataSource(data={"x": [], "y": [], "color": []})
@@ -276,7 +358,8 @@ def _panel_make_layout(self):
time_max = self.controller.get_num_samples(self.segment_index) / self.controller.sampling_frequency
self.scatter_fig.x_range = Range1d(0., time_max)
- slow_lasso(self.scatter_source, self._on_panel_lasso_selected)
+ # Add SelectionGeometry event handler to capture lasso vertices
+ self.scatter_fig.on_event('selectiongeometry', self._on_panel_selection_geometry)
self.hist_fig = bpl.figure(
tools="reset,wheel_zoom",
@@ -292,8 +375,19 @@ def _panel_make_layout(self):
self.hist_fig.xaxis.axis_label = "Count"
self.hist_fig.x_range = Range1d(0, 1000) # Initial x range for histogram
+ toolbar_elements = [self.segment_selector, self.select_toggle_button]
+ if self.controller.curation:
+ toolbar_elements.append(self.split_button)
+
+ if self.controller.curation:
+ from .utils_panel import KeyboardShortcut, KeyboardShortcuts
+ shortcuts = [KeyboardShortcut(key="s", name="split", ctrlKey=True)]
+ shortcuts_component = KeyboardShortcuts(shortcuts=shortcuts)
+ shortcuts_component.on_msg(self._panel_handle_shortcut)
+ toolbar_elements.append(shortcuts_component)
+
self.layout = pn.Column(
- pn.Row(self.segment_selector, self.select_toggle_button, sizing_mode="stretch_width"),
+ pn.Row(*toolbar_elements, sizing_mode="stretch_width"),
pn.Row(
pn.Column(
self.scatter_fig,
@@ -374,30 +468,46 @@ def _panel_refresh(self):
self.hist_fig.x_range.end = max_count
def _panel_on_select_button(self, event):
- if self.select_toggle_button.value and len(self.controller.get_visible_unit_ids()) == 1:
+ if self.select_toggle_button.value:
self.scatter_fig.toolbar.active_drag = self.lasso_tool
else:
self.scatter_fig.toolbar.active_drag = None
self.scatter_source.selected.indices = []
- self._on_panel_lasso_selected(None, None, None)
+
def _panel_change_segment(self, event):
+ self._current_selected = 0
self.segment_index = int(self.segment_selector.value.split()[-1])
time_max = self.controller.get_num_samples(self.segment_index) / self.controller.sampling_frequency
self.scatter_fig.x_range.end = time_max
self.refresh()
- def _on_panel_lasso_selected(self, attr, old, new):
+ def _on_panel_selection_geometry(self, event):
"""
- Handle selection changes in the scatter plot.
+ Handle SelectionGeometry event to capture lasso polygon vertices.
"""
- if self.select_toggle_button.value:
+ if event.final:
+ xs = np.array(event.geometry["x"])
+ ys = np.array(event.geometry["y"])
+ polygon = np.column_stack((xs, ys))
+
selected = self.scatter_source.selected.indices
if len(selected) == 0:
self.controller.set_indices_spike_selected([])
self.notify_spike_selection_changed()
return
+ # Append the current polygon to the lasso vertices if shift is held
+ seg_index = self.segment_index
+ if len(selected) > self._current_selected:
+ self._current_selected = len(selected)
+ # Store the current polygon for the current segment
+ self._lasso_vertices[seg_index].append(polygon)
+ else:
+ self._lasso_vertices[seg_index] = [polygon]
+
+ # print(f"Lasso selection for segment {self.segment_index} has {len(self._lasso_vertices[self.segment_index])} polygons.")
+
# Map back to original indices
sl = self.controller.segment_slices[self.segment_index]
spikes_in_seg = self.controller.spikes[sl]
@@ -409,9 +519,16 @@ def _on_panel_lasso_selected(self, attr, old, new):
# Map back to original indices
visible_indices = np.nonzero(visible_mask)[0]
selected_indices = sl.start + visible_indices[selected]
- self.controller.set_indices_spike_selected(selected_indices)
+ already_selected = self.controller.get_indices_spike_selected()
+ all_selected = np.concatenate([already_selected, selected_indices])
+ self.controller.set_indices_spike_selected(all_selected)
self.notify_spike_selection_changed()
+ def _panel_split(self, event):
+ """
+ Handle split button click in panel mode.
+ """
+ self.split()
def _panel_update_selected_spikes(self):
# handle selected spikes
@@ -426,7 +543,7 @@ def _panel_update_selected_spikes(self):
visible_indices = sl.start + np.nonzero(visible_mask)[0]
selected_indices = np.nonzero(np.isin(visible_indices, selected_spike_indices))[0]
# set selected spikes in scatter plot
- if self.settings["auto_decimate"]:
+ if self.settings["auto_decimate"] and len(selected_indices) > 0:
selected_indices, = np.nonzero(np.isin(self.plotted_inds, selected_spike_indices))
self.scatter_source.selected.indices = list(selected_indices)
else:
@@ -447,3 +564,7 @@ def _panel_on_spike_selection_changed(self):
# update selected spikes
self._panel_update_selected_spikes()
+ def _panel_handle_shortcut(self, event):
+ if event.data == "split":
+ if len(self.controller.get_visible_unit_ids()) == 1:
+ self.split()
\ No newline at end of file
diff --git a/spikeinterface_gui/controller.py b/spikeinterface_gui/controller.py
index ad653e6..b0f3978 100644
--- a/spikeinterface_gui/controller.py
+++ b/spikeinterface_gui/controller.py
@@ -12,9 +12,10 @@
import spikeinterface.qualitymetrics
from spikeinterface.core.sorting_tools import spike_vector_to_indices
from spikeinterface.core.core_tools import check_json
+from spikeinterface.curation import validate_curation_dict
from spikeinterface.widgets.utils import make_units_table_from_analyzer
-from .curation_tools import adding_group, default_label_definitions, empty_curation_data
+from .curation_tools import add_merge, default_label_definitions, empty_curation_data
spike_dtype =[('sample_index', 'int64'), ('unit_index', 'int64'),
('channel_index', 'int64'), ('segment_index', 'int64'),
@@ -260,13 +261,14 @@ def __init__(self, analyzer=None, backend="qt", parent=None, verbose=False, save
spike_vector2 = self.analyzer.sorting.to_spike_vector(concatenated=False)
# this is dict of list because per segment spike_indices[segment_index][unit_id]
+ spike_indices_abs = spike_vector_to_indices(spike_vector2, unit_ids, absolute_index=True)
spike_indices = spike_vector_to_indices(spike_vector2, unit_ids)
# this is flatten
spike_per_seg = [s.size for s in spike_vector2]
# dict[unit_id] -> all indices for this unit across segments
self._spike_index_by_units = {}
# dict[seg_index][unit_id] -> all indices for this unit for one segment
- self._spike_index_by_segment_and_units = spike_indices
+ self._spike_index_by_segment_and_units = spike_indices_abs
for unit_id in unit_ids:
inds = []
for seg_ind in range(num_seg):
@@ -319,7 +321,23 @@ def __init__(self, analyzer=None, backend="qt", parent=None, verbose=False, save
if curation_data is None:
self.curation_data = empty_curation_data.copy()
else:
- self.curation_data = curation_data
+ # validate the curation data
+ format_version = curation_data.get("format_version", None)
+ # assume version 2 if not present
+ if format_version is None:
+ raise ValueError("Curation data format version is missing and is required in the curation data.")
+ try:
+ validate_curation_dict(curation_data)
+ self.curation_data = curation_data
+ except Exception as e:
+ print(f"Invalid curation data. Initializing with empty curation data.\nError: {e}")
+ self.curation_data = empty_curation_data.copy()
+ if curation_data.get("merges") is None:
+ curation_data["merges"] = []
+ if curation_data.get("splits") is None:
+ curation_data["splits"] = []
+ if curation_data.get("removed") is None:
+ curation_data["removed"] = []
self.has_default_quality_labels = False
if "label_definitions" not in self.curation_data:
@@ -337,6 +355,8 @@ def __init__(self, analyzer=None, backend="qt", parent=None, verbose=False, save
print('Curation quality labels are the default ones')
self.has_default_quality_labels = True
+ # this is used to store the active split unit
+ self.active_split = None
def check_is_view_possible(self, view_name):
from .viewlist import possible_class_views
@@ -442,6 +462,13 @@ def set_visible_unit_ids(self, visible_unit_ids):
if len(visible_unit_ids) > lim:
visible_unit_ids = visible_unit_ids[:lim]
self._visible_unit_ids = list(visible_unit_ids)
+ self.active_split = None
+ if len(visible_unit_ids) == 1 and self.curation:
+ # check if unit is split
+ for split in self.curation_data['splits']:
+ if visible_unit_ids[0] == split['unit_id']:
+ self.active_split = split
+ break
def get_visible_unit_ids(self):
"""Get list of visible unit_ids"""
@@ -506,10 +533,21 @@ def get_indices_spike_visible(self):
return self._spike_visible_indices
def get_indices_spike_selected(self):
+ if self.active_split is not None:
+ # select the splitted spikes in the active split
+ split_unit_id = self.active_split['unit_id']
+ spike_inds = self.get_spike_indices(split_unit_id, seg_index=None)
+ split_indices = self.active_split['indices']
+ self._spike_selected_indices = np.array(spike_inds[split_indices], dtype='int64')
return self._spike_selected_indices
-
+
def set_indices_spike_selected(self, inds):
self._spike_selected_indices = np.array(inds)
+ if len(self._spike_selected_indices) == 1:
+ # set time info
+ segment_index = self.spikes['segment_index'][self._spike_selected_indices[0]]
+ sample_index = self.spikes['sample_index'][self._spike_selected_indices[0]]
+ self.set_time(time=sample_index / self.sampling_frequency, segment_index=segment_index)
def get_spike_indices(self, unit_id, seg_index=None):
if seg_index is None:
@@ -668,7 +706,7 @@ def curation_can_be_saved(self):
def construct_final_curation(self):
d = dict()
- d["format_version"] = "1"
+ d["format_version"] = "2"
d["unit_ids"] = self.unit_ids.tolist()
d.update(self.curation_data.copy())
return d
@@ -699,14 +737,14 @@ def make_manual_delete_if_possible(self, removed_unit_ids):
if not self.curation:
return
- all_merged_units = sum(self.curation_data["merge_unit_groups"], [])
+ all_merged_units = sum([m["unit_ids"] for m in self.curation_data["merges"]], [])
for unit_id in removed_unit_ids:
- if unit_id in self.curation_data["removed_units"]:
+ if unit_id in self.curation_data["removed"]:
continue
# TODO: check if unit is already in a merge group
if unit_id in all_merged_units:
continue
- self.curation_data["removed_units"].append(unit_id)
+ self.curation_data["removed"].append(unit_id)
if self.verbose:
print(f"Unit {unit_id} is removed from the curation data")
@@ -718,10 +756,10 @@ def make_manual_restore(self, restore_unit_ids):
return
for unit_id in restore_unit_ids:
- if unit_id in self.curation_data["removed_units"]:
+ if unit_id in self.curation_data["removed"]:
if self.verbose:
print(f"Unit {unit_id} is restored from the curation data")
- self.curation_data["removed_units"].remove(unit_id)
+ self.curation_data["removed"].remove(unit_id)
def make_manual_merge_if_possible(self, merge_unit_ids):
"""
@@ -740,22 +778,75 @@ def make_manual_merge_if_possible(self, merge_unit_ids):
return False
for unit_id in merge_unit_ids:
- if unit_id in self.curation_data["removed_units"]:
+ if unit_id in self.curation_data["removed"]:
+ return False
+
+ new_merges = add_merge(self.curation_data["merges"], merge_unit_ids)
+ self.curation_data["merges"] = new_merges
+ if self.verbose:
+ print(f"Merged unit group: {[str(u) for u in merge_unit_ids]}")
+ return True
+
+ def make_manual_split_if_possible(self, unit_id, indices):
+ """
+ Check if the a unit_id can be split into a new split in the curation_data.
+
+ If unit_id is already in the removed list then the split is skipped.
+ If unit_id is already in some other split then the split is skipped.
+ """
+ if not self.curation:
+ return False
+
+ if unit_id in self.curation_data["removed"]:
+ return False
+
+ # check if unit_id is already in a split
+ for split in self.curation_data["splits"]:
+ if split["unit_id"] == unit_id:
return False
- merged_groups = adding_group(self.curation_data["merge_unit_groups"], merge_unit_ids)
- self.curation_data["merge_unit_groups"] = merged_groups
+
+ new_split = {
+ "unit_id": unit_id,
+ "mode": "indices",
+ "indices": indices
+ }
+ self.curation_data["splits"].append(new_split)
if self.verbose:
- print(f"Merged unit group: {merge_unit_ids}")
+ print(f"Split unit {unit_id} with {len(indices)} spikes")
return True
- def make_manual_restore_merge(self, merge_group_indices):
+ def make_manual_restore_merge(self, merge_indices):
+ if not self.curation:
+ return
+ for merge_index in merge_indices:
+ if self.verbose:
+ print(f"Unmerged {self.curation_data['merges'][merge_index]['unit_ids']}")
+ self.curation_data["merges"].pop(merge_index)
+
+ def make_manual_restore_split(self, split_indices):
if not self.curation:
return
- merge_groups_to_remove = [self.curation_data["merge_unit_groups"][merge_group_index] for merge_group_index in merge_group_indices]
- for merge_group in merge_groups_to_remove:
+ for split_index in split_indices:
if self.verbose:
- print(f"Unmerged merge group {merge_group}")
- self.curation_data["merge_unit_groups"].remove(merge_group)
+ print(f"Unsplitting {self.curation_data['splits'][split_index]['unit_id']}")
+ self.curation_data["splits"].pop(split_index)
+
+ def set_active_split_unit(self, unit_id):
+ """
+ Set the active split unit_id.
+ This is used to set the label for the split unit.
+ """
+ if not self.curation:
+ return
+ if unit_id is None:
+ self.active_split = None
+ else:
+ if unit_id in self.curation_data["removed"]:
+ print(f"Unit {unit_id} is removed, cannot set as active split unit")
+ return
+ active_split = [s for s in self.curation_data["splits"] if s["unit_id"] == unit_id]
+ if len(active_split) == 1:
+ self.active_split = active_split[0]
def get_curation_label_definitions(self):
# give only label definition with exclusive
diff --git a/spikeinterface_gui/crosscorrelogramview.py b/spikeinterface_gui/crosscorrelogramview.py
index 205c566..fcf5321 100644
--- a/spikeinterface_gui/crosscorrelogramview.py
+++ b/spikeinterface_gui/crosscorrelogramview.py
@@ -6,11 +6,10 @@ class CrossCorrelogramView(ViewBase):
_supported_backend = ['qt', 'panel']
_depend_on = ["correlograms"]
_settings = [
- {'name': 'window_ms', 'type': 'float', 'value' : 50. },
- {'name': 'bin_ms', 'type': 'float', 'value' : 1.0 },
- {'name': 'display_axis', 'type': 'bool', 'value' : True },
- {'name': 'max_visible', 'type': 'int', 'value' : 8 },
- ]
+ {'name': 'window_ms', 'type': 'float', 'value' : 50. },
+ {'name': 'bin_ms', 'type': 'float', 'value' : 1.0 },
+ {'name': 'display_axis', 'type': 'bool', 'value' : True },
+ ]
_need_compute = True
def __init__(self, controller=None, parent=None, backend="qt"):
@@ -26,6 +25,36 @@ def _on_settings_changed(self):
def _compute(self):
self.ccg, self.bins = self.controller.compute_correlograms(
self.settings['window_ms'], self.settings['bin_ms'])
+
+ def _compute_split_ccg(self):
+ """
+ This method is used to compute the cross-correlogram for a split unit.
+ It is called when the user selects a split unit in the controller.
+ """
+ from spikeinterface import NumpySorting
+ from spikeinterface.postprocessing import compute_correlograms
+
+ if self.controller.active_split is None:
+ raise ValueError("No active split unit selected.")
+
+ split_unit_id = self.controller.active_split["unit_id"]
+ spike_inds = self.controller.get_spike_indices(split_unit_id, seg_index=None)
+ split_indices = self.controller.active_split['indices']
+ spikes_split_unit = self.controller.spikes[spike_inds]
+ unit_index = spikes_split_unit[0]["unit_index"]
+ # change unit_index for split indices
+ spikes_split_unit["unit_index"][split_indices] = unit_index + 1
+ split_sorting = NumpySorting(
+ spikes=spikes_split_unit,
+ sampling_frequency=self.controller.sampling_frequency,
+ unit_ids=[f"{split_unit_id}-0", f"{split_unit_id}-1"]
+ )
+ ccg, bins = compute_correlograms(
+ split_sorting,
+ window_ms=self.settings['window_ms'],
+ bin_ms=self.settings['bin_ms']
+ )
+ return ccg, bins
## Qt ##
@@ -51,18 +80,32 @@ def _qt_refresh(self):
return
visible_unit_ids = self.controller.get_visible_unit_ids()
- visible_unit_ids = visible_unit_ids[:self.settings['max_visible']]
-
- n = len(visible_unit_ids)
-
- unit_ids = list(self.controller.unit_ids)
+ if self.controller.active_split is None:
+ n = len(visible_unit_ids)
+ unit_ids = list(self.controller.unit_ids)
+ colors = {
+ unit_id: self.get_unit_color(unit_id) for unit_id in visible_unit_ids
+ }
+ ccg = self.ccg
+ bins = self.bins
+ else:
+ split_unit_id = visible_unit_ids[0]
+ n = 2
+ unit_ids = [f"{split_unit_id}-0", f"{split_unit_id}-1"]
+ visible_unit_ids = unit_ids
+ ccg, bins = self._compute_split_ccg()
+ split_unit_color = self.get_unit_color(split_unit_id)
+ colors = {
+ f"{split_unit_id}-0": split_unit_color,
+ f"{split_unit_id}-1": split_unit_color,
+ }
for r in range(n):
for c in range(r, n):
i = unit_ids.index(visible_unit_ids[r])
j = unit_ids.index(visible_unit_ids[c])
- count = self.ccg[i, j, :]
+ count = ccg[i, j, :]
plot = pg.PlotItem()
if not self.settings['display_axis']:
@@ -71,16 +114,15 @@ def _qt_refresh(self):
if r==c:
unit_id = visible_unit_ids[r]
- color = self.get_unit_color(unit_id)
+ color = colors[unit_id]
else:
color = (120,120,120,120)
- curve = pg.PlotCurveItem(self.bins, count, stepMode='center', fillLevel=0, brush=color, pen=color)
+ curve = pg.PlotCurveItem(bins, count, stepMode='center', fillLevel=0, brush=color, pen=color)
plot.addItem(curve)
self.grid.addItem(plot, row=r, col=c)
## panel ##
-
def _panel_make_layout(self):
import panel as pn
import bokeh.plotting as bpl
@@ -115,31 +157,40 @@ def _panel_refresh(self):
return
visible_unit_ids = self.controller.get_visible_unit_ids()
-
- # Show warning above the plot if too many visible units
- if len(visible_unit_ids) > self.settings['max_visible']:
- warning_msg = f"Only showing first {self.settings['max_visible']} units out of {len(visible_unit_ids)} visible units"
- insert_warning(self, warning_msg)
- self.is_warning_active = True
- return
- if self.is_warning_active:
- clear_warning(self)
- self.is_warning_active = False
-
- visible_unit_ids = visible_unit_ids[:self.settings['max_visible']]
-
- n = len(visible_unit_ids)
- unit_ids = list(self.controller.unit_ids)
+ if self.controller.active_split is None:
+ n = len(visible_unit_ids)
+ unit_ids = list(self.controller.unit_ids)
+ colors = {
+ unit_id: self.get_unit_color(unit_id) for unit_id in visible_unit_ids
+ }
+ ccg = self.ccg
+ bins = self.bins
+ else:
+ split_unit_id = visible_unit_ids[0]
+ n = 2
+ unit_ids = [f"{split_unit_id}-0", f"{split_unit_id}-1"]
+ visible_unit_ids = unit_ids
+ ccg, bins = self._compute_split_ccg()
+ split_unit_color = self.get_unit_color(split_unit_id)
+ colors = {
+ f"{split_unit_id}-0": split_unit_color,
+ f"{split_unit_id}-1": split_unit_color,
+ }
+
+ first_fig = None
for r in range(n):
row_plots = []
for c in range(r, n):
-
i = unit_ids.index(visible_unit_ids[r])
j = unit_ids.index(visible_unit_ids[c])
- count = self.ccg[i, j, :]
+ count = ccg[i, j, :]
# Create Bokeh figure
- p = bpl.figure(
+ if first_fig is not None:
+ extra_kwargs = dict(x_range=first_fig.x_range)
+ else:
+ extra_kwargs = dict()
+ fig = bpl.figure(
width=250,
height=250,
tools="pan,wheel_zoom,reset",
@@ -148,29 +199,32 @@ def _panel_refresh(self):
background_fill_color=_bg_color,
border_fill_color=_bg_color,
outline_line_color="white",
+ **extra_kwargs,
)
- p.toolbar.logo = None
+ fig.toolbar.logo = None
# Get color from controller
if r == c:
unit_id = visible_unit_ids[r]
- color = self.get_unit_color(unit_id)
+ color = colors[unit_id]
fill_alpha = 0.7
else:
color = "lightgray"
fill_alpha = 0.4
- p.quad(
+ fig.quad(
top=count,
bottom=0,
- left=self.bins[:-1],
- right=self.bins[1:],
+ left=bins[:-1],
+ right=bins[1:],
fill_color=color,
line_color=color,
alpha=fill_alpha,
)
+ if first_fig is None:
+ first_fig = fig
- row_plots.append(p)
+ row_plots.append(fig)
# Fill row with None for proper spacing
full_row = [None] * r + row_plots + [None] * (n - len(row_plots))
self.plots.append(full_row)
diff --git a/spikeinterface_gui/curation_tools.py b/spikeinterface_gui/curation_tools.py
index 2d92b41..d96964b 100644
--- a/spikeinterface_gui/curation_tools.py
+++ b/spikeinterface_gui/curation_tools.py
@@ -11,27 +11,29 @@
empty_curation_data = {
"manual_labels": [],
- "merge_unit_groups": [],
- "removed_units": []
+ "merges": [],
+ "splits": [],
+ "removes": []
}
-def adding_group(previous_groups, new_group):
+def add_merge(previous_merges, new_merge_unit_ids):
# this is to ensure that np.str_ types are rendered as str
- to_merge = [np.array(new_group).tolist()]
+ to_merge = [np.array(new_merge_unit_ids).tolist()]
unchanged = []
- for c_prev in previous_groups:
+ for c_prev in previous_merges:
is_unaffected = True
-
- for c_new in new_group:
- if c_new in c_prev:
+ c_prev_unit_ids = c_prev["unit_ids"]
+ for c_new in new_merge_unit_ids:
+ if c_new in c_prev_unit_ids:
is_unaffected = False
- to_merge.append(c_prev)
+ to_merge.append(c_prev_unit_ids)
break
if is_unaffected:
- unchanged.append(c_prev)
- new_merge_group = [sum(to_merge, [])]
- new_merge_group.extend(unchanged)
- # Ensure the unicity
- new_merge_group = [list(set(gp)) for gp in new_merge_group]
- return new_merge_group
+ unchanged.append(c_prev_unit_ids)
+
+ new_merge_units = [sum(to_merge, [])]
+ new_merge_units.extend(unchanged)
+ # Ensure the uniqueness
+ new_merges = [{"unit_ids": list(set(gp))} for gp in new_merge_units]
+ return new_merges
diff --git a/spikeinterface_gui/curationview.py b/spikeinterface_gui/curationview.py
index b6ed141..ed496b3 100644
--- a/spikeinterface_gui/curationview.py
+++ b/spikeinterface_gui/curationview.py
@@ -28,7 +28,7 @@ def restore_units(self):
self.notify_manual_curation_updated()
self.refresh()
- def unmerge_groups(self):
+ def unmerge(self):
if self.backend == 'qt':
merge_indices = self._qt_get_merge_table_row()
else:
@@ -38,6 +38,16 @@ def unmerge_groups(self):
self.notify_manual_curation_updated()
self.refresh()
+ def unsplit(self):
+ if self.backend == 'qt':
+ split_indices = self._qt_get_split_table_row()
+ else:
+ split_indices = self._panel_get_split_table_row()
+ if split_indices is not None:
+ self.controller.make_manual_restore_split(split_indices)
+ self.notify_manual_curation_updated()
+ self.refresh()
+
## Qt
def _qt_make_layout(self):
from .myqt import QT
@@ -60,6 +70,26 @@ def _qt_make_layout(self):
h = QT.QHBoxLayout()
self.layout.addLayout(h)
+
+ v = QT.QVBoxLayout()
+ h.addLayout(v)
+ v.addWidget(QT.QLabel("Deleted"))
+ self.table_delete = QT.QTableWidget(selectionMode=QT.QAbstractItemView.SingleSelection,
+ selectionBehavior=QT.QAbstractItemView.SelectRows)
+ v.addWidget(self.table_delete)
+ self.table_delete.setContextMenuPolicy(QT.Qt.CustomContextMenu)
+ self.table_delete.customContextMenuRequested.connect(self._qt_open_context_menu_delete)
+ self.table_delete.itemSelectionChanged.connect(self._qt_on_item_selection_changed_delete)
+
+
+
+ self.delete_menu = QT.QMenu()
+ act = self.delete_menu.addAction('Restore')
+ act.triggered.connect(self.restore_units)
+ shortcut_restore = QT.QShortcut(self.qt_widget)
+ shortcut_restore.setKey(QT.QKeySequence("ctrl+r"))
+ shortcut_restore.activated.connect(self.restore_units)
+
v = QT.QVBoxLayout()
h.addLayout(v)
v.addWidget(QT.QLabel("Merges"))
@@ -73,35 +103,32 @@ def _qt_make_layout(self):
self.table_merge.itemSelectionChanged.connect(self._qt_on_item_selection_changed_merge)
self.merge_menu = QT.QMenu()
- act = self.merge_menu.addAction('Remove merge group')
- act.triggered.connect(self.unmerge_groups)
+ act = self.merge_menu.addAction('Remove merge')
+ act.triggered.connect(self.unmerge)
shortcut_unmerge = QT.QShortcut(self.qt_widget)
shortcut_unmerge.setKey(QT.QKeySequence("ctrl+u"))
- shortcut_unmerge.activated.connect(self.unmerge_groups)
-
+ shortcut_unmerge.activated.connect(self.unmerge)
v = QT.QVBoxLayout()
h.addLayout(v)
- v.addWidget(QT.QLabel("Deleted"))
- self.table_delete = QT.QTableWidget(selectionMode=QT.QAbstractItemView.SingleSelection,
+ v.addWidget(QT.QLabel("Splits"))
+ self.table_split = QT.QTableWidget(selectionMode=QT.QAbstractItemView.SingleSelection,
selectionBehavior=QT.QAbstractItemView.SelectRows)
- v.addWidget(self.table_delete)
- self.table_delete.setContextMenuPolicy(QT.Qt.CustomContextMenu)
- self.table_delete.customContextMenuRequested.connect(self._qt_open_context_menu_delete)
- self.table_delete.itemSelectionChanged.connect(self._qt_on_item_selection_changed_delete)
-
-
- self.delete_menu = QT.QMenu()
- act = self.delete_menu.addAction('Restore')
- act.triggered.connect(self.restore_units)
- shortcut_restore = QT.QShortcut(self.qt_widget)
- shortcut_restore.setKey(QT.QKeySequence("ctrl+r"))
- shortcut_restore.activated.connect(self.restore_units)
+ v.addWidget(self.table_split)
+ self.table_split.setContextMenuPolicy(QT.Qt.CustomContextMenu)
+ self.table_split.customContextMenuRequested.connect(self._qt_open_context_menu_split)
+ self.table_split.itemSelectionChanged.connect(self._qt_on_item_selection_changed_split)
+ self.split_menu = QT.QMenu()
+ act = self.split_menu.addAction('Remove split')
+ act.triggered.connect(self.unsplit)
+ shortcut_unsplit = QT.QShortcut(self.qt_widget)
+ shortcut_unsplit.setKey(QT.QKeySequence("ctrl+x"))
+ shortcut_unsplit.activated.connect(self.unsplit)
def _qt_refresh(self):
from .myqt import QT
# Merged
- merged_units = self.controller.curation_data["merge_unit_groups"]
+ merged_units = [m["unit_ids"] for m in self.controller.curation_data["merges"]]
self.table_merge.clear()
self.table_merge.setRowCount(len(merged_units))
self.table_merge.setColumnCount(1)
@@ -114,8 +141,8 @@ def _qt_refresh(self):
for i in range(self.table_merge.columnCount()):
self.table_merge.resizeColumnToContents(i)
- ## deleted
- removed_units = self.controller.curation_data["removed_units"]
+ # Removed
+ removed_units = self.controller.curation_data["removed"]
self.table_delete.clear()
self.table_delete.setRowCount(len(removed_units))
self.table_delete.setColumnCount(1)
@@ -133,6 +160,24 @@ def _qt_refresh(self):
item.unit_id = unit_id
self.table_delete.resizeColumnToContents(0)
+ # Splits
+ splits = self.controller.curation_data["splits"]
+ self.table_split.clear()
+ self.table_split.setRowCount(len(splits))
+ self.table_split.setColumnCount(1)
+ self.table_split.setHorizontalHeaderLabels(["Split units"])
+ self.table_split.setSortingEnabled(False)
+ for i, split in enumerate(splits):
+ unit_id = split["unit_id"]
+ num_indices = len(split["indices"])
+ num_spikes = self.controller.num_spikes[unit_id]
+ num_splits = f"({num_indices}-{num_spikes - num_indices})"
+ item = QT.QTableWidgetItem(f"{unit_id} {num_splits}")
+ item.setFlags(QT.Qt.ItemIsEnabled|QT.Qt.ItemIsSelectable)
+ self.table_split.setItem(i, 0, item)
+ item.unit_id = unit_id
+ self.table_split.resizeColumnToContents(0)
+
def _qt_get_delete_table_selection(self):
@@ -148,12 +193,22 @@ def _qt_get_merge_table_row(self):
return None
else:
return [s.row() for s in selected_items]
+
+ def _qt_get_split_table_row(self):
+ selected_items = self.table_split.selectedItems()
+ if len(selected_items) == 0:
+ return None
+ else:
+ return [s.row() for s in selected_items]
def _qt_open_context_menu_delete(self):
self.delete_menu.popup(self.qt_widget.cursor().pos())
def _qt_open_context_menu_merge(self):
self.merge_menu.popup(self.qt_widget.cursor().pos())
+
+ def _qt_open_context_menu_split(self):
+ self.split_menu.popup(self.qt_widget.cursor().pos())
def _qt_on_item_selection_changed_merge(self):
if len(self.table_merge.selectedIndexes()) == 0:
@@ -161,16 +216,28 @@ def _qt_on_item_selection_changed_merge(self):
dtype = self.controller.unit_ids.dtype
ind = self.table_merge.selectedIndexes()[0].row()
- visible_unit_ids = self.controller.curation_data["merge_unit_groups"][ind]
+ visible_unit_ids = [m["unit_ids"] for m in self.controller.curation_data["merges"]][ind]
visible_unit_ids = [dtype.type(unit_id) for unit_id in visible_unit_ids]
self.controller.set_visible_unit_ids(visible_unit_ids)
self.notify_unit_visibility_changed()
+ def _qt_on_item_selection_changed_split(self):
+ if len(self.table_split.selectedIndexes()) == 0:
+ return
+
+ dtype = self.controller.unit_ids.dtype
+ ind = self.table_split.selectedIndexes()[0].row()
+ split_unit_str = self.table_split.item(ind, 0).text()
+ split_unit_id = dtype.type(split_unit_str.split(" ")[0])
+ self.controller.set_visible_unit_ids([split_unit_id])
+ self.controller.set_active_split_unit(split_unit_id)
+ self.notify_unit_visibility_changed()
+
def _qt_on_item_selection_changed_delete(self):
if len(self.table_delete.selectedIndexes()) == 0:
return
ind = self.table_delete.selectedIndexes()[0].row()
- unit_id = self.controller.curation_data["removed_units"][ind]
+ unit_id = self.controller.curation_data["removed"][ind]
self.controller.set_all_unit_visibility_off()
# convert to the correct type
unit_id = self.controller.unit_ids.dtype.type(unit_id)
@@ -216,39 +283,56 @@ def _panel_make_layout(self):
pn.extension("tabulator")
# Create dataframe
- merge_df = pd.DataFrame({"merge_groups": []})
- delete_df = pd.DataFrame({"deleted_unit_id": []})
+ delete_df = pd.DataFrame({"removed": []})
+ merge_df = pd.DataFrame({"merges": []})
+ split_df = pd.DataFrame({"splits": []})
# Create tables
+ self.table_delete = SelectableTabulator(
+ delete_df,
+ show_index=False,
+ disabled=True,
+ sortable=False,
+ formatters={"removed": "plaintext"},
+ sizing_mode="stretch_width",
+ # SelectableTabulator functions
+ parent_view=self,
+ # refresh_table_function=self.refresh,
+ conditional_shortcut=self._conditional_refresh_delete,
+ column_callbacks={"removed": self._panel_on_deleted_col},
+ )
self.table_merge = SelectableTabulator(
merge_df,
show_index=False,
disabled=True,
sortable=False,
- formatters={"merge_groups": "plaintext"},
+ selectable=1,
+ formatters={"merges": "plaintext"},
sizing_mode="stretch_width",
# SelectableTabulator functions
parent_view=self,
# refresh_table_function=self.refresh,
conditional_shortcut=self._conditional_refresh_merge,
- column_callbacks={"merge_groups": self._panel_on_merged_col},
+ column_callbacks={"merges": self._panel_on_merged_col},
)
- self.table_delete = SelectableTabulator(
- delete_df,
+ self.table_split = SelectableTabulator(
+ split_df,
show_index=False,
disabled=True,
sortable=False,
- formatters={"deleted_unit_id": "plaintext"},
+ selectable=1,
+ formatters={"splits": "plaintext"},
sizing_mode="stretch_width",
# SelectableTabulator functions
parent_view=self,
# refresh_table_function=self.refresh,
- conditional_shortcut=self._conditional_refresh_delete,
- column_callbacks={"deleted_unit_id": self._panel_on_deleted_col},
+ conditional_shortcut=self._conditional_refresh_split,
+ column_callbacks={"splits": self._panel_on_split_col},
)
self.table_delete.param.watch(self._panel_update_unit_visibility, "selection")
self.table_merge.param.watch(self._panel_update_unit_visibility, "selection")
+ self.table_split.param.watch(self._panel_update_unit_visibility, "selection")
# Create buttons
save_button = pn.widgets.Button(
@@ -277,7 +361,7 @@ def _panel_make_layout(self):
button_type="primary",
height=30
)
- remove_merge_button.on_click(self._panel_unmerge_groups)
+ remove_merge_button.on_click(self._panel_unmerge)
submit_button = pn.widgets.Button(
name="Submit to parent",
@@ -306,12 +390,14 @@ def _panel_make_layout(self):
shortcuts = [
KeyboardShortcut(name="restore", key="r", ctrlKey=True),
KeyboardShortcut(name="unmerge", key="u", ctrlKey=True),
+ KeyboardShortcut(name="unsplit", key="x", ctrlKey=True),
]
shortcuts_component = KeyboardShortcuts(shortcuts=shortcuts)
shortcuts_component.on_msg(self._panel_handle_shortcut)
# Create main layout with proper sizing
- sections = pn.Row(self.table_merge, self.table_delete, sizing_mode="stretch_width")
+ sections = pn.Row(self.table_delete, self.table_merge, self.table_split,
+ sizing_mode="stretch_width")
self.layout = pn.Column(
save_sections,
buttons_curate,
@@ -331,47 +417,68 @@ def _panel_make_layout(self):
def _panel_refresh(self):
import pandas as pd
- # Merged
- merged_units = self.controller.curation_data["merge_unit_groups"]
+ ## deleted
+ removed_units = self.controller.curation_data["removed"]
+ removed_units = [str(unit_id) for unit_id in removed_units]
+ df = pd.DataFrame({"removed": removed_units})
+ self.table_delete.value = df
+ self.table_delete.selection = []
+
+ # Merged
+ merged_units = [m["unit_ids"] for m in self.controller.curation_data["merges"]]
# for visualization, we make all row entries strings
merged_units_str = []
for group in merged_units:
# convert to string
group = [str(unit_id) for unit_id in group]
merged_units_str.append(" - ".join(group))
- df = pd.DataFrame({"merge_groups": merged_units_str})
+ df = pd.DataFrame({"merges": merged_units_str})
self.table_merge.value = df
self.table_merge.selection = []
- ## deleted
- removed_units = self.controller.curation_data["removed_units"]
- removed_units = [str(unit_id) for unit_id in removed_units]
- df = pd.DataFrame({"deleted_unit_id": removed_units})
- self.table_delete.value = df
- self.table_delete.selection = []
+ # Splits
+ split_units_str = []
+ num_spikes = self.controller.num_spikes
+ for split in self.controller.curation_data["splits"]:
+ unit_id = split["unit_id"]
+ num_indices = len(split["indices"])
+ num_splits = f"({num_indices}-{num_spikes[unit_id] - num_indices})"
+ split_units_str.append(f"{unit_id} {num_splits}")
+ df = pd.DataFrame({"splits": split_units_str})
+ self.table_split.value = df
+ self.table_split.selection = []
def _panel_update_unit_visibility(self, event):
unit_dtype = self.controller.unit_ids.dtype
if self.active_table == "delete":
- visible_unit_ids = self.table_delete.value["deleted_unit_id"].values[self.table_delete.selection].tolist()
+ visible_unit_ids = self.table_delete.value["removed"].values[self.table_delete.selection].tolist()
visible_unit_ids = [unit_dtype.type(unit_id) for unit_id in visible_unit_ids]
self.controller.set_visible_unit_ids(visible_unit_ids)
elif self.active_table == "merge":
- merge_groups = self.table_merge.value["merge_groups"].values[self.table_merge.selection].tolist()
+ merge_groups = self.table_merge.value["merges"].values[self.table_merge.selection].tolist()
# self.controller.set_all_unit_visibility_off()
visible_unit_ids = []
for merge_group in merge_groups:
merge_unit_ids = [unit_dtype.type(unit_id) for unit_id in merge_group.split(" - ")]
visible_unit_ids.extend(merge_unit_ids)
self.controller.set_visible_unit_ids(visible_unit_ids)
+ elif self.active_table == "split":
+ split_unit_str = self.table_split.value["splits"].values[self.table_split.selection].tolist()
+ if len(split_unit_str) == 1:
+ split_unit_str = split_unit_str[0]
+ split_unit = split_unit_str.split(" ")[0]
+ # self.controller.set_all_unit_visibility_off()
+ split_unit = unit_dtype.type(split_unit)
+ self.controller.set_visible_unit_ids([split_unit])
+ self.controller.set_active_split_unit(split_unit)
self.notify_unit_visibility_changed()
def _panel_restore_units(self, event):
self.restore_units()
- def _panel_unmerge_groups(self, event):
- self.unmerge_groups()
+ def _panel_unmerge(self, event):
+ self.unmerge()
def _panel_save_in_analyzer(self, event):
self.save_in_analyzer()
@@ -392,7 +499,7 @@ def _panel_get_delete_table_selection(self):
if len(selected_items) == 0:
return None
else:
- return self.table_delete.value["deleted_unit_id"].values[selected_items].tolist()
+ return self.table_delete.value["removed"].values[selected_items].tolist()
def _panel_get_merge_table_row(self):
selected_items = self.table_merge.selection
@@ -401,11 +508,20 @@ def _panel_get_merge_table_row(self):
else:
return selected_items
+ def _panel_get_split_table_row(self):
+ selected_items = self.table_split.selection
+ if len(selected_items) == 0:
+ return None
+ else:
+ return selected_items
+
def _panel_handle_shortcut(self, event):
if event.data == "restore":
self.restore_units()
elif event.data == "unmerge":
- self.unmerge_groups()
+ self.unmerge()
+ elif event.data == "unsplit":
+ pass
def _panel_submit_to_parent(self, event):
"""Send the curation data to the parent window"""
@@ -441,10 +557,17 @@ def _panel_submit_to_parent(self, event):
def _panel_on_deleted_col(self, row):
self.active_table = "delete"
self.table_merge.selection = []
+ self.table_split.selection = []
def _panel_on_merged_col(self, row):
self.active_table = "merge"
self.table_delete.selection = []
+ self.table_split.selection = []
+
+ def _panel_on_split_col(self, row):
+ self.active_table = "split"
+ self.table_delete.selection = []
+ self.table_merge.selection = []
def _conditional_refresh_merge(self):
# Check if the view is active before refreshing
@@ -460,6 +583,13 @@ def _conditional_refresh_delete(self):
else:
return False
+ def _conditional_refresh_split(self):
+ # Check if the view is active before refreshing
+ if self.is_view_active() and self.active_table == "split":
+ return True
+ else:
+ return False
+
CurationView._gui_help_txt = """
## Curation View
diff --git a/spikeinterface_gui/mergeview.py b/spikeinterface_gui/mergeview.py
index 8f5fa76..6578ccd 100644
--- a/spikeinterface_gui/mergeview.py
+++ b/spikeinterface_gui/mergeview.py
@@ -80,7 +80,7 @@ def get_table_data(self, include_deleted=False):
unit_ids = list(self.controller.unit_ids)
for group_ids in self.proposed_merge_unit_groups:
if not include_deleted and self.controller.curation:
- deleted_unit_ids = self.controller.curation_data["removed_units"]
+ deleted_unit_ids = self.controller.curation_data["removed"]
if any(unit_id in deleted_unit_ids for unit_id in group_ids):
continue
diff --git a/spikeinterface_gui/ndscatterview.py b/spikeinterface_gui/ndscatterview.py
index 1bafa92..ebd5102 100644
--- a/spikeinterface_gui/ndscatterview.py
+++ b/spikeinterface_gui/ndscatterview.py
@@ -54,6 +54,8 @@ def __init__(self, controller=None, parent=None, backend="qt"):
self.tour_step = 0
self.auto_update_limit = True
+ self._lasso_vertices = []
+
ViewBase.__init__(self, controller=controller, parent=parent, backend=backend)
@@ -107,9 +109,6 @@ def random_projection(self):
# here we don't want to update the components because it's been done already!
self.refresh(update_components=False)
- def on_spike_selection_changed(self):
- self.refresh()
-
def on_unit_visibility_changed(self):
self.random_projection()
@@ -121,6 +120,19 @@ def apply_dot(self, data):
return projected
def get_plotting_data(self, return_spike_indices=False):
+ """
+ Get the data to plot in the scatter plot.
+
+ Parameters
+ ----------
+ return_spike_indices : bool, default: False
+ If True, also return the indices of the spikes for each unit.
+
+ Returns
+ -------
+ _type_
+ _description_
+ """
scatter_x = {}
scatter_y = {}
all_limits = []
@@ -134,7 +146,7 @@ def get_plotting_data(self, return_spike_indices=False):
projected_2d = projected[:, :2]
all_limits.append(float(np.percentile(np.abs(projected_2d), 95) * 2.))
if return_spike_indices:
- spike_indices[unit_id] = mask
+ spike_indices[unit_id] = self.random_spikes_indices[mask]
if len(all_limits) > 0 and self.auto_update_limit:
self.limit = max(all_limits)
@@ -323,7 +335,9 @@ def _qt_refresh(self, update_components=True, update_colors=True):
self.plot.setYRange(-self.limit, self.limit)
# self.graphicsview.repaint()
-
+
+ def _qt_on_spike_selection_changed(self):
+ self.refresh()
def _qt_start_stop_tour(self, checked):
if checked:
@@ -346,17 +360,31 @@ def _qt_on_lasso_drawing(self, points):
points = np.array(points)
self.lasso.setData(points[:, 0], points[:, 1])
- def _qt_on_lasso_finished(self, points):
+ def _qt_on_lasso_finished(self, points, shift_held=False):
self.lasso.setData([], [])
vertices = np.array(points)
+ self._lasso_vertices.append(vertices)
# inside lasso and visibles
ind_visibles, = np.nonzero(np.isin(self.random_spikes_indices, self.controller.get_indices_spike_visible()))
projected = self.apply_dot(self.data[ind_visibles, :])
inside = inside_poly(projected, vertices)
- inds = self.random_spikes_indices[ind_visibles[inside]]
- self.controller.set_indices_spike_selected(inds)
+ new_selected_inds = self.random_spikes_indices[ind_visibles[inside]]
+
+ if shift_held:
+ # Extend existing selection
+ current_selection = self.controller.get_indices_spike_selected()
+ if len(current_selection) > 0 and len(new_selected_inds) > 0:
+ extended_selection = np.unique(np.concatenate([current_selection, new_selected_inds]))
+ self.controller.set_indices_spike_selected(extended_selection)
+ elif len(new_selected_inds) > 0:
+ # No current selection, just use new selection
+ self.controller.set_indices_spike_selected(new_selected_inds)
+ # If no new selection and shift held, keep existing selection unchanged
+ else:
+ # Replace selection (original behavior)
+ self.controller.set_indices_spike_selected(new_selected_inds)
self.refresh()
self.notify_spike_selection_changed()
@@ -369,7 +397,7 @@ def _panel_make_layout(self):
from bokeh.models import ColumnDataSource, LassoSelectTool, Range1d
from bokeh.events import MouseWheel
- from .utils_panel import _bg_color, slow_lasso
+ from .utils_panel import _bg_color
self.lasso_tool = LassoSelectTool()
@@ -393,12 +421,9 @@ def _panel_make_layout(self):
# remove the bokeh mousewheel zoom and keep only this one
self.scatter_fig.on_event(MouseWheel, self._panel_gain_zoom)
- self.scatter_source = ColumnDataSource({"x": [], "y": [], "color": []})
- self.scatter_select_source = ColumnDataSource({"x": [], "y": [], "color": []})
+ self.scatter_source = ColumnDataSource({"x": [], "y": [], "color": [], "spike_indices": []})
self.scatter = self.scatter_fig.scatter("x", "y", source=self.scatter_source, size=3, color="color", alpha=0.7)
- self.scatter_select = self.scatter_fig.scatter("x", "y", source=self.scatter_select_source,
- size=11, color="white", alpha=0.8)
# toolbar
self.next_face_button = pn.widgets.Button(name="Next Face", button_type="default", width=100)
@@ -410,14 +435,14 @@ def _panel_make_layout(self):
self.random_tour_button = pn.widgets.Toggle(name="Random Tour", button_type="default", width=100)
self.random_tour_button.param.watch(self._panel_start_stop_tour, "value")
- # self.select_toggle_button = pn.widgets.Toggle(name="Select")
- # self.select_toggle_button.param.watch(self._panel_on_select_button, 'value')
+ self.select_toggle_button = pn.widgets.Toggle(name="Select")
+ self.select_toggle_button.param.watch(self._panel_on_select_button, 'value')
- # TODO: add a lasso selection
- # slow_lasso(self.scatter_source, self._on_panel_lasso_selected)
+ self.scatter_fig.on_event('selectiongeometry', self._on_panel_selection_geometry)
self.toolbar = pn.Row(
- self.next_face_button, self.random_button, self.random_tour_button, sizing_mode="stretch_both",
+ self.next_face_button, self.random_button, self.random_tour_button, self.select_toggle_button,
+ sizing_mode="stretch_both",
styles={"flex": "0.15"}
)
@@ -433,15 +458,16 @@ def _panel_make_layout(self):
def _panel_refresh(self, update_components=True, update_colors=True):
if update_components:
self.update_selected_components()
- scatter_x, scatter_y, selected_scatter_x, selected_scatter_y = self.get_plotting_data()
+ scatter_x, scatter_y, _, _, spike_indices = self.get_plotting_data(return_spike_indices=True)
- xs, ys, colors = [], [], []
+ xs, ys, colors, plotted_spike_indices = [], [], [], []
for unit_id in scatter_x.keys():
color = self.get_unit_color(unit_id)
xs.extend(scatter_x[unit_id])
ys.extend(scatter_y[unit_id])
if update_colors:
colors.extend([color] * len(scatter_x[unit_id]))
+ plotted_spike_indices.extend(spike_indices[unit_id])
if not update_colors:
colors = self.scatter_source.data.get("color")
@@ -450,23 +476,20 @@ def _panel_refresh(self, update_components=True, update_colors=True):
"x": xs,
"y": ys,
"color": colors,
+ "spike_indices": plotted_spike_indices
}
- self.scatter_select_source.data = {
- "x": selected_scatter_x,
- "y": selected_scatter_y,
- }
-
- # TODO: handle selection with lasso
- # mask = np.isin(self.random_spikes_indices, self.controller.get_indices_spike_selected())
- # selected_indices = np.flatnonzero(mask)
- # self.scatter_source.selected.indices = selected_indices.tolist()
-
self.scatter_fig.x_range.start = -self.limit
self.scatter_fig.x_range.end = self.limit
self.scatter_fig.y_range.start = -self.limit
self.scatter_fig.y_range.end = self.limit
+ def _panel_on_spike_selection_changed(self):
+ # handle selection with lasso
+ plotted_spike_indices = self.scatter_source.data.get("spike_indices", [])
+ ind_selected, = np.nonzero(np.isin(plotted_spike_indices, self.controller.get_indices_spike_selected()))
+ self.scatter_source.selected.indices = ind_selected
+
def _panel_gain_zoom(self, event):
from bokeh.models import Range1d
@@ -501,24 +524,37 @@ def _panel_on_select_button(self, event):
else:
self.scatter_fig.toolbar.active_drag = None
self.scatter_source.selected.indices = []
- # self._on_panel_lasso_selected(None, None, None)
-
-
- # TODO: Handle lasso selection and updates
- # def _on_panel_lasso_selected(self, attr, old, new):
- # if len(self.scatter_source.selected.indices) == 0:
- # self.notify_spike_selection_changed()
- # self.refresh()
- # return
-
- # # inside lasso and visibles
- # inside = self.scatter_source.selected.indices
-
- # inds = self.random_spikes_indices[inside]
- # self.controller.set_indices_spike_selected(inds)
- # self.refresh()
- # self.notify_spike_selection_changed()
+ def _on_panel_selection_geometry(self, event):
+ """
+ Handle SelectionGeometry event to capture lasso polygon vertices.
+ """
+ if event.final:
+ xs = np.array(event.geometry["x"])
+ ys = np.array(event.geometry["y"])
+ polygon = np.column_stack((xs, ys))
+
+ selected = self.scatter_source.selected.indices
+
+ if len(selected) == 0:
+ self.notify_spike_selection_changed()
+ self.refresh()
+ return
+
+ if len(selected) > self._current_selected:
+ self._current_selected = len(selected)
+ # Store the current polygon for the current segment
+ self._lasso_vertices.append(polygon)
+ else:
+ self._lasso_vertices = [polygon]
+
+ # inside lasso and visibles
+ ind_visibles, = np.nonzero(np.isin(self.random_spikes_indices, self.controller.get_indices_spike_visible()))
+ inds = self.random_spikes_indices[ind_visibles[selected]]
+ self.controller.set_indices_spike_selected(inds)
+
+ self.notify_spike_selection_changed()
+ self.refresh()
def inside_poly(data, vertices):
diff --git a/spikeinterface_gui/tests/test_mainwindow_panel.py b/spikeinterface_gui/tests/test_mainwindow_panel.py
index e20e231..0cc96d5 100644
--- a/spikeinterface_gui/tests/test_mainwindow_panel.py
+++ b/spikeinterface_gui/tests/test_mainwindow_panel.py
@@ -32,7 +32,7 @@ def teardown_module():
clean_all(test_folder)
-def test_mainwindow(start_app=False, verbose=True, curation=False, only_some_extensions=False, from_si_api=False):
+def test_mainwindow(start_app=False, verbose=True, curation=False, only_some_extensions=False, from_si_api=False, port=0):
analyzer = load_sorting_analyzer(test_folder / "sorting_analyzer")
@@ -71,7 +71,7 @@ def test_mainwindow(start_app=False, verbose=True, curation=False, only_some_ext
layout_preset='default',
# skip_extensions=["waveforms", "principal_components", "template_similarity", "spike_amplitudes"],
# address="10.69.168.40",
- # port=5000,
+ port=port,
)
return win
@@ -100,9 +100,9 @@ def test_launcher(verbose=True):
if not test_folder.is_dir():
setup_module()
- # win = test_mainwindow(start_app=True, verbose=True, curation=True)
+ win = test_mainwindow(start_app=True, verbose=True, curation=True, port=5006)
- test_launcher(verbose=True)
+ # test_launcher(verbose=True)
# TO RUN with panel serve:
# win = test_mainwindow(start_app=False, verbose=True, curation=True)
diff --git a/spikeinterface_gui/tests/test_mainwindow_qt.py b/spikeinterface_gui/tests/test_mainwindow_qt.py
index 14e65d2..d29f8ca 100644
--- a/spikeinterface_gui/tests/test_mainwindow_qt.py
+++ b/spikeinterface_gui/tests/test_mainwindow_qt.py
@@ -15,7 +15,7 @@
test_folder = Path(__file__).parent / 'my_dataset_small'
-test_folder = Path(__file__).parent / 'my_dataset_big'
+# test_folder = Path(__file__).parent / 'my_dataset_big'
# test_folder = Path(__file__).parent / 'my_dataset_multiprobe'
# yep is for testing
@@ -110,7 +110,7 @@ def test_launcher(verbose=True):
if __name__ == '__main__':
if not test_folder.is_dir():
setup_module()
- # win = test_mainwindow(start_app=True, verbose=True, curation=True)
+ win = test_mainwindow(start_app=True, verbose=True, curation=True)
# win = test_mainwindow(start_app=True, verbose=True, curation=False)
- test_launcher(verbose=True)
+ # test_launcher(verbose=True)
diff --git a/spikeinterface_gui/tests/testingtools.py b/spikeinterface_gui/tests/testingtools.py
index 4f1a7a0..92e1218 100644
--- a/spikeinterface_gui/tests/testingtools.py
+++ b/spikeinterface_gui/tests/testingtools.py
@@ -137,6 +137,7 @@ def make_analyzer_folder(test_folder, case="small", unit_dtype="str"):
def make_curation_dict(analyzer):
unit_ids = analyzer.unit_ids.tolist()
curation_dict = {
+ "format_version": "2",
"unit_ids": unit_ids,
"label_definitions": {
"quality":{
@@ -153,8 +154,8 @@ def make_curation_dict(analyzer):
{'unit_id': unit_ids[2], "putative_type": ["exitatory"]},
{'unit_id': unit_ids[3], "quality": ["noise"], "putative_type": ["inhibitory"]},
],
- "merge_unit_groups": [unit_ids[:3], unit_ids[3:5]],
- "removed_units": unit_ids[5:8],
+ "merges": [{"unit_ids": unit_ids[:3]}, {"unit_ids": unit_ids[3:5]}],
+ "removed": unit_ids[5:8],
}
return curation_dict
diff --git a/spikeinterface_gui/utils_qt.py b/spikeinterface_gui/utils_qt.py
index a5b5044..a815adc 100644
--- a/spikeinterface_gui/utils_qt.py
+++ b/spikeinterface_gui/utils_qt.py
@@ -47,12 +47,13 @@ def add_stretch_to_qtoolbar(tb):
class ViewBoxHandlingLasso(pg.ViewBox):
doubleclicked = QT.pyqtSignal()
lasso_drawing = QT.pyqtSignal(object)
- lasso_finished = QT.pyqtSignal(object)
+ lasso_finished = QT.pyqtSignal(object, bool) # Added bool parameter for shift_held
def __init__(self, *args, **kwds):
pg.ViewBox.__init__(self, *args, **kwds)
self.drag_points = []
self.lasso_active = False
+ self.shift_held = False
def mouseDoubleClickEvent(self, ev):
self.doubleclicked.emit()
@@ -68,12 +69,14 @@ def mouseDragEvent(self, ev):
if ev.isStart():
self.drag_points = []
+ # Check if shift is held at the start of the drag
+ self.shift_held = ev.modifiers() == QT.Qt.ShiftModifier
pos = self.mapToView(ev.pos())
self.drag_points.append([pos.x(), pos.y()])
if ev.isFinish():
- self.lasso_finished.emit(self.drag_points)
+ self.lasso_finished.emit(self.drag_points, self.shift_held)
else:
self.lasso_drawing.emit(self.drag_points)
@@ -123,12 +126,13 @@ class ViewBoxHandlingLassoAndGain(pg.ViewBox):
doubleclicked = QT.pyqtSignal()
gain_zoom = QT.pyqtSignal(float)
lasso_drawing = QT.pyqtSignal(object)
- lasso_finished = QT.pyqtSignal(object)
+ lasso_finished = QT.pyqtSignal(object, bool) # Added bool parameter for shift_held
def __init__(self, *args, **kwds):
pg.ViewBox.__init__(self, *args, **kwds)
self.disableAutoRange()
self.drag_points = []
+ self.shift_held = False
def mouseClickEvent(self, ev):
ev.accept()
@@ -152,12 +156,14 @@ def mouseDragEvent(self, ev):
if ev.isStart():
self.drag_points = []
+ # Check if shift is held at the start of the drag
+ self.shift_held = ev.modifiers() == QT.Qt.ShiftModifier
pos = self.mapToView(ev.pos())
self.drag_points.append([pos.x(), pos.y()])
if ev.isFinish():
- self.lasso_finished.emit(self.drag_points)
+ self.lasso_finished.emit(self.drag_points, self.shift_held)
else:
self.lasso_drawing.emit(self.drag_points)
diff --git a/spikeinterface_gui/view_base.py b/spikeinterface_gui/view_base.py
index 3efec5a..8aab966 100644
--- a/spikeinterface_gui/view_base.py
+++ b/spikeinterface_gui/view_base.py
@@ -113,6 +113,12 @@ def _refresh(self, **kwargs):
self._qt_refresh(**kwargs)
elif self.backend == "panel":
self._panel_refresh(**kwargs)
+
+ def warning(self, warning_msg):
+ if self.backend == "qt":
+ self._qt_insert_warning(warning_msg)
+ elif self.backend == "panel":
+ self._panel_insert_warning(warning_msg)
def get_unit_color(self, unit_id):
if self.backend == "qt":
@@ -212,6 +218,13 @@ def _qt_on_time_info_updated(self):
def _qt_on_unit_color_changed(self):
self.refresh()
+ def _qt_insert_warning(self, warning_msg):
+ from .myqt import QT
+
+ alert = QT.QMessageBox(QT.QMessageBox.Warning, "Warning", warning_msg, parent=self.qt_widget)
+ alert.setStandardButtons(QT.QMessageBox.Ok)
+ alert.exec_()
+
## PANEL ##
def _panel_make_layout(self):
raise NotImplementedError
@@ -237,3 +250,15 @@ def _panel_on_time_info_updated(self):
def _panel_on_unit_color_changed(self):
self.refresh()
+
+ def _panel_insert_warning(self, warning_msg):
+ import panel as pn
+
+ alert_markdown = pn.pane.Markdown(f"⚠️⚠️⚠️ {warning_msg} ⚠️⚠️⚠️", styles={'color': 'red', 'font-size': '16px'})
+ close_button = pn.widgets.Button(name="X")
+ close_button.on_click(self._panel_clear_warning)
+ row = pn.Row(alert_markdown, close_button, sizing_mode='stretch_width')
+ self.layout.insert(0, row)
+
+ def _panel_clear_warning(self, event):
+ self.layout.pop(0)