diff --git a/src/spikeinterface/__init__.py b/src/spikeinterface/__init__.py index 5f2086ede7..99325d7d65 100644 --- a/src/spikeinterface/__init__.py +++ b/src/spikeinterface/__init__.py @@ -9,6 +9,7 @@ from .core import * import warnings + warnings.filterwarnings("ignore", message="distutils Version classes are deprecated") warnings.filterwarnings("ignore", message="the imp module is deprecated") diff --git a/src/spikeinterface/preprocessing/clip.py b/src/spikeinterface/preprocessing/clip.py index a2349c1ee9..e48d60ecc8 100644 --- a/src/spikeinterface/preprocessing/clip.py +++ b/src/spikeinterface/preprocessing/clip.py @@ -1,5 +1,11 @@ import numpy as np +try: + from numba import njit + HAVE_NUMBA = True +except ModuleNotFoundError as err: + HAVE_NUMBA = False + from spikeinterface.core.core_tools import define_function_from_class from .basepreprocessor import BasePreprocessor, BasePreprocessorSegment @@ -70,6 +76,10 @@ class BlankSaturationRecording(BasePreprocessor): fill_value: float or None The value to write instead of the saturating signal. If None, then the value is automatically computed as the median signal value + ms_before: float (default 0) + Time (ms) to replace before the saturation signal + ms_after: float (default 0) + Time (ms) to replace after the saturation signal num_chunks_per_segment: int (default 50) The number of chunks per segments to consider to estimate the threshold/fill_values chunk_size: int (default 500) @@ -83,8 +93,14 @@ class BlankSaturationRecording(BasePreprocessor): The filtered traces recording extractor object """ + name = 'blank_staturation' + + def __init__(self, recording, abs_threshold=None, quantile_threshold=None, + direction='upper', fill_value=None, + ms_before=0, ms_after=0, + num_chunks_per_segment=50, chunk_size=500, seed=0): + - name = "blank_staturation" def __init__( self, @@ -135,41 +151,106 @@ def __init__( BasePreprocessor.__init__(self, recording) for parent_segment in recording._recording_segments: - rec_segment = ClipRecordingSegment(parent_segment, a_min, value_min, a_max, value_max) + rec_segment = ClipRecordingSegment( + parent_segment, a_min, value_min, a_max, value_max, + ms_before=ms_before, ms_after=ms_after + ) self.add_recording_segment(rec_segment) - self._kwargs = dict( - recording=recording, - abs_threshold=abs_threshold, - quantile_threshold=quantile_threshold, - direction=direction, - fill_value=fill_value, - num_chunks_per_segment=num_chunks_per_segment, - chunk_size=chunk_size, - seed=seed, - ) + self._kwargs = dict(recording=recording, abs_threshold=abs_threshold, ms_before=ms_before, ms_after=ms_after, + quantile_threshold=quantile_threshold, direction=direction, fill_value=fill_value, + num_chunks_per_segment=num_chunks_per_segment, chunk_size=chunk_size, + seed=seed) + class ClipRecordingSegment(BasePreprocessorSegment): - def __init__(self, parent_recording_segment, a_min, value_min, a_max, value_max): + def __init__(self, parent_recording_segment, a_min, value_min, a_max, value_max, + ms_before=0, ms_after=0): BasePreprocessorSegment.__init__(self, parent_recording_segment) self.a_min = a_min self.value_min = value_min self.a_max = a_max self.value_max = value_max + self.ms_before = ms_before + self.ms_after = ms_after + def get_traces(self, start_frame, end_frame, channel_indices): traces = self.parent_recording_segment.get_traces(start_frame, end_frame, channel_indices) traces = traces.copy() + fs = self.parent_recording_segment.sampling_frequency + + frames_before = int(self.ms_before * fs // 1000) + frames_after = int(self.ms_after * fs // 1000) if self.a_min is not None: - traces[traces <= self.a_min] = self.value_min + traces = replace_slice_min(traces, self.a_min, frames_before, frames_after, self.value_min) + if self.a_max is not None: - traces[traces >= self.a_max] = self.value_max + traces = replace_slice_max(traces, self.a_max, frames_before, frames_after, self.value_max) return traces +def replace_slice_min(traces, a_min, frames_before, frames_after, value_min): + if HAVE_NUMBA: + return _replace_slice_min_numba(traces, a_min, frames_before, frames_after, value_min) + else: + return _replace_slice_min_for_loop(traces, a_min, frames_before, frames_after, value_min) + +def replace_slice_max(traces, a_max, frames_before, frames_after, value_max): + if HAVE_NUMBA: + return _replace_slice_max_numba(traces, a_max, frames_before, frames_after, value_max) + else: + return _replace_slice_max_for_loop(traces, a_max, frames_before, frames_after, value_max) + +# For loops +def _replace_slice_min_for_loop(traces, a_min, frames_before, frames_after, value_min): + min_indices, channels = np.where(traces <= a_min) + for index, chan in zip(min_indices, channels): + traces[max(0, index - frames_before):min(len(traces), index + frames_after + 1), chan] = value_min + return traces + +def _replace_slice_max_for_loop(traces, a_max, frames_before, frames_after, value_max): + max_indices, channels = np.where(traces >= a_max) + for index, chan in zip(max_indices, channels): + traces[max(0, index - frames_before):min(len(traces), index + frames_after + 1), chan] = value_max + return traces + +if HAVE_NUMBA: + # Numba + @njit(cache=True) + def _replace_slice_max_numba(traces, a_max, frames_before, frames_after, value_max): + m, n = traces.shape + to_clear = np.zeros(m, dtype=np.bool_) + for j in range(n): + to_clear[:] = False + for i in range(m): + if traces[i, j] >= a_max: + to_clear[ + max(0, i - frames_before) : min(m, i + frames_after + 1) + ] = True + for i in range(m): + if to_clear[i]: + traces[i, j] = value_max + return traces + + @njit(cache=True) + def _replace_slice_min_numba(traces, a_min, frames_before, frames_after, value_min): + m, n = traces.shape + to_clear = np.zeros(m, dtype=np.bool_) + for j in range(n): + to_clear[:] = False + for i in range(m): + if traces[i, j] <= a_min: + to_clear[ + max(0, i - frames_before) : min(m, i + frames_after + 1) + ] = True + for i in range(m): + if to_clear[i]: + traces[i, j] = value_min + return traces clip = define_function_from_class(source_class=ClipRecording, name="clip") blank_staturation = define_function_from_class(source_class=BlankSaturationRecording, name="blank_staturation")