Skip to content

Add shift start time function. #3509

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions src/spikeinterface/core/baserecording.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,35 @@ def reset_times(self):
rs.t_start = None
rs.sampling_frequency = self.sampling_frequency

def shift_times(self, shift: int | float, segment_index: int | None = None) -> None:
"""
Shift all times by a scalar value.

Parameters
----------
shift : int | float
The shift to apply. If positive, times will be increased by `shift`.
e.g. shifting by 1 will be like the recording started 1 second later.
If negative, the start time will be decreased i.e. as if the recording
started earlier.

segment_index : int | None
The segment on which to shift the times.
If `None`, all segments will be shifted.
"""
if segment_index is None:
segments_to_shift = range(self.get_num_segments())
else:
segments_to_shift = (segment_index,)

for idx in segments_to_shift:
rs = self._recording_segments[idx]

if self.has_time_vector(segment_index=idx):
rs.time_vector += shift
else:
rs.t_start += shift

def sample_index_to_time(self, sample_ind, segment_index=None):
"""
Transform sample index into time in seconds
Expand Down
92 changes: 89 additions & 3 deletions src/spikeinterface/core/tests/test_time_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@ class TestTimeHandling:
is generated on the fly. Both time representations are tested here.
"""

# Fixtures #####
# #########################################################################
# Fixtures
# #########################################################################

@pytest.fixture(scope="session")
def time_vector_recording(self):
"""
Expand Down Expand Up @@ -95,7 +98,10 @@ def _get_fixture_data(self, request, fixture_name):
raw_recording, times_recording, all_times = time_recording_fixture
return (raw_recording, times_recording, all_times)

# Tests #####
# #########################################################################
# Tests
# #########################################################################

def test_has_time_vector(self, time_vector_recording):
"""
Test the `has_time_vector` function returns `False` before
Expand Down Expand Up @@ -305,7 +311,87 @@ def test_sorting_analyzer_get_durations_no_recording(self, time_vector_recording

assert np.array_equal(sorting_analyzer.get_total_duration(), raw_recording.get_total_duration())

# Helpers ####
@pytest.mark.parametrize("fixture_name", ["time_vector_recording", "t_start_recording"])
@pytest.mark.parametrize("shift", [-123.456, 123.456])
def test_shift_time_all_segments(self, request, fixture_name, shift):
"""
Shift the times in every segment using the `None` default, then
check that every segment of the recording is shifted as expected.
"""
_, times_recording, all_times = self._get_fixture_data(request, fixture_name)

num_segments, orig_seg_data = self._store_all_times(times_recording)

times_recording.shift_times(shift) # use default `segment_index=None`

for idx in range(num_segments):
assert np.allclose(
orig_seg_data[idx], times_recording.get_times(segment_index=idx) - shift, rtol=0, atol=1e-8
)

@pytest.mark.parametrize("fixture_name", ["time_vector_recording", "t_start_recording"])
@pytest.mark.parametrize("shift", [-123.456, 123.456])
def test_shift_times_different_segments(self, request, fixture_name, shift):
"""
Shift each segment separately, and check the shifted segment only
is shifted as expected.
"""
_, times_recording, all_times = self._get_fixture_data(request, fixture_name)

num_segments, orig_seg_data = self._store_all_times(times_recording)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe you could just do recording.copy?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice this will allow to remove a lot of boiler plate and make the assert more explicit. However check this out!:

        import copy

        _, times_recording, all_times = self._get_fixture_data(request, fixture_name)

        orig_times_recording = copy.deepcopy(times_recording)

        # num_segments, orig_seg_data = self._store_all_times(times_recording)

        times_recording.shift_times(shift)  # use default `segment_index=None`

        for idx in range(num_segments):
            assert np.allclose(
                orig_times_recording.get_times(segment_index=idx),
                times_recording.get_times(segment_index=idx) - shift,
                rtol=0, atol=1e-8
            )

This fails, orig_times_recording.get_times(segment_index=idx) gives

# array([0.00000000e+00, 3.33333333e-05, 6.66666667e-05, ...,
 #      9.99990000e+00, 9.99993333e+00, 9.99996667e+00])

It seems the time vector info is not being propagated in the copy, though it seemed to work okay for the save/ load 🤔 will have to come back to this, do you have ideas why this might be happening?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mmm yeah, I think that that's probably by design. I think cloning uses to do_dict mechanism and that coupled with saving to disk and serializing. Because timestamps can be large we probably don't want that there.

The solution is to change copy (which should be a in memory thing) to copy the time vector but I think that should be done in another PR.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A little out of scope, but do you remember when we were looking up copy vs clone at Edinburgh. I forget what it said (but basically one is deep and one is suppose to be shallow). It would probably be worth us making clear in our docstrings if our copy's are shallow or not (also for another PR).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a little confused sorry, what is clone vs copy? BTW I forgot to say based on the suggested command, I got an error:

(Pdb) times_recording.copy()
*** AttributeError: 'NoiseGeneratorRecording' object has no attribute 'copy'

should this have worked?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a little confused sorry, what is clone vs copy? BTW I forgot to say based on the suggested command, I got an error:

(Pdb) times_recording.copy()
*** AttributeError: 'NoiseGeneratorRecording' object has no attribute 'copy'

should this have worked?

No I don't think this would work. I guess @h-mayorquin was suggesting to do a:

from copy import copy

copy(times_recording)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a little confused sorry, what is clone vs copy?

Some projects use copy(deep=True/False) to indicate whether to just do metadata vs actual array data whereas other projects use clone vs copy to make the deep vs shallow distinction (again deep being everything and shallow just being metadata). Because there is no requirement for clone and copy being shallow or deep it would probably be better for us as a project to either make it clear in the docstring or for us to use the clone/copy names to make clear which functions are deep and which are shallow. Is that clearer? Again not crucial to this PR, but related to the fact that the array is not being copied (ie our copy is shallow).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @zm711 that makes sense! I opened an issue to discuss this #3546.


# For each segment, shift the segment only and check the
# times are updated as expected.
for idx in range(num_segments):

scaler = idx + 2
times_recording.shift_times(shift * scaler, segment_index=idx)

assert np.allclose(
orig_seg_data[idx], times_recording.get_times(segment_index=idx) - shift * scaler, rtol=0, atol=1e-8
)

# Just do a little check that we are not
# accidentally changing some other segments,
# which should remain unchanged at this point in the loop.
if idx != num_segments - 1:
assert np.array_equal(orig_seg_data[idx + 1], times_recording.get_times(segment_index=idx + 1))

@pytest.mark.parametrize("fixture_name", ["time_vector_recording", "t_start_recording"])
def test_save_and_load_time_shift(self, request, fixture_name, tmp_path):
"""
Save the shifted data and check the shift is propagated correctly.
"""
_, times_recording, all_times = self._get_fixture_data(request, fixture_name)

shift = 100
times_recording.shift_times(shift=shift)

times_recording.save(folder=tmp_path / "my_file")

loaded_recording = si.load_extractor(tmp_path / "my_file")

for idx in range(times_recording.get_num_segments()):
assert np.array_equal(
times_recording.get_times(segment_index=idx), loaded_recording.get_times(segment_index=idx)
)

def _store_all_times(self, recording):
"""
Convenience function to store original times of all segments to a dict.
"""
num_segments = recording.get_num_segments()
seg_data = {}

for idx in range(num_segments):
seg_data[idx] = copy.deepcopy(recording.get_times(segment_index=idx))

return num_segments, seg_data

# #########################################################################
# Helpers
# #########################################################################

def _check_times_match(self, recording, all_times):
"""
For every segment in a recording, check the `get_times()`
Expand Down