Skip to content

Commit 6c87d22

Browse files
committed
Add tests.
1 parent 6b617ab commit 6c87d22

File tree

1 file changed

+319
-0
lines changed

1 file changed

+319
-0
lines changed

src/spikeinterface/core/tests/test_time_handling.py

Lines changed: 319 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,328 @@
1+
import copy
2+
13
import pytest
24
import numpy as np
35

46
from spikeinterface.core import generate_recording, generate_sorting
7+
import spikeinterface.full as si
8+
9+
10+
class TestTimeHandling:
11+
12+
# Fixtures #####
13+
@pytest.fixture(scope="session")
14+
def raw_recording(self):
15+
"""
16+
A three-segment raw recording without times added.
17+
"""
18+
durations = [10, 15, 20]
19+
recording = generate_recording(num_channels=4, durations=durations)
20+
return recording
21+
22+
@pytest.fixture(scope="session")
23+
def time_vector_recording(self, raw_recording):
24+
"""
25+
Add time vectors to the recording, returning the
26+
raw recording, recording with time vectors added to
27+
segments, and list a the time vectors added to the recording.
28+
"""
29+
return self._get_time_vector_recording(raw_recording)
30+
31+
@pytest.fixture(scope="session")
32+
def t_start_recording(self, raw_recording):
33+
"""
34+
Add a t_starts to the recording, returning the
35+
raw recording, recording with t_starts added to segments,
36+
and a list of the time vectors generated from adding the
37+
t_start to the recording times.
38+
"""
39+
return self._get_t_start_recording(raw_recording)
40+
41+
def _get_time_vector_recording(self, raw_recording):
42+
"""
43+
Loop through all recording segments, adding a different time
44+
vector to each segment. The time vector is the original times with
45+
a t_start and irregularly spaced offsets to mimic irregularly
46+
spaced timeseries data. Return the original recording,
47+
recoridng with time vectors added and list including the added time vectors.
48+
"""
49+
times_recording = copy.deepcopy(raw_recording)
50+
all_time_vectors = []
51+
for segment_index in range(raw_recording.get_num_segments()):
52+
53+
t_start = segment_index + 1 * 100
54+
offsets = np.arange(times_recording.get_num_samples(segment_index)) * (
55+
1 / times_recording.get_sampling_frequency()
56+
)
57+
time_vector = t_start + times_recording.get_times(segment_index) + offsets
58+
59+
all_time_vectors.append(time_vector)
60+
times_recording.set_times(times=time_vector, segment_index=segment_index)
61+
62+
assert np.array_equal(
63+
times_recording._recording_segments[segment_index].time_vector,
64+
time_vector,
65+
), "time_vector was not properly set during test setup"
66+
67+
return (raw_recording, times_recording, all_time_vectors)
68+
69+
def _get_t_start_recording(self, raw_recording):
70+
"""
71+
For each segment in the recording, add a different `t_start`.
72+
Return a list of time vectors generating from the recording times
73+
+ the t_starts.
74+
"""
75+
t_start_recording = copy.deepcopy(raw_recording)
76+
77+
all_t_starts = []
78+
for segment_index in range(raw_recording.get_num_segments()):
79+
80+
t_start = (segment_index + 1) * 100
81+
82+
all_t_starts.append(t_start + t_start_recording.get_times(segment_index))
83+
t_start_recording.set_times(times=t_start, segment_index=segment_index)
84+
85+
assert np.array_equal(
86+
t_start_recording._recording_segments[segment_index].t_start,
87+
t_start,
88+
), "t_start was not properly set during test setup"
89+
90+
return (raw_recording, t_start_recording, all_t_starts)
91+
92+
def _get_fixture_data(self, request, fixture_name):
93+
"""
94+
A convenience function to get the data from a fixture
95+
based on the name. This is used to allow parameterising
96+
tests across fixtures.
97+
"""
98+
time_recording_fixture = request.getfixturevalue(fixture_name)
99+
raw_recording, times_recording, all_times = time_recording_fixture
100+
return (raw_recording, times_recording, all_times)
101+
102+
# Tests #####
103+
def test_has_time_vector(self, time_vector_recording):
104+
"""
105+
Test the `has_time_vector` function returns `False` before
106+
a time vector is added and `True` afterwards.
107+
"""
108+
raw_recording, times_recording, _ = time_vector_recording
109+
110+
for segment_idx in range(raw_recording.get_num_segments()):
111+
112+
assert raw_recording.has_time_vector(segment_idx) is False
113+
assert times_recording.has_time_vector(segment_idx) is True
114+
115+
def test_get_durations(self, time_vector_recording, t_start_recording):
116+
"""
117+
Test the `get_durations` functions that return the total duration
118+
for a segment. Test that it is correct after adding both `t_start`
119+
or `time_vector` to the recording.
120+
"""
121+
raw_recording, tvector_recording, all_time_vectors = time_vector_recording
122+
_, tstart_recording, all_t_starts = t_start_recording
123+
124+
ts = 1 / raw_recording.get_sampling_frequency()
125+
126+
all_raw_durations = []
127+
all_vector_durations = []
128+
for segment_index in range(raw_recording.get_num_segments()):
129+
130+
# Test before `t_start` and `t_start` (`t_start` is just an offset,
131+
# should not affect duration).
132+
raw_duration = all_t_starts[segment_index][-1] - all_t_starts[segment_index][0] + ts
133+
134+
assert np.isclose(raw_recording.get_duration(segment_index), raw_duration, rtol=0, atol=1e-8)
135+
assert np.isclose(tstart_recording.get_duration(segment_index), raw_duration, rtol=0, atol=1e-8)
136+
137+
# Test the duration from the time vector.
138+
vector_duration = all_time_vectors[segment_index][-1] - all_time_vectors[segment_index][0] + ts
139+
140+
assert tvector_recording.get_duration(segment_index) == vector_duration
141+
142+
all_raw_durations.append(raw_duration)
143+
all_vector_durations.append(vector_duration)
144+
145+
# Finally test the total recording duration
146+
assert np.isclose(tstart_recording.get_total_duration(), sum(all_raw_durations), rtol=0, atol=1e-8)
147+
assert np.isclose(tvector_recording.get_total_duration(), sum(all_vector_durations), rtol=0, atol=1e-8)
148+
149+
@pytest.mark.parametrize("mode", ["binary", "zarr"])
150+
@pytest.mark.parametrize("fixture_name", ["time_vector_recording", "t_start_recording"])
151+
def test_times_propagated_to_save_folder(self, request, fixture_name, mode, tmp_path):
152+
"""
153+
Test `t_start` or `time_vector` is propagated to a saved recording,
154+
by saving, reloading, and checking times are correct.
155+
"""
156+
_, times_recording, all_times = self._get_fixture_data(request, fixture_name)
157+
158+
folder_name = "recording"
159+
recording_cache = times_recording.save(format=mode, folder=tmp_path / folder_name)
160+
161+
if mode == "zarr":
162+
folder_name += ".zarr"
163+
recording_load = si.load_extractor(tmp_path / folder_name)
164+
165+
self._check_times_match(recording_cache, all_times)
166+
self._check_times_match(recording_load, all_times)
167+
168+
@pytest.mark.parametrize("fixture_name", ["time_vector_recording", "t_start_recording"])
169+
@pytest.mark.parametrize("sharedmem", [True, False])
170+
def test_times_propagated_to_save_memory(self, request, fixture_name, sharedmem):
171+
"""
172+
Test t_start and time_vector are propagated to recording saved into memory.
173+
"""
174+
_, times_recording, all_times = self._get_fixture_data(request, fixture_name)
175+
176+
recording_load = times_recording.save(format="memory", sharedmem=sharedmem)
177+
178+
self._check_times_match(recording_load, all_times)
179+
180+
@pytest.mark.parametrize("fixture_name", ["time_vector_recording", "t_start_recording"])
181+
def test_time_propagated_to_select_segments(self, request, fixture_name):
182+
"""
183+
Test that when `recording.select_segments()` is used, the times
184+
are propagated to the new recoridng object.
185+
"""
186+
_, times_recording, all_times = self._get_fixture_data(request, fixture_name)
187+
188+
for segment_index in range(times_recording.get_num_segments()):
189+
segment = times_recording.select_segments(segment_index)
190+
assert np.array_equal(segment.get_times(), all_times[segment_index])
191+
192+
@pytest.mark.parametrize("fixture_name", ["time_vector_recording", "t_start_recording"])
193+
def test_times_propagated_to_sorting(self, request, fixture_name):
194+
"""
195+
Check that when attached to a sorting object, the times are propagated
196+
to the object. This means that all spike times should respect the
197+
`t_start` or `time_vector` added.
198+
"""
199+
raw_recording, times_recording, all_times = self._get_fixture_data(request, fixture_name)
200+
sorting = self._get_sorting_with_recording_attached(
201+
recording_for_durations=raw_recording, recording_to_attach=times_recording
202+
)
203+
for segment_index in range(raw_recording.get_num_segments()):
204+
205+
if fixture_name == "time_vector_recording":
206+
assert sorting.has_time_vector(segment_index=segment_index)
207+
208+
self._check_spike_times_are_correct(sorting, times_recording, segment_index)
209+
210+
@pytest.mark.parametrize("fixture_name", ["time_vector_recording", "t_start_recording"])
211+
def test_time_sample_converters(self, request, fixture_name):
212+
"""
213+
Test the `recording.sample_time_to_index` and
214+
`recording.time_to_sample_index` convenience functions.
215+
"""
216+
raw_recording, times_recording, all_times = self._get_fixture_data(request, fixture_name)
217+
with pytest.raises(ValueError) as e:
218+
times_recording.sample_index_to_time(0)
219+
assert "Provide 'segment_index'" in str(e)
220+
221+
for segment_index in range(times_recording.get_num_segments()):
222+
223+
sample_index = np.random.randint(low=0, high=times_recording.get_num_samples(segment_index))
224+
time_ = times_recording.sample_index_to_time(sample_index, segment_index=segment_index)
225+
226+
assert time_ == all_times[segment_index][sample_index]
227+
228+
new_sample_index = times_recording.time_to_sample_index(time_, segment_index=segment_index)
229+
230+
assert new_sample_index == sample_index
231+
232+
@pytest.mark.parametrize("time_type", ["time_vector", "t_start"])
233+
@pytest.mark.parametrize("bounds", ["start", "middle", "end"])
234+
def test_slice_recording(self, time_type, bounds):
235+
"""
236+
Test after `frame_slice` and `time_slice` a recording or
237+
sorting (for `frame_slice`), the recording times are
238+
correct with respect to the set `t_start` or `time_vector`.
239+
"""
240+
raw_recording = generate_recording(num_channels=4, durations=[10])
241+
242+
if time_type == "time_vector":
243+
raw_recording, times_recording, all_times = self._get_time_vector_recording(raw_recording)
244+
else:
245+
raw_recording, times_recording, all_times = self._get_t_start_recording(raw_recording)
246+
247+
sorting = self._get_sorting_with_recording_attached(
248+
recording_for_durations=raw_recording, recording_to_attach=times_recording
249+
)
250+
251+
# Take some different times, including min and max bounds of
252+
# the recording, and some arbitaray times in the middle (20% and 80%).
253+
if bounds == "start":
254+
start_frame = 0
255+
end_frame = int(times_recording.get_num_samples(0) * 0.8)
256+
elif bounds == "end":
257+
start_frame = int(times_recording.get_num_samples(0) * 0.2)
258+
end_frame = times_recording.get_num_samples(0) - 1
259+
elif bounds == "middle":
260+
start_frame = int(times_recording.get_num_samples(0) * 0.2)
261+
end_frame = int(times_recording.get_num_samples(0) * 0.8)
262+
263+
# Slice the recording and get the new times are correct
264+
rec_frame_slice = times_recording.frame_slice(start_frame=start_frame, end_frame=end_frame)
265+
sort_frame_slice = sorting.frame_slice(start_frame=start_frame, end_frame=end_frame)
266+
267+
assert np.allclose(rec_frame_slice.get_times(0), all_times[0][start_frame:end_frame], rtol=0, atol=1e-8)
268+
269+
self._check_spike_times_are_correct(sort_frame_slice, rec_frame_slice, segment_index=0)
270+
271+
# Test `time_slice`
272+
start_time = times_recording.sample_index_to_time(start_frame)
273+
end_time = times_recording.sample_index_to_time(end_frame)
274+
275+
rec_slice = times_recording.time_slice(start_time=start_time, end_time=end_time)
276+
277+
assert np.allclose(rec_slice.get_times(0), all_times[0][start_frame:end_frame], rtol=0, atol=1e-8)
278+
279+
# Helpers ####
280+
def _check_times_match(self, recording, all_times):
281+
"""
282+
For every segment in a recording, check the `get_times()`
283+
match the expected times in the list of time vectors, `all_times`.
284+
"""
285+
for segment_index in range(recording.get_num_segments()):
286+
assert np.array_equal(recording.get_times(segment_index), all_times[segment_index])
287+
288+
def _check_spike_times_are_correct(self, sorting, times_recording, segment_index):
289+
"""
290+
For every unit in the `sorting`, for a particular segment, check that
291+
the unit times match the times of the original recording as
292+
retrieved with `get_times()`.
293+
"""
294+
for unit_id in sorting.get_unit_ids():
295+
spike_times = sorting.get_unit_spike_train(unit_id, segment_index=segment_index, return_times=True)
296+
spike_indexes = sorting.get_unit_spike_train(unit_id, segment_index=segment_index)
297+
rec_times = times_recording.get_times(segment_index=segment_index)
298+
299+
assert np.array_equal(
300+
spike_times,
301+
rec_times[spike_indexes],
302+
)
303+
304+
def _get_sorting_with_recording_attached(self, recording_for_durations, recording_to_attach):
305+
"""
306+
Convenience function to create a sorting object with
307+
a recording attached. Typically use the raw recordings
308+
for the durations of which to make the sorter, as
309+
the generate_sorter is not setup to handle the
310+
(strange) edge case of the irregularly spaced
311+
test time vectors.
312+
"""
313+
durations = [
314+
recording_for_durations.get_duration(idx) for idx in range(recording_for_durations.get_num_segments())
315+
]
316+
317+
sorting = generate_sorting(num_units=10, durations=durations)
318+
319+
sorting.register_recording(recording_to_attach)
320+
assert sorting.has_recording()
321+
322+
return sorting
5323

6324

325+
# TODO: deprecate original implementations ###
7326
def test_time_handling(create_cache_folder):
8327
cache_folder = create_cache_folder
9328
durations = [[10], [10, 5]]

0 commit comments

Comments
 (0)