Skip to content

Commit 818cf86

Browse files
committed
Try a recursive nonrigid alignment, doesn't really work.
1 parent 1bfe176 commit 818cf86

File tree

4 files changed

+133
-21
lines changed

4 files changed

+133
-21
lines changed

debugging/alignment_utils.py

Lines changed: 127 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,132 @@ def run_alignment_estimation(
392392
# of iterative template. Want to leave separate for now for prototyping
393393
# but should combine the shared parts later.
394394

395+
# TODO: I wonder if it is better to estimate the hitsogram with finer bin size
396+
# than try and interpolate the xcorr. What about smoothing the activity histograms directly?
397+
398+
# TOOD: the iterative_template seems a little different to the interpolation
399+
# of nonrigid segments that is described in the NP2.0 paper. Oh, the KS
400+
# implementation is different to that described in the paper/ where is the
401+
# Akima spline interpolation?
402+
403+
# TODO: make sure that the num bins will always align.
404+
# Apply the linear shifts, don't roll, as we don't want circular (why would the top of the probe appear at the bottom?)
405+
# They roll the windowed version that is zeros, but here we want all done up front to simplify later code
406+
407+
# TODO: what about the Akima Spline, this would be cooler
408+
# TODO: try out logarithmic scaling as some neurons fire too much...
409+
410+
num_bins = spatial_bin_centers.shape[0]
411+
# assert spatial_bin_centers.shape[0] %2 == 0, "num channels must be even"
412+
413+
min_num_bins = 10
414+
divs = 2**np.arange(10)
415+
divs = divs[np.where(num_bins / divs > min_num_bins)]
416+
417+
accumulated_shifts = []
418+
419+
step_shifts = []
420+
for step_idx, num_steps in enumerate(divs): # TOOD: use this compeltely dynamically, for rigid, kilosort-like and recursive
421+
422+
423+
bin_edges = np.arange(num_steps + 1) * (num_bins/num_steps)
424+
print(bin_edges)
425+
bin_edges = bin_edges.astype(int)
426+
# bin_edges = np.arange(1, num_steps + 1)[::-1]
427+
# bin_edges = np.r_[0, (num_bins / bin_edges).astype(int)]
428+
429+
non_rigid_windows = np.zeros((num_steps, num_bins))
430+
non_rigid_window_centers = np.zeros(num_steps)
431+
432+
for i in range(num_steps):
433+
non_rigid_windows[i, bin_edges[i]:bin_edges[i+1]] = 1
434+
non_rigid_window_centers[i] = np.mean(spatial_bin_centers[non_rigid_windows[i].astype(bool)]) # TODO: maybe not mean, maybe (max - min)/ 2 ... :S
435+
436+
if num_steps == 1:
437+
shifted_histograms = all_session_hists[:, :, np.newaxis]
438+
else:
439+
shifted_histograms = np.repeat(shifted_histograms, 2, axis=2)
440+
441+
# shifted_histograms *= non_rigid_windows # TODO
442+
443+
window_shifts = []
444+
for win_idx in range(shifted_histograms.shape[2]):
445+
446+
window_hist = shifted_histograms[:, :, win_idx] * non_rigid_windows[win_idx]
447+
448+
if np.any(window_hist):
449+
# this method just xcorr the entire window does not provide subset of samples like kilosort_like
450+
window_hist_array = _compute_rigid_hist_crosscorr(num_sessions, num_bins, window_hist, robust=False)
451+
452+
all_ses_shifts = -np.mean(window_hist_array, axis=0)
453+
454+
if np.any(all_ses_shifts > 400):
455+
breakpoint()
456+
457+
else:
458+
all_ses_shifts = np.zeros(window_hist.shape[0])
459+
460+
window_shifts.append(all_ses_shifts)
461+
462+
# perform shift
463+
for ses_idx, shift in enumerate(all_ses_shifts):
464+
465+
abs_shift = np.abs(shift).astype(int)
466+
467+
if abs_shift == 0:
468+
cut_padded_hist = window_hist[ses_idx]
469+
else:
470+
pad_tuple = (0, abs_shift) if shift > 0 else (abs_shift, 0) # TODO: check direction!
471+
472+
padded_hist = np.pad(window_hist[ses_idx], pad_tuple, mode="constant")
473+
cut_padded_hist = padded_hist[abs_shift:] if shift >= 0 else padded_hist[:-abs_shift]
474+
475+
shifted_histograms[ses_idx, :, win_idx] = cut_padded_hist
476+
477+
step_shifts.append(window_shifts)
478+
479+
breakpoint()
480+
"""
481+
482+
try:
483+
splitto_binno = np.split(y, divo)
484+
except:
485+
breakpoint()
486+
487+
hist_idx = np.where(i in binno for binno in splitto_binno)[0]
488+
489+
for ses_idx in range(num_sessions):
490+
491+
shift = int(accumulated_shifts[seg_idx][ses_idx, hist_idx])
492+
493+
abs_shift = np.abs(shift)
494+
pad_tuple = (0, abs_shift) if shift > 0 else (abs_shift, 0) # TODO: check direction!
495+
496+
padded_hist = np.pad(all_session_hists[ses_idx, :], pad_tuple, mode="constant")
497+
cut_padded_hist = padded_hist[abs_shift:] if shift >= 0 else padded_hist[:-abs_shift]
498+
try:
499+
shifted_histograms[ses_idx, hist_idx, :] = cut_padded_hist
500+
except:
501+
breakpoint()
502+
503+
non_rigid_shifts = np.zeros((num_sessions, non_rigid_windows.shape[0]))
504+
for i, window in enumerate(non_rigid_windows): # TODO: use same name
505+
506+
windowed_histogram = shifted_histograms[:, i, :] * window # these are shifted, but not windows. Maybe better to window then shift like kilosort.
507+
508+
window_hist_array = _compute_rigid_hist_crosscorr(
509+
num_sessions, num_bins, windowed_histogram, robust=False
510+
# this method just xcorr the entire window does not provide subset of samples like kilosort_like
511+
)
512+
non_rigid_shifts[:, i] = -np.mean(window_hist_array, axis=0)
513+
514+
accumulated_shifts.append(non_rigid_shifts)
515+
"""
516+
517+
return optimal_shift_indices + non_rigid_shifts, non_rigid_window_centers # TODO: tidy up
518+
519+
520+
"""
395521
num_steps = 7
396522
win_step_um = (np.max(spatial_bin_centers) - np.min(spatial_bin_centers)) / num_steps
397523
@@ -405,21 +531,7 @@ def run_alignment_estimation(
405531
win_margin_um=0,
406532
# zero_threshold=None,
407533
)
408-
# TODO: I wonder if it is better to estimate the hitsogram with finer bin size
409-
# than try and interpolate the xcorr. What about smoothing the activity histograms directly?
410-
411-
# TOOD: the iterative_template seems a little different to the interpolation
412-
# of nonrigid segments that is described in the NP2.0 paper. Oh, the KS
413-
# implementation is different to that described in the paper/ where is the
414-
# Akima spline interpolation?
415-
416-
# TODO: make sure that the num bins will always align.
417-
# Apply the linear shifts, don't roll, as we don't want circular (why would the top of the probe appear at the bottom?)
418-
# They roll the windowed version that is zeros, but here we want all done up front to simplify later code
419-
420-
import matplotlib.pyplot as plt
421534
422-
# TODO: for recursive version, shift cannot be larger than previous shift!
423535
shifted_histograms = np.zeros_like(all_session_hists)
424536
for i in range(all_session_hists.shape[0]):
425537
@@ -442,9 +554,7 @@ def run_alignment_estimation(
442554
non_rigid_shifts[:, i] = -np.mean(window_hist_array, axis=0)
443555
444556
return optimal_shift_indices + non_rigid_shifts, non_rigid_window_centers # TODO: tidy up
445-
446-
# TODO: what about the Akima Spline
447-
# TODO: try out logarithmic scaling as some neurons fire too much...
557+
"""
448558

449559
def _compute_rigid_hist_crosscorr(num_sessions, num_bins, all_session_hists, robust=False):
450560
""""""

debugging/all_recordings.pickle

986 KB
Binary file not shown.

debugging/session_alignment.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,8 @@ def estimate_inter_session_displacement(
136136
min_y = np.min([np.min(locs["y"]) for locs in peak_locations_list])
137137
max_y = np.max([np.max(locs["y"]) for locs in peak_locations_list])
138138

139-
spatial_bin_edges = np.arange(min_y, max_y + bin_um, bin_um) # TODO: expose a margin...
139+
# TOOD: specifically chosen to get num bins to work!!!!!!!!!!!!!!!!!! #######################################################################################################
140+
spatial_bin_edges = np.linspace(min_y, max_y, 1024 + 1) # np.arange(min_y, max_y + bin_um, bin_um) # TODO: expose a margin...
140141
spatial_bin_centers = alignment_utils.get_bin_centers(spatial_bin_edges)
141142

142143
# Estimate an activity histogram per-session

debugging/test_session_alignment.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
# What we really want to find is maximal subset of the data that matches
4141

4242
SAVE = False
43-
PLOT = False
43+
PLOT = True
4444
BIN_UM = 2
4545

4646
if SAVE:
@@ -50,7 +50,7 @@
5050
num_units=25,
5151
recording_durations=(100, 100),
5252
recording_shifts=(
53-
(0, 0), (0, 75),
53+
(0, 0), (0, 75), # TODO: check the histogram, why is this shift not actually 75 um!??!?!? could be an x-axis plotting issue...
5454
),
5555
recording_amplitude_scalings=None, # {"method": "by_amplitude_and_firing_rate", "scalings": scalings},
5656
seed=None,
@@ -76,9 +76,10 @@
7676

7777

7878
corrected_recordings_list, motion_objects_list, extra_info = session_alignment.run_inter_session_displacement_correction(
79-
recordings_list, peaks_list, peak_locations_list, bin_um=BIN_UM, histogram_estimation_method="entire_session", alignment_method="mean_crosscorr", rigid=True
79+
recordings_list, peaks_list, peak_locations_list, bin_um=BIN_UM, histogram_estimation_method="entire_session", alignment_method="mean_crosscorr", rigid=False
8080
)
8181

82+
# TODO: make sure raster plot y-axis are aligned
8283
plotting.SessionAlignmentWidget(
8384
recordings_list,
8485
peaks_list,

0 commit comments

Comments
 (0)