Skip to content

Commit 00e1cf9

Browse files
authored
Merge pull request #3120 from JoeZiminski/fix_save_to_memory_t_start
Fix `t_starts` not propagated to `save_to_memory`.
2 parents f273020 + 30606cc commit 00e1cf9

File tree

3 files changed

+292
-57
lines changed

3 files changed

+292
-57
lines changed

src/spikeinterface/core/baserecording.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -498,24 +498,35 @@ def time_to_sample_index(self, time_s, segment_index=None):
498498
rs = self._recording_segments[segment_index]
499499
return rs.time_to_sample_index(time_s)
500500

501-
def _save(self, format="binary", verbose: bool = False, **save_kwargs):
501+
def _get_t_starts(self):
502502
# handle t_starts
503503
t_starts = []
504504
has_time_vectors = []
505-
for segment_index, rs in enumerate(self._recording_segments):
505+
for rs in self._recording_segments:
506506
d = rs.get_times_kwargs()
507507
t_starts.append(d["t_start"])
508-
has_time_vectors.append(d["time_vector"] is not None)
509508

510509
if all(t_start is None for t_start in t_starts):
511510
t_starts = None
511+
return t_starts
512512

513+
def _get_time_vectors(self):
514+
time_vectors = []
515+
for rs in self._recording_segments:
516+
d = rs.get_times_kwargs()
517+
time_vectors.append(d["time_vector"])
518+
if all(time_vector is None for time_vector in time_vectors):
519+
time_vectors = None
520+
return time_vectors
521+
522+
def _save(self, format="binary", verbose: bool = False, **save_kwargs):
513523
kwargs, job_kwargs = split_job_kwargs(save_kwargs)
514524

515525
if format == "binary":
516526
folder = kwargs["folder"]
517527
file_paths = [folder / f"traces_cached_seg{i}.raw" for i in range(self.get_num_segments())]
518528
dtype = kwargs.get("dtype", None) or self.get_dtype()
529+
t_starts = self._get_t_starts()
519530

520531
write_binary_recording(self, file_paths=file_paths, dtype=dtype, verbose=verbose, **job_kwargs)
521532

@@ -572,11 +583,11 @@ def _save(self, format="binary", verbose: bool = False, **save_kwargs):
572583
probegroup = self.get_probegroup()
573584
cached.set_probegroup(probegroup)
574585

575-
for segment_index, rs in enumerate(self._recording_segments):
576-
d = rs.get_times_kwargs()
577-
time_vector = d["time_vector"]
578-
if time_vector is not None:
579-
cached._recording_segments[segment_index].time_vector = time_vector
586+
time_vectors = self._get_time_vectors()
587+
if time_vectors is not None:
588+
for segment_index, time_vector in enumerate(time_vectors):
589+
if time_vector is not None:
590+
cached.set_times(time_vector, segment_index=segment_index)
580591

581592
return cached
582593

src/spikeinterface/core/numpyextractors.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,9 @@ def __init__(self, traces_list, sampling_frequency, t_starts=None, channel_ids=N
8383
@staticmethod
8484
def from_recording(source_recording, **job_kwargs):
8585
traces_list, shms = write_memory_recording(source_recording, dtype=None, **job_kwargs)
86+
87+
t_starts = source_recording._get_t_starts()
88+
8689
if shms[0] is not None:
8790
# if the computation was done in parallel then traces_list is shared array
8891
# this can lead to problem
@@ -91,13 +94,14 @@ def from_recording(source_recording, **job_kwargs):
9194
for shm in shms:
9295
shm.close()
9396
shm.unlink()
94-
# TODO later : propagte t_starts ?
97+
9598
recording = NumpyRecording(
9699
traces_list,
97100
source_recording.get_sampling_frequency(),
98-
t_starts=None,
101+
t_starts=t_starts,
99102
channel_ids=source_recording.channel_ids,
100103
)
104+
return recording
101105

102106

103107
class NumpyRecordingSegment(BaseRecordingSegment):
@@ -206,15 +210,15 @@ def __del__(self):
206210
def from_recording(source_recording, **job_kwargs):
207211
traces_list, shms = write_memory_recording(source_recording, buffer_type="sharedmem", **job_kwargs)
208212

209-
# TODO later : propagte t_starts ?
213+
t_starts = source_recording._get_t_starts()
210214

211215
recording = SharedMemoryRecording(
212216
shm_names=[shm.name for shm in shms],
213217
shape_list=[traces.shape for traces in traces_list],
214218
dtype=source_recording.dtype,
215219
sampling_frequency=source_recording.sampling_frequency,
216220
channel_ids=source_recording.channel_ids,
217-
t_starts=None,
221+
t_starts=t_starts,
218222
main_shm_owner=True,
219223
)
220224

0 commit comments

Comments
 (0)