Skip to content

Commit 02ec55a

Browse files
committed
Some refactoring, tidying up.
1 parent 146179c commit 02ec55a

File tree

2 files changed

+496
-82
lines changed

2 files changed

+496
-82
lines changed

src/spikeinterface/working/load_kilosort_utils.py

Lines changed: 93 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,12 @@
88
from scipy import stats
99

1010
# TODO: spike_times -> spike_indexes
11+
"""
12+
Notes
13+
-----
14+
- not everything is used for current purposes
15+
- things might be useful in future for making a sorting analyzer - compute template amplitude as average of spike amplitude.
16+
"""
1117

1218

1319
def compute_spike_amplitude_and_depth(
@@ -75,53 +81,58 @@ def compute_spike_amplitude_and_depth(
7581

7682
localised_template_by_spike = np.isin(params["spike_templates"], localised_templates)
7783

78-
params["spike_templates"] = params["spike_templates"][localised_template_by_spike]
79-
params["spike_indexes"] = params["spike_indexes"][localised_template_by_spike]
80-
params["spike_clusters"] = params["spike_clusters"][localised_template_by_spike]
81-
params["temp_scaling_amplitudes"] = params["temp_scaling_amplitudes"][localised_template_by_spike]
82-
params["pc_features"] = params["pc_features"][localised_template_by_spike]
84+
_strip_spikes(params, localised_template_by_spike)
8385

8486
# Compute spike depths
85-
pc_features = params["pc_features"][:, 0, :]
87+
pc_features = params["pc_features"][:, 0, :] # Do this compute
8688
pc_features[pc_features < 0] = 0
8789

88-
# Get the channel indexes corresponding to the 32 channels from the PC.
89-
spike_features_indices = params["pc_features_indices"][params["spike_templates"], :]
90+
# Some spikes do not load at all onto the first PC. To avoid biasing the
91+
# dataset by removing these, we repeat the above for the next PC,
92+
# to compute distances for neurons that do not load onto the 1st PC.
93+
# This is not ideal at all, it would be much better to a) find the
94+
# max value for each channel on each of the PCs (i.e. basis vectors).
95+
# Then recompute the estimated waveform peak on each channel by
96+
# summing the PCs by their respective weights. However, the PC basis
97+
# vectors themselves do not appear to be output by KS.
98+
no_pc1_signal_spikes = np.where(np.sum(pc_features, axis=1) == 0)
99+
100+
pc_features_2 = params["pc_features"][:, 1, :]
101+
pc_features_2[pc_features_2 < 0] = 0
90102

91-
ycoords = params["channel_positions"][:, 1]
92-
spike_feature_ycoords = ycoords[spike_features_indices]
103+
pc_features[no_pc1_signal_spikes] = pc_features_2[no_pc1_signal_spikes]
93104

94-
spike_depths = np.sum(spike_feature_ycoords * pc_features**2, axis=1) / np.sum(pc_features**2, axis=1)
105+
if any(np.sum(pc_features, axis=1) == 0):
106+
raise RuntimeError(
107+
"Some spikes do not load at all onto the first"
108+
"or second principal component. It is necessary"
109+
"to extend this code section to handle more components."
110+
)
95111

112+
# Get the channel indexes corresponding to the 32 channels from the PC.
113+
spike_features_indices = params["pc_features_indices"][params["spike_templates"], :]
114+
115+
# Compute the spike locations as the center of mass of the PC scores
96116
spike_feature_coords = params["channel_positions"][spike_features_indices, :]
97117
norm_weights = pc_features / np.sum(pc_features, axis=1)[:, np.newaxis] # TOOD: see why they use square
98-
weighted_locs = spike_feature_coords * norm_weights[:, :, np.newaxis]
99-
weighted_locs = np.sum(weighted_locs, axis=1)
118+
spike_locations = spike_feature_coords * norm_weights[:, :, np.newaxis]
119+
spike_locations = np.sum(spike_locations, axis=1)
120+
121+
# TODO: now max site per spike is computed from PCs, not as the channel max site as previous
122+
spike_sites = spike_features_indices[np.arange(spike_features_indices.shape[0]), np.argmax(norm_weights, axis=1)]
123+
100124
# Amplitude is calculated for each spike as the template amplitude
101125
# multiplied by the `template_scaling_amplitudes`.
102-
103-
# Compute amplitudes, scale if required and drop un-localised spikes before returning.
104-
spike_amplitudes, _, _, _, unwhite_templates, *_ = get_template_info_and_spike_amplitudes(
126+
template_amplitudes_unscaled, *_ = get_unwhite_template_info(
105127
params["templates"],
106128
params["whitening_matrix_inv"],
107-
ycoords,
108-
params["spike_templates"],
109-
params["temp_scaling_amplitudes"],
129+
params["channel_positions"],
110130
)
131+
spike_amplitudes = template_amplitudes_unscaled[params["spike_templates"]] * params["temp_scaling_amplitudes"]
111132

112133
if gain is not None:
113134
spike_amplitudes *= gain
114135

115-
max_site = np.argmax(
116-
np.max(np.abs(templates), axis=1), axis=1
117-
) # TODO: combine this with above function. Maybe the above function can be templates only, and everything spike-related is here.
118-
max_site = np.argmax(np.max(np.abs(unwhite_templates), axis=1), axis=1)
119-
spike_sites = max_site[params["spike_templates"]]
120-
121-
# TODO: here the max site is the same for all spikes from the same template.
122-
# is this the case for spikeinterface? Should we estimate max-site per spike from
123-
# the PCs?
124-
125136
if localised_spikes_only:
126137
# Interpolate the channel ids to location.
127138
# Remove spikes > 5 um from average position
@@ -130,23 +141,32 @@ def compute_spike_amplitude_and_depth(
130141
# TODO: a couple of approaches. 1) do everything in 3D, draw a sphere around prediction, take spikes only within the sphere
131142
# 2) do separate for x, y. But resolution will be much lower, making things noisier, also harder to determine threshold.
132143
# 3) just use depth. Probably go for that. check with others.
133-
spike_depths = weighted_locs[:, 1]
144+
spike_depths = spike_locations[:, 1]
134145
b = stats.linregress(spike_depths, spike_sites).slope
135146
i = np.abs(spike_sites - b * spike_depths) <= 5 # TODO: need to expose this
136147

137148
params["spike_indexes"] = params["spike_indexes"][i]
138149
spike_amplitudes = spike_amplitudes[i]
139-
weighted_locs = weighted_locs[i, :]
150+
spike_locations = spike_locations[i, :]
151+
152+
return params["spike_indexes"], spike_amplitudes, spike_locations, spike_sites
140153

141-
return params["spike_indexes"], spike_amplitudes, weighted_locs, spike_sites # TODO: rename everything
142154

155+
def _strip_spikes_in_place(params, indices):
156+
""" """
157+
params["spike_templates"] = params["spike_templates"][
158+
indices
159+
] # TODO: make an function for this. because we do this a lot
160+
params["spike_indexes"] = params["spike_indexes"][indices]
161+
params["spike_clusters"] = params["spike_clusters"][indices]
162+
params["temp_scaling_amplitudes"] = params["temp_scaling_amplitudes"][indices]
163+
params["pc_features"] = params["pc_features"][indices] # TODO: be conciststetn! change indees to indices
143164

144-
def get_template_info_and_spike_amplitudes(
165+
166+
def get_unwhite_template_info(
145167
templates: np.ndarray,
146168
inverse_whitening_matrix: np.ndarray,
147-
ycoords: np.ndarray,
148-
spike_templates: np.ndarray,
149-
template_scaling_amplitudes: np.ndarray,
169+
channel_positions: np.ndarray,
150170
) -> tuple[np.ndarray, ...]:
151171
"""
152172
Calculate the amplitude and depths of (unwhitened) templates and spikes.
@@ -163,28 +183,20 @@ def get_template_info_and_spike_amplitudes(
163183
inverse_whitening_matrix: np.ndarray
164184
Inverse of the whitening matrix used in KS preprocessing, used to
165185
unwhiten templates.
166-
ycoords : np.ndarray
167-
(num_channels,) array of the y-axis (depth) channel positions.
168-
spike_templates : np.ndarray
169-
(num_spikes,) array indicating the template associated with each spike.
170-
template_scaling_amplitudes : np.ndarray
171-
(num_spikes,) array holding the scaling amplitudes, by which the
172-
template was scaled to match each spike.
186+
channel_positions : np.ndarray
187+
(num_channels, 2) array of the x, y channel positions.
173188
174189
Returns
175190
-------
176-
spike_amplitudes : np.ndarray
177-
(num_spikes,) array of the amplitude of each spike.
178-
spike_depths : np.ndarray
179-
(num_spikes,) array of the depth (probe y-axis) of each spike. Note
180-
this is just the template depth for each spike (i.e. depth of all spikes
181-
from the same cluster are identical).
182-
template_amplitudes : np.ndarray
183-
(num_templates,) Amplitude of each template, calculated as average of spike amplitudes.
184-
template_depths : np.ndarray
185-
(num_templates,) array of the depth of each template.
191+
template_amplitudes_unscaled : np.ndarray
192+
(num_templates,) array of the unscaled tempalte amplitudes. These can be
193+
used to calculate spike amplitude with `template_amplitude_scalings`.
194+
template_locations : np.ndarray
195+
(num_templates, 2) array of the x, y positions (center of mass) of each template.
186196
unwhite_templates : np.ndarray
187197
Unwhitened templates (num_clusters, num_samples, num_channels).
198+
template_max_site : np.array
199+
The maximum loading spike for the unwhitened template.
188200
trough_peak_durations : np.ndarray
189201
(num_templates, ) array of durations from trough to peak for each template waveform
190202
waveforms : np.ndarray
@@ -195,43 +207,31 @@ def get_template_info_and_spike_amplitudes(
195207
for idx, template in enumerate(templates):
196208
unwhite_templates[idx, :, :] = templates[idx, :, :] @ inverse_whitening_matrix
197209

198-
# First, calculate the depth of each template from the amplitude
199-
# on each channel by the center of mass method.
200-
201210
# Take the max amplitude for each channel, then use the channel
202-
# with most signal as template amplitude. Zero any small channel amplitudes.
211+
# with most signal as template amplitude.
203212
template_amplitudes_per_channel = np.max(unwhite_templates, axis=1) - np.min(unwhite_templates, axis=1)
204213

205214
template_amplitudes_unscaled = np.max(template_amplitudes_per_channel, axis=1)
206215

207-
threshold_values = 0.3 * template_amplitudes_unscaled
208-
template_amplitudes_per_channel[template_amplitudes_per_channel < threshold_values[:, np.newaxis]] = 0
216+
# Zero any small channel amplitudes
217+
# threshold_values = 0.3 * template_amplitudes_unscaled TODO: remove this to be more general. Agree?
218+
# template_amplitudes_per_channel[template_amplitudes_per_channel < threshold_values[:, np.newaxis]] = 0
209219

210220
# Calculate the template depth as the center of mass based on channel amplitudes
211-
template_depths = np.sum(template_amplitudes_per_channel * ycoords[np.newaxis, :], axis=1) / np.sum(
212-
template_amplitudes_per_channel, axis=1
213-
)
214-
215-
# Next, find the depth of each spike based on its template. Recompute the template
216-
# amplitudes as the average of the spike amplitudes ('since
217-
# tempScalingAmps are equal mean for all templates')
218-
spike_amplitudes = template_amplitudes_unscaled[spike_templates] * template_scaling_amplitudes
219-
220-
# Take the average of all spike amplitudes to get actual template amplitudes
221-
# (since tempScalingAmps are equal mean for all templates)
222-
num_indices = templates.shape[0]
223-
sum_per_index = np.zeros(num_indices, dtype=np.float64)
224-
np.add.at(sum_per_index, spike_templates, spike_amplitudes)
225-
counts = np.bincount(spike_templates, minlength=num_indices)
226-
template_amplitudes = np.divide(sum_per_index, counts, out=np.zeros_like(sum_per_index), where=counts != 0)
221+
weights = template_amplitudes_per_channel / np.sum(template_amplitudes_per_channel, axis=1)[:, np.newaxis]
222+
template_locations = weights @ channel_positions
227223

228224
# Get channel with the largest amplitude (take that as the waveform)
229-
max_site = np.argmax(np.max(np.abs(templates), axis=1), axis=1)
225+
template_max_site = np.argmax(
226+
np.max(np.abs(unwhite_templates), axis=1), axis=1
227+
) # TODO: i changed this to use unwhitened templates instead of templates. This okay?
230228

231229
# Use template channel with max signal as waveform
232-
waveforms = np.empty(templates.shape[:2])
233-
for idx, template in enumerate(templates):
234-
waveforms[idx, :] = templates[idx, :, max_site[idx]]
230+
waveforms = np.empty(
231+
unwhite_templates.shape[:2]
232+
) # TODO: i changed this to use unwhitened templates instead of templates. This okay?
233+
for idx, template in enumerate(unwhite_templates):
234+
waveforms[idx, :] = unwhite_templates[idx, :, template_max_site[idx]]
235235

236236
# Get trough-to-peak time for each template. Find the trough as the
237237
# minimum signal for the template waveform. The duration (in
@@ -244,15 +244,26 @@ def get_template_info_and_spike_amplitudes(
244244
trough_peak_durations[idx] = np.argmax(tmp_max[waveform_trough[idx] :])
245245

246246
return (
247-
spike_amplitudes,
248-
template_depths,
249-
template_amplitudes,
247+
template_amplitudes_unscaled,
248+
template_locations,
249+
template_max_site,
250250
unwhite_templates,
251251
trough_peak_durations,
252252
waveforms,
253253
)
254254

255255

256+
def compute_template_amplitudes_from_spikes():
257+
# Take the average of all spike amplitudes to get actual template amplitudes
258+
# (since tempScalingAmps are equal mean for all templates)
259+
num_indices = templates.shape[0]
260+
sum_per_index = np.zeros(num_indices, dtype=np.float64)
261+
np.add.at(sum_per_index, spike_templates, spike_amplitudes)
262+
counts = np.bincount(spike_templates, minlength=num_indices)
263+
template_amplitudes = np.divide(sum_per_index, counts, out=np.zeros_like(sum_per_index), where=counts != 0)
264+
return template_amplitudes
265+
266+
256267
def _load_ks_dir(sorter_output: Path, exclude_noise: bool = True, load_pcs: bool = False) -> dict:
257268
"""
258269
Loads the output of Kilosort into a `params` dict.

0 commit comments

Comments
 (0)