-
Notifications
You must be signed in to change notification settings - Fork 222
Add shift start time
function.
#3509
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
f55a940
620f801
22d5dfc
458a3dc
3e98c67
4d7246a
a1cf336
469b3b0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,7 +15,10 @@ class TestTimeHandling: | |
is generated on the fly. Both time representations are tested here. | ||
""" | ||
|
||
# Fixtures ##### | ||
# ######################################################################### | ||
# Fixtures | ||
# ######################################################################### | ||
|
||
@pytest.fixture(scope="session") | ||
def time_vector_recording(self): | ||
""" | ||
|
@@ -95,7 +98,10 @@ def _get_fixture_data(self, request, fixture_name): | |
raw_recording, times_recording, all_times = time_recording_fixture | ||
return (raw_recording, times_recording, all_times) | ||
|
||
# Tests ##### | ||
# ######################################################################### | ||
# Tests | ||
# ######################################################################### | ||
|
||
def test_has_time_vector(self, time_vector_recording): | ||
""" | ||
Test the `has_time_vector` function returns `False` before | ||
|
@@ -305,7 +311,87 @@ def test_sorting_analyzer_get_durations_no_recording(self, time_vector_recording | |
|
||
assert np.array_equal(sorting_analyzer.get_total_duration(), raw_recording.get_total_duration()) | ||
|
||
# Helpers #### | ||
@pytest.mark.parametrize("fixture_name", ["time_vector_recording", "t_start_recording"]) | ||
@pytest.mark.parametrize("shift", [-123.456, 123.456]) | ||
def test_shift_time_all_segments(self, request, fixture_name, shift): | ||
""" | ||
Shift the times in every segment using the `None` default, then | ||
check that every segment of the recording is shifted as expected. | ||
""" | ||
_, times_recording, all_times = self._get_fixture_data(request, fixture_name) | ||
|
||
num_segments, orig_seg_data = self._store_all_times(times_recording) | ||
|
||
times_recording.shift_times(shift) # use default `segment_index=None` | ||
|
||
for idx in range(num_segments): | ||
assert np.allclose( | ||
orig_seg_data[idx], times_recording.get_times(segment_index=idx) - shift, rtol=0, atol=1e-8 | ||
) | ||
|
||
@pytest.mark.parametrize("fixture_name", ["time_vector_recording", "t_start_recording"]) | ||
@pytest.mark.parametrize("shift", [-123.456, 123.456]) | ||
def test_shift_times_different_segments(self, request, fixture_name, shift): | ||
""" | ||
Shift each segment separately, and check the shifted segment only | ||
is shifted as expected. | ||
""" | ||
_, times_recording, all_times = self._get_fixture_data(request, fixture_name) | ||
|
||
num_segments, orig_seg_data = self._store_all_times(times_recording) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe you could just do recording.copy? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice this will allow to remove a lot of boiler plate and make the assert more explicit. However check this out!:
This fails,
It seems the time vector info is not being propagated in the copy, though it seemed to work okay for the save/ load 🤔 will have to come back to this, do you have ideas why this might be happening? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Mmm yeah, I think that that's probably by design. I think cloning uses to do_dict mechanism and that coupled with saving to disk and serializing. Because timestamps can be large we probably don't want that there. The solution is to change copy (which should be a in memory thing) to copy the time vector but I think that should be done in another PR. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A little out of scope, but do you remember when we were looking up copy vs clone at Edinburgh. I forget what it said (but basically one is deep and one is suppose to be shallow). It would probably be worth us making clear in our docstrings if our copy's are shallow or not (also for another PR). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm a little confused sorry, what is
should this have worked? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
No I don't think this would work. I guess @h-mayorquin was suggesting to do a:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Some projects use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
||
# For each segment, shift the segment only and check the | ||
# times are updated as expected. | ||
for idx in range(num_segments): | ||
|
||
scaler = idx + 2 | ||
times_recording.shift_times(shift * scaler, segment_index=idx) | ||
|
||
assert np.allclose( | ||
orig_seg_data[idx], times_recording.get_times(segment_index=idx) - shift * scaler, rtol=0, atol=1e-8 | ||
) | ||
|
||
# Just do a little check that we are not | ||
# accidentally changing some other segments, | ||
# which should remain unchanged at this point in the loop. | ||
if idx != num_segments - 1: | ||
assert np.array_equal(orig_seg_data[idx + 1], times_recording.get_times(segment_index=idx + 1)) | ||
|
||
@pytest.mark.parametrize("fixture_name", ["time_vector_recording", "t_start_recording"]) | ||
def test_save_and_load_time_shift(self, request, fixture_name, tmp_path): | ||
""" | ||
Save the shifted data and check the shift is propagated correctly. | ||
""" | ||
_, times_recording, all_times = self._get_fixture_data(request, fixture_name) | ||
|
||
shift = 100 | ||
times_recording.shift_times(shift=shift) | ||
|
||
times_recording.save(folder=tmp_path / "my_file") | ||
|
||
loaded_recording = si.load_extractor(tmp_path / "my_file") | ||
|
||
for idx in range(times_recording.get_num_segments()): | ||
assert np.array_equal( | ||
times_recording.get_times(segment_index=idx), loaded_recording.get_times(segment_index=idx) | ||
) | ||
|
||
def _store_all_times(self, recording): | ||
""" | ||
Convenience function to store original times of all segments to a dict. | ||
""" | ||
num_segments = recording.get_num_segments() | ||
seg_data = {} | ||
|
||
for idx in range(num_segments): | ||
seg_data[idx] = copy.deepcopy(recording.get_times(segment_index=idx)) | ||
|
||
return num_segments, seg_data | ||
|
||
# ######################################################################### | ||
# Helpers | ||
# ######################################################################### | ||
|
||
def _check_times_match(self, recording, all_times): | ||
""" | ||
For every segment in a recording, check the `get_times()` | ||
|
Uh oh!
There was an error while loading. Please reload this page.