8
8
from scipy import stats
9
9
10
10
# TODO: spike_times -> spike_indexes
11
+ """
12
+ Notes
13
+ -----
14
+ - not everything is used for current purposes
15
+ - things might be useful in future for making a sorting analyzer - compute template amplitude as average of spike amplitude.
16
+ """
11
17
12
18
13
19
def compute_spike_amplitude_and_depth (
@@ -75,53 +81,58 @@ def compute_spike_amplitude_and_depth(
75
81
76
82
localised_template_by_spike = np .isin (params ["spike_templates" ], localised_templates )
77
83
78
- params ["spike_templates" ] = params ["spike_templates" ][localised_template_by_spike ]
79
- params ["spike_indexes" ] = params ["spike_indexes" ][localised_template_by_spike ]
80
- params ["spike_clusters" ] = params ["spike_clusters" ][localised_template_by_spike ]
81
- params ["temp_scaling_amplitudes" ] = params ["temp_scaling_amplitudes" ][localised_template_by_spike ]
82
- params ["pc_features" ] = params ["pc_features" ][localised_template_by_spike ]
84
+ _strip_spikes (params , localised_template_by_spike )
83
85
84
86
# Compute spike depths
85
- pc_features = params ["pc_features" ][:, 0 , :]
87
+ pc_features = params ["pc_features" ][:, 0 , :] # Do this compute
86
88
pc_features [pc_features < 0 ] = 0
87
89
88
- # Get the channel indexes corresponding to the 32 channels from the PC.
89
- spike_features_indices = params ["pc_features_indices" ][params ["spike_templates" ], :]
90
+ # Some spikes do not load at all onto the first PC. To avoid biasing the
91
+ # dataset by removing these, we repeat the above for the next PC,
92
+ # to compute distances for neurons that do not load onto the 1st PC.
93
+ # This is not ideal at all, it would be much better to a) find the
94
+ # max value for each channel on each of the PCs (i.e. basis vectors).
95
+ # Then recompute the estimated waveform peak on each channel by
96
+ # summing the PCs by their respective weights. However, the PC basis
97
+ # vectors themselves do not appear to be output by KS.
98
+ no_pc1_signal_spikes = np .where (np .sum (pc_features , axis = 1 ) == 0 )
99
+
100
+ pc_features_2 = params ["pc_features" ][:, 1 , :]
101
+ pc_features_2 [pc_features_2 < 0 ] = 0
90
102
91
- ycoords = params ["channel_positions" ][:, 1 ]
92
- spike_feature_ycoords = ycoords [spike_features_indices ]
103
+ pc_features [no_pc1_signal_spikes ] = pc_features_2 [no_pc1_signal_spikes ]
93
104
94
- spike_depths = np .sum (spike_feature_ycoords * pc_features ** 2 , axis = 1 ) / np .sum (pc_features ** 2 , axis = 1 )
105
+ if any (np .sum (pc_features , axis = 1 ) == 0 ):
106
+ raise RuntimeError (
107
+ "Some spikes do not load at all onto the first"
108
+ "or second principal component. It is necessary"
109
+ "to extend this code section to handle more components."
110
+ )
95
111
112
+ # Get the channel indexes corresponding to the 32 channels from the PC.
113
+ spike_features_indices = params ["pc_features_indices" ][params ["spike_templates" ], :]
114
+
115
+ # Compute the spike locations as the center of mass of the PC scores
96
116
spike_feature_coords = params ["channel_positions" ][spike_features_indices , :]
97
117
norm_weights = pc_features / np .sum (pc_features , axis = 1 )[:, np .newaxis ] # TOOD: see why they use square
98
- weighted_locs = spike_feature_coords * norm_weights [:, :, np .newaxis ]
99
- weighted_locs = np .sum (weighted_locs , axis = 1 )
118
+ spike_locations = spike_feature_coords * norm_weights [:, :, np .newaxis ]
119
+ spike_locations = np .sum (spike_locations , axis = 1 )
120
+
121
+ # TODO: now max site per spike is computed from PCs, not as the channel max site as previous
122
+ spike_sites = spike_features_indices [np .arange (spike_features_indices .shape [0 ]), np .argmax (norm_weights , axis = 1 )]
123
+
100
124
# Amplitude is calculated for each spike as the template amplitude
101
125
# multiplied by the `template_scaling_amplitudes`.
102
-
103
- # Compute amplitudes, scale if required and drop un-localised spikes before returning.
104
- spike_amplitudes , _ , _ , _ , unwhite_templates , * _ = get_template_info_and_spike_amplitudes (
126
+ template_amplitudes_unscaled , * _ = get_unwhite_template_info (
105
127
params ["templates" ],
106
128
params ["whitening_matrix_inv" ],
107
- ycoords ,
108
- params ["spike_templates" ],
109
- params ["temp_scaling_amplitudes" ],
129
+ params ["channel_positions" ],
110
130
)
131
+ spike_amplitudes = template_amplitudes_unscaled [params ["spike_templates" ]] * params ["temp_scaling_amplitudes" ]
111
132
112
133
if gain is not None :
113
134
spike_amplitudes *= gain
114
135
115
- max_site = np .argmax (
116
- np .max (np .abs (templates ), axis = 1 ), axis = 1
117
- ) # TODO: combine this with above function. Maybe the above function can be templates only, and everything spike-related is here.
118
- max_site = np .argmax (np .max (np .abs (unwhite_templates ), axis = 1 ), axis = 1 )
119
- spike_sites = max_site [params ["spike_templates" ]]
120
-
121
- # TODO: here the max site is the same for all spikes from the same template.
122
- # is this the case for spikeinterface? Should we estimate max-site per spike from
123
- # the PCs?
124
-
125
136
if localised_spikes_only :
126
137
# Interpolate the channel ids to location.
127
138
# Remove spikes > 5 um from average position
@@ -130,23 +141,32 @@ def compute_spike_amplitude_and_depth(
130
141
# TODO: a couple of approaches. 1) do everything in 3D, draw a sphere around prediction, take spikes only within the sphere
131
142
# 2) do separate for x, y. But resolution will be much lower, making things noisier, also harder to determine threshold.
132
143
# 3) just use depth. Probably go for that. check with others.
133
- spike_depths = weighted_locs [:, 1 ]
144
+ spike_depths = spike_locations [:, 1 ]
134
145
b = stats .linregress (spike_depths , spike_sites ).slope
135
146
i = np .abs (spike_sites - b * spike_depths ) <= 5 # TODO: need to expose this
136
147
137
148
params ["spike_indexes" ] = params ["spike_indexes" ][i ]
138
149
spike_amplitudes = spike_amplitudes [i ]
139
- weighted_locs = weighted_locs [i , :]
150
+ spike_locations = spike_locations [i , :]
151
+
152
+ return params ["spike_indexes" ], spike_amplitudes , spike_locations , spike_sites
140
153
141
- return params ["spike_indexes" ], spike_amplitudes , weighted_locs , spike_sites # TODO: rename everything
142
154
155
+ def _strip_spikes_in_place (params , indices ):
156
+ """ """
157
+ params ["spike_templates" ] = params ["spike_templates" ][
158
+ indices
159
+ ] # TODO: make an function for this. because we do this a lot
160
+ params ["spike_indexes" ] = params ["spike_indexes" ][indices ]
161
+ params ["spike_clusters" ] = params ["spike_clusters" ][indices ]
162
+ params ["temp_scaling_amplitudes" ] = params ["temp_scaling_amplitudes" ][indices ]
163
+ params ["pc_features" ] = params ["pc_features" ][indices ] # TODO: be conciststetn! change indees to indices
143
164
144
- def get_template_info_and_spike_amplitudes (
165
+
166
+ def get_unwhite_template_info (
145
167
templates : np .ndarray ,
146
168
inverse_whitening_matrix : np .ndarray ,
147
- ycoords : np .ndarray ,
148
- spike_templates : np .ndarray ,
149
- template_scaling_amplitudes : np .ndarray ,
169
+ channel_positions : np .ndarray ,
150
170
) -> tuple [np .ndarray , ...]:
151
171
"""
152
172
Calculate the amplitude and depths of (unwhitened) templates and spikes.
@@ -163,28 +183,20 @@ def get_template_info_and_spike_amplitudes(
163
183
inverse_whitening_matrix: np.ndarray
164
184
Inverse of the whitening matrix used in KS preprocessing, used to
165
185
unwhiten templates.
166
- ycoords : np.ndarray
167
- (num_channels,) array of the y-axis (depth) channel positions.
168
- spike_templates : np.ndarray
169
- (num_spikes,) array indicating the template associated with each spike.
170
- template_scaling_amplitudes : np.ndarray
171
- (num_spikes,) array holding the scaling amplitudes, by which the
172
- template was scaled to match each spike.
186
+ channel_positions : np.ndarray
187
+ (num_channels, 2) array of the x, y channel positions.
173
188
174
189
Returns
175
190
-------
176
- spike_amplitudes : np.ndarray
177
- (num_spikes,) array of the amplitude of each spike.
178
- spike_depths : np.ndarray
179
- (num_spikes,) array of the depth (probe y-axis) of each spike. Note
180
- this is just the template depth for each spike (i.e. depth of all spikes
181
- from the same cluster are identical).
182
- template_amplitudes : np.ndarray
183
- (num_templates,) Amplitude of each template, calculated as average of spike amplitudes.
184
- template_depths : np.ndarray
185
- (num_templates,) array of the depth of each template.
191
+ template_amplitudes_unscaled : np.ndarray
192
+ (num_templates,) array of the unscaled tempalte amplitudes. These can be
193
+ used to calculate spike amplitude with `template_amplitude_scalings`.
194
+ template_locations : np.ndarray
195
+ (num_templates, 2) array of the x, y positions (center of mass) of each template.
186
196
unwhite_templates : np.ndarray
187
197
Unwhitened templates (num_clusters, num_samples, num_channels).
198
+ template_max_site : np.array
199
+ The maximum loading spike for the unwhitened template.
188
200
trough_peak_durations : np.ndarray
189
201
(num_templates, ) array of durations from trough to peak for each template waveform
190
202
waveforms : np.ndarray
@@ -195,43 +207,31 @@ def get_template_info_and_spike_amplitudes(
195
207
for idx , template in enumerate (templates ):
196
208
unwhite_templates [idx , :, :] = templates [idx , :, :] @ inverse_whitening_matrix
197
209
198
- # First, calculate the depth of each template from the amplitude
199
- # on each channel by the center of mass method.
200
-
201
210
# Take the max amplitude for each channel, then use the channel
202
- # with most signal as template amplitude. Zero any small channel amplitudes.
211
+ # with most signal as template amplitude.
203
212
template_amplitudes_per_channel = np .max (unwhite_templates , axis = 1 ) - np .min (unwhite_templates , axis = 1 )
204
213
205
214
template_amplitudes_unscaled = np .max (template_amplitudes_per_channel , axis = 1 )
206
215
207
- threshold_values = 0.3 * template_amplitudes_unscaled
208
- template_amplitudes_per_channel [template_amplitudes_per_channel < threshold_values [:, np .newaxis ]] = 0
216
+ # Zero any small channel amplitudes
217
+ # threshold_values = 0.3 * template_amplitudes_unscaled TODO: remove this to be more general. Agree?
218
+ # template_amplitudes_per_channel[template_amplitudes_per_channel < threshold_values[:, np.newaxis]] = 0
209
219
210
220
# Calculate the template depth as the center of mass based on channel amplitudes
211
- template_depths = np .sum (template_amplitudes_per_channel * ycoords [np .newaxis , :], axis = 1 ) / np .sum (
212
- template_amplitudes_per_channel , axis = 1
213
- )
214
-
215
- # Next, find the depth of each spike based on its template. Recompute the template
216
- # amplitudes as the average of the spike amplitudes ('since
217
- # tempScalingAmps are equal mean for all templates')
218
- spike_amplitudes = template_amplitudes_unscaled [spike_templates ] * template_scaling_amplitudes
219
-
220
- # Take the average of all spike amplitudes to get actual template amplitudes
221
- # (since tempScalingAmps are equal mean for all templates)
222
- num_indices = templates .shape [0 ]
223
- sum_per_index = np .zeros (num_indices , dtype = np .float64 )
224
- np .add .at (sum_per_index , spike_templates , spike_amplitudes )
225
- counts = np .bincount (spike_templates , minlength = num_indices )
226
- template_amplitudes = np .divide (sum_per_index , counts , out = np .zeros_like (sum_per_index ), where = counts != 0 )
221
+ weights = template_amplitudes_per_channel / np .sum (template_amplitudes_per_channel , axis = 1 )[:, np .newaxis ]
222
+ template_locations = weights @ channel_positions
227
223
228
224
# Get channel with the largest amplitude (take that as the waveform)
229
- max_site = np .argmax (np .max (np .abs (templates ), axis = 1 ), axis = 1 )
225
+ template_max_site = np .argmax (
226
+ np .max (np .abs (unwhite_templates ), axis = 1 ), axis = 1
227
+ ) # TODO: i changed this to use unwhitened templates instead of templates. This okay?
230
228
231
229
# Use template channel with max signal as waveform
232
- waveforms = np .empty (templates .shape [:2 ])
233
- for idx , template in enumerate (templates ):
234
- waveforms [idx , :] = templates [idx , :, max_site [idx ]]
230
+ waveforms = np .empty (
231
+ unwhite_templates .shape [:2 ]
232
+ ) # TODO: i changed this to use unwhitened templates instead of templates. This okay?
233
+ for idx , template in enumerate (unwhite_templates ):
234
+ waveforms [idx , :] = unwhite_templates [idx , :, template_max_site [idx ]]
235
235
236
236
# Get trough-to-peak time for each template. Find the trough as the
237
237
# minimum signal for the template waveform. The duration (in
@@ -244,15 +244,26 @@ def get_template_info_and_spike_amplitudes(
244
244
trough_peak_durations [idx ] = np .argmax (tmp_max [waveform_trough [idx ] :])
245
245
246
246
return (
247
- spike_amplitudes ,
248
- template_depths ,
249
- template_amplitudes ,
247
+ template_amplitudes_unscaled ,
248
+ template_locations ,
249
+ template_max_site ,
250
250
unwhite_templates ,
251
251
trough_peak_durations ,
252
252
waveforms ,
253
253
)
254
254
255
255
256
+ def compute_template_amplitudes_from_spikes ():
257
+ # Take the average of all spike amplitudes to get actual template amplitudes
258
+ # (since tempScalingAmps are equal mean for all templates)
259
+ num_indices = templates .shape [0 ]
260
+ sum_per_index = np .zeros (num_indices , dtype = np .float64 )
261
+ np .add .at (sum_per_index , spike_templates , spike_amplitudes )
262
+ counts = np .bincount (spike_templates , minlength = num_indices )
263
+ template_amplitudes = np .divide (sum_per_index , counts , out = np .zeros_like (sum_per_index ), where = counts != 0 )
264
+ return template_amplitudes
265
+
266
+
256
267
def _load_ks_dir (sorter_output : Path , exclude_noise : bool = True , load_pcs : bool = False ) -> dict :
257
268
"""
258
269
Loads the output of Kilosort into a `params` dict.
0 commit comments