From f55a9405040310064f716015f8d9b0c976b97923 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Mon, 28 Oct 2024 14:28:05 +0000 Subject: [PATCH 1/8] Add 'shift start time' function. --- src/spikeinterface/core/baserecording.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index 5e2e9e4014..b8a0420794 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -509,6 +509,26 @@ def reset_times(self): rs.t_start = None rs.sampling_frequency = self.sampling_frequency + def shift_start_time(self, shift, segment_index=None): + """ + Shift the starting time of the times. + + shift : int | float + The shift to apply to the first time point. If positive, + the current start time will be increased by `shift`. If + negative, the start time will be decreased. + + segment_index : int | None + The segment on which to shift the times. + """ + segment_index = self._check_segment_index(segment_index) + rs = self._recording_segments[segment_index] + + if self.has_time_vector(): + rs.time_vector += shift + else: + rs.t_start += shift + def sample_index_to_time(self, sample_ind, segment_index=None): """ Transform sample index into time in seconds From 620f8013b8bf4f1332a7802dd3f6804ce068493c Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 13 Nov 2024 13:37:50 +0000 Subject: [PATCH 2/8] Apply to all segments if 'segment_index' is 'None'. --- src/spikeinterface/core/baserecording.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index b8a0420794..7392caa69b 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -521,13 +521,20 @@ def shift_start_time(self, shift, segment_index=None): segment_index : int | None The segment on which to shift the times. """ - segment_index = self._check_segment_index(segment_index) - rs = self._recording_segments[segment_index] + self._check_segment_index(segment_index) - if self.has_time_vector(): - rs.time_vector += shift + if segment_index is None: + segments_to_shift = range(self.get_num_segments()) else: - rs.t_start += shift + segments_to_shift = (segment_index,) + + for idx in segments_to_shift: + rs = self._recording_segments[idx] + + if self.has_time_vector(): + rs.time_vector += shift + else: + rs.t_start += shift def sample_index_to_time(self, sample_ind, segment_index=None): """ From 22d5dfc2a552e00d7b55d7c28681e25a1f51a711 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 13 Nov 2024 13:39:34 +0000 Subject: [PATCH 3/8] Add type hints. --- src/spikeinterface/core/baserecording.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index 7392caa69b..0af9c4bb6a 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -509,7 +509,7 @@ def reset_times(self): rs.t_start = None rs.sampling_frequency = self.sampling_frequency - def shift_start_time(self, shift, segment_index=None): + def shift_start_time(self, shift: int | float, segment_index: int | None = None) -> None: """ Shift the starting time of the times. @@ -536,15 +536,14 @@ def shift_start_time(self, shift, segment_index=None): else: rs.t_start += shift - def sample_index_to_time(self, sample_ind, segment_index=None): - """ - Transform sample index into time in seconds - """ + def sample_index_to_time(self, sample_ind: int, segment_index: int | None = None): + """ """ segment_index = self._check_segment_index(segment_index) rs = self._recording_segments[segment_index] return rs.sample_index_to_time(sample_ind) - def time_to_sample_index(self, time_s, segment_index=None): + def time_to_sample_index(self, time_s: float, segment_index: int | None = None): + """ """ segment_index = self._check_segment_index(segment_index) rs = self._recording_segments[segment_index] return rs.time_to_sample_index(time_s) From 458a3dcc201380740583ef1f075951e83ee77ed8 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 13 Nov 2024 13:43:45 +0000 Subject: [PATCH 4/8] Update name and docstring. --- src/spikeinterface/core/baserecording.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index 0af9c4bb6a..91f99f17b0 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -509,19 +509,24 @@ def reset_times(self): rs.t_start = None rs.sampling_frequency = self.sampling_frequency - def shift_start_time(self, shift: int | float, segment_index: int | None = None) -> None: + def shift_times(self, shift: int | float, segment_index: int | None = None) -> None: """ - Shift the starting time of the times. + Shift all times by a scalar value. The default behaviour is to + shift all segments uniformly. + Parameters + ---------- shift : int | float - The shift to apply to the first time point. If positive, - the current start time will be increased by `shift`. If - negative, the start time will be decreased. + The shift to apply. If positive, times will be increased by `shift`. + e.g. shifting by 1 will be like the recording started 1 second later. + If negative, the start time will be decreased i.e. as if the recording + started earlier. segment_index : int | None - The segment on which to shift the times. + The segment on which to shift the times. if `None`, all + segments will be shifted. """ - self._check_segment_index(segment_index) + self._check_segment_index(segment_index) # Check the segment index is valid only if segment_index is None: segments_to_shift = range(self.get_num_segments()) From 3e98c670a27671590613b7c1c4118780a8c47ce8 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 13 Nov 2024 18:32:48 +0000 Subject: [PATCH 5/8] Add tests. --- .../core/tests/test_time_handling.py | 92 ++++++++++++++++++- 1 file changed, 89 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/core/tests/test_time_handling.py b/src/spikeinterface/core/tests/test_time_handling.py index a129316ee7..9b7ed11bbb 100644 --- a/src/spikeinterface/core/tests/test_time_handling.py +++ b/src/spikeinterface/core/tests/test_time_handling.py @@ -15,7 +15,10 @@ class TestTimeHandling: is generated on the fly. Both time representations are tested here. """ - # Fixtures ##### + # ######################################################################### + # Fixtures + # ######################################################################### + @pytest.fixture(scope="session") def time_vector_recording(self): """ @@ -95,7 +98,10 @@ def _get_fixture_data(self, request, fixture_name): raw_recording, times_recording, all_times = time_recording_fixture return (raw_recording, times_recording, all_times) - # Tests ##### + # ######################################################################### + # Tests + # ######################################################################### + def test_has_time_vector(self, time_vector_recording): """ Test the `has_time_vector` function returns `False` before @@ -305,7 +311,87 @@ def test_sorting_analyzer_get_durations_no_recording(self, time_vector_recording assert np.array_equal(sorting_analyzer.get_total_duration(), raw_recording.get_total_duration()) - # Helpers #### + @pytest.mark.parametrize("fixture_name", ["time_vector_recording", "t_start_recording"]) + @pytest.mark.parametrize("shift", [-123.456, 123.456]) + def test_shift_time_all_segments(self, request, fixture_name, shift): + """ + Shift the times in every segment using the `None` default, then + check that every segment of the recording is shifted as expected. + """ + _, times_recording, all_times = self._get_fixture_data(request, fixture_name) + + num_segments, orig_seg_data = self._store_all_times(times_recording) + + times_recording.shift_times(shift) # use default `segment_index=None` + + for idx in range(num_segments): + assert np.allclose( + orig_seg_data[idx], times_recording.get_times(segment_index=idx) - shift, rtol=0, atol=1e-8 + ) + + @pytest.mark.parametrize("fixture_name", ["time_vector_recording", "t_start_recording"]) + @pytest.mark.parametrize("shift", [-123.456, 123.456]) + def test_shift_times_different_segments(self, request, fixture_name, shift): + """ + Shift each segment separately, and check the shifted segment only + is shifted as expected. + """ + _, times_recording, all_times = self._get_fixture_data(request, fixture_name) + + num_segments, orig_seg_data = self._store_all_times(times_recording) + + # For each segment, shift the segment only and check the + # times are updated as expected. + for idx in range(num_segments): + + scaler = idx + 2 + times_recording.shift_times(shift * scaler, segment_index=idx) + + assert np.allclose( + orig_seg_data[idx], times_recording.get_times(segment_index=idx) - shift * scaler, rtol=0, atol=1e-8 + ) + + # Just do a little check that we are not + # accidentally changing some other segments, + # which should remain unchanged at this point in the loop. + if idx != num_segments - 1: + assert np.array_equal(orig_seg_data[idx + 1], times_recording.get_times(segment_index=idx + 1)) + + @pytest.mark.parametrize("fixture_name", ["time_vector_recording", "t_start_recording"]) + def test_save_and_load_time_shift(self, request, fixture_name, tmp_path): + """ + Save the shifted data and check the shift is propagated correctly. + """ + _, times_recording, all_times = self._get_fixture_data(request, fixture_name) + + shift = 100 + times_recording.shift_times(shift=shift) + + times_recording.save(folder=tmp_path / "my_file") + + loaded_recording = si.load_extractor(tmp_path / "my_file") + + for idx in range(times_recording.get_num_segments()): + assert np.array_equal( + times_recording.get_times(segment_index=idx), loaded_recording.get_times(segment_index=idx) + ) + + def _store_all_times(self, recording): + """ + Convenience function to store original times of all segments to a dict. + """ + num_segments = recording.get_num_segments() + seg_data = {} + + for idx in range(num_segments): + seg_data[idx] = copy.deepcopy(recording.get_times(segment_index=idx)) + + return num_segments, seg_data + + # ######################################################################### + # Helpers + # ######################################################################### + def _check_times_match(self, recording, all_times): """ For every segment in a recording, check the `get_times()` From 4d7246a529e3d17747cf5a496a0a04bd97f4eb09 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 13 Nov 2024 18:33:17 +0000 Subject: [PATCH 6/8] Fixes on shift function. --- src/spikeinterface/core/baserecording.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index 91f99f17b0..4b545dc7c7 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -526,8 +526,6 @@ def shift_times(self, shift: int | float, segment_index: int | None = None) -> N The segment on which to shift the times. if `None`, all segments will be shifted. """ - self._check_segment_index(segment_index) # Check the segment index is valid only - if segment_index is None: segments_to_shift = range(self.get_num_segments()) else: @@ -536,7 +534,7 @@ def shift_times(self, shift: int | float, segment_index: int | None = None) -> N for idx in segments_to_shift: rs = self._recording_segments[idx] - if self.has_time_vector(): + if self.has_time_vector(segment_index=idx): rs.time_vector += shift else: rs.t_start += shift From a1cf3367d18a549281208b25c622f2a1ee773226 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 13 Nov 2024 18:35:32 +0000 Subject: [PATCH 7/8] Undo out of scope changes. --- src/spikeinterface/core/baserecording.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index 4b545dc7c7..886f7db79f 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -539,14 +539,15 @@ def shift_times(self, shift: int | float, segment_index: int | None = None) -> N else: rs.t_start += shift - def sample_index_to_time(self, sample_ind: int, segment_index: int | None = None): - """ """ + def sample_index_to_time(self, sample_ind, segment_index=None): + """ + Transform sample index into time in seconds + """ segment_index = self._check_segment_index(segment_index) rs = self._recording_segments[segment_index] return rs.sample_index_to_time(sample_ind) - def time_to_sample_index(self, time_s: float, segment_index: int | None = None): - """ """ + def time_to_sample_index(self, time_s, segment_index=None): segment_index = self._check_segment_index(segment_index) rs = self._recording_segments[segment_index] return rs.time_to_sample_index(time_s) From 469b3b0e36fdbc0571d37e100d99d6c741af1377 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 13 Nov 2024 18:37:20 +0000 Subject: [PATCH 8/8] Fix docstring. --- src/spikeinterface/core/baserecording.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index 886f7db79f..6d9d2a827f 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -511,8 +511,7 @@ def reset_times(self): def shift_times(self, shift: int | float, segment_index: int | None = None) -> None: """ - Shift all times by a scalar value. The default behaviour is to - shift all segments uniformly. + Shift all times by a scalar value. Parameters ---------- @@ -523,8 +522,8 @@ def shift_times(self, shift: int | float, segment_index: int | None = None) -> N started earlier. segment_index : int | None - The segment on which to shift the times. if `None`, all - segments will be shifted. + The segment on which to shift the times. + If `None`, all segments will be shifted. """ if segment_index is None: segments_to_shift = range(self.get_num_segments())