Skip to content

Commit 64dc406

Browse files
committed
Finish naive alignment to first session.
1 parent 87f129a commit 64dc406

File tree

1 file changed

+21
-13
lines changed

1 file changed

+21
-13
lines changed

src/spikeinterface/preprocessing/inter_session_displacement.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -105,13 +105,16 @@ def correct_inter_session_displacement(
105105

106106
motion_histogram_list = []
107107

108+
bin_um = 2 # TODO: critial paraneter. easier to take no binning and gaus smooth?
109+
108110
# TODO: own function
109111
for recording, peaks, peak_locations in zip(
110112
recordings_list,
111113
peaks_list,
112114
peak_locations_list, # TODO: this is overwriting above variable names. Own function!
113115
): # TODO: do a lot of checks to make sure these bin sizes make sesnese
114116
# Do some checks on temporal and spatial bin edges that they are all the same?
117+
115118
if motion_histogram_dim == "2D":
116119
motion_histogram = make_2d_motion_histogram(
117120
recording,
@@ -120,7 +123,7 @@ def correct_inter_session_displacement(
120123
weight_with_amplitude=False,
121124
direction="y",
122125
bin_duration_s=recording.get_duration(segment_index=0), # 1.0,
123-
bin_um=2.0,
126+
bin_um=bin_um,
124127
margin_um=50,
125128
spatial_bin_edges=None,
126129
)
@@ -131,7 +134,7 @@ def correct_inter_session_displacement(
131134
peak_locations,
132135
direction="y",
133136
bin_duration_s=recording.get_duration(segment_index=0), # 1.0,
134-
bin_um=2.0,
137+
bin_um=bin_um,
135138
margin_um=50,
136139
num_amp_bins=20,
137140
log_transform=True,
@@ -148,32 +151,37 @@ def correct_inter_session_displacement(
148151
# which is (2P-1)^N where P is length of motion histogram and N is number of recordings.
149152
# TODO: double-check what is done in kilosort-like / DREDGE
150153
# put histograms into X and do X^T X then mean(U), det or eigs of covar mat
154+
# can try iterative template. Not sure it will work so well taking the mean
155+
# over only a few histograms that could be wildy different.
156+
# Displacemene
151157
num_recordings = len(recordings_list)
152158

153159
shifts = np.zeros(num_recordings)
154160

155161
# TODO: not checked any of the below properly
156162
first_hist = motion_histogram_list[0] / motion_histogram_list[0].sum()
157-
first_hist -= np.mean(first_hist) # TODO: pretty sure not necessary
163+
# first_hist -= np.mean(first_hist) # TODO: pretty sure not necessary
158164

159165
for i in range(1, num_recordings):
160166

161167
hist = motion_histogram_list[i] / motion_histogram_list[i].sum()
162-
hist -= np.mean(hist) # TODO: pretty sure not necessary
163-
164-
conv = np.convolve(first_hist, hist, mode="full")
168+
# hist -= np.mean(hist) # TODO: pretty sure not necessary
169+
conv = np.correlate(first_hist, hist, mode="full")
165170

166-
if hist.size % 2 == 0:
167-
midpoint = hist.size / 2
171+
if conv.size % 2 == 0:
172+
midpoint = conv.size / 2
168173
else:
169-
midpoint = (hist.size - 1) / 2 # TODO: carefully double check!
174+
midpoint = (conv.size - 1) / 2 # TODO: carefully double check!
170175

171-
shifts[i] = midpoint - np.argmax(conv) # TODO: the bin spacing is super important for resoltuion
172-
173-
# TODO
174-
import matplotlib.pyplot as plt
176+
# TODO: think will need to make this negative
177+
shifts[i] = (midpoint - np.argmax(conv)) * bin_um # # TODO: the bin spacing is super important for resoltuion
175178

176179
# TODO: handle only the 2D case for now
177180
# TODO: do multi-session optimisation
178181

179182
# Handle drift
183+
184+
# TODO: add motion to motion if exists otherwise create InterpolateMotionRecording object!
185+
# Will need the y-axis bins for this
186+
motion = Motion([motion_array], [temporal_bins], non_rigid_window_centers, direction=direction)
187+
recording_corrected = InterpolateMotionRecording(recording, motion, **interpolate_motion_kwargs)

0 commit comments

Comments
 (0)