7
7
8
8
from scipy import stats
9
9
10
- # TODO: spike_times -> spike_indexes
10
+ # TODO: spike_times -> spike_indices
11
11
"""
12
12
Notes
13
13
-----
14
14
- not everything is used for current purposes
15
15
- things might be useful in future for making a sorting analyzer - compute template amplitude as average of spike amplitude.
16
16
"""
17
17
18
+ ########################################################################################################################
19
+ # Get Spike Data
20
+ ########################################################################################################################
21
+
18
22
19
23
def compute_spike_amplitude_and_depth (
20
24
sorter_output : str | Path ,
21
25
localised_spikes_only ,
22
26
exclude_noise ,
23
27
gain : float | None = None ,
24
- localised_spikes_channel_cutoff : int = None , # TODO
28
+ localised_spikes_channel_cutoff : int = None ,
25
29
) -> tuple [np .ndarray , ...]:
26
30
"""
27
31
Compute the amplitude and depth of all detected spikes from the kilosort output.
@@ -46,8 +50,8 @@ def compute_spike_amplitude_and_depth(
46
50
47
51
Returns
48
52
-------
49
- spike_indexes : np.ndarray
50
- (num_spikes,) array of spike indexes .
53
+ spike_indices : np.ndarray
54
+ (num_spikes,) array of spike indices .
51
55
spike_amplitudes : np.ndarray
52
56
(num_spikes,) array of corresponding spike amplitudes.
53
57
spike_depths : np.ndarray
@@ -66,7 +70,7 @@ def compute_spike_amplitude_and_depth(
66
70
if isinstance (sorter_output , str ):
67
71
sorter_output = Path (sorter_output )
68
72
69
- params = _load_ks_dir (sorter_output , load_pcs = True , exclude_noise = exclude_noise )
73
+ params = load_ks_dir (sorter_output , load_pcs = True , exclude_noise = exclude_noise )
70
74
71
75
if localised_spikes_only :
72
76
localised_templates = []
@@ -81,10 +85,52 @@ def compute_spike_amplitude_and_depth(
81
85
82
86
localised_template_by_spike = np .isin (params ["spike_templates" ], localised_templates )
83
87
84
- _strip_spikes (params , localised_template_by_spike )
88
+ params ["spike_templates" ] = params ["spike_templates" ][localised_template_by_spike ]
89
+ params ["spike_indices" ] = params ["spike_indices" ][localised_template_by_spike ]
90
+ params ["spike_clusters" ] = params ["spike_clusters" ][localised_template_by_spike ]
91
+ params ["temp_scaling_amplitudes" ] = params ["temp_scaling_amplitudes" ][localised_template_by_spike ]
92
+ params ["pc_features" ] = params ["pc_features" ][localised_template_by_spike ]
93
+
94
+ spike_locations , spike_max_sites = _get_locations_from_pc_features (params )
95
+
96
+ # Amplitude is calculated for each spike as the template amplitude
97
+ # multiplied by the `template_scaling_amplitudes`.
98
+ template_amplitudes_unscaled , * _ = get_unwhite_template_info (
99
+ params ["templates" ],
100
+ params ["whitening_matrix_inv" ],
101
+ params ["channel_positions" ],
102
+ )
103
+ spike_amplitudes = template_amplitudes_unscaled [params ["spike_templates" ]] * params ["temp_scaling_amplitudes" ]
104
+
105
+ if gain is not None :
106
+ spike_amplitudes *= gain
85
107
108
+ compute_template_amplitudes_from_spikes (params ["templates" ], params ["spike_templates" ], spike_amplitudes )
109
+
110
+ if localised_spikes_only :
111
+ # Interpolate the channel ids to location.
112
+ # Remove spikes > 5 um from average position
113
+ # Above we already removed non-localized templates, but that on its own is insufficient.
114
+ # Note for IMEC probe adding a constant term kills the regression making the regressors rank deficient
115
+ # TODO: a couple of approaches. 1) do everything in 3D, draw a sphere around prediction, take spikes only within the sphere
116
+ # 2) do separate for x, y. But resolution will be much lower, making things noisier, also harder to determine threshold.
117
+ # 3) just use depth. Probably go for that. check with others.
118
+ spike_depths = spike_locations [:, 1 ]
119
+ b = stats .linregress (spike_depths , spike_max_sites ).slope
120
+ i = np .abs (spike_max_sites - b * spike_depths ) <= 5
121
+
122
+ params ["spike_indices" ] = params ["spike_indices" ][i ]
123
+ spike_amplitudes = spike_amplitudes [i ]
124
+ spike_locations = spike_locations [i , :]
125
+ spike_max_sites = spike_max_sites [i ]
126
+
127
+ return params ["spike_indices" ], spike_amplitudes , spike_locations , spike_max_sites
128
+
129
+
130
+ def _get_locations_from_pc_features (params ):
131
+ """ """
86
132
# Compute spike depths
87
- pc_features = params ["pc_features" ][:, 0 , :] # Do this compute
133
+ pc_features = params ["pc_features" ][:, 0 , :]
88
134
pc_features [pc_features < 0 ] = 0
89
135
90
136
# Some spikes do not load at all onto the first PC. To avoid biasing the
@@ -109,58 +155,28 @@ def compute_spike_amplitude_and_depth(
109
155
"to extend this code section to handle more components."
110
156
)
111
157
112
- # Get the channel indexes corresponding to the 32 channels from the PC.
158
+ # Get the channel indices corresponding to the 32 channels from the PC.
113
159
spike_features_indices = params ["pc_features_indices" ][params ["spike_templates" ], :]
114
160
115
161
# Compute the spike locations as the center of mass of the PC scores
116
162
spike_feature_coords = params ["channel_positions" ][spike_features_indices , :]
117
- norm_weights = pc_features / np .sum (pc_features , axis = 1 )[:, np .newaxis ] # TOOD: see why they use square
163
+ norm_weights = (
164
+ pc_features / np .sum (pc_features , axis = 1 )[:, np .newaxis ]
165
+ ) # TOOD: discuss use of square. Probbaly do not use to keep in line with COM in SI.
118
166
spike_locations = spike_feature_coords * norm_weights [:, :, np .newaxis ]
119
167
spike_locations = np .sum (spike_locations , axis = 1 )
120
168
121
169
# TODO: now max site per spike is computed from PCs, not as the channel max site as previous
122
- spike_sites = spike_features_indices [np .arange (spike_features_indices .shape [0 ]), np .argmax (norm_weights , axis = 1 )]
170
+ spike_max_sites = spike_features_indices [
171
+ np .arange (spike_features_indices .shape [0 ]), np .argmax (norm_weights , axis = 1 )
172
+ ]
123
173
124
- # Amplitude is calculated for each spike as the template amplitude
125
- # multiplied by the `template_scaling_amplitudes`.
126
- template_amplitudes_unscaled , * _ = get_unwhite_template_info (
127
- params ["templates" ],
128
- params ["whitening_matrix_inv" ],
129
- params ["channel_positions" ],
130
- )
131
- spike_amplitudes = template_amplitudes_unscaled [params ["spike_templates" ]] * params ["temp_scaling_amplitudes" ]
132
-
133
- if gain is not None :
134
- spike_amplitudes *= gain
135
-
136
- if localised_spikes_only :
137
- # Interpolate the channel ids to location.
138
- # Remove spikes > 5 um from average position
139
- # Above we already removed non-localized templates, but that on its own is insufficient.
140
- # Note for IMEC probe adding a constant term kills the regression making the regressors rank deficient
141
- # TODO: a couple of approaches. 1) do everything in 3D, draw a sphere around prediction, take spikes only within the sphere
142
- # 2) do separate for x, y. But resolution will be much lower, making things noisier, also harder to determine threshold.
143
- # 3) just use depth. Probably go for that. check with others.
144
- spike_depths = spike_locations [:, 1 ]
145
- b = stats .linregress (spike_depths , spike_sites ).slope
146
- i = np .abs (spike_sites - b * spike_depths ) <= 5 # TODO: need to expose this
174
+ return spike_locations , spike_max_sites
147
175
148
- params ["spike_indexes" ] = params ["spike_indexes" ][i ]
149
- spike_amplitudes = spike_amplitudes [i ]
150
- spike_locations = spike_locations [i , :]
151
176
152
- return params ["spike_indexes" ], spike_amplitudes , spike_locations , spike_sites
153
-
154
-
155
- def _strip_spikes_in_place (params , indices ):
156
- """ """
157
- params ["spike_templates" ] = params ["spike_templates" ][
158
- indices
159
- ] # TODO: make an function for this. because we do this a lot
160
- params ["spike_indexes" ] = params ["spike_indexes" ][indices ]
161
- params ["spike_clusters" ] = params ["spike_clusters" ][indices ]
162
- params ["temp_scaling_amplitudes" ] = params ["temp_scaling_amplitudes" ][indices ]
163
- params ["pc_features" ] = params ["pc_features" ][indices ] # TODO: be conciststetn! change indees to indices
177
+ ########################################################################################################################
178
+ # Get Template Data
179
+ ########################################################################################################################
164
180
165
181
166
182
def get_unwhite_template_info (
@@ -213,7 +229,7 @@ def get_unwhite_template_info(
213
229
214
230
template_amplitudes_unscaled = np .max (template_amplitudes_per_channel , axis = 1 )
215
231
216
- # Zero any small channel amplitudes
232
+ # Zero any small channel amplitudes TODO: removed this.
217
233
# threshold_values = 0.3 * template_amplitudes_unscaled TODO: remove this to be more general. Agree?
218
234
# template_amplitudes_per_channel[template_amplitudes_per_channel < threshold_values[:, np.newaxis]] = 0
219
235
@@ -253,9 +269,11 @@ def get_unwhite_template_info(
253
269
)
254
270
255
271
256
- def compute_template_amplitudes_from_spikes ():
257
- # Take the average of all spike amplitudes to get actual template amplitudes
258
- # (since tempScalingAmps are equal mean for all templates)
272
+ def compute_template_amplitudes_from_spikes (templates , spike_templates , spike_amplitudes ):
273
+ """
274
+ Take the average of all spike amplitudes to get actual template amplitudes
275
+ (since tempScalingAmps are equal mean for all templates)
276
+ """
259
277
num_indices = templates .shape [0 ]
260
278
sum_per_index = np .zeros (num_indices , dtype = np .float64 )
261
279
np .add .at (sum_per_index , spike_templates , spike_amplitudes )
@@ -264,7 +282,12 @@ def compute_template_amplitudes_from_spikes():
264
282
return template_amplitudes
265
283
266
284
267
- def _load_ks_dir (sorter_output : Path , exclude_noise : bool = True , load_pcs : bool = False ) -> dict :
285
+ ########################################################################################################################
286
+ # Load Parameters from KS Directory
287
+ ########################################################################################################################
288
+
289
+
290
+ def load_ks_dir (sorter_output : Path , exclude_noise : bool = True , load_pcs : bool = False ) -> dict :
268
291
"""
269
292
Loads the output of Kilosort into a `params` dict.
270
293
@@ -300,7 +323,7 @@ def _load_ks_dir(sorter_output: Path, exclude_noise: bool = True, load_pcs: bool
300
323
301
324
params = read_python (sorter_output / "params.py" )
302
325
303
- spike_indexes = np .load (sorter_output / "spike_times.npy" )
326
+ spike_indices = np .load (sorter_output / "spike_times.npy" )
304
327
spike_templates = np .load (sorter_output / "spike_templates.npy" )
305
328
306
329
if (clusters_path := sorter_output / "spike_clusters.csv" ).is_dir ():
@@ -328,7 +351,7 @@ def _load_ks_dir(sorter_output: Path, exclude_noise: bool = True, load_pcs: bool
328
351
noise_cluster_ids = cluster_ids [cluster_groups == 0 ]
329
352
not_noise_clusters_by_spike = ~ np .isin (spike_clusters .ravel (), noise_cluster_ids )
330
353
331
- spike_indexes = spike_indexes [not_noise_clusters_by_spike ]
354
+ spike_indices = spike_indices [not_noise_clusters_by_spike ]
332
355
spike_templates = spike_templates [not_noise_clusters_by_spike ]
333
356
temp_scaling_amplitudes = temp_scaling_amplitudes [not_noise_clusters_by_spike ]
334
357
@@ -343,7 +366,7 @@ def _load_ks_dir(sorter_output: Path, exclude_noise: bool = True, load_pcs: bool
343
366
cluster_groups = 3 * np .ones (cluster_ids .size )
344
367
345
368
new_params = {
346
- "spike_indexes " : spike_indexes .squeeze (),
369
+ "spike_indices " : spike_indices .squeeze (),
347
370
"spike_templates" : spike_templates .squeeze (),
348
371
"spike_clusters" : spike_clusters .squeeze (),
349
372
"pc_features" : pc_features ,
0 commit comments