diff --git a/src/spikeinterface/working/load_kilosort_utils.py b/src/spikeinterface/working/load_kilosort_utils.py new file mode 100644 index 0000000000..1bd9178d34 --- /dev/null +++ b/src/spikeinterface/working/load_kilosort_utils.py @@ -0,0 +1,416 @@ +from __future__ import annotations + +from pathlib import Path +from spikeinterface.core import read_python +import numpy as np +import pandas as pd + +from scipy import stats + +""" +Notes +----- +- not everything is used for current purposes +- things might be useful in future for making a sorting analyzer - compute template amplitude as average of spike amplitude. + +TODO: testing against diferent ks versions +""" + +######################################################################################################################## +# Get Spike Data +######################################################################################################################## + + +def compute_spike_amplitude_and_depth( + params: dict, + localised_spikes_only, + gain: float | None = None, + localised_spikes_channel_cutoff: int = None, +) -> tuple[np.ndarray, ...]: + """ + Compute the indicies, amplitudes and locations for all detected spikes from the kilosort output. + + This function is based on code in Cortex Lab's `spikes` repository, + https://github.com/cortex-lab/spikes + + Parameters + ---------- + params : dict + `params` as loaded from the kilosort output directory (see `load_ks_dir()`) + localised_spikes_only : bool + If `True`, only spikes with small spatial footprint (i.e. 20 channels within 1/2 of the + amplitude of the maximum loading channel) and which are close to the average depth for + the cluster are returned. + gain: float | None + If a float provided, the `spike_amplitudes` will be scaled by this gain. + localised_spikes_channel_cutoff : int + If `localised_spikes_only` is `True`, spikes that have less than half of the + maximum loading channel over a range of n channels are removed. + This sets the number of channels. + + Returns + ------- + spike_indices : np.ndarray + (num_spikes,) array of spike indices. + spike_amplitudes : np.ndarray + (num_spikes,) array of corresponding spike amplitudes. + spike_locations : np.ndarray + (num_spikes, 2) array of corresponding spike locations (x, y) estimated using + center of mass from the first PC (or, second PC if no signal on first PC). + See `_get_locations_from_pc_features()` for details. + """ + if params["pc_features"] is None: + raise ValueError("`pc_features` must be loaded into params. Use `load_ks_dir` with `load_pcs=True`.") + + if localised_spikes_only: + localised_templates = [] + + for idx, template in enumerate(params["templates"]): + max_channel = np.max(np.abs(params["templates"][idx, :, :])) + channels_over_threshold = np.max(np.abs(params["templates"][idx, :, :]), axis=0) > 0.5 * max_channel + channel_ids_over_threshold = np.where(channels_over_threshold)[0] + + if np.ptp(channel_ids_over_threshold) <= localised_spikes_channel_cutoff: + localised_templates.append(idx) + + localised_template_by_spike = np.isin(params["spike_templates"], localised_templates) + + params["spike_templates"] = params["spike_templates"][localised_template_by_spike] + params["spike_indices"] = params["spike_indices"][localised_template_by_spike] + params["spike_clusters"] = params["spike_clusters"][localised_template_by_spike] + params["temp_scaling_amplitudes"] = params["temp_scaling_amplitudes"][localised_template_by_spike] + params["pc_features"] = params["pc_features"][localised_template_by_spike] + + # Compute the spike locations and maximum-loading channel per spike + spike_locations, spike_max_sites = _get_locations_from_pc_features(params) + + # Amplitude is calculated for each spike as the template amplitude + # multiplied by the `template_scaling_amplitudes`. + template_amplitudes_unscaled, *_ = get_unwhite_template_info( + params["templates"], + params["whitening_matrix_inv"], + params["channel_positions"], + ) + spike_amplitudes = template_amplitudes_unscaled[params["spike_templates"]] * params["temp_scaling_amplitudes"] + + if gain is not None: + spike_amplitudes *= gain + + if localised_spikes_only: + # Interpolate the channel ids to location. + # Remove spikes > 5 um from average position + # Above we already removed non-localized templates, but that on its own is insufficient. + # Note for IMEC probe adding a constant term kills the regression making the regressors rank deficient + spike_depths = spike_locations[:, 1] + b = stats.linregress(spike_depths, spike_max_sites).slope + i = np.abs(spike_max_sites - b * spike_depths) <= 5 + + params["spike_indices"] = params["spike_indices"][i] + spike_amplitudes = spike_amplitudes[i] + spike_locations = spike_locations[i, :] + spike_max_sites = spike_max_sites[i] + + return params["spike_indices"], spike_amplitudes, spike_locations, spike_max_sites + + +def _get_locations_from_pc_features(params): + """ + Compute locations from the waveform principal component scores. + + Notes + ----- + My understanding so far. KS1 paper; The individual spike waveforms are decomposed into + 'private PCs'. Let the waveform matrix W be time (t) x channel (c). PCA + decompoisition is performed to compute c basis waveforms. Scores for each + channel onto the top three PCs are stored (these recover the waveform well. + + This function is based on code in Cortex Lab's `spikes` repository, + https://github.com/cortex-lab/spikes + """ + pc_features = params["pc_features"][:, 0, :].copy() + pc_features[pc_features < 0] = 0 + + if np.any(np.sum(pc_features, axis=1) == 0): + # TODO: 1) handle this case for pc_features + # 2) instead use the template_features for all other versions. + raise RuntimeError( + "Some spikes do not load at all onto the first" + "or second principal component. It is necessary" + "to extend this code section to handle more components." + ) + + # Get the channel indices corresponding to the channels from the PC. + spike_features_indices = params["pc_features_indices"][params["spike_templates"], :] + + # Compute the spike locations as the center of mass of the PC scores + spike_feature_coords = params["channel_positions"][spike_features_indices, :] + norm_weights = pc_features / np.sum(pc_features, axis=1)[:, np.newaxis] + + spike_locations = spike_feature_coords * norm_weights[:, :, np.newaxis] + spike_locations = np.sum(spike_locations, axis=1) + + # Find the max site as the channel with the largest PC weight. + spike_max_sites = spike_features_indices[ + np.arange(spike_features_indices.shape[0]), np.argmax(norm_weights, axis=1) + ] + + return spike_locations, spike_max_sites + + +######################################################################################################################## +# Get Template Data +######################################################################################################################## + + +def get_unwhite_template_info( + templates: np.ndarray, + inverse_whitening_matrix: np.ndarray, + channel_positions: np.ndarray, +) -> tuple[np.ndarray, ...]: + """ + Calculate the amplitude and depths of (unwhitened) templates and spikes. + Amplitude is calculated for each spike as the template amplitude + multiplied by the `template_scaling_amplitudes`. + + This function is based on code in Cortex Lab's `spikes` repository, + https://github.com/cortex-lab/spikes + + Parameters + ---------- + templates : np.ndarray + (num_clusters, num_samples, num_channels) array of templates. + inverse_whitening_matrix: np.ndarray + Inverse of the whitening matrix used in KS preprocessing, used to + unwhiten templates. + channel_positions : np.ndarray + (num_channels, 2) array of the x, y channel positions. + + Returns + ------- + template_amplitudes_unscaled : np.ndarray + (num_templates,) array of the unscaled tempalte amplitudes. These can be + used to calculate spike amplitude with `template_amplitude_scalings`. + template_locations : np.ndarray + (num_templates, 2) array of the x, y positions (center of mass) of each template. + unwhite_templates : np.ndarray + Unwhitened templates (num_clusters, num_samples, num_channels). + template_max_site : np.array + The maximum loading spike for the unwhitened template. + trough_peak_durations : np.ndarray + (num_templates, ) array of durations from trough to peak for each template waveform + waveforms : np.ndarray + (num_templates, num_samples) Waveform of each template, taken as the signal on the maximum loading channel. + """ + # Unwhiten the template waveforms + unwhite_templates = np.zeros_like(templates) + for idx, template in enumerate(templates): + unwhite_templates[idx, :, :] = templates[idx, :, :] @ inverse_whitening_matrix + + # Take the max amplitude for each channel, then use the channel + # with most signal as template amplitude. + template_amplitudes_per_channel = np.max(unwhite_templates, axis=1) - np.min(unwhite_templates, axis=1) + + template_amplitudes_unscaled = np.max(template_amplitudes_per_channel, axis=1) + + # Calculate the template depth as the center of mass based on channel amplitudes + weights = template_amplitudes_per_channel / np.sum(template_amplitudes_per_channel, axis=1)[:, np.newaxis] + template_locations = weights @ channel_positions + + # Get channel with the largest amplitude (take that as the waveform) + template_max_site = np.argmax(np.max(np.abs(unwhite_templates), axis=1), axis=1) + + # Use template channel with max signal as waveform + waveforms = np.empty(unwhite_templates.shape[:2]) + + for idx, template in enumerate(unwhite_templates): + waveforms[idx, :] = unwhite_templates[idx, :, template_max_site[idx]] + + # Get trough-to-peak time for each template. Find the trough as the + # minimum signal for the template waveform. The duration (in + # samples) is the num samples from trough to the largest value + # following the trough. + waveform_trough = np.argmin(waveforms, axis=1) + + trough_peak_durations = np.zeros(waveforms.shape[0]) + for idx, tmp_max in enumerate(waveforms): + trough_peak_durations[idx] = np.argmax(tmp_max[waveform_trough[idx] :]) + + return ( + template_amplitudes_unscaled, + template_locations, + template_max_site, + unwhite_templates, + trough_peak_durations, + waveforms, + ) + + +def compute_template_amplitudes_from_spikes(templates, spike_templates, spike_amplitudes): + """ + Take the average of all spike amplitudes to get actual template amplitudes + (since tempScalingAmps are equal mean for all templates) + + This function is ported from Cortex Lab's `spikes` repository, + https://github.com/cortex-lab/spikes + """ + num_indices = templates.shape[0] + sum_per_index = np.zeros(num_indices, dtype=np.float64) + np.add.at(sum_per_index, spike_templates, spike_amplitudes) + counts = np.bincount(spike_templates, minlength=num_indices) + template_amplitudes = np.divide(sum_per_index, counts, out=np.zeros_like(sum_per_index), where=counts != 0) + return template_amplitudes + + +######################################################################################################################## +# Load Parameters from KS Directory +######################################################################################################################## + + +def load_ks_dir(sorter_output: Path, exclude_noise: bool = True, load_pcs: bool = False) -> dict: + """ + Loads the output of Kilosort into a `params` dict. + + This function was ported from Cortex Lab's `spikes` repository MATLAB + code, https://github.com/cortex-lab/spikes + + Parameters + ---------- + sorter_output : Path + Path to the kilosort run sorting output. + exclude_noise : bool + If `True`, units labelled as "noise` are removed from all + returned arrays (i.e. both units and associated spikes are dropped). + load_pcs : bool + If `True`, principal component (PC) features are loaded. + + Parameters + ---------- + params : dict + A dictionary of parameters combining both the kilosort `params.py` + file as data loaded from `npy` files. The contents of the `npy` + files can be found in the Phy documentation. + + Notes + ----- + When merging and splitting in `Phy`, all changes are made to the + `spike_clusters.npy` (cluster assignment per spike) and `cluster_groups` + csv/tsv which contains the quality assignment (e.g. "noise") for each cluster. + As this function strips the spikes and units based on only these two + data structures, they will work following manual reassignment in Phy. + """ + if isinstance(sorter_output, str): + sorter_output = Path(sorter_output) + + params = read_python(sorter_output / "params.py") + + spike_indices = np.load(sorter_output / "spike_times.npy") + spike_templates = np.load(sorter_output / "spike_templates.npy") + + if (clusters_path := sorter_output / "spike_clusters.csv").is_dir(): + spike_clusters = np.load(clusters_path) + else: + spike_clusters = spike_templates.copy() + + temp_scaling_amplitudes = np.load(sorter_output / "amplitudes.npy") + + if load_pcs: + pc_features = np.load(sorter_output / "pc_features.npy") + pc_features_indices = np.load(sorter_output / "pc_feature_ind.npy") + + if (sorter_output / "template_features.npy").is_file(): + template_features = np.load(sorter_output / "template_features.npy") + template_features_indices = np.load(sorter_output / "templates_ind.npy") + else: + template_features = template_features_indices = None + else: + pc_features = pc_features_indices = None + template_features = template_features_indices = None + + # This makes the assumption that there will never be different .csv and .tsv files + # in the same sorter output (this should never happen, there will never even be two). + # Though can be saved as .tsv, it seems the .csv is also tab formatted as far as pandas is concerned. + if exclude_noise and ( + (cluster_path := sorter_output / "cluster_groups.csv").is_file() + or (cluster_path := sorter_output / "cluster_group.tsv").is_file() + ): + cluster_ids, cluster_groups = _load_cluster_groups(cluster_path) + + noise_cluster_ids = cluster_ids[cluster_groups == 0] + not_noise_clusters_by_spike = ~np.isin(spike_clusters.ravel(), noise_cluster_ids) + + spike_indices = spike_indices[not_noise_clusters_by_spike] + spike_templates = spike_templates[not_noise_clusters_by_spike] + temp_scaling_amplitudes = temp_scaling_amplitudes[not_noise_clusters_by_spike] + + if load_pcs: + pc_features = pc_features[not_noise_clusters_by_spike, :, :] + if template_features is not None: + template_features = template_features[not_noise_clusters_by_spike, :, :] + + spike_clusters = spike_clusters[not_noise_clusters_by_spike] + cluster_ids = cluster_ids[cluster_groups != 0] + cluster_groups = cluster_groups[cluster_groups != 0] + else: + cluster_ids = np.unique(spike_clusters) + cluster_groups = 3 * np.ones(cluster_ids.size) + + new_params = { + "spike_indices": spike_indices.squeeze(), + "spike_templates": spike_templates.squeeze(), + "spike_clusters": spike_clusters.squeeze(), + "pc_features": pc_features, + "pc_features_indices": pc_features_indices, + "template_features": template_features, + "template_features_indices": template_features_indices, + "temp_scaling_amplitudes": temp_scaling_amplitudes.squeeze(), + "cluster_ids": cluster_ids, + "cluster_groups": cluster_groups, + "channel_positions": np.load(sorter_output / "channel_positions.npy"), + "templates": np.load(sorter_output / "templates.npy"), + "whitening_matrix_inv": np.load(sorter_output / "whitening_mat_inv.npy"), + } + params.update(new_params) + + return params + + +def _load_cluster_groups(cluster_path: Path) -> tuple[np.ndarray, ...]: + """ + Load kilosort `cluster_groups` file, that contains a table of + quality assignments, one per unit. These can be "noise", "mua", "good" + or "unsorted". + + There is some slight formatting differences between the `.tsv` and `.csv` + versions, presumably from different kilosort versions. + + This function was ported from Cortex Lab's `spikes` repository MATLAB code, + https://github.com/cortex-lab/spikes + + Parameters + ---------- + cluster_path : Path + The full filepath to the `cluster_groups` tsv or csv file. + + Returns + ------- + cluster_ids : np.ndarray + (num_clusters,) Array of (integer) unit IDs. + + cluster_groups : np.ndarray + (num_clusters,) Array of (integer) unit quality assignments, see code + below for mapping to "noise", "mua", "good" and "unsorted". + """ + cluster_groups_table = pd.read_csv(cluster_path, sep="\t") + + group_key = cluster_groups_table.columns[1] # "groups" (csv) or "KSLabel" (tsv) + + for key, _id in zip( + ["noise", "mua", "good", "unsorted"], + ["0", "1", "2", "3"], # required as str to avoid pandas replace downcast FutureWarning + ): + cluster_groups_table[group_key] = cluster_groups_table[group_key].replace(key, _id) + + cluster_ids = cluster_groups_table["cluster_id"].to_numpy() + cluster_groups = cluster_groups_table[group_key].astype(int).to_numpy() + + return cluster_ids, cluster_groups diff --git a/src/spikeinterface/working/plot_kilosort_drift_map.py b/src/spikeinterface/working/plot_kilosort_drift_map.py new file mode 100644 index 0000000000..e61b7bddd9 --- /dev/null +++ b/src/spikeinterface/working/plot_kilosort_drift_map.py @@ -0,0 +1,426 @@ +from pathlib import Path +import matplotlib.axis +import scipy.signal + +# from spikeinterface.core import read_python +import numpy as np +import pandas as pd + +import matplotlib.pyplot as plt +from scipy import stats +import load_kilosort_utils + +from spikeinterface.widgets.base import BaseWidget, to_attr + + +class KilosortDriftMapWidget(BaseWidget): + """ + Create a drift map plot in the kilosort style. This is ported from Nick Steinmetz's + `spikes` repository MATLAB code, https://github.com/cortex-lab/spikes. + By default, a raster plot is drawn with the y-axis is spike depth and + x-axis is time. Optionally, a corresponding 2D activity histogram can be + added as a subplot (spatial bins, spike counts) with optional + peak coloring and drift event detection (see below). + Parameters + ---------- + sorter_output : str | Path, + Path to the kilosort output folder. + only_include_large_amplitude_spikes : bool + If `True`, only spikes with larger amplitudes are included. For + details, see `_filter_large_amplitude_spikes()`. + decimate : None | int + If an integer n, only every nth spike is kept from the plot. Useful for improving + performance when there are many spikes. If `None`, spikes will not be decimated. + add_histogram_plot : bool + If `True`, an activity histogram will be added to a new subplot to the + left of the drift map. + add_histogram_peaks_and_boundaries : bool + If `True`, activity histogram peaks are detected and colored red if + isolated according to start/end boundaries of the peak (blue otherwise). + add_drift_events : bool + If `True`, drift events will be plot on the raster map. Required + `add_histogram_plot` and `add_histogram_peaks_and_boundaries` to run. + weight_histogram_by_amplitude : bool + If `True`, histogram counts will be weighted by spike amplitude. + localised_spikes_only : bool + If `True`, only spatially isolated spikes will be included. + exclude_noise : bool + If `True`, units labelled as noise in the `cluster_groups` file + will be excluded. + gain : float | None + If not `None`, amplitudes will be scaled by the supplied gain. + large_amplitude_only_segment_size: float + If `only_include_large_amplitude_spikes` is `True`, the probe is split into + segments to compute mean and std used as threshold. This sets the size of the + segments in um. + localised_spikes_channel_cutoff: int + If `localised_spikes_only` is `True`, spikes that have more than half of the + maximum loading channel over a range of > n channels are removed. + This sets the number of channels. + """ + + def __init__( + self, + sorter_output: str | Path, + only_include_large_amplitude_spikes: bool = True, + decimate: None | int = None, + add_histogram_plot: bool = False, + add_histogram_peaks_and_boundaries: bool = True, + add_drift_events: bool = True, + weight_histogram_by_amplitude: bool = False, + localised_spikes_only: bool = False, + exclude_noise: bool = False, + gain: float | None = None, + large_amplitude_only_segment_size: float = 800.0, + localised_spikes_channel_cutoff: int = 20, + ): + if not isinstance(sorter_output, Path): + sorter_output = Path(sorter_output) + + if not sorter_output.is_dir(): + raise ValueError(f"No output folder found at {sorter_output}") + + if not (sorter_output / "params.py").is_file(): + raise ValueError( + "The `sorting_output` path is not a valid kilosort output" + "folder. It does not contain a `params.py` file`." + ) + + plot_data = dict( + sorter_output=sorter_output, + only_include_large_amplitude_spikes=only_include_large_amplitude_spikes, + decimate=decimate, + add_histogram_plot=add_histogram_plot, + add_histogram_peaks_and_boundaries=add_histogram_peaks_and_boundaries, + add_drift_events=add_drift_events, + weight_histogram_by_amplitude=weight_histogram_by_amplitude, + localised_spikes_only=localised_spikes_only, + exclude_noise=exclude_noise, + gain=gain, + large_amplitude_only_segment_size=large_amplitude_only_segment_size, + localised_spikes_channel_cutoff=localised_spikes_channel_cutoff, + ) + BaseWidget.__init__(self, plot_data, backend="matplotlib") + + def plot_matplotlib(self, data_plot: dict, **unused_kwargs) -> None: + + dp = to_attr(data_plot) + + params = load_kilosort_utils.load_ks_dir(dp.sorter_output, load_pcs=True, exclude_noise=dp.exclude_noise) + + spike_indexes, spike_amplitudes, spike_locations, _ = load_kilosort_utils.compute_spike_amplitude_and_depth( + params, dp.localised_spikes_only, dp.gain, dp.localised_spikes_channel_cutoff + ) + spike_times = spike_indexes / 30000 + spike_depths = spike_locations[:, 1] + + # Calculate the amplitude range for plotting first, so the scale is always the + # same across all options (e.g. decimation) which helps with interpretability. + if dp.only_include_large_amplitude_spikes: + amplitude_range_all_spikes = ( + spike_amplitudes.min(), + spike_amplitudes.max(), + ) + else: + amplitude_range_all_spikes = np.percentile(spike_amplitudes, (1, 90)) + + if dp.decimate: + spike_times = spike_times[:: dp.decimate] + spike_amplitudes = spike_amplitudes[:: dp.decimate] + spike_depths = spike_depths[:: dp.decimate] + + if dp.only_include_large_amplitude_spikes: + spike_times, spike_amplitudes, spike_depths = self._filter_large_amplitude_spikes( + spike_times, spike_amplitudes, spike_depths, dp.large_amplitude_only_segment_size + ) + + # Setup axis and plot the raster drift map + fig = plt.figure(figsize=(10, 10 * (6 / 8))) + + if dp.add_histogram_plot: + gs = fig.add_gridspec(1, 2, width_ratios=[1, 5]) + hist_axis = fig.add_subplot(gs[0]) + raster_axis = fig.add_subplot(gs[1], sharey=hist_axis) + else: + raster_axis = fig.add_subplot() + + self._plot_kilosort_drift_map_raster( + spike_times, + spike_amplitudes, + spike_depths, + amplitude_range_all_spikes, + axis=raster_axis, + ) + + if not dp.add_histogram_plot: + raster_axis.set_xlabel("time") + raster_axis.set_ylabel("y position") + self.axes = [raster_axis] + return + + # If the histogram plot is requested, plot it alongside + # it's peak colouring, bounds display and drift point display. + hist_axis.set_xlabel("count") + raster_axis.set_xlabel("time") + hist_axis.set_ylabel("y position") + + bin_centers, counts = self._compute_activity_histogram( + spike_amplitudes, spike_depths, dp.weight_histogram_by_amplitude + ) + hist_axis.plot(counts, bin_centers, color="black", linewidth=1) + + if dp.add_histogram_peaks_and_boundaries: + drift_events = self._color_histogram_peaks_and_detect_drift_events( + spike_times, spike_depths, counts, bin_centers, hist_axis + ) + + if dp.add_drift_events and np.any(drift_events): + raster_axis.scatter(drift_events[:, 0], drift_events[:, 1], facecolors="r", edgecolors="none") + for i, _ in enumerate(drift_events): + raster_axis.text( + drift_events[i, 0] + 1, drift_events[i, 1], str(np.round(drift_events[i, 2])), color="r" + ) + self.axes = [hist_axis, raster_axis] + + def _plot_kilosort_drift_map_raster( + self, + spike_times: np.ndarray, + spike_amplitudes: np.ndarray, + spike_depths: np.ndarray, + amplitude_range: np.ndarray | tuple, + axis: matplotlib.axes.Axes, + ) -> None: + """ + Plot a drift raster plot in the kilosort style. + This function was ported from Nick Steinmetz's `spikes` repository + MATLAB code, https://github.com/cortex-lab/spikes + Parameters + ---------- + spike_times : np.ndarray + (num_spikes,) array of spike times. + spike_amplitudes : np.ndarray + (num_spikes,) array of corresponding spike amplitudes. + spike_depths : np.ndarray + (num_spikes,) array of corresponding spike depths. + amplitude_range : np.ndarray | tuple + (2,) array of min, max amplitude values for color binning. + axis : matplotlib.axes.Axes + Matplotlib axes object on which to plot the drift map. + """ + n_color_bins = 20 + marker_size = 0.5 + + color_bins = np.linspace(amplitude_range[0], amplitude_range[1], n_color_bins) + + colors = plt.get_cmap("gray")(np.linspace(0, 1, n_color_bins))[::-1] + + for bin_idx in range(n_color_bins - 1): + + spikes_in_amplitude_bin = np.logical_and( + spike_amplitudes >= color_bins[bin_idx], spike_amplitudes <= color_bins[bin_idx + 1] + ) + axis.scatter( + spike_times[spikes_in_amplitude_bin], + spike_depths[spikes_in_amplitude_bin], + color=colors[bin_idx], + s=marker_size, + antialiased=True, + ) + + def _compute_activity_histogram( + self, spike_amplitudes: np.ndarray, spike_depths: np.ndarray, weight_histogram_by_amplitude: bool + ) -> tuple[np.ndarray, ...]: + """ + Compute the activity histogram for the kilosort drift map's left-side plot. + Parameters + ---------- + spike_amplitudes : np.ndarray + (num_spikes,) array of spike amplitudes. + spike_depths : np.ndarray + (num_spikes,) array of spike depths. + weight_histogram_by_amplitude : bool + If `True`, the spike amplitudes are taken into consideration when generating the + histogram. The amplitudes are scaled to the range [0, 1] then summed for each bin, + to generate the histogram values. If `False`, counts (i.e. num spikes per bin) + are used. + Returns + ------- + bin_centers : np.ndarray + The spatial bin centers (probe depth) for the histogram. + values : np.ndarray + The histogram values. If `weight_histogram_by_amplitude` is `False`, these + values represent are counts, otherwise they are counts weighted by amplitude. + """ + assert ( + spike_amplitudes.dtype == np.float64 + ), "`spike amplitudes should be high precision as many values are summed." + + bin_um = 2 + bins = np.arange(spike_depths.min() - bin_um, spike_depths.max() + bin_um, bin_um) + values, bins = np.histogram(spike_depths, bins=bins) + bin_centers = (bins[:-1] + bins[1:]) / 2 + + if weight_histogram_by_amplitude: + bin_indices = np.digitize(spike_depths, bins, right=True) - 1 + values = np.zeros(bin_indices.max() + 1, dtype=np.float64) + scaled_spike_amplitudes = (spike_amplitudes - spike_amplitudes.min()) / np.ptp(spike_amplitudes) + np.add.at(values, bin_indices, scaled_spike_amplitudes) + + return bin_centers, values + + def _color_histogram_peaks_and_detect_drift_events( + self, + spike_times: np.ndarray, + spike_depths: np.ndarray, + counts: np.ndarray, + bin_centers: np.ndarray, + hist_axis: matplotlib.axes.Axes, + ) -> np.ndarray: + """ + Given an activity histogram, color the peaks red (isolated peak) or + blue (peak overlaps with other peaks) and compute spatial drift + events for isolated peaks across time bins. + This function was ported from Nick Steinmetz's `spikes` repository + MATLAB code, https://github.com/cortex-lab/spikes + Parameters + ---------- + spike_times : np.ndarray + (num_spikes,) array of spike times. + spike_depths : np.ndarray + (num_spikes,) array of corresponding spike depths. + counts : np.ndarray + (num_bins,) array of histogram bin counts. + bin_centers : np.ndarray + (num_bins,) array of histogram bin centers. + hist_axis : matplotlib.axes.Axes + Axes on which the histogram is plot, to add peaks. + Returns + ------- + drift_events : np.ndarray + A (num_drift_events, 3) array of drift events. The columns are + (time_position, spatial_position, drift_value). The drift + value is computed per time, spatial bin as the difference between + the median position of spikes in the bin, and the bin center. + """ + all_peak_indexes = scipy.signal.find_peaks( + counts, + )[0] + + # Filter low-frequency peaks, so they are not included in the + # step to determine whether peaks are overlapping (new step + # introduced in the port to python) + bin_above_freq_threshold = counts[all_peak_indexes] > 0.3 * spike_times[-1] + filtered_peak_indexes = all_peak_indexes[bin_above_freq_threshold] + + drift_events = [] + for idx, peak_index in enumerate(filtered_peak_indexes): + + peak_count = counts[peak_index] + + # Find the start and end of peak min/max bounds (5% of amplitude) + start_position = np.where(counts[:peak_index] < peak_count * 0.05)[0].max() + end_position = np.where(counts[peak_index:] < peak_count * 0.05)[0].min() + peak_index + + if ( # bounds include another, different histogram peak + idx > 0 + and start_position < filtered_peak_indexes[idx - 1] + or idx < filtered_peak_indexes.size - 1 + and end_position > filtered_peak_indexes[idx + 1] + ): + hist_axis.scatter(peak_count, bin_centers[peak_index], facecolors="none", edgecolors="blue") + continue + + else: + for position in [start_position, end_position]: + hist_axis.axhline(bin_centers[position], 0, counts.max(), color="grey", linestyle="--") + hist_axis.scatter(peak_count, bin_centers[peak_index], facecolors="none", edgecolors="red") + + # For isolated histogram peaks, detect the drift events, defined as + # difference between spatial bin center and median spike depth in the bin + # over 6 um (in time / spatial bins with at least 10 spikes). + depth_in_window = np.logical_and( + spike_depths > bin_centers[start_position], + spike_depths < bin_centers[end_position], + ) + current_spike_depths = spike_depths[depth_in_window] + current_spike_times = spike_times[depth_in_window] + + window_s = 10 + + all_time_bins = np.arange(0, np.ceil(spike_times[-1]).astype(int), window_s) + for time_bin in all_time_bins: + + spike_in_time_bin = np.logical_and( + current_spike_times >= time_bin, current_spike_times <= time_bin + window_s + ) + drift_size = bin_centers[peak_index] - np.median(current_spike_depths[spike_in_time_bin]) + + # 6 um is the hardcoded threshold for drift, and we want at least 10 spikes for the median calculation + bin_has_drift = np.abs(drift_size) > 6 and np.sum(spike_in_time_bin, dtype=np.int16) > 10 + if bin_has_drift: + drift_events.append((time_bin + window_s / 2, bin_centers[peak_index], drift_size)) + + drift_events = np.array(drift_events) + + return drift_events + + def _filter_large_amplitude_spikes( + self, + spike_times: np.ndarray, + spike_amplitudes: np.ndarray, + spike_depths: np.ndarray, + large_amplitude_only_segment_size, + ) -> tuple[np.ndarray, ...]: + """ + Return spike properties with only the largest-amplitude spikes included. The probe + is split into egments, and within each segment the mean and std computed. + Any spike less than 1.5x the standard deviation in amplitude of it's segment is excluded + Splitting the probe is only done for the exclusion step, the returned array are flat. + Takes as input arrays `spike_times`, `spike_depths` and `spike_amplitudes` and returns + copies of these arrays containing only the large amplitude spikes. + """ + spike_bool = np.zeros_like(spike_amplitudes, dtype=bool) + + segment_size_um = large_amplitude_only_segment_size + + probe_segments_left_edges = np.arange(np.floor(spike_depths.max() / segment_size_um) + 1) * segment_size_um + + for segment_left_edge in probe_segments_left_edges: + segment_right_edge = segment_left_edge + segment_size_um + + spikes_in_seg = np.where( + np.logical_and(spike_depths >= segment_left_edge, spike_depths < segment_right_edge) + )[0] + spike_amps_in_seg = spike_amplitudes[spikes_in_seg] + is_high_amplitude = spike_amps_in_seg > np.mean(spike_amps_in_seg) + 1.5 * np.std(spike_amps_in_seg, ddof=1) + + spike_bool[spikes_in_seg] = is_high_amplitude + + spike_times = spike_times[spike_bool] + spike_amplitudes = spike_amplitudes[spike_bool] + spike_depths = spike_depths[spike_bool] + + return spike_times, spike_amplitudes, spike_depths + + +KilosortDriftMapWidget( + "/Users/joeziminski/data/bombcelll/sorter_output", + only_include_large_amplitude_spikes=False, + localised_spikes_only=True, +) +plt.show() + +""" + sorter_output: str | Path, + only_include_large_amplitude_spikes: bool = True, + decimate: None | int = None, + add_histogram_plot: bool = False, + add_histogram_peaks_and_boundaries: bool = True, + add_drift_events: bool = True, + weight_histogram_by_amplitude: bool = False, + localised_spikes_only: bool = False, + exclude_noise: bool = False, + gain: float | None = None, + large_amplitude_only_segment_size: float = 800.0, + localised_spikes_channel_cutoff: int = 20, +""" diff --git a/src/spikeinterface/working/test_peaks_from_ks.py b/src/spikeinterface/working/test_peaks_from_ks.py new file mode 100644 index 0000000000..438268baf6 --- /dev/null +++ b/src/spikeinterface/working/test_peaks_from_ks.py @@ -0,0 +1,77 @@ +from __future__ import annotations + +import spikeinterface.full as si +from spikeinterface.sortingcomponents.peak_detection import detect_peaks +from spikeinterface.sortingcomponents.peak_localization import localize_peaks +import numpy as np +from spikeinterface.core.node_pipeline import ( + base_peak_dtype, +) +from spikeinterface.postprocessing.unit_locations import ( + dtype_localize_by_method, +) +import matplotlib.pyplot as plt +from load_kilosort_utils import compute_spike_amplitude_and_depth + + +recording, sorting = si.generate_ground_truth_recording( + durations=[30.0], + sampling_frequency=30000.0, +) +# job_kwargs = dict(n_jobs=2, chunk_size=10000, progress_bar=True) +job_kwargs = dict(n_jobs=1, chunk_size=10000, progress_bar=True) + +if False: + peaks_ = detect_peaks( + recording, method="locally_exclusive", peak_sign="neg", detect_threshold=5, exclude_sweep_ms=0.1, **job_kwargs + ) + + list_locations = [] + + peak_locations = localize_peaks(recording, peaks_, method="center_of_mass", **job_kwargs) +""" +dtype=[('sample_index', '