Skip to content

Commit 146179c

Browse files
committed
Add some notes.
1 parent 84b9a17 commit 146179c

File tree

1 file changed

+10
-45
lines changed

1 file changed

+10
-45
lines changed

src/spikeinterface/working/load_kilosort_utils.py

Lines changed: 10 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def compute_spike_amplitude_and_depth(
4949
5050
Notes
5151
-----
52-
In `_template_positions_amplitudes` spike depths is calculated as simply the template
52+
In `get_template_info_and_spike_amplitudes` spike depths is calculated as simply the template
5353
depth, for each spike (so it is the same for all spikes in a cluster). Here we need
5454
to find the depth of each individual spike, using its low-dimensional projection.
5555
`pc_features` (num_spikes, num_PC, num_channels) holds the PC values for each spike.
@@ -101,7 +101,7 @@ def compute_spike_amplitude_and_depth(
101101
# multiplied by the `template_scaling_amplitudes`.
102102

103103
# Compute amplitudes, scale if required and drop un-localised spikes before returning.
104-
spike_amplitudes, _, _, _, unwhite_templates, *_ = _template_positions_amplitudes(
104+
spike_amplitudes, _, _, _, unwhite_templates, *_ = get_template_info_and_spike_amplitudes(
105105
params["templates"],
106106
params["whitening_matrix_inv"],
107107
ycoords,
@@ -112,9 +112,16 @@ def compute_spike_amplitude_and_depth(
112112
if gain is not None:
113113
spike_amplitudes *= gain
114114

115+
max_site = np.argmax(
116+
np.max(np.abs(templates), axis=1), axis=1
117+
) # TODO: combine this with above function. Maybe the above function can be templates only, and everything spike-related is here.
115118
max_site = np.argmax(np.max(np.abs(unwhite_templates), axis=1), axis=1)
116119
spike_sites = max_site[params["spike_templates"]]
117120

121+
# TODO: here the max site is the same for all spikes from the same template.
122+
# is this the case for spikeinterface? Should we estimate max-site per spike from
123+
# the PCs?
124+
118125
if localised_spikes_only:
119126
# Interpolate the channel ids to location.
120127
# Remove spikes > 5 um from average position
@@ -134,45 +141,7 @@ def compute_spike_amplitude_and_depth(
134141
return params["spike_indexes"], spike_amplitudes, weighted_locs, spike_sites # TODO: rename everything
135142

136143

137-
def _filter_large_amplitude_spikes(
138-
spike_times: np.ndarray,
139-
spike_amplitudes: np.ndarray,
140-
spike_depths: np.ndarray,
141-
large_amplitude_only_segment_size,
142-
) -> tuple[np.ndarray, ...]:
143-
"""
144-
Return spike properties with only the largest-amplitude spikes included. The probe
145-
is split into egments, and within each segment the mean and std computed.
146-
Any spike less than 1.5x the standard deviation in amplitude of it's segment is excluded
147-
Splitting the probe is only done for the exclusion step, the returned array are flat.
148-
149-
Takes as input arrays `spike_times`, `spike_depths` and `spike_amplitudes` and returns
150-
copies of these arrays containing only the large amplitude spikes.
151-
"""
152-
spike_bool = np.zeros_like(spike_amplitudes, dtype=bool)
153-
154-
segment_size_um = large_amplitude_only_segment_size
155-
probe_segments_left_edges = np.arange(np.floor(spike_depths.max() / segment_size_um) + 1) * segment_size_um
156-
157-
for segment_left_edge in probe_segments_left_edges:
158-
segment_right_edge = segment_left_edge + segment_size_um
159-
160-
spikes_in_seg = np.where(np.logical_and(spike_depths >= segment_left_edge, spike_depths < segment_right_edge))[
161-
0
162-
]
163-
spike_amps_in_seg = spike_amplitudes[spikes_in_seg]
164-
is_high_amplitude = spike_amps_in_seg > np.mean(spike_amps_in_seg) + 1.5 * np.std(spike_amps_in_seg, ddof=1)
165-
166-
spike_bool[spikes_in_seg] = is_high_amplitude
167-
168-
spike_times = spike_times[spike_bool]
169-
spike_amplitudes = spike_amplitudes[spike_bool]
170-
spike_depths = spike_depths[spike_bool]
171-
172-
return spike_times, spike_amplitudes, spike_depths
173-
174-
175-
def _template_positions_amplitudes(
144+
def get_template_info_and_spike_amplitudes(
176145
templates: np.ndarray,
177146
inverse_whitening_matrix: np.ndarray,
178147
ycoords: np.ndarray,
@@ -256,9 +225,6 @@ def _template_positions_amplitudes(
256225
counts = np.bincount(spike_templates, minlength=num_indices)
257226
template_amplitudes = np.divide(sum_per_index, counts, out=np.zeros_like(sum_per_index), where=counts != 0)
258227

259-
# Each spike's depth is the depth of its template
260-
spike_depths = template_depths[spike_templates]
261-
262228
# Get channel with the largest amplitude (take that as the waveform)
263229
max_site = np.argmax(np.max(np.abs(templates), axis=1), axis=1)
264230

@@ -279,7 +245,6 @@ def _template_positions_amplitudes(
279245

280246
return (
281247
spike_amplitudes,
282-
spike_depths,
283248
template_depths,
284249
template_amplitudes,
285250
unwhite_templates,

0 commit comments

Comments
 (0)