@@ -49,7 +49,7 @@ def compute_spike_amplitude_and_depth(
49
49
50
50
Notes
51
51
-----
52
- In `_template_positions_amplitudes ` spike depths is calculated as simply the template
52
+ In `get_template_info_and_spike_amplitudes ` spike depths is calculated as simply the template
53
53
depth, for each spike (so it is the same for all spikes in a cluster). Here we need
54
54
to find the depth of each individual spike, using its low-dimensional projection.
55
55
`pc_features` (num_spikes, num_PC, num_channels) holds the PC values for each spike.
@@ -101,7 +101,7 @@ def compute_spike_amplitude_and_depth(
101
101
# multiplied by the `template_scaling_amplitudes`.
102
102
103
103
# Compute amplitudes, scale if required and drop un-localised spikes before returning.
104
- spike_amplitudes , _ , _ , _ , unwhite_templates , * _ = _template_positions_amplitudes (
104
+ spike_amplitudes , _ , _ , _ , unwhite_templates , * _ = get_template_info_and_spike_amplitudes (
105
105
params ["templates" ],
106
106
params ["whitening_matrix_inv" ],
107
107
ycoords ,
@@ -112,9 +112,16 @@ def compute_spike_amplitude_and_depth(
112
112
if gain is not None :
113
113
spike_amplitudes *= gain
114
114
115
+ max_site = np .argmax (
116
+ np .max (np .abs (templates ), axis = 1 ), axis = 1
117
+ ) # TODO: combine this with above function. Maybe the above function can be templates only, and everything spike-related is here.
115
118
max_site = np .argmax (np .max (np .abs (unwhite_templates ), axis = 1 ), axis = 1 )
116
119
spike_sites = max_site [params ["spike_templates" ]]
117
120
121
+ # TODO: here the max site is the same for all spikes from the same template.
122
+ # is this the case for spikeinterface? Should we estimate max-site per spike from
123
+ # the PCs?
124
+
118
125
if localised_spikes_only :
119
126
# Interpolate the channel ids to location.
120
127
# Remove spikes > 5 um from average position
@@ -134,45 +141,7 @@ def compute_spike_amplitude_and_depth(
134
141
return params ["spike_indexes" ], spike_amplitudes , weighted_locs , spike_sites # TODO: rename everything
135
142
136
143
137
- def _filter_large_amplitude_spikes (
138
- spike_times : np .ndarray ,
139
- spike_amplitudes : np .ndarray ,
140
- spike_depths : np .ndarray ,
141
- large_amplitude_only_segment_size ,
142
- ) -> tuple [np .ndarray , ...]:
143
- """
144
- Return spike properties with only the largest-amplitude spikes included. The probe
145
- is split into egments, and within each segment the mean and std computed.
146
- Any spike less than 1.5x the standard deviation in amplitude of it's segment is excluded
147
- Splitting the probe is only done for the exclusion step, the returned array are flat.
148
-
149
- Takes as input arrays `spike_times`, `spike_depths` and `spike_amplitudes` and returns
150
- copies of these arrays containing only the large amplitude spikes.
151
- """
152
- spike_bool = np .zeros_like (spike_amplitudes , dtype = bool )
153
-
154
- segment_size_um = large_amplitude_only_segment_size
155
- probe_segments_left_edges = np .arange (np .floor (spike_depths .max () / segment_size_um ) + 1 ) * segment_size_um
156
-
157
- for segment_left_edge in probe_segments_left_edges :
158
- segment_right_edge = segment_left_edge + segment_size_um
159
-
160
- spikes_in_seg = np .where (np .logical_and (spike_depths >= segment_left_edge , spike_depths < segment_right_edge ))[
161
- 0
162
- ]
163
- spike_amps_in_seg = spike_amplitudes [spikes_in_seg ]
164
- is_high_amplitude = spike_amps_in_seg > np .mean (spike_amps_in_seg ) + 1.5 * np .std (spike_amps_in_seg , ddof = 1 )
165
-
166
- spike_bool [spikes_in_seg ] = is_high_amplitude
167
-
168
- spike_times = spike_times [spike_bool ]
169
- spike_amplitudes = spike_amplitudes [spike_bool ]
170
- spike_depths = spike_depths [spike_bool ]
171
-
172
- return spike_times , spike_amplitudes , spike_depths
173
-
174
-
175
- def _template_positions_amplitudes (
144
+ def get_template_info_and_spike_amplitudes (
176
145
templates : np .ndarray ,
177
146
inverse_whitening_matrix : np .ndarray ,
178
147
ycoords : np .ndarray ,
@@ -256,9 +225,6 @@ def _template_positions_amplitudes(
256
225
counts = np .bincount (spike_templates , minlength = num_indices )
257
226
template_amplitudes = np .divide (sum_per_index , counts , out = np .zeros_like (sum_per_index ), where = counts != 0 )
258
227
259
- # Each spike's depth is the depth of its template
260
- spike_depths = template_depths [spike_templates ]
261
-
262
228
# Get channel with the largest amplitude (take that as the waveform)
263
229
max_site = np .argmax (np .max (np .abs (templates ), axis = 1 ), axis = 1 )
264
230
@@ -279,7 +245,6 @@ def _template_positions_amplitudes(
279
245
280
246
return (
281
247
spike_amplitudes ,
282
- spike_depths ,
283
248
template_depths ,
284
249
template_amplitudes ,
285
250
unwhite_templates ,
0 commit comments