Skip to content

Commit c0eb2f5

Browse files
committed
Adding a few more options.
1 parent ff2de84 commit c0eb2f5

File tree

4 files changed

+104
-72
lines changed

4 files changed

+104
-72
lines changed

debugging/alignment_utils.py

Lines changed: 77 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
# -----------------------------------------------------------------------------
2020

2121
# TODO: this function might be pointless
22-
def get_entire_session_hist(recording, peaks, peak_locations, spatial_bin_edges):
22+
def get_entire_session_hist(recording, peaks, peak_locations, spatial_bin_edges, log_scale):
2323
"""
2424
TODO: assumes 1-segment recording
2525
"""
@@ -41,11 +41,14 @@ def get_entire_session_hist(recording, peaks, peak_locations, spatial_bin_edges)
4141

4242
spatial_centers = get_bin_centers(spatial_bin_edges)
4343

44+
if log_scale:
45+
entire_session_hist = np.log10(1 + entire_session_hist)
46+
4447
return entire_session_hist, temporal_bin_edges, spatial_centers
4548

4649

4750
def get_chunked_histogram( # TODO: this function might be pointless
48-
recording, peaks, peak_locations, bin_s, spatial_bin_edges, weight_with_amplitude=False
51+
recording, peaks, peak_locations, bin_s, spatial_bin_edges, log_scale, weight_with_amplitude=False
4952
):
5053
chunked_session_hist, temporal_bin_edges, _ = \
5154
make_2d_motion_histogram(
@@ -66,6 +69,9 @@ def get_chunked_histogram( # TODO: this function might be pointless
6669
bin_times = np.diff(temporal_bin_edges)[:, np.newaxis]
6770
chunked_session_hist /= bin_times
6871

72+
if log_scale:
73+
chunked_session_hist = np.log10(1 + chunked_session_hist)
74+
6975
return chunked_session_hist, temporal_centers, spatial_centers
7076

7177
# -----------------------------------------------------------------------------
@@ -354,72 +360,81 @@ def run_kilosort_like_rigid_registration(all_hists, non_rigid_windows):
354360
return -optimal_shift_indices # TODO: these are reversed at this stage
355361

356362

363+
# TODO: I wonder if it is better to estimate the hitsogram with finer bin size
364+
# than try and interpolate the xcorr. What about smoothing the activity histograms directly?
365+
366+
# TOOD: the iterative_template seems a little different to the interpolation
367+
# of nonrigid segments that is described in the NP2.0 paper. Oh, the KS
368+
# implementation is different to that described in the paper/ where is the
369+
# Akima spline interpolation?
370+
371+
# TODO: make sure that the num bins will always align.
372+
# Apply the linear shifts, don't roll, as we don't want circular (why would the top of the probe appear at the bottom?)
373+
# They roll the windowed version that is zeros, but here we want all done up front to simplify later code
374+
375+
# TODO: this is basically a re-implimentation of the nonrigid part
376+
# of iterative template. Want to leave separate for now for prototyping
377+
# but should combine the shared parts later.
378+
379+
# TOOD: important differenence, this does not roll, will need to test when new spikes are added...
380+
381+
# TODO: try out logarithmic scaling as some neurons fire too much...
382+
383+
384+
357385
def run_alignment_estimation(
358-
all_session_hists, spatial_bin_centers, rigid, robust=False
386+
all_session_hists, spatial_bin_centers, rigid, num_nonrigid_bins, robust=False
359387
):
360388
"""
361389
"""
390+
# TODO: figure out best way to represent this, should probably be
391+
# suffix _list instead of prefixed all_ for consistency
362392
if isinstance(all_session_hists, list):
363-
all_session_hists = np.array(all_session_hists) # TODO: figure out best way to represent this, should probably be suffix _list instead of prefixed all_ for consistency
393+
all_session_hists = np.array(all_session_hists)
364394

365395
num_bins = spatial_bin_centers.size
366396
num_sessions = all_session_hists.shape[0]
367397

398+
# TODO: rename
368399
hist_array = _compute_rigid_hist_crosscorr(
369400
num_sessions, num_bins, all_session_hists, robust
370-
) # TODO: rename
401+
)
371402

372403
optimal_shift_indices = -np.mean(hist_array, axis=0)[:, np.newaxis]
373-
# (2, 1)
404+
405+
# First, perform the rigid alignment.
406+
374407
if rigid:
375-
# TODO: this just shifts everything to the center. It would be (better?)
376-
# to optmize so all shifts are about the same.
408+
# TODO: used to get window center, for now just get them from the spatial bin
409+
# centers and use no margin, which was applied earlier. Same below.
377410
non_rigid_windows, non_rigid_window_centers = get_spatial_windows(
378411
spatial_bin_centers,
379-
# TODO: used to get window center, for now just get them from the spatial bin centers and use no margin, which was applied earlier
380412
spatial_bin_centers,
381413
rigid=True,
382-
win_shape="gaussian", # rect
383-
win_step_um=None, # TODO: expose! CHECK THIS!
384-
# win_scale_um=win_scale_um,
414+
win_shape="gaussian",
415+
win_step_um=None,
385416
win_margin_um=0,
386-
# zero_threshold=None,
417+
# win_scale_um=win_scale_um,
418+
# zero_threshold=None,
387419
)
388420

389-
return optimal_shift_indices, non_rigid_window_centers # TODO: rename rigid, also this is weird to pass back bins in the rigid case
390-
391-
# TODO: this is basically a re-implimentation of the nonrigid part
392-
# of iterative template. Want to leave separate for now for prototyping
393-
# but should combine the shared parts later.
421+
# TODO: rename rigid, also this is weird to pass back bins in the rigid case
422+
return optimal_shift_indices, non_rigid_window_centers
394423

395-
num_steps = 7
396-
win_step_um = (np.max(spatial_bin_centers) - np.min(spatial_bin_centers)) / num_steps
424+
win_step_um = (np.max(spatial_bin_centers) - np.min(spatial_bin_centers)) / num_nonrigid_bins
397425

398426
non_rigid_windows, non_rigid_window_centers = get_spatial_windows(
399-
spatial_bin_centers, # TODO: used to get window center, for now just get them from the spatial bin centers and use no margin, which was applied earlier
427+
spatial_bin_centers,
400428
spatial_bin_centers,
401429
rigid=False,
402-
win_shape="gaussian", # rect
430+
win_shape="gaussian",
403431
win_step_um=win_step_um, # TODO: expose!
404-
# win_scale_um=win_scale_um,
405432
win_margin_um=0,
406-
# zero_threshold=None,
433+
# win_scale_um=win_scale_um,
434+
# zero_threshold=None,
407435
)
408-
# TODO: I wonder if it is better to estimate the hitsogram with finer bin size
409-
# than try and interpolate the xcorr. What about smoothing the activity histograms directly?
410-
411-
# TOOD: the iterative_template seems a little different to the interpolation
412-
# of nonrigid segments that is described in the NP2.0 paper. Oh, the KS
413-
# implementation is different to that described in the paper/ where is the
414-
# Akima spline interpolation?
415-
416-
# TODO: make sure that the num bins will always align.
417-
# Apply the linear shifts, don't roll, as we don't want circular (why would the top of the probe appear at the bottom?)
418-
# They roll the windowed version that is zeros, but here we want all done up front to simplify later code
419436

420-
import matplotlib.pyplot as plt
421-
422-
# TODO: for recursive version, shift cannot be larger than previous shift!
437+
# Shift the histograms according to the rigid shift
423438
shifted_histograms = np.zeros_like(all_session_hists)
424439
for i in range(all_session_hists.shape[0]):
425440

@@ -431,20 +446,39 @@ def run_alignment_estimation(
431446
cut_padded_hist = padded_hist[abs_shift:] if shift > 0 else padded_hist[:-abs_shift]
432447
shifted_histograms[i, :] = cut_padded_hist
433448

449+
# For each nonrigid window, compute the shift
434450
non_rigid_shifts = np.zeros((num_sessions, non_rigid_windows.shape[0]))
435-
for i, window in enumerate(non_rigid_windows): # TODO: use same name
451+
for i, window in enumerate(non_rigid_windows):
436452

437453
windowed_histogram = shifted_histograms * window
438454

455+
# NOTE: this method just xcorr the entire window,
456+
# does not provide subset of samples like kilosort_like
439457
window_hist_array = _compute_rigid_hist_crosscorr(
440-
num_sessions, num_bins, windowed_histogram, robust=False # this method just xcorr the entire window does not provide subset of samples like kilosort_like
458+
num_sessions, num_bins, windowed_histogram, robust=False
441459
)
442460
non_rigid_shifts[:, i] = -np.mean(window_hist_array, axis=0)
443461

444-
return optimal_shift_indices + non_rigid_shifts, non_rigid_window_centers # TODO: tidy up
462+
akima = False # TODO: decide whether to keep, factor to own function
463+
if akima:
464+
from scipy.interpolate import Akima1DInterpolator
465+
x = win_step_um * np.arange(non_rigid_windows.shape[0])
466+
xs = spatial_bin_centers
467+
468+
new_nonrigid_shifts = np.zeros((non_rigid_shifts.shape[0], num_bins))
469+
for ses_idx in range(non_rigid_shifts.shape[0]):
470+
471+
y = non_rigid_shifts[ses_idx]
472+
y_new = Akima1DInterpolator(x, y, method="akima", extrapolate=True)(xs) # requires scipy 14
473+
new_nonrigid_shifts[ses_idx, :] = y_new
474+
475+
shifts = optimal_shift_indices + new_nonrigid_shifts
476+
non_rigid_window_centers = spatial_bin_centers
477+
else:
478+
shifts = optimal_shift_indices + non_rigid_shifts
479+
480+
return shifts, non_rigid_window_centers
445481

446-
# TODO: what about the Akima Spline
447-
# TODO: try out logarithmic scaling as some neurons fire too much...
448482

449483
def _compute_rigid_hist_crosscorr(num_sessions, num_bins, all_session_hists, robust=False):
450484
""""""

debugging/all_recordings.pickle

1.82 MB
Binary file not shown.

debugging/session_alignment.py

Lines changed: 16 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -31,18 +31,6 @@
3131
- how to measure 'confidence' (peak height? std?) larger peaks may have
3232
higher std, but we care about them more, so I think this is largely pointless.
3333
34-
weight_on_confidence = True
35-
# TODO: better handle single-time point estimation.
36-
if weight_on_confidence and np.any(std_devs): # TODO: there is no reason this can be done just for poisson, can be done for all... maybe POisson has better variance estimate, do properly!
37-
# do exponential
38-
# this is a bad idea, we literally want to weight on height!
39-
stds = np.array(std_devs)
40-
stds = stds[~(stds==0)]
41-
stds = (stds - np.min(stds)) / (np.max(stds) - np.min(stds))
42-
43-
# TODO: or weight by confidence? this is basically the same as weighting by signal due to poisson variation
44-
stds = stds * (2 - np.exp(2 * stds)) # TODO: expose param, does this even make sense? does it scale?
45-
stds[np.where(stds<0)] = 0
4634
4735
trimmed_percentiles = (20, 80) # TODO: this is originally in the context of Poisson estimation
4836
if trimmed_percentiles is not False:
@@ -66,12 +54,12 @@
6654
# entire session. Otherwise, we will want to add chunking as part of above.
6755

6856
def run_inter_session_displacement_correction(
69-
recordings_list, peaks_list, peak_locations_list, bin_um, histogram_estimation_method, alignment_method, rigid=True
57+
recordings_list, peaks_list, peak_locations_list, bin_um, histogram_estimation_method, alignment_method, log_scale=True, rigid=True, num_nonrigid_bins=7
7058
): # TOOD: rename
7159
"""
7260
"""
7361
motion_estimates_list, all_temporal_bin_centers, spatial_bin_centers, non_rigid_bin_centers, histogram_info = estimate_inter_session_displacement(
74-
recordings_list, peaks_list, peak_locations_list, bin_um, histogram_estimation_method, alignment_method, rigid
62+
recordings_list, peaks_list, peak_locations_list, bin_um, histogram_estimation_method, alignment_method, rigid, log_scale, num_nonrigid_bins
7563
)
7664

7765
# _, non_ridgid_spatial_windows = alignment_utils.get_spatial_windows_alignment(
@@ -83,7 +71,7 @@ def run_inter_session_displacement_correction(
8371
)
8472

8573
corrected_peak_locations_list, corrected_session_histogram_list = _session_displacement_correct_peaks_and_generate_histogram(
86-
corrected_recordings_list, peaks_list, peak_locations_list, motion_objects_list, spatial_bin_centers
74+
corrected_recordings_list, peaks_list, peak_locations_list, motion_objects_list, spatial_bin_centers, log_scale
8775
)
8876

8977
extra_outputs_dict = {
@@ -101,7 +89,7 @@ def run_inter_session_displacement_correction(
10189

10290

10391
def _session_displacement_correct_peaks_and_generate_histogram(
104-
recordings_list, peaks_list, peak_locations_list, motion_objects_list, spatial_bin_centers
92+
recordings_list, peaks_list, peak_locations_list, motion_objects_list, spatial_bin_centers, log_scale
10593
):
10694
"""
10795
"""
@@ -121,13 +109,13 @@ def _session_displacement_correct_peaks_and_generate_histogram(
121109

122110
for i in range(len(corrected_peak_locations_list)): # TODO: unwrap a bit
123111
corrected_session_histogram_list.append(
124-
alignment_utils.get_entire_session_hist(recordings_list[i], peaks_list[i], corrected_peak_locations_list[i], spatial_bin_centers)[0]
112+
alignment_utils.get_entire_session_hist(recordings_list[i], peaks_list[i], corrected_peak_locations_list[i], spatial_bin_centers, log_scale)[0]
125113
)
126114

127115
return corrected_peak_locations_list, corrected_session_histogram_list
128116

129117
def estimate_inter_session_displacement(
130-
recordings_list, peaks_list, peak_locations_list, bin_um, histogram_estimation_method, alignment_method, rigid
118+
recordings_list, peaks_list, peak_locations_list, bin_um, histogram_estimation_method, alignment_method, rigid, log_scale, num_nonrigid_bins
131119
):
132120
"""
133121
"""
@@ -148,7 +136,7 @@ def estimate_inter_session_displacement(
148136
for recording, peaks, peak_locations in zip(recordings_list, peaks_list, peak_locations_list):
149137

150138
session_hist, temporal_bin_centers, session_chunked_hists, chunked_hist_stdevs = _get_single_session_activity_histogram(
151-
recording, peaks, peak_locations, histogram_estimation_method, spatial_bin_edges
139+
recording, peaks, peak_locations, histogram_estimation_method, spatial_bin_edges, log_scale
152140
)
153141

154142
all_session_hists.append(session_hist)
@@ -167,7 +155,7 @@ def estimate_inter_session_displacement(
167155
) * bin_um
168156
else:
169157
all_motion_arrays, non_rigid_bin_centers = alignment_utils.run_alignment_estimation( # TODO: rename because some times rigid!
170-
all_session_hists, spatial_bin_centers, rigid
158+
all_session_hists, spatial_bin_centers, rigid, num_nonrigid_bins
171159
) # TODO: here the motion arrays are made negative initially. In motion correction they are done later. Discuss with others and make consistent.
172160
all_motion_arrays *= bin_um
173161

@@ -180,7 +168,7 @@ def estimate_inter_session_displacement(
180168
return all_motion_arrays, all_temporal_bin_centers, spatial_bin_centers, non_rigid_bin_centers, extra_outputs_dict
181169

182170

183-
def _get_single_session_activity_histogram(recording, peaks, peak_locations, method, spatial_bin_edges):
171+
def _get_single_session_activity_histogram(recording, peaks, peak_locations, method, spatial_bin_edges, log_scale):
184172
"""
185173
"""
186174
accepted_methods = ["entire_session", "chunked_mean", "chunked_median", "chunked_supremum", "chunked_poisson"]
@@ -190,18 +178,22 @@ def _get_single_session_activity_histogram(recording, peaks, peak_locations, met
190178
)
191179
# First, get the histogram across the entire session
192180
entire_session_hist, temporal_bin_centers, _ = alignment_utils.get_entire_session_hist( # TODO: assert spatial bin edges
193-
recording, peaks, peak_locations, spatial_bin_edges
181+
recording, peaks, peak_locations, spatial_bin_edges, log_scale=False
194182
)
195183

196184
if method == "entire_session":
185+
186+
if log_scale:
187+
entire_session_hist = np.log10(1 + entire_session_hist)
188+
197189
return entire_session_hist, temporal_bin_centers, None, None
198190

199191
# If method is not "entire_session", estimate the session
200192
# histogram based on histograms calculated from chunks.
201-
bin_s, percentile_lambda = alignment_utils.estimate_chunk_size(entire_session_hist, recording)
193+
bin_s, percentile_lambda = alignment_utils.estimate_chunk_size(entire_session_hist, recording) # TODO: handle with log properly
202194

203195
chunked_session_hist, chunked_temporal_bin_centers, _ = alignment_utils.get_chunked_histogram( # TODO: do the centering higher levle as duplciating
204-
recording, peaks, peak_locations, bin_s, spatial_bin_edges
196+
recording, peaks, peak_locations, bin_s, spatial_bin_edges, log_scale
205197
)
206198
session_std = np.sum(np.std(chunked_session_hist, axis=0)) / chunked_session_hist.shape[1]
207199

debugging/test_session_alignment.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,18 +39,24 @@
3939

4040
# What we really want to find is maximal subset of the data that matches
4141

42+
# TOOD: here use natural log for scaling, should prob go to base 10
43+
44+
# TODO: major check, refactor and tidy up
45+
# list out carefully all notes
46+
# handle the case where the passed recordings are not motion correction recordings.
47+
4248
SAVE = False
4349
PLOT = False
4450
BIN_UM = 2
4551

4652
if SAVE:
4753
scalings = [np.ones(25), np.r_[np.zeros(10), np.ones(15)]] # TODO: there is something wrong here, because why are the maximum histograms not removed?
4854
recordings_list, _ = generate_session_displacement_recordings(
49-
non_rigid_gradient=0.1, # None,
50-
num_units=25,
51-
recording_durations=(100, 100),
55+
non_rigid_gradient=None, # 0.05, # None,
56+
num_units=15,
57+
recording_durations=(100, 100, 100, 100),
5258
recording_shifts=(
53-
(0, 0), (0, 75),
59+
(0, 0), (0, 75), (0, -125), (0, 25),
5460
),
5561
recording_amplitude_scalings=None, # {"method": "by_amplitude_and_firing_rate", "scalings": scalings},
5662
seed=None,
@@ -76,7 +82,7 @@
7682

7783

7884
corrected_recordings_list, motion_objects_list, extra_info = session_alignment.run_inter_session_displacement_correction(
79-
recordings_list, peaks_list, peak_locations_list, bin_um=BIN_UM, histogram_estimation_method="entire_session", alignment_method="mean_crosscorr", rigid=True
85+
recordings_list, peaks_list, peak_locations_list, bin_um=BIN_UM, histogram_estimation_method="entire_session", alignment_method="mean_crosscorr", rigid=False, log_scale=True, num_nonrigid_bins=7
8086
)
8187

8288
plotting.SessionAlignmentWidget(

0 commit comments

Comments
 (0)