@@ -498,24 +498,35 @@ def time_to_sample_index(self, time_s, segment_index=None):
498
498
rs = self ._recording_segments [segment_index ]
499
499
return rs .time_to_sample_index (time_s )
500
500
501
- def _save (self , format = "binary" , verbose : bool = False , ** save_kwargs ):
501
+ def _get_t_starts (self ):
502
502
# handle t_starts
503
503
t_starts = []
504
504
has_time_vectors = []
505
- for segment_index , rs in enumerate ( self ._recording_segments ) :
505
+ for rs in self ._recording_segments :
506
506
d = rs .get_times_kwargs ()
507
507
t_starts .append (d ["t_start" ])
508
- has_time_vectors .append (d ["time_vector" ] is not None )
509
508
510
509
if all (t_start is None for t_start in t_starts ):
511
510
t_starts = None
511
+ return t_starts
512
512
513
+ def _get_time_vectors (self ):
514
+ time_vectors = []
515
+ for rs in self ._recording_segments :
516
+ d = rs .get_times_kwargs ()
517
+ time_vectors .append (d ["time_vector" ])
518
+ if all (time_vector is None for time_vector in time_vectors ):
519
+ time_vectors = None
520
+ return time_vectors
521
+
522
+ def _save (self , format = "binary" , verbose : bool = False , ** save_kwargs ):
513
523
kwargs , job_kwargs = split_job_kwargs (save_kwargs )
514
524
515
525
if format == "binary" :
516
526
folder = kwargs ["folder" ]
517
527
file_paths = [folder / f"traces_cached_seg{ i } .raw" for i in range (self .get_num_segments ())]
518
528
dtype = kwargs .get ("dtype" , None ) or self .get_dtype ()
529
+ t_starts = self ._get_t_starts ()
519
530
520
531
write_binary_recording (self , file_paths = file_paths , dtype = dtype , verbose = verbose , ** job_kwargs )
521
532
@@ -572,11 +583,11 @@ def _save(self, format="binary", verbose: bool = False, **save_kwargs):
572
583
probegroup = self .get_probegroup ()
573
584
cached .set_probegroup (probegroup )
574
585
575
- for segment_index , rs in enumerate ( self ._recording_segments ):
576
- d = rs . get_times_kwargs ()
577
- time_vector = d [ " time_vector" ]
578
- if time_vector is not None :
579
- cached ._recording_segments [ segment_index ]. time_vector = time_vector
586
+ time_vectors = self ._get_time_vectors ()
587
+ if time_vectors is not None :
588
+ for segment_index , time_vector in enumerate ( time_vectors ):
589
+ if time_vector is not None :
590
+ cached .set_times ( time_vector , segment_index = segment_index )
580
591
581
592
return cached
582
593
0 commit comments