Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -122,14 +122,23 @@ def interpolate_motion_on_traces(
time_bins = interpolation_time_bin_centers_s
if time_bins is None:
time_bins = motion.temporal_bins_s[segment_index]
bin_s = time_bins[1] - time_bins[0]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to check my understanding, the searchsorted on the bin_edges is functionally equivalent to this approach? (but of course searchsorted is less verbose)

bins_start = time_bins[0] - 0.5 * bin_s
# nearest bin center for each frame?
bin_inds = (times - bins_start) // bin_s
bin_inds = bin_inds.astype(int)

# nearest interpolation bin:
# seachsorted(b, t, side="right") == i means that b[i-1] <= t < b[i]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# seachsorted(b, t, side="right") == i means that b[i-1] <= t < b[i]
# searchsorted(b, t, side="right") == i means that b[i-1] <= t < b[i]

# hence the -1. doing it with "left" is not as nice -- we want t==b[0]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wow, I cannot believe that is not the default behaviour of "left"!

# to lead to i=1 (rounding down).
# time_bins are bin centers, but we want to snap to the nearest center.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't quite understand this line

# idea is to get the left bin edges and bin the interp times.
# this is like subtracting bin_dt_s/2, but allows non-equally-spaced bins.
bin_left = np.zeros_like(time_bins)
# it's fine to use the first bin center for the first left edge
bin_left[0] = time_bins[0]
bin_left[1:] = 0.5 * (time_bins[1:] + time_bins[:-1])
bin_inds = np.searchsorted(bin_left, times, side="right") - 1

# the time bins may not cover the whole set of times in the recording,
# so we need to clip these indices to the valid range
np.clip(bin_inds, 0, time_bins.size, out=bin_inds)
np.clip(bin_inds, 0, time_bins.size - 1, out=bin_inds)

# -- what are the possibilities here anyway?
bins_here = np.arange(bin_inds[0], bin_inds[-1] + 1)
Expand Down Expand Up @@ -433,9 +442,6 @@ def __init__(
self.motion = motion

def get_traces(self, start_frame, end_frame, channel_indices):
if self.time_vector is not None:
raise NotImplementedError("InterpolateMotionRecording does not yet support recordings with time_vectors.")

if start_frame is None:
start_frame = 0
if end_frame is None:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
from pathlib import Path
import warnings

import numpy as np
import pytest
import spikeinterface.core as sc
from spikeinterface import download_dataset
from spikeinterface.sortingcomponents.motion import Motion
from spikeinterface.sortingcomponents.motion.motion_interpolation import (
InterpolateMotionRecording,
correct_motion_on_peaks,
interpolate_motion,
interpolate_motion_on_traces,
)
from spikeinterface.sortingcomponents.motion import Motion
from spikeinterface.sortingcomponents.tests.common import make_dataset


Expand Down Expand Up @@ -115,6 +113,66 @@ def test_interpolation_simple():
assert np.all(traces_corrected[:, 2:] == 0)


def test_cross_band_interpolation():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a really nice test, super useful also for conceptualising the cross-band interpolation

"""Simple version of using LFP to interpolate AP data

This also tests the time vector implementation in interpolation.
The idea is to have two recordings which are all 0s with a 1 that
moves from one channel to another after 3s. They're at different
sampling frequencies. motion estimation in one sampling frequency
applied to the other should still lead to perfect correction.
"""
from spikeinterface.sortingcomponents.motion import estimate_motion

# sampling freqs and timing for AP and LFP recordings
fs_lfp = 50.0
fs_ap = 300.0
t_start = 10.0
total_duration = 5.0
nt_lfp = int(fs_lfp * total_duration)
nt_ap = int(fs_ap * total_duration)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is nt num_timepoints? could it be expanded?

t_switch = 3

# because interpolation uses bin centers logic, there will be a half
# bin offset at the change point in the AP recording.
halfbin_ap_lfp = int(0.5 * (fs_ap / fs_lfp))

# channel geometry
nc = 10
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could this be num_channels or num_chans?

geom = np.c_[np.zeros(nc), np.arange(nc)]

# make an LFP recording which drifts a bit
traces_lfp = np.zeros((nt_lfp, nc))
traces_lfp[: int(t_switch * fs_lfp), 5] = 1.0
traces_lfp[int(t_switch * fs_lfp) :, 6] = 1.0
rec_lfp = sc.NumpyRecording(traces_lfp, sampling_frequency=fs_lfp)
rec_lfp.set_dummy_probe_from_locations(geom)

# same for AP
traces_ap = np.zeros((nt_ap, nc))
traces_ap[: int(t_switch * fs_ap) - halfbin_ap_lfp, 5] = 1.0
traces_ap[int(t_switch * fs_ap) - halfbin_ap_lfp :, 6] = 1.0
rec_ap = sc.NumpyRecording(traces_ap, sampling_frequency=fs_ap)
rec_ap.set_dummy_probe_from_locations(geom)

# set times for both, and silence the warning
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=UserWarning)
rec_lfp.set_times(t_start + np.arange(nt_lfp) / fs_lfp)
rec_ap.set_times(t_start + np.arange(nt_ap) / fs_ap)

# estimate motion
motion = estimate_motion(rec_lfp, method="dredge_lfp", rigid=True)

# nearest to keep it simple
rec_corrected = interpolate_motion(rec_ap, motion, spatial_interpolation_method="nearest", num_closest=2)
traces_corrected = rec_corrected.get_traces()
target = np.zeros((nt_ap, nc - 2))
target[:, 4] = 1
ii, jj = np.nonzero(traces_corrected)
assert np.array_equal(traces_corrected, target)


def test_InterpolateMotionRecording():
rec, sorting = make_dataset()
motion = make_fake_motion(rec)
Expand Down Expand Up @@ -148,5 +206,6 @@ def test_InterpolateMotionRecording():
if __name__ == "__main__":
# test_correct_motion_on_peaks()
# test_interpolate_motion_on_traces()
test_interpolation_simple()
test_InterpolateMotionRecording()
# test_interpolation_simple()
# test_InterpolateMotionRecording()
test_cross_band_interpolation()