Skip to content

Commit ec9c7f1

Browse files
committed
continue playing around with inter-session displacement generator.
1 parent 204578b commit ec9c7f1

File tree

5 files changed

+409
-606
lines changed

5 files changed

+409
-606
lines changed

src/spikeinterface/core/generate.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1633,6 +1633,10 @@ def generate_templates(
16331633
channel_factors = alpha * np.exp(-distances / spatial_decay)
16341634
wfs = wf[:, np.newaxis] * channel_factors[np.newaxis, :]
16351635

1636+
# if u in [3, 4]:
1637+
# import matplotlib.pyplot as plt
1638+
# breakpoint()
1639+
16361640
# This mimic a propagation delay for distant channel
16371641
propagation_speed = params["propagation_speed"][u]
16381642
if propagation_speed is not None:
@@ -1901,6 +1905,13 @@ def get_traces(
19011905
upsample_ind = self.upsample_vector[i]
19021906
template = self.templates[unit_ind, :, :, upsample_ind]
19031907

1908+
# if unit_ind == 4:
1909+
# from spikeinterface.core import order_channels_by_depth
1910+
# import matplotlib.pyplot as plt
1911+
# off = np.atleast_2d(np.arange(template.shape[1])* 10)
1912+
# plt.plot(template); plt.show()
1913+
# breakpoint()
1914+
19041915
if channel_indices is not None:
19051916
template = template[:, channel_indices]
19061917

@@ -1926,6 +1937,10 @@ def get_traces(
19261937
wf = wf * self.amplitude_vector[i]
19271938
traces[start_traces:end_traces] += wf.astype(traces.dtype, copy=False)
19281939

1940+
# TODO: does templates on one channel not overwrite template on another channel?
1941+
# no, because the waveform is ADDED, but I guess the noise is added multiple times?
1942+
# but it is gaussian, so shouldn't matter.
1943+
19291944
return traces.astype(self.dtype, copy=False)
19301945

19311946
def get_num_samples(self) -> int:
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
import spikeinterface.full as si
2+
from spikeinterface.generation.drifting_generator import generate_drifting_recording
3+
from spikeinterface.generation.session_displacement_generator import generate_inter_session_displacement_recordings
4+
import matplotlib.pyplot as plt
5+
6+
if False:
7+
8+
_, raw_recording, _ = generate_drifting_recording(
9+
num_units=10,
10+
duration=50,
11+
generate_sorting_kwargs=dict(firing_rates=(15, 25), refractory_period_ms=4.0),
12+
seed=42,
13+
generate_displacement_vector_kwargs=dict(
14+
motion_list=[
15+
dict(drift_mode="zigzag", non_rigid_gradient=0.01, t_start_drift=5, t_end_drift=45, period_s=10),
16+
],
17+
),
18+
)
19+
20+
# TODO: understand why sometimes the units are spraed all over the probe.
21+
# In my example, for example, they are clearly ok on the raw image.
22+
# but detect peaks is terrible.
23+
import numpy as np
24+
25+
default_unit_params_range = dict(
26+
alpha=(100.0, 500.0),
27+
depolarization_ms=(0.09, 0.14),
28+
repolarization_ms=(0.5, 0.8),
29+
recovery_ms=(1.0, 1.5),
30+
positive_amplitude=(0.1, 0.25),
31+
smooth_ms=(0.03, 0.07),
32+
spatial_decay=(20, 40),
33+
propagation_speed=(250.0, 350.0), # um / ms
34+
b=(0.5, 1), # (0.5, 1)
35+
c=(0.5, 1), # (0.5, 1)
36+
x_angle=(0, np.pi),
37+
y_angle=(0, np.pi),
38+
z_angle=(0, np.pi), # (0, 2)
39+
)
40+
41+
if False:
42+
rec, _ = si.generate_ground_truth_recording(
43+
num_units=5,
44+
durations=[25],
45+
num_channels=128,
46+
# generate_sorting_kwargs=dict(firing_rates=(200, 300), refractory_period_ms=2.0),
47+
# noise_kwargs=dict(noise_levels=5),
48+
# generate_templates_kwargs=dict(unit_params=dict(alpha=300)), # TODO: increase alpha default?
49+
seed=44,
50+
)
51+
rec_list = [rec]
52+
53+
default_unit_params_range["alpha"] = (400, 500) # do this or change the margin on generate_unit_locations_kwargs
54+
55+
rec_list, _ = generate_inter_session_displacement_recordings(
56+
non_rigid_gradient=None, # 0.05,
57+
num_units=5,
58+
rec_durations=(25, 25, 25), # TODO: checks on inputs
59+
rec_shifts=(
60+
0,
61+
200,
62+
400,
63+
), # WTF happening at +100 um?? maybe needs to be discrete in terms of channels or it is ignored? shouldnt matter...
64+
generate_sorting_kwargs=dict(firing_rates=(149, 150), refractory_period_ms=4.0),
65+
generate_templates_kwargs=dict(unit_params=default_unit_params_range, ms_before=1.5, ms_after=3), # mode=sphere
66+
seed=44,
67+
generate_unit_locations_kwargs=dict(
68+
margin_um=0.0, # if this is say 20, then units go off the edge of the probe and are such low amplitude they are not picked up.
69+
minimum_z=5.0,
70+
maximum_z=45.0,
71+
minimum_distance=18.0,
72+
max_iteration=100,
73+
distance_strict=False,
74+
),
75+
)
76+
77+
from spikeinterface.sortingcomponents.peak_detection import detect_peaks
78+
from spikeinterface.sortingcomponents.peak_localization import localize_peaks
79+
80+
# < 150 might be a problem for detecting alpha
81+
for rec in rec_list:
82+
83+
si.plot_traces(rec, time_range=(0, 1))
84+
# plt.show()
85+
86+
peaks = detect_peaks(rec, method="locally_exclusive")
87+
peak_locs = localize_peaks(rec, peaks, method="grid_convolution")
88+
89+
print(peaks.shape)
90+
91+
si.plot_drift_raster_map(peaks=peaks, peak_locations=peak_locs, recording=rec, clim=(-300, 0))
92+
plt.show()
93+
94+
# TODO: could try and create peaks, peak locs from
95+
# unit locations and rec.spike_vector...
96+
rec.spike_vector.size # (sample, unit, segment)
97+
# peaks (sample, channel, amplitude, segment)
98+
99+
# Why are units dropping out (distance too far from probe)
100+
# Why are the amplitudes changing (clim)
101+
# why are the traces look like they have too many units (current settings)
102+
# TODO: I think it is OK when it goes over the edge, just need to check explicitly what happens.
103+
# It will be greater than margin, etc. Is it explicitly clipped or just allowed to go outside of range?
104+
# I guess this will ignore the margin. That's fine. I think here is it just going outside of range.

0 commit comments

Comments
 (0)