Skip to content

Commit dbcf818

Browse files
committed
Tidy up, general checks.
1 parent 02ec55a commit dbcf818

File tree

2 files changed

+103
-59
lines changed

2 files changed

+103
-59
lines changed

src/spikeinterface/working/load_kilosort_utils.py

Lines changed: 79 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -7,21 +7,25 @@
77

88
from scipy import stats
99

10-
# TODO: spike_times -> spike_indexes
10+
# TODO: spike_times -> spike_indices
1111
"""
1212
Notes
1313
-----
1414
- not everything is used for current purposes
1515
- things might be useful in future for making a sorting analyzer - compute template amplitude as average of spike amplitude.
1616
"""
1717

18+
########################################################################################################################
19+
# Get Spike Data
20+
########################################################################################################################
21+
1822

1923
def compute_spike_amplitude_and_depth(
2024
sorter_output: str | Path,
2125
localised_spikes_only,
2226
exclude_noise,
2327
gain: float | None = None,
24-
localised_spikes_channel_cutoff: int = None, # TODO
28+
localised_spikes_channel_cutoff: int = None,
2529
) -> tuple[np.ndarray, ...]:
2630
"""
2731
Compute the amplitude and depth of all detected spikes from the kilosort output.
@@ -46,8 +50,8 @@ def compute_spike_amplitude_and_depth(
4650
4751
Returns
4852
-------
49-
spike_indexes : np.ndarray
50-
(num_spikes,) array of spike indexes.
53+
spike_indices : np.ndarray
54+
(num_spikes,) array of spike indices.
5155
spike_amplitudes : np.ndarray
5256
(num_spikes,) array of corresponding spike amplitudes.
5357
spike_depths : np.ndarray
@@ -66,7 +70,7 @@ def compute_spike_amplitude_and_depth(
6670
if isinstance(sorter_output, str):
6771
sorter_output = Path(sorter_output)
6872

69-
params = _load_ks_dir(sorter_output, load_pcs=True, exclude_noise=exclude_noise)
73+
params = load_ks_dir(sorter_output, load_pcs=True, exclude_noise=exclude_noise)
7074

7175
if localised_spikes_only:
7276
localised_templates = []
@@ -81,10 +85,52 @@ def compute_spike_amplitude_and_depth(
8185

8286
localised_template_by_spike = np.isin(params["spike_templates"], localised_templates)
8387

84-
_strip_spikes(params, localised_template_by_spike)
88+
params["spike_templates"] = params["spike_templates"][localised_template_by_spike]
89+
params["spike_indices"] = params["spike_indices"][localised_template_by_spike]
90+
params["spike_clusters"] = params["spike_clusters"][localised_template_by_spike]
91+
params["temp_scaling_amplitudes"] = params["temp_scaling_amplitudes"][localised_template_by_spike]
92+
params["pc_features"] = params["pc_features"][localised_template_by_spike]
93+
94+
spike_locations, spike_max_sites = _get_locations_from_pc_features(params)
95+
96+
# Amplitude is calculated for each spike as the template amplitude
97+
# multiplied by the `template_scaling_amplitudes`.
98+
template_amplitudes_unscaled, *_ = get_unwhite_template_info(
99+
params["templates"],
100+
params["whitening_matrix_inv"],
101+
params["channel_positions"],
102+
)
103+
spike_amplitudes = template_amplitudes_unscaled[params["spike_templates"]] * params["temp_scaling_amplitudes"]
104+
105+
if gain is not None:
106+
spike_amplitudes *= gain
85107

108+
compute_template_amplitudes_from_spikes(params["templates"], params["spike_templates"], spike_amplitudes)
109+
110+
if localised_spikes_only:
111+
# Interpolate the channel ids to location.
112+
# Remove spikes > 5 um from average position
113+
# Above we already removed non-localized templates, but that on its own is insufficient.
114+
# Note for IMEC probe adding a constant term kills the regression making the regressors rank deficient
115+
# TODO: a couple of approaches. 1) do everything in 3D, draw a sphere around prediction, take spikes only within the sphere
116+
# 2) do separate for x, y. But resolution will be much lower, making things noisier, also harder to determine threshold.
117+
# 3) just use depth. Probably go for that. check with others.
118+
spike_depths = spike_locations[:, 1]
119+
b = stats.linregress(spike_depths, spike_max_sites).slope
120+
i = np.abs(spike_max_sites - b * spike_depths) <= 5
121+
122+
params["spike_indices"] = params["spike_indices"][i]
123+
spike_amplitudes = spike_amplitudes[i]
124+
spike_locations = spike_locations[i, :]
125+
spike_max_sites = spike_max_sites[i]
126+
127+
return params["spike_indices"], spike_amplitudes, spike_locations, spike_max_sites
128+
129+
130+
def _get_locations_from_pc_features(params):
131+
""" """
86132
# Compute spike depths
87-
pc_features = params["pc_features"][:, 0, :] # Do this compute
133+
pc_features = params["pc_features"][:, 0, :]
88134
pc_features[pc_features < 0] = 0
89135

90136
# Some spikes do not load at all onto the first PC. To avoid biasing the
@@ -109,58 +155,28 @@ def compute_spike_amplitude_and_depth(
109155
"to extend this code section to handle more components."
110156
)
111157

112-
# Get the channel indexes corresponding to the 32 channels from the PC.
158+
# Get the channel indices corresponding to the 32 channels from the PC.
113159
spike_features_indices = params["pc_features_indices"][params["spike_templates"], :]
114160

115161
# Compute the spike locations as the center of mass of the PC scores
116162
spike_feature_coords = params["channel_positions"][spike_features_indices, :]
117-
norm_weights = pc_features / np.sum(pc_features, axis=1)[:, np.newaxis] # TOOD: see why they use square
163+
norm_weights = (
164+
pc_features / np.sum(pc_features, axis=1)[:, np.newaxis]
165+
) # TOOD: discuss use of square. Probbaly do not use to keep in line with COM in SI.
118166
spike_locations = spike_feature_coords * norm_weights[:, :, np.newaxis]
119167
spike_locations = np.sum(spike_locations, axis=1)
120168

121169
# TODO: now max site per spike is computed from PCs, not as the channel max site as previous
122-
spike_sites = spike_features_indices[np.arange(spike_features_indices.shape[0]), np.argmax(norm_weights, axis=1)]
170+
spike_max_sites = spike_features_indices[
171+
np.arange(spike_features_indices.shape[0]), np.argmax(norm_weights, axis=1)
172+
]
123173

124-
# Amplitude is calculated for each spike as the template amplitude
125-
# multiplied by the `template_scaling_amplitudes`.
126-
template_amplitudes_unscaled, *_ = get_unwhite_template_info(
127-
params["templates"],
128-
params["whitening_matrix_inv"],
129-
params["channel_positions"],
130-
)
131-
spike_amplitudes = template_amplitudes_unscaled[params["spike_templates"]] * params["temp_scaling_amplitudes"]
132-
133-
if gain is not None:
134-
spike_amplitudes *= gain
135-
136-
if localised_spikes_only:
137-
# Interpolate the channel ids to location.
138-
# Remove spikes > 5 um from average position
139-
# Above we already removed non-localized templates, but that on its own is insufficient.
140-
# Note for IMEC probe adding a constant term kills the regression making the regressors rank deficient
141-
# TODO: a couple of approaches. 1) do everything in 3D, draw a sphere around prediction, take spikes only within the sphere
142-
# 2) do separate for x, y. But resolution will be much lower, making things noisier, also harder to determine threshold.
143-
# 3) just use depth. Probably go for that. check with others.
144-
spike_depths = spike_locations[:, 1]
145-
b = stats.linregress(spike_depths, spike_sites).slope
146-
i = np.abs(spike_sites - b * spike_depths) <= 5 # TODO: need to expose this
174+
return spike_locations, spike_max_sites
147175

148-
params["spike_indexes"] = params["spike_indexes"][i]
149-
spike_amplitudes = spike_amplitudes[i]
150-
spike_locations = spike_locations[i, :]
151176

152-
return params["spike_indexes"], spike_amplitudes, spike_locations, spike_sites
153-
154-
155-
def _strip_spikes_in_place(params, indices):
156-
""" """
157-
params["spike_templates"] = params["spike_templates"][
158-
indices
159-
] # TODO: make an function for this. because we do this a lot
160-
params["spike_indexes"] = params["spike_indexes"][indices]
161-
params["spike_clusters"] = params["spike_clusters"][indices]
162-
params["temp_scaling_amplitudes"] = params["temp_scaling_amplitudes"][indices]
163-
params["pc_features"] = params["pc_features"][indices] # TODO: be conciststetn! change indees to indices
177+
########################################################################################################################
178+
# Get Template Data
179+
########################################################################################################################
164180

165181

166182
def get_unwhite_template_info(
@@ -213,7 +229,7 @@ def get_unwhite_template_info(
213229

214230
template_amplitudes_unscaled = np.max(template_amplitudes_per_channel, axis=1)
215231

216-
# Zero any small channel amplitudes
232+
# Zero any small channel amplitudes TODO: removed this.
217233
# threshold_values = 0.3 * template_amplitudes_unscaled TODO: remove this to be more general. Agree?
218234
# template_amplitudes_per_channel[template_amplitudes_per_channel < threshold_values[:, np.newaxis]] = 0
219235

@@ -253,9 +269,11 @@ def get_unwhite_template_info(
253269
)
254270

255271

256-
def compute_template_amplitudes_from_spikes():
257-
# Take the average of all spike amplitudes to get actual template amplitudes
258-
# (since tempScalingAmps are equal mean for all templates)
272+
def compute_template_amplitudes_from_spikes(templates, spike_templates, spike_amplitudes):
273+
"""
274+
Take the average of all spike amplitudes to get actual template amplitudes
275+
(since tempScalingAmps are equal mean for all templates)
276+
"""
259277
num_indices = templates.shape[0]
260278
sum_per_index = np.zeros(num_indices, dtype=np.float64)
261279
np.add.at(sum_per_index, spike_templates, spike_amplitudes)
@@ -264,7 +282,12 @@ def compute_template_amplitudes_from_spikes():
264282
return template_amplitudes
265283

266284

267-
def _load_ks_dir(sorter_output: Path, exclude_noise: bool = True, load_pcs: bool = False) -> dict:
285+
########################################################################################################################
286+
# Load Parameters from KS Directory
287+
########################################################################################################################
288+
289+
290+
def load_ks_dir(sorter_output: Path, exclude_noise: bool = True, load_pcs: bool = False) -> dict:
268291
"""
269292
Loads the output of Kilosort into a `params` dict.
270293
@@ -300,7 +323,7 @@ def _load_ks_dir(sorter_output: Path, exclude_noise: bool = True, load_pcs: bool
300323

301324
params = read_python(sorter_output / "params.py")
302325

303-
spike_indexes = np.load(sorter_output / "spike_times.npy")
326+
spike_indices = np.load(sorter_output / "spike_times.npy")
304327
spike_templates = np.load(sorter_output / "spike_templates.npy")
305328

306329
if (clusters_path := sorter_output / "spike_clusters.csv").is_dir():
@@ -328,7 +351,7 @@ def _load_ks_dir(sorter_output: Path, exclude_noise: bool = True, load_pcs: bool
328351
noise_cluster_ids = cluster_ids[cluster_groups == 0]
329352
not_noise_clusters_by_spike = ~np.isin(spike_clusters.ravel(), noise_cluster_ids)
330353

331-
spike_indexes = spike_indexes[not_noise_clusters_by_spike]
354+
spike_indices = spike_indices[not_noise_clusters_by_spike]
332355
spike_templates = spike_templates[not_noise_clusters_by_spike]
333356
temp_scaling_amplitudes = temp_scaling_amplitudes[not_noise_clusters_by_spike]
334357

@@ -343,7 +366,7 @@ def _load_ks_dir(sorter_output: Path, exclude_noise: bool = True, load_pcs: bool
343366
cluster_groups = 3 * np.ones(cluster_ids.size)
344367

345368
new_params = {
346-
"spike_indexes": spike_indexes.squeeze(),
369+
"spike_indices": spike_indices.squeeze(),
347370
"spike_templates": spike_templates.squeeze(),
348371
"spike_clusters": spike_clusters.squeeze(),
349372
"pc_features": pc_features,

src/spikeinterface/working/plot_kilosort_drift_map.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
from pathlib import Path
2-
from spikeinterface.widgets.base import BaseWidget, to_attr
32
import matplotlib.axis
43
import scipy.signal
5-
from spikeinterface.core import read_python
4+
5+
# from spikeinterface.core import read_python
66
import numpy as np
77
import pandas as pd
88

99
import matplotlib.pyplot as plt
1010
from scipy import stats
1111
import load_kilosort_utils
1212

13+
from spikeinterface.widgets.base import BaseWidget, to_attr
14+
1315

1416
class KilosortDriftMapWidget(BaseWidget):
1517
"""
@@ -399,5 +401,24 @@ def _filter_large_amplitude_spikes(
399401
return spike_times, spike_amplitudes, spike_depths
400402

401403

402-
KilosortDriftMapWidget(r"D:\data\New folder\CA_528_1\imec0_ks2")
404+
KilosortDriftMapWidget(
405+
"/Users/joeziminski/data/bombcelll/sorter_output",
406+
only_include_large_amplitude_spikes=False,
407+
localised_spikes_only=True,
408+
)
403409
plt.show()
410+
411+
"""
412+
sorter_output: str | Path,
413+
only_include_large_amplitude_spikes: bool = True,
414+
decimate: None | int = None,
415+
add_histogram_plot: bool = False,
416+
add_histogram_peaks_and_boundaries: bool = True,
417+
add_drift_events: bool = True,
418+
weight_histogram_by_amplitude: bool = False,
419+
localised_spikes_only: bool = False,
420+
exclude_noise: bool = False,
421+
gain: float | None = None,
422+
large_amplitude_only_segment_size: float = 800.0,
423+
localised_spikes_channel_cutoff: int = 20,
424+
"""

0 commit comments

Comments
 (0)