Skip to content

Commit 1e41c25

Browse files
authored
Merge pull request #1357 from samuelgarcia/pipeline_to_file
Add gather_func concept to ChunkRecordingExecutor.
2 parents 5d3a509 + 2170b50 commit 1e41c25

File tree

6 files changed

+358
-83
lines changed

6 files changed

+358
-83
lines changed

spikeinterface/core/job_tools.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,9 @@ class ChunkRecordingExecutor:
258258
If True, a progress bar is printed to monitor the progress of the process
259259
handle_returns: bool
260260
If True, the function can return values
261+
gather_func: None or callable
262+
Optional function that is called in the main thread and retrieves the results of each worker.
263+
This function can be used instead of `handle_returns` to implement custom storage on-the-fly.
261264
n_jobs: int
262265
Number of jobs to be used (default 1). Use -1 to use as many jobs as number of cores
263266
total_memory: str
@@ -277,15 +280,16 @@ class ChunkRecordingExecutor:
277280
Limit the number of thread per process using threadpoolctl modules.
278281
This used only when n_jobs>1
279282
If None, no limits.
280-
283+
284+
281285
Returns
282286
-------
283287
res: list
284288
If 'handle_returns' is True, the results for each chunk process
285289
"""
286290

287291
def __init__(self, recording, func, init_func, init_args, verbose=False, progress_bar=False, handle_returns=False,
288-
n_jobs=1, total_memory=None, chunk_size=None, chunk_memory=None, chunk_duration=None,
292+
gather_func=None, n_jobs=1, total_memory=None, chunk_size=None, chunk_memory=None, chunk_duration=None,
289293
mp_context=None, job_name='', max_threads_per_process=1):
290294
self.recording = recording
291295
self.func = func
@@ -303,6 +307,7 @@ def __init__(self, recording, func, init_func, init_args, verbose=False, progres
303307
self.progress_bar = progress_bar
304308

305309
self.handle_returns = handle_returns
310+
self.gather_func = gather_func
306311

307312
self.n_jobs = ensure_n_jobs(recording, n_jobs=n_jobs)
308313
self.chunk_size = ensure_chunk_size(recording,
@@ -339,6 +344,8 @@ def run(self):
339344
res = self.func(segment_index, frame_start, frame_stop, worker_ctx)
340345
if self.handle_returns:
341346
returns.append(res)
347+
if self.gather_func is not None:
348+
self.gather_func(res)
342349
else:
343350
n_jobs = min(self.n_jobs, len(all_chunks))
344351
######## Do you want to limit the number of threads per process?
@@ -357,12 +364,11 @@ def run(self):
357364
if self.progress_bar:
358365
results = tqdm(results, desc=self.job_name, total=len(all_chunks))
359366

360-
if self.handle_returns:
361-
for res in results:
367+
for res in results:
368+
if self.handle_returns:
362369
returns.append(res)
363-
else:
364-
for res in results:
365-
pass
370+
if self.gather_func is not None:
371+
self.gather_func(res)
366372

367373
return returns
368374

spikeinterface/core/tests/test_job_tools.py

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from spikeinterface.core import generate_recording
55

66
from spikeinterface.core.job_tools import divide_segment_into_chunks, ensure_n_jobs, ensure_chunk_size, \
7-
ChunkRecordingExecutor, fix_job_kwargs, split_job_kwargs
7+
ChunkRecordingExecutor, fix_job_kwargs, split_job_kwargs, divide_recording_into_chunks
88

99

1010
def test_divide_segment_into_chunks():
@@ -95,18 +95,37 @@ def test_ChunkRecordingExecutor():
9595
n_jobs=1, chunk_size=None)
9696
processor.run()
9797

98-
# chunk + loop
98+
# simple gathering function
99+
def gathering_result(res):
100+
# print(res)
101+
pass
102+
103+
# chunk + loop + gather_func
99104
processor = ChunkRecordingExecutor(recording, func, init_func, init_args,
100-
verbose=True, progress_bar=False,
105+
verbose=True, progress_bar=False, gather_func=gathering_result,
101106
n_jobs=1, chunk_memory="500k")
102107
processor.run()
103108

104-
# chunk + parallel
109+
# more adavnce trick : gathering using class with callable
110+
class GatherClass:
111+
def __init__(self):
112+
self.pos = 0
113+
114+
def __call__(self, res):
115+
self.pos += 1
116+
# print(self.pos, res)
117+
pass
118+
gathering_func2 = GatherClass()
119+
120+
# chunk + parallel + gather_func
105121
processor = ChunkRecordingExecutor(recording, func, init_func, init_args,
106-
verbose=True, progress_bar=True,
122+
verbose=True, progress_bar=True, gather_func=gathering_func2,
107123
n_jobs=2, chunk_duration="200ms",
108124
job_name='job_name')
109125
processor.run()
126+
num_chunks = len(divide_recording_into_chunks(recording, processor.chunk_size))
127+
128+
assert gathering_func2.pos == num_chunks
110129

111130
# chunk + parallel + spawn
112131
processor = ChunkRecordingExecutor(recording, func, init_func, init_args,
@@ -153,9 +172,9 @@ def test_split_job_kwargs():
153172

154173

155174
if __name__ == '__main__':
156-
test_divide_segment_into_chunks()
157-
test_ensure_n_jobs()
158-
test_ensure_chunk_size()
175+
# test_divide_segment_into_chunks()
176+
# test_ensure_n_jobs()
177+
# test_ensure_chunk_size()
159178
test_ChunkRecordingExecutor()
160-
test_fix_job_kwargs()
161-
test_split_job_kwargs()
179+
# test_fix_job_kwargs()
180+
# test_split_job_kwargs()

spikeinterface/sortingcomponents/peak_detection.py

Lines changed: 33 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from ..core import get_chunk_with_margin
1010

11-
from .peak_pipeline import PipelineNode, check_graph, run_nodes
11+
from .peak_pipeline import PipelineNode, check_graph, run_nodes, GatherToMemory, GatherToNpy
1212
from .tools import make_multi_method_doc
1313

1414
try:
@@ -28,7 +28,9 @@
2828
('amplitude', 'float64'), ('segment_ind', 'int64')]
2929

3030

31-
def detect_peaks(recording, method='by_channel', pipeline_nodes=None, **kwargs):
31+
def detect_peaks(recording, method='by_channel', pipeline_nodes=None,
32+
gather_mode='memory', folder=None, names=None,
33+
**kwargs):
3234
"""Peak detection based on threshold crossing in term of k x MAD.
3335
3436
In 'by_channel' : peak are detected in each channel independently
@@ -42,6 +44,18 @@ def detect_peaks(recording, method='by_channel', pipeline_nodes=None, **kwargs):
4244
pipeline_nodes: None or list[PipelineNode]
4345
Optional additional PipelineNode need to computed just after detection time.
4446
This avoid reading the recording multiple times.
47+
gather_mode: str
48+
How to gather the results:
49+
50+
* "memory": results are returned as in-memory numpy arrays
51+
52+
* "npy": results are stored to .npy files in `folder`
53+
54+
folder: str or Path
55+
If gather_mode is "npy", the folder where the files are created.
56+
names: list
57+
List of strings with file stems associated with returns.
58+
4559
{method_doc}
4660
{job_doc}
4761
@@ -66,27 +80,31 @@ def detect_peaks(recording, method='by_channel', pipeline_nodes=None, **kwargs):
6680
method_args = method_class.check_params(recording, **method_kwargs)
6781

6882
extra_margin = 0
69-
if pipeline_nodes is not None:
83+
if pipeline_nodes is None:
84+
squeeze_output = True
85+
else:
7086
check_graph(pipeline_nodes)
7187
extra_margin = max(node.get_trace_margin() for node in pipeline_nodes)
72-
88+
squeeze_output = False
89+
90+
if gather_mode == 'memory':
91+
gather_func = GatherToMemory()
92+
elif gather_mode == 'npy':
93+
gather_func = GatherToNpy(folder, names)
94+
else:
95+
raise ValueError(f"Wrong gather_mode : {gather_mode}. Available gather modes: 'memory' | 'npy'")
96+
7397
func = _detect_peaks_chunk
7498
init_func = _init_worker_detect_peaks
7599
init_args = (recording, method, method_args, extra_margin, pipeline_nodes)
76100
processor = ChunkRecordingExecutor(recording, func, init_func, init_args,
77-
handle_returns=True, job_name='detect peaks',
101+
gather_func=gather_func, job_name='detect peaks',
78102
mp_context=mp_context, **job_kwargs)
79-
outputs = processor.run()
80-
81-
if pipeline_nodes is None:
82-
peaks = np.concatenate(outputs)
83-
return peaks
84-
else:
85-
outs_concat = ()
86-
for output_node in zip(*outputs):
87-
outs_concat += (np.concatenate(output_node, axis=0), )
88-
return outs_concat
103+
processor.run()
89104

105+
outs = gather_func.finalize_buffers(squeeze_output=squeeze_output)
106+
return outs
107+
90108

91109
def _init_worker_detect_peaks(recording, method, method_args, extra_margin, pipeline_nodes):
92110
"""Initialize a worker for detecting peaks."""

0 commit comments

Comments
 (0)