Skip to content

Commit a0cee31

Browse files
committed
Very minimal first working version.
1 parent 27d66ba commit a0cee31

File tree

3 files changed

+70
-19
lines changed

3 files changed

+70
-19
lines changed

src/spikeinterface/preprocessing/inter_session_displacement.py

Lines changed: 57 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
def correct_inter_session_displacement(
3636
recordings_list: list[BaseRecording],
3737
existing_motion_info: Optional[list[Dict]] = None,
38+
keep_channels_constant=False,
3839
detect_kwargs={}, # TODO: make non-mutable (same for motion.py)
3940
select_kwargs={},
4041
localize_peaks_kwargs={},
@@ -44,10 +45,10 @@ def correct_inter_session_displacement(
4445
from spikeinterface.sortingcomponents.peak_detection import detect_peaks, detect_peak_methods
4546
from spikeinterface.sortingcomponents.peak_selection import select_peaks
4647
from spikeinterface.sortingcomponents.peak_localization import localize_peaks, localize_peak_methods
47-
from spikeinterface.sortingcomponents.motion_estimation import estimate_motion
48-
from spikeinterface.sortingcomponents.motion_interpolation import InterpolateMotionRecording
48+
from spikeinterface.sortingcomponents.motion.motion_estimation import estimate_motion
49+
from spikeinterface.sortingcomponents.motion.motion_interpolation import InterpolateMotionRecording
4950
from spikeinterface.core.node_pipeline import ExtractDenseWaveforms, run_node_pipeline
50-
from spikeinterface.sortingcomponents.motion_utils import Motion
51+
from spikeinterface.sortingcomponents.motion.motion_utils import Motion, get_spatial_windows
5152

5253
# TODO: do not accept multi-segment recordings.
5354
# TODO: check all recordings have the same probe dimensions!
@@ -101,12 +102,13 @@ def correct_inter_session_displacement(
101102
peaks_list = [info["peaks"] for info in existing_motion_info]
102103
peak_locations_list = [info["peak_locations"] for info in existing_motion_info]
103104

104-
from spikeinterface.sortingcomponents.motion_estimation import make_2d_motion_histogram, make_3d_motion_histograms
105+
from spikeinterface.sortingcomponents.motion.motion_utils import make_2d_motion_histogram, make_3d_motion_histograms
105106

106107
# make motion histogram
107108
motion_histogram_dim = "2D" # "2D" or "3D", for now only handle 2D case
108109

109110
motion_histogram_list = []
111+
all_temporal_bin_edges = [] # TODO: fix naming
110112

111113
bin_um = 2 # TODO: critial paraneter. easier to take no binning and gaus smooth?
112114

@@ -125,13 +127,13 @@ def correct_inter_session_displacement(
125127
peak_locations,
126128
weight_with_amplitude=False,
127129
direction="y",
128-
bin_duration_s=recording.get_duration(segment_index=0), # 1.0,
130+
bin_s=recording.get_duration(segment_index=0), # 1.0,
129131
bin_um=bin_um,
130-
margin_um=50,
132+
hist_margin_um=50,
131133
spatial_bin_edges=None,
132134
)
133135
else:
134-
assert NotImplementedError
136+
assert NotImplementedError # TODO: might be old API pre-dredge
135137
motion_histogram = make_3d_motion_histograms(
136138
recording,
137139
peaks,
@@ -146,8 +148,8 @@ def correct_inter_session_displacement(
146148
)
147149
motion_histogram_list.append(motion_histogram[0].squeeze())
148150
# store bin edges
149-
temporal_bin_edges = motion_histogram[1]
150-
spatial_bin_edges = motion_histogram[2]
151+
all_temporal_bin_edges.append(motion_histogram[1])
152+
spatial_bin_edges_um = motion_histogram[2] # should be same across all recordings
151153

152154
# Do some checks on temporal and spatial bin edges that they are all the same?
153155
# TODO: do some smoothing? Try some other methds (e.g. NMI, KL divergence)
@@ -183,6 +185,12 @@ def correct_inter_session_displacement(
183185
# TODO: think will need to make this negative
184186
shifts[i] = (midpoint - np.argmax(conv)) * bin_um # # TODO: the bin spacing is super important for resoltuion
185187

188+
# half
189+
# TODO: need to figure out interpolation to the center point, weird;y
190+
# the below does not work
191+
# shifts[0] = (shifts[1] / 2)
192+
# shifts[1] = (shifts[1] / 2) * -1
193+
# print("SHIFTS", shifts)
186194
# TODO: handle only the 2D case for now
187195
# TODO: do multi-session optimisation
188196

@@ -196,16 +204,37 @@ def correct_inter_session_displacement(
196204
for i, recording in enumerate(recordings_list):
197205

198206
# TODO: direct copy, use 'get_window' from motion machinery
199-
bin_centers = spatial_bin_edges[:-1] + bin_um / 2.0
200-
n = bin_centers.size
201-
non_rigid_windows = [np.ones(n, dtype="float64")]
202-
middle = (spatial_bin_edges[0] + spatial_bin_edges[-1]) / 2.0
203-
non_rigid_window_centers = np.array([middle])
204-
205-
motion_array = shifts[i] # TODO: this is the rigid case!
207+
if False:
208+
bin_centers = spatial_bin_edges[:-1] + bin_um / 2.0
209+
n = bin_centers.size
210+
non_rigid_windows = [np.ones(n, dtype="float64")]
211+
middle = (spatial_bin_edges[0] + spatial_bin_edges[-1]) / 2.0
212+
non_rigid_window_centers = np.array([middle])
213+
214+
dim = 1 # ["x", "y", "z"].index(direction)
215+
contact_depths = recording.get_channel_locations()[:, dim]
216+
spatial_bin_centers = 0.5 * (spatial_bin_edges_um[1:] + spatial_bin_edges_um[:-1])
217+
218+
_, window_centers = get_spatial_windows(
219+
contact_depths, spatial_bin_centers, rigid=True # TODO: handle non-rigid case
220+
)
221+
# win_shape=win_shape, TODO: handle defaults better
222+
# win_step_um=win_step_um,
223+
# win_scale_um=win_scale_um,
224+
# win_margin_um=win_margin_um,
225+
# zero_threshold=1e-5,
226+
227+
# if shifts[i] == 0:
228+
## all_recording_corrected.append(recording) # TODO
229+
# continue
230+
temporal_bin_edges = all_temporal_bin_edges[i]
206231
temporal_bins = 0.5 * (temporal_bin_edges[1:] + temporal_bin_edges[:-1])
232+
233+
motion_array = np.zeros((temporal_bins.size, window_centers.size)) # TODO: check this is the expected shape
234+
motion_array[:, :] = shifts[i] # TODO: this is the rigid case!
235+
207236
motion = Motion(
208-
[np.atleast_2d(motion_array)], [temporal_bins], non_rigid_window_centers, direction="y"
237+
[motion_array], [temporal_bins], window_centers, direction="y"
209238
) # will be same for all except for shifts
210239
all_motion_info.append(motion) # not certain on this
211240

@@ -225,4 +254,15 @@ def correct_inter_session_displacement(
225254
"all_motion_histograms": motion_histogram_list, # TODO: naming
226255
"all_shifts": shifts,
227256
}
257+
258+
if keep_channels_constant:
259+
# TODO: use set
260+
import functools
261+
262+
common_channels = functools.reduce(
263+
np.intersect1d, [recording.channel_ids for recording in all_recording_corrected]
264+
)
265+
266+
all_recording_corrected = [recording.channel_slice(common_channels) for recording in all_recording_corrected]
267+
228268
return all_recording_corrected, displacement_info # TODO: output more stuff later e.g. the Motion object

src/spikeinterface/sortingcomponents/motion/motion_interpolation.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,10 @@ def interpolate_motion_on_traces(
120120
time_bins = interpolation_time_bin_centers_s
121121
if time_bins is None:
122122
time_bins = motion.temporal_bins_s[segment_index]
123-
bin_s = time_bins[1] - time_bins[0]
123+
bin_s = (
124+
time_bins[1] - time_bins[0] if time_bins.size > 1 else time_bins * 2
125+
) # TODO: check this is * 2 but yes must be because its in the middle NO ITS NOT if first time is not 0
126+
# must use a different stragery
124127
bins_start = time_bins[0] - 0.5 * bin_s
125128
# nearest bin center for each frame?
126129
bin_inds = (times - bins_start) // bin_s
@@ -130,7 +133,7 @@ def interpolate_motion_on_traces(
130133
np.clip(bin_inds, 0, time_bins.size, out=bin_inds)
131134

132135
# -- what are the possibilities here anyway?
133-
bins_here = np.arange(bin_inds[0], bin_inds[-1] + 1)
136+
bins_here = np.arange(bin_inds[0], bin_inds[-1] + 1) # TODO: just replace this with 0
134137

135138
# inperpolation kernel will be the same per temporal bin
136139
interp_times = np.empty(total_num_chans)
@@ -307,12 +310,14 @@ def __init__(
307310
sigma_um=sigma_um, p=p, num_closest=num_closest, **spatial_interpolation_kwargs
308311
)
309312
if border_mode == "remove_channels":
313+
310314
locs = channel_locations[:, motion.dim]
311315
l0, l1 = np.min(locs), np.max(locs)
312316

313317
# check if channels stay inside the probe extents for all segments
314318
channel_inside = np.ones(locs.shape[0], dtype="bool")
315319
for segment_index in range(recording.get_num_segments()):
320+
316321
# evaluate the positions of all channels over all time bins
317322
channel_displacements = motion.get_displacement_at_time_and_depth(
318323
times_s=motion.temporal_bins_s[segment_index],

src/spikeinterface/sortingcomponents/motion/motion_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,12 @@ def get_displacement_at_time_and_depth(self, times_s, locations_um, segment_inde
148148
# reshape to grid domain shape if necessary
149149
displacement = displacement.reshape(out_shape)
150150

151+
# TODO: hacky
152+
if self.temporal_bins_s[segment_index].size == 1 and self.spatial_bins_um.size == 1:
153+
assert np.all(np.isnan(displacement))
154+
assert self.displacement[segment_index].size == 1
155+
displacement[:] = self.displacement[segment_index]
156+
151157
return displacement
152158

153159
def to_dict(self):

0 commit comments

Comments
 (0)