@@ -243,6 +243,73 @@ def test_slice_recording(self, time_type, bounds):
243
243
244
244
assert np .allclose (rec_slice .get_times (0 ), all_times [0 ][start_frame :end_frame ], rtol = 0 , atol = 1e-8 )
245
245
246
+ def test_get_durations (self , time_vector_recording , t_start_recording ):
247
+ """
248
+ Test the `get_durations` functions that return the total duration
249
+ for a segment. Test that it is correct after adding both `t_start`
250
+ or `time_vector` to the recording.
251
+ """
252
+ raw_recording , tvector_recording , all_time_vectors = time_vector_recording
253
+ _ , tstart_recording , all_t_starts = t_start_recording
254
+
255
+ ts = 1 / raw_recording .get_sampling_frequency ()
256
+
257
+ all_raw_durations = []
258
+ all_vector_durations = []
259
+ for segment_index in range (raw_recording .get_num_segments ()):
260
+
261
+ # Test before `t_start` and `t_start` (`t_start` is just an offset,
262
+ # should not affect duration).
263
+ raw_duration = all_t_starts [segment_index ][- 1 ] - all_t_starts [segment_index ][0 ] + ts
264
+
265
+ assert np .isclose (raw_recording .get_duration (segment_index ), raw_duration , rtol = 0 , atol = 1e-8 )
266
+ assert np .isclose (tstart_recording .get_duration (segment_index ), raw_duration , rtol = 0 , atol = 1e-8 )
267
+
268
+ # Test the duration from the time vector.
269
+ vector_duration = all_time_vectors [segment_index ][- 1 ] - all_time_vectors [segment_index ][0 ] + ts
270
+
271
+ assert tvector_recording .get_duration (segment_index ) == vector_duration
272
+
273
+ all_raw_durations .append (raw_duration )
274
+ all_vector_durations .append (vector_duration )
275
+
276
+ # Finally test the total recording duration
277
+ assert np .isclose (tstart_recording .get_total_duration (), sum (all_raw_durations ), rtol = 0 , atol = 1e-8 )
278
+ assert np .isclose (tvector_recording .get_total_duration (), sum (all_vector_durations ), rtol = 0 , atol = 1e-8 )
279
+
280
+ def test_sorting_analyzer_get_durations_from_recording (self , time_vector_recording ):
281
+ """
282
+ Test that when a recording is set on `sorting_analyzer`, the
283
+ total duration is propagated from the recording to the
284
+ `sorting_analyzer.get_total_duration()` function.
285
+ """
286
+ _ , times_recording , _ = time_vector_recording
287
+
288
+ sorting = si .generate_sorting (
289
+ durations = [times_recording .get_duration (s ) for s in range (times_recording .get_num_segments ())]
290
+ )
291
+ sorting_analyzer = si .create_sorting_analyzer (sorting , recording = times_recording )
292
+
293
+ assert np .array_equal (sorting_analyzer .get_total_duration (), times_recording .get_total_duration ())
294
+
295
+ def test_sorting_analyzer_get_durations_no_recording (self , time_vector_recording ):
296
+ """
297
+ Test when the `sorting_analzyer` does not have a recording set,
298
+ the total duration is calculated on the fly from num samples and
299
+ sampling frequency (thus matching `raw_recording` with no times set
300
+ that uses the same method to calculate the total duration).
301
+ """
302
+ raw_recording , _ , _ = time_vector_recording
303
+
304
+ sorting = si .generate_sorting (
305
+ durations = [raw_recording .get_duration (s ) for s in range (raw_recording .get_num_segments ())]
306
+ )
307
+ sorting_analyzer = si .create_sorting_analyzer (sorting , recording = raw_recording )
308
+
309
+ sorting_analyzer ._recording = None
310
+
311
+ assert np .array_equal (sorting_analyzer .get_total_duration (), raw_recording .get_total_duration ())
312
+
246
313
# Helpers ####
247
314
def _check_times_match (self , recording , all_times ):
248
315
"""
0 commit comments