35
35
def correct_inter_session_displacement (
36
36
recordings_list : list [BaseRecording ],
37
37
existing_motion_info : Optional [list [Dict ]] = None ,
38
+ keep_channels_constant = False ,
38
39
detect_kwargs = {}, # TODO: make non-mutable (same for motion.py)
39
40
select_kwargs = {},
40
41
localize_peaks_kwargs = {},
@@ -44,10 +45,10 @@ def correct_inter_session_displacement(
44
45
from spikeinterface .sortingcomponents .peak_detection import detect_peaks , detect_peak_methods
45
46
from spikeinterface .sortingcomponents .peak_selection import select_peaks
46
47
from spikeinterface .sortingcomponents .peak_localization import localize_peaks , localize_peak_methods
47
- from spikeinterface .sortingcomponents .motion_estimation import estimate_motion
48
- from spikeinterface .sortingcomponents .motion_interpolation import InterpolateMotionRecording
48
+ from spikeinterface .sortingcomponents .motion . motion_estimation import estimate_motion
49
+ from spikeinterface .sortingcomponents .motion . motion_interpolation import InterpolateMotionRecording
49
50
from spikeinterface .core .node_pipeline import ExtractDenseWaveforms , run_node_pipeline
50
- from spikeinterface .sortingcomponents .motion_utils import Motion
51
+ from spikeinterface .sortingcomponents .motion . motion_utils import Motion , get_spatial_windows
51
52
52
53
# TODO: do not accept multi-segment recordings.
53
54
# TODO: check all recordings have the same probe dimensions!
@@ -101,12 +102,13 @@ def correct_inter_session_displacement(
101
102
peaks_list = [info ["peaks" ] for info in existing_motion_info ]
102
103
peak_locations_list = [info ["peak_locations" ] for info in existing_motion_info ]
103
104
104
- from spikeinterface .sortingcomponents .motion_estimation import make_2d_motion_histogram , make_3d_motion_histograms
105
+ from spikeinterface .sortingcomponents .motion . motion_utils import make_2d_motion_histogram , make_3d_motion_histograms
105
106
106
107
# make motion histogram
107
108
motion_histogram_dim = "2D" # "2D" or "3D", for now only handle 2D case
108
109
109
110
motion_histogram_list = []
111
+ all_temporal_bin_edges = [] # TODO: fix naming
110
112
111
113
bin_um = 2 # TODO: critial paraneter. easier to take no binning and gaus smooth?
112
114
@@ -125,13 +127,13 @@ def correct_inter_session_displacement(
125
127
peak_locations ,
126
128
weight_with_amplitude = False ,
127
129
direction = "y" ,
128
- bin_duration_s = recording .get_duration (segment_index = 0 ), # 1.0,
130
+ bin_s = recording .get_duration (segment_index = 0 ), # 1.0,
129
131
bin_um = bin_um ,
130
- margin_um = 50 ,
132
+ hist_margin_um = 50 ,
131
133
spatial_bin_edges = None ,
132
134
)
133
135
else :
134
- assert NotImplementedError
136
+ assert NotImplementedError # TODO: might be old API pre-dredge
135
137
motion_histogram = make_3d_motion_histograms (
136
138
recording ,
137
139
peaks ,
@@ -146,8 +148,8 @@ def correct_inter_session_displacement(
146
148
)
147
149
motion_histogram_list .append (motion_histogram [0 ].squeeze ())
148
150
# store bin edges
149
- temporal_bin_edges = motion_histogram [1 ]
150
- spatial_bin_edges = motion_histogram [2 ]
151
+ all_temporal_bin_edges . append ( motion_histogram [1 ])
152
+ spatial_bin_edges_um = motion_histogram [2 ] # should be same across all recordings
151
153
152
154
# Do some checks on temporal and spatial bin edges that they are all the same?
153
155
# TODO: do some smoothing? Try some other methds (e.g. NMI, KL divergence)
@@ -183,6 +185,12 @@ def correct_inter_session_displacement(
183
185
# TODO: think will need to make this negative
184
186
shifts [i ] = (midpoint - np .argmax (conv )) * bin_um # # TODO: the bin spacing is super important for resoltuion
185
187
188
+ # half
189
+ # TODO: need to figure out interpolation to the center point, weird;y
190
+ # the below does not work
191
+ # shifts[0] = (shifts[1] / 2)
192
+ # shifts[1] = (shifts[1] / 2) * -1
193
+ # print("SHIFTS", shifts)
186
194
# TODO: handle only the 2D case for now
187
195
# TODO: do multi-session optimisation
188
196
@@ -196,16 +204,37 @@ def correct_inter_session_displacement(
196
204
for i , recording in enumerate (recordings_list ):
197
205
198
206
# TODO: direct copy, use 'get_window' from motion machinery
199
- bin_centers = spatial_bin_edges [:- 1 ] + bin_um / 2.0
200
- n = bin_centers .size
201
- non_rigid_windows = [np .ones (n , dtype = "float64" )]
202
- middle = (spatial_bin_edges [0 ] + spatial_bin_edges [- 1 ]) / 2.0
203
- non_rigid_window_centers = np .array ([middle ])
204
-
205
- motion_array = shifts [i ] # TODO: this is the rigid case!
207
+ if False :
208
+ bin_centers = spatial_bin_edges [:- 1 ] + bin_um / 2.0
209
+ n = bin_centers .size
210
+ non_rigid_windows = [np .ones (n , dtype = "float64" )]
211
+ middle = (spatial_bin_edges [0 ] + spatial_bin_edges [- 1 ]) / 2.0
212
+ non_rigid_window_centers = np .array ([middle ])
213
+
214
+ dim = 1 # ["x", "y", "z"].index(direction)
215
+ contact_depths = recording .get_channel_locations ()[:, dim ]
216
+ spatial_bin_centers = 0.5 * (spatial_bin_edges_um [1 :] + spatial_bin_edges_um [:- 1 ])
217
+
218
+ _ , window_centers = get_spatial_windows (
219
+ contact_depths , spatial_bin_centers , rigid = True # TODO: handle non-rigid case
220
+ )
221
+ # win_shape=win_shape, TODO: handle defaults better
222
+ # win_step_um=win_step_um,
223
+ # win_scale_um=win_scale_um,
224
+ # win_margin_um=win_margin_um,
225
+ # zero_threshold=1e-5,
226
+
227
+ # if shifts[i] == 0:
228
+ ## all_recording_corrected.append(recording) # TODO
229
+ # continue
230
+ temporal_bin_edges = all_temporal_bin_edges [i ]
206
231
temporal_bins = 0.5 * (temporal_bin_edges [1 :] + temporal_bin_edges [:- 1 ])
232
+
233
+ motion_array = np .zeros ((temporal_bins .size , window_centers .size )) # TODO: check this is the expected shape
234
+ motion_array [:, :] = shifts [i ] # TODO: this is the rigid case!
235
+
207
236
motion = Motion (
208
- [np . atleast_2d ( motion_array ) ], [temporal_bins ], non_rigid_window_centers , direction = "y"
237
+ [motion_array ], [temporal_bins ], window_centers , direction = "y"
209
238
) # will be same for all except for shifts
210
239
all_motion_info .append (motion ) # not certain on this
211
240
@@ -225,4 +254,15 @@ def correct_inter_session_displacement(
225
254
"all_motion_histograms" : motion_histogram_list , # TODO: naming
226
255
"all_shifts" : shifts ,
227
256
}
257
+
258
+ if keep_channels_constant :
259
+ # TODO: use set
260
+ import functools
261
+
262
+ common_channels = functools .reduce (
263
+ np .intersect1d , [recording .channel_ids for recording in all_recording_corrected ]
264
+ )
265
+
266
+ all_recording_corrected = [recording .channel_slice (common_channels ) for recording in all_recording_corrected ]
267
+
228
268
return all_recording_corrected , displacement_info # TODO: output more stuff later e.g. the Motion object
0 commit comments