Skip to content

Commit a14ee81

Browse files
authored
Merge pull request #3118 from JoeZiminski/add_time_vector_case_to_get_duration
Add time vector case to `get_durations`.
2 parents e2fe22e + 7714724 commit a14ee81

File tree

4 files changed

+82
-5
lines changed

4 files changed

+82
-5
lines changed

src/spikeinterface/core/baserecording.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -233,8 +233,14 @@ def get_duration(self, segment_index=None) -> float:
233233
The duration in seconds
234234
"""
235235
segment_index = self._check_segment_index(segment_index)
236-
segment_num_samples = self.get_num_samples(segment_index=segment_index)
237-
segment_duration = segment_num_samples / self.get_sampling_frequency()
236+
237+
if self.has_time_vector(segment_index):
238+
times = self.get_times(segment_index)
239+
segment_duration = times[-1] - times[0] + (1 / self.get_sampling_frequency())
240+
else:
241+
segment_num_samples = self.get_num_samples(segment_index=segment_index)
242+
segment_duration = segment_num_samples / self.get_sampling_frequency()
243+
238244
return segment_duration
239245

240246
def get_total_duration(self) -> float:
@@ -246,7 +252,7 @@ def get_total_duration(self) -> float:
246252
float
247253
The duration in seconds
248254
"""
249-
duration = self.get_total_samples() / self.get_sampling_frequency()
255+
duration = sum([self.get_duration(idx) for idx in range(self.get_num_segments())])
250256
return duration
251257

252258
def get_memory_size(self, segment_index=None) -> int:

src/spikeinterface/core/sortinganalyzer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -821,7 +821,10 @@ def get_total_samples(self) -> int:
821821
return s
822822

823823
def get_total_duration(self) -> float:
824-
duration = self.get_total_samples() / self.sampling_frequency
824+
if self.has_recording() or self.has_temporary_recording():
825+
duration = self.recording.get_total_duration()
826+
else:
827+
duration = self.get_total_samples() / self.sampling_frequency
825828
return duration
826829

827830
def get_num_channels(self) -> int:

src/spikeinterface/core/tests/test_time_handling.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,73 @@ def test_slice_recording(self, time_type, bounds):
243243

244244
assert np.allclose(rec_slice.get_times(0), all_times[0][start_frame:end_frame], rtol=0, atol=1e-8)
245245

246+
def test_get_durations(self, time_vector_recording, t_start_recording):
247+
"""
248+
Test the `get_durations` functions that return the total duration
249+
for a segment. Test that it is correct after adding both `t_start`
250+
or `time_vector` to the recording.
251+
"""
252+
raw_recording, tvector_recording, all_time_vectors = time_vector_recording
253+
_, tstart_recording, all_t_starts = t_start_recording
254+
255+
ts = 1 / raw_recording.get_sampling_frequency()
256+
257+
all_raw_durations = []
258+
all_vector_durations = []
259+
for segment_index in range(raw_recording.get_num_segments()):
260+
261+
# Test before `t_start` and `t_start` (`t_start` is just an offset,
262+
# should not affect duration).
263+
raw_duration = all_t_starts[segment_index][-1] - all_t_starts[segment_index][0] + ts
264+
265+
assert np.isclose(raw_recording.get_duration(segment_index), raw_duration, rtol=0, atol=1e-8)
266+
assert np.isclose(tstart_recording.get_duration(segment_index), raw_duration, rtol=0, atol=1e-8)
267+
268+
# Test the duration from the time vector.
269+
vector_duration = all_time_vectors[segment_index][-1] - all_time_vectors[segment_index][0] + ts
270+
271+
assert tvector_recording.get_duration(segment_index) == vector_duration
272+
273+
all_raw_durations.append(raw_duration)
274+
all_vector_durations.append(vector_duration)
275+
276+
# Finally test the total recording duration
277+
assert np.isclose(tstart_recording.get_total_duration(), sum(all_raw_durations), rtol=0, atol=1e-8)
278+
assert np.isclose(tvector_recording.get_total_duration(), sum(all_vector_durations), rtol=0, atol=1e-8)
279+
280+
def test_sorting_analyzer_get_durations_from_recording(self, time_vector_recording):
281+
"""
282+
Test that when a recording is set on `sorting_analyzer`, the
283+
total duration is propagated from the recording to the
284+
`sorting_analyzer.get_total_duration()` function.
285+
"""
286+
_, times_recording, _ = time_vector_recording
287+
288+
sorting = si.generate_sorting(
289+
durations=[times_recording.get_duration(s) for s in range(times_recording.get_num_segments())]
290+
)
291+
sorting_analyzer = si.create_sorting_analyzer(sorting, recording=times_recording)
292+
293+
assert np.array_equal(sorting_analyzer.get_total_duration(), times_recording.get_total_duration())
294+
295+
def test_sorting_analyzer_get_durations_no_recording(self, time_vector_recording):
296+
"""
297+
Test when the `sorting_analzyer` does not have a recording set,
298+
the total duration is calculated on the fly from num samples and
299+
sampling frequency (thus matching `raw_recording` with no times set
300+
that uses the same method to calculate the total duration).
301+
"""
302+
raw_recording, _, _ = time_vector_recording
303+
304+
sorting = si.generate_sorting(
305+
durations=[raw_recording.get_duration(s) for s in range(raw_recording.get_num_segments())]
306+
)
307+
sorting_analyzer = si.create_sorting_analyzer(sorting, recording=raw_recording)
308+
309+
sorting_analyzer._recording = None
310+
311+
assert np.array_equal(sorting_analyzer.get_total_duration(), raw_recording.get_total_duration())
312+
246313
# Helpers ####
247314
def _check_times_match(self, recording, all_times):
248315
"""

src/spikeinterface/generation/hybrid_tools.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,7 @@ def generate_hybrid_recording(
400400
num_segments = recording.get_num_segments()
401401
dtype = recording.dtype
402402
durations = np.array([recording.get_duration(segment_index) for segment_index in range(num_segments)])
403+
num_samples = np.array([recording.get_num_samples(segment_index) for segment_index in range(num_segments)])
403404
channel_locations = probe.contact_positions
404405

405406
assert (
@@ -548,7 +549,7 @@ def generate_hybrid_recording(
548549
displacement_vectors=displacement_vectors,
549550
displacement_sampling_frequency=displacement_sampling_frequency,
550551
displacement_unit_factor=displacement_unit_factor,
551-
num_samples=(np.array(durations) * sampling_frequency).astype("int64"),
552+
num_samples=num_samples.astype("int64"),
552553
amplitude_factor=amplitude_factor,
553554
)
554555

0 commit comments

Comments
 (0)