From d545d6acb1b5ce395130cfb80e321abba055fedf Mon Sep 17 00:00:00 2001 From: Sebastien Date: Fri, 21 Feb 2025 11:12:55 +0100 Subject: [PATCH 01/19] WIP --- .../preprocessing/preprocessinglist.py | 2 + .../preprocessing/silence_artefacts.py | 200 ++++++++++++++++++ .../preprocessing/silence_periods.py | 2 +- 3 files changed, 203 insertions(+), 1 deletion(-) create mode 100644 src/spikeinterface/preprocessing/silence_artefacts.py diff --git a/src/spikeinterface/preprocessing/preprocessinglist.py b/src/spikeinterface/preprocessing/preprocessinglist.py index bdf5f2219c..1708af2df1 100644 --- a/src/spikeinterface/preprocessing/preprocessinglist.py +++ b/src/spikeinterface/preprocessing/preprocessinglist.py @@ -43,6 +43,7 @@ from .depth_order import DepthOrderRecording, depth_order from .astype import AstypeRecording, astype from .unsigned_to_signed import UnsignedToSignedRecording, unsigned_to_signed +from .silence_artefacts import SilenceArtefactsRecording, silence_artefacts preprocessers_full_list = [ @@ -79,6 +80,7 @@ DirectionalDerivativeRecording, AstypeRecording, UnsignedToSignedRecording, + SilenceArtefactsRecording, ] preprocesser_dict = {pp_class.name: pp_class for pp_class in preprocessers_full_list} diff --git a/src/spikeinterface/preprocessing/silence_artefacts.py b/src/spikeinterface/preprocessing/silence_artefacts.py new file mode 100644 index 0000000000..70ba086ca8 --- /dev/null +++ b/src/spikeinterface/preprocessing/silence_artefacts.py @@ -0,0 +1,200 @@ +from __future__ import annotations + +import numpy as np + +from spikeinterface.core.core_tools import define_function_from_class +from silence_periods import SilencedPeriodsRecording +from rectify import RectifyRecording +from filter_gaussian import FilterGaussianRecording +from ..core.job_tools import split_job_kwargs, fix_job_kwargs + +from ..core import get_noise_levels +from ..core.generate import NoiseGeneratorRecording +from .basepreprocessor import BasePreprocessor + + +from ..core.node_pipeline import PeakDetector, base_peak_dtype +import numpy as np + +class DetectThresholdCrossing(PeakDetector): + + name = "threshold_crossings" + preferred_mp_context = None + + def __init__( + self, + recording, + detect_threshold=5, + noise_levels=None, + random_chunk_kwargs={}, + ): + PeakDetector.__init__(self, recording, return_output=True) + if noise_levels is None: + noise_levels = get_noise_levels(recording, return_scaled=False, **random_chunk_kwargs) + self.abs_thresholds = noise_levels * detect_threshold + self._dtype = np.dtype(base_peak_dtype + [("onset", "bool")]) + + def get_trace_margin(self): + return 0 + + def get_dtype(self): + return self._dtype + + def compute(self, traces, start_frame, end_frame, segment_index, max_margin): + z = (traces - self.abs_thresholds) + threshold_mask = np.diff((z > 0) != 0, axis=0) + indices = np.where(threshold_mask) + local_peaks = np.zeros(indices[0].size, dtype=self._dtype) + local_peaks["sample_index"] = indices[0] + local_peaks["channel_index"] = indices[1] + for channel_ind in np.unique(indices[1]): + mask = np.flatnonzero(indices[1] == channel_ind) + local_peaks["onset"][mask[::2]] = True + local_peaks["onset"][mask[1::2]] = False + idx = np.argsort(local_peaks["sample_index"]) + local_peaks = local_peaks[idx] + return (local_peaks, ) + + +def detect_onsets(enveloppe, detect_threshold=5, **job_kwargs): + + from spikeinterface.core.node_pipeline import ( + run_node_pipeline, + ) + + node0 = DetectThresholdCrossing(enveloppe) + + peaks = run_node_pipeline( + enveloppe, + [node0], + job_kwargs, + job_name="detecting threshold crossings", + ) + + results = {} + for channel_ind, channel_id in enumerate(enveloppe.channel_ids): + + mask = peaks["channel_index"] == channel_ind + sub_peaks = peaks[mask] + onset_mask = sub_peaks["onset"] == 1 + onsets = sub_peaks[onset_mask] + offsets = sub_peaks[~onset_mask] + periods = [] + + if len(onsets) > 0: + if onsets['sample_index'][0] > offsets['sample_index'][0]: + periods += [(0, offsets['sample_index'][0])] + offsets = offsets[1:] + + for i in range(len(onsets)): + periods += [(onsets['sample_index'][i], offsets['sample_index'][i])] + + if len(onsets) > len(offsets): + periods += [(onsets['sample_index'][0], enveloppe.get_num_samples())] + + results[channel_id] = periods + + return results + + +class SilencedArtefactsRecording(SilencedPeriodsRecording): + """ + Silence user-defined periods from recording extractor traces. The code will construct + an enveloppe of the recording (as a low pass filtered version of the traces) and detect + threshold crossings to identify the periods to silence. The periods are then silenced either + on a per channel basis or across all channels by replacing the values by zeros or by + adding gaussian noise with the same variance as the one in the recordings + + Parameters + ---------- + recording : RecordingExtractor + The recording extractor to silence putative artefacts + per_channel : bool, default: True + If True, the periods are silenced on a per channel basis. If False, the periods are silenced + across all channels + detect_threshold : float, default: 5 + The threshold to detect artefacts. The threshold is computed as `detect_threshold * noise_level` + freq_max : float, default: 20 + The maximum frequency for the low pass filter used + noise_levels : array + Noise levels if already computed + seed : int | None, default: None + Random seed for `get_noise_levels` and `NoiseGeneratorRecording`. + If none, `get_noise_levels` uses `seed=0` and `NoiseGeneratorRecording` generates a random seed using `numpy.random.default_rng`. + mode : "zeros" | "noise, default: "zeros" + Determines what periods are replaced by. Can be one of the following: + + - "zeros": Artifacts are replaced by zeros. + + - "noise": The periods are filled with a gaussion noise that has the + same variance that the one in the recordings, on a per channel + basis + **random_chunk_kwargs : Keyword arguments for `spikeinterface.core.get_random_data_chunk()` function + + Returns + ------- + silenced_recording : SilencedArtefactsRecording + The recording extractor after silencing detected artefacts + """ + + def __init__(self, + recording, + per_channel=True, + detect_threshold=5, + freq_max=20., + mode="zeros", + noise_levels=None, + seed=None, + **random_chunk_kwargs): + + + _, job_kwargs = split_job_kwargs(random_chunk_kwargs) + job_kwargs = fix_job_kwargs(job_kwargs) + + available_modes = ("zeros", "noise") + recording = RectifyRecording(recording) + recording = FilterGaussianRecording(recording, freq_min=0, freq_max=20) + + periods = detect_onsets(recording, + detect_threshold=detect_threshold, + **random_chunk_kwargs) + + # some checks + assert mode in available_modes, f"mode {mode} is not an available mode: {available_modes}" + + if mode in ["noise"]: + if noise_levels is None: + random_slices_kwargs = random_chunk_kwargs.copy() + random_slices_kwargs["seed"] = seed + noise_levels = get_noise_levels( + recording, return_scaled=False, random_slices_kwargs=random_slices_kwargs + ) + noise_generator = NoiseGeneratorRecording( + num_channels=recording.get_num_channels(), + sampling_frequency=recording.sampling_frequency, + durations=[recording.select_segments(i).get_duration() for i in range(recording.get_num_segments())], + dtype=recording.dtype, + seed=seed, + noise_levels=noise_levels, + strategy="on_the_fly", + noise_block_size=int(recording.sampling_frequency), + ) + else: + noise_generator = None + + + + BasePreprocessor.__init__(self, recording) + for seg_index, parent_segment in enumerate(recording._recording_segments): + periods = list_periods[seg_index] + periods = np.asarray(periods, dtype="int64") + periods = np.sort(periods, axis=0) + rec_segment = SilencedArtefactsRecording(parent_segment, periods, mode, noise_generator, seg_index) + self.add_recording_segment(rec_segment) + + self._kwargs = dict(recording=recording, list_periods=list_periods, mode=mode, seed=seed) + self._kwargs.update(random_chunk_kwargs) + + +# function for API +silence_artefacts = define_function_from_class(source_class=SilencedArtefactsRecording, name="silence_artefacts") diff --git a/src/spikeinterface/preprocessing/silence_periods.py b/src/spikeinterface/preprocessing/silence_periods.py index 00d9a1a407..8adeb879cf 100644 --- a/src/spikeinterface/preprocessing/silence_periods.py +++ b/src/spikeinterface/preprocessing/silence_periods.py @@ -5,7 +5,7 @@ from spikeinterface.core.core_tools import define_function_from_class from .basepreprocessor import BasePreprocessor, BasePreprocessorSegment -from ..core import get_random_data_chunks, get_noise_levels +from ..core import get_noise_levels from ..core.generate import NoiseGeneratorRecording From c69b7be7c53f6c4360668317f0d66ab1d09eb360 Mon Sep 17 00:00:00 2001 From: Sebastien Date: Fri, 21 Feb 2025 11:29:34 +0100 Subject: [PATCH 02/19] WIP --- .../preprocessing/preprocessinglist.py | 4 +- .../preprocessing/silence_artefacts.py | 123 +++++++----------- 2 files changed, 49 insertions(+), 78 deletions(-) diff --git a/src/spikeinterface/preprocessing/preprocessinglist.py b/src/spikeinterface/preprocessing/preprocessinglist.py index 1708af2df1..4fc397a50d 100644 --- a/src/spikeinterface/preprocessing/preprocessinglist.py +++ b/src/spikeinterface/preprocessing/preprocessinglist.py @@ -43,7 +43,7 @@ from .depth_order import DepthOrderRecording, depth_order from .astype import AstypeRecording, astype from .unsigned_to_signed import UnsignedToSignedRecording, unsigned_to_signed -from .silence_artefacts import SilenceArtefactsRecording, silence_artefacts +from .silence_artefacts import SilencedArtefactsRecording, silence_artefacts preprocessers_full_list = [ @@ -80,7 +80,7 @@ DirectionalDerivativeRecording, AstypeRecording, UnsignedToSignedRecording, - SilenceArtefactsRecording, + SilencedArtefactsRecording, ] preprocesser_dict = {pp_class.name: pp_class for pp_class in preprocessers_full_list} diff --git a/src/spikeinterface/preprocessing/silence_artefacts.py b/src/spikeinterface/preprocessing/silence_artefacts.py index 70ba086ca8..fc53237a5c 100644 --- a/src/spikeinterface/preprocessing/silence_artefacts.py +++ b/src/spikeinterface/preprocessing/silence_artefacts.py @@ -3,9 +3,9 @@ import numpy as np from spikeinterface.core.core_tools import define_function_from_class -from silence_periods import SilencedPeriodsRecording -from rectify import RectifyRecording -from filter_gaussian import FilterGaussianRecording +from .silence_periods import SilencedPeriodsRecording +from .rectify import RectifyRecording +from .filter_gaussian import GaussianFilterRecording from ..core.job_tools import split_job_kwargs, fix_job_kwargs from ..core import get_noise_levels @@ -41,58 +41,61 @@ def get_dtype(self): return self._dtype def compute(self, traces, start_frame, end_frame, segment_index, max_margin): - z = (traces - self.abs_thresholds) + z = (traces - self.abs_thresholds).mean(1) threshold_mask = np.diff((z > 0) != 0, axis=0) - indices = np.where(threshold_mask) - local_peaks = np.zeros(indices[0].size, dtype=self._dtype) - local_peaks["sample_index"] = indices[0] - local_peaks["channel_index"] = indices[1] - for channel_ind in np.unique(indices[1]): - mask = np.flatnonzero(indices[1] == channel_ind) - local_peaks["onset"][mask[::2]] = True - local_peaks["onset"][mask[1::2]] = False - idx = np.argsort(local_peaks["sample_index"]) - local_peaks = local_peaks[idx] + indices, = np.where(threshold_mask) + local_peaks = np.zeros(indices.size, dtype=self._dtype) + local_peaks["sample_index"] = indices + #local_peaks["channel_index"] = indices[-1] + #for channel_ind in np.unique(indices[1]): + # mask = np.flatnonzero(indices[1] == channel_ind) + # local_peaks["onset"][mask[::2]] = True + # local_peaks["onset"][mask[1::2]] = False + #idx = np.argsort(local_peaks["sample_index"]) + #local_peaks = local_peaks[idx] + local_peaks["onset"][::2] = True + local_peaks["onset"][1::2] = False return (local_peaks, ) -def detect_onsets(enveloppe, detect_threshold=5, **job_kwargs): +def detect_onsets(recording, detect_threshold=5, **job_kwargs): from spikeinterface.core.node_pipeline import ( run_node_pipeline, ) - node0 = DetectThresholdCrossing(enveloppe) + node0 = DetectThresholdCrossing(recording, detect_threshold, **job_kwargs) peaks = run_node_pipeline( - enveloppe, + recording, [node0], job_kwargs, job_name="detecting threshold crossings", ) - results = {} - for channel_ind, channel_id in enumerate(enveloppe.channel_ids): - - mask = peaks["channel_index"] == channel_ind - sub_peaks = peaks[mask] - onset_mask = sub_peaks["onset"] == 1 - onsets = sub_peaks[onset_mask] - offsets = sub_peaks[~onset_mask] - periods = [] - - if len(onsets) > 0: - if onsets['sample_index'][0] > offsets['sample_index'][0]: - periods += [(0, offsets['sample_index'][0])] - offsets = offsets[1:] + results = [] + print(peaks) + # for channel_ind, channel_id in enumerate(recording.channel_ids): + + # mask = peaks["channel_index"] == channel_ind + # sub_peaks = peaks[mask] + # onset_mask = sub_peaks["onset"] == 1 + # onsets = sub_peaks[onset_mask] + # offsets = sub_peaks[~onset_mask] + # periods = [] + + # # if len(onsets) > 0: + # # if onsets['sample_index'][0] > offsets['sample_index'][0]: + # # periods += [(0, offsets['sample_index'][0])] + # # offsets = offsets[1:] - for i in range(len(onsets)): - periods += [(onsets['sample_index'][i], offsets['sample_index'][i])] + # # for i in range(len(onsets)): + # # periods += [(onsets['sample_index'][i], offsets['sample_index'][i])] - if len(onsets) > len(offsets): - periods += [(onsets['sample_index'][0], enveloppe.get_num_samples())] + # # if len(onsets) > len(offsets): + # # periods += [(onsets['sample_index'][0], recording.get_num_samples())] - results[channel_id] = periods + # results[channel_id] = periods return results @@ -109,9 +112,6 @@ class SilencedArtefactsRecording(SilencedPeriodsRecording): ---------- recording : RecordingExtractor The recording extractor to silence putative artefacts - per_channel : bool, default: True - If True, the periods are silenced on a per channel basis. If False, the periods are silenced - across all channels detect_threshold : float, default: 5 The threshold to detect artefacts. The threshold is computed as `detect_threshold * noise_level` freq_max : float, default: 20 @@ -151,49 +151,20 @@ def __init__(self, _, job_kwargs = split_job_kwargs(random_chunk_kwargs) job_kwargs = fix_job_kwargs(job_kwargs) - available_modes = ("zeros", "noise") recording = RectifyRecording(recording) - recording = FilterGaussianRecording(recording, freq_min=0, freq_max=20) + recording = GaussianFilterRecording(recording, freq_min=None, freq_max=freq_max) periods = detect_onsets(recording, detect_threshold=detect_threshold, **random_chunk_kwargs) - # some checks - assert mode in available_modes, f"mode {mode} is not an available mode: {available_modes}" - - if mode in ["noise"]: - if noise_levels is None: - random_slices_kwargs = random_chunk_kwargs.copy() - random_slices_kwargs["seed"] = seed - noise_levels = get_noise_levels( - recording, return_scaled=False, random_slices_kwargs=random_slices_kwargs - ) - noise_generator = NoiseGeneratorRecording( - num_channels=recording.get_num_channels(), - sampling_frequency=recording.sampling_frequency, - durations=[recording.select_segments(i).get_duration() for i in range(recording.get_num_segments())], - dtype=recording.dtype, - seed=seed, - noise_levels=noise_levels, - strategy="on_the_fly", - noise_block_size=int(recording.sampling_frequency), - ) - else: - noise_generator = None - - - - BasePreprocessor.__init__(self, recording) - for seg_index, parent_segment in enumerate(recording._recording_segments): - periods = list_periods[seg_index] - periods = np.asarray(periods, dtype="int64") - periods = np.sort(periods, axis=0) - rec_segment = SilencedArtefactsRecording(parent_segment, periods, mode, noise_generator, seg_index) - self.add_recording_segment(rec_segment) - - self._kwargs = dict(recording=recording, list_periods=list_periods, mode=mode, seed=seed) - self._kwargs.update(random_chunk_kwargs) + SilencedPeriodsRecording.__init__(self, + recording, + periods, + mode=mode, + noise_levels=noise_levels, + seed=seed, + **random_chunk_kwargs) # function for API From c5a538c4676b2e4c24f5b6c65fc95e49df0dc1f4 Mon Sep 17 00:00:00 2001 From: Sebastien Date: Fri, 21 Feb 2025 11:43:28 +0100 Subject: [PATCH 03/19] WIP --- .../preprocessing/silence_artefacts.py | 74 ++++++++----------- 1 file changed, 31 insertions(+), 43 deletions(-) diff --git a/src/spikeinterface/preprocessing/silence_artefacts.py b/src/spikeinterface/preprocessing/silence_artefacts.py index fc53237a5c..132b168b09 100644 --- a/src/spikeinterface/preprocessing/silence_artefacts.py +++ b/src/spikeinterface/preprocessing/silence_artefacts.py @@ -7,10 +7,7 @@ from .rectify import RectifyRecording from .filter_gaussian import GaussianFilterRecording from ..core.job_tools import split_job_kwargs, fix_job_kwargs - from ..core import get_noise_levels -from ..core.generate import NoiseGeneratorRecording -from .basepreprocessor import BasePreprocessor from ..core.node_pipeline import PeakDetector, base_peak_dtype @@ -43,61 +40,54 @@ def get_dtype(self): def compute(self, traces, start_frame, end_frame, segment_index, max_margin): z = (traces - self.abs_thresholds).mean(1) threshold_mask = np.diff((z > 0) != 0, axis=0) - indices, = np.where(threshold_mask) + indices = np.flatnonzero(threshold_mask) local_peaks = np.zeros(indices.size, dtype=self._dtype) local_peaks["sample_index"] = indices - #local_peaks["channel_index"] = indices[-1] - #for channel_ind in np.unique(indices[1]): - # mask = np.flatnonzero(indices[1] == channel_ind) - # local_peaks["onset"][mask[::2]] = True - # local_peaks["onset"][mask[1::2]] = False - #idx = np.argsort(local_peaks["sample_index"]) - #local_peaks = local_peaks[idx] local_peaks["onset"][::2] = True local_peaks["onset"][1::2] = False return (local_peaks, ) -def detect_onsets(recording, detect_threshold=5, **job_kwargs): +def detect_onsets(recording, detect_threshold=5, **extra_kwargs): from spikeinterface.core.node_pipeline import ( run_node_pipeline, ) - node0 = DetectThresholdCrossing(recording, detect_threshold, **job_kwargs) + random_chunk_kwargs, job_kwargs = split_job_kwargs(extra_kwargs) + job_kwargs = fix_job_kwargs(job_kwargs) + + node0 = DetectThresholdCrossing(recording, detect_threshold, **random_chunk_kwargs) peaks = run_node_pipeline( recording, [node0], job_kwargs, - job_name="detecting threshold crossings", + job_name="detect threshold crossings", ) - results = [] - print(peaks) - # for channel_ind, channel_id in enumerate(recording.channel_ids): - - # mask = peaks["channel_index"] == channel_ind - # sub_peaks = peaks[mask] - # onset_mask = sub_peaks["onset"] == 1 - # onsets = sub_peaks[onset_mask] - # offsets = sub_peaks[~onset_mask] - # periods = [] - - # # if len(onsets) > 0: - # # if onsets['sample_index'][0] > offsets['sample_index'][0]: - # # periods += [(0, offsets['sample_index'][0])] - # # offsets = offsets[1:] - - # # for i in range(len(onsets)): - # # periods += [(onsets['sample_index'][i], offsets['sample_index'][i])] - - # # if len(onsets) > len(offsets): - # # periods += [(onsets['sample_index'][0], recording.get_num_samples())] - - # results[channel_id] = periods + periods = [] + num_seg = recording.get_num_segments() + for seg_index in range(num_seg): + sub_periods = [] + mask = peaks["segment_index"] == 0 + sub_peaks = peaks[mask] + onsets = sub_peaks[sub_peaks['onset']] + offsets = sub_peaks[~sub_peaks['onset']] + + if onsets['sample_index'][0] > offsets['sample_index'][0]: + sub_periods += [(0, offsets['sample_index'][0])] + offsets = offsets[1:] - return results + for i in range(min(len(onsets), len(offsets))): + sub_periods += [(onsets['sample_index'][i], offsets['sample_index'][i])] + + if len(onsets) > len(offsets): + sub_periods += [(onsets['sample_index'][0], recording.get_num_samples(seg_index))] + + periods.append(sub_periods) + + return periods class SilencedArtefactsRecording(SilencedPeriodsRecording): @@ -139,17 +129,12 @@ class SilencedArtefactsRecording(SilencedPeriodsRecording): def __init__(self, recording, - per_channel=True, detect_threshold=5, freq_max=20., mode="zeros", noise_levels=None, seed=None, **random_chunk_kwargs): - - - _, job_kwargs = split_job_kwargs(random_chunk_kwargs) - job_kwargs = fix_job_kwargs(job_kwargs) recording = RectifyRecording(recording) recording = GaussianFilterRecording(recording, freq_min=None, freq_max=freq_max) @@ -166,6 +151,9 @@ def __init__(self, seed=seed, **random_chunk_kwargs) + #self._kwargs = dict(recording=recording, list_periods=list_periods, mode=mode, seed=seed) + #self._kwargs.update(random_chunk_kwargs) + # function for API silence_artefacts = define_function_from_class(source_class=SilencedArtefactsRecording, name="silence_artefacts") From b1ce7262b2ad73960c1319c0ee1d6913993375bf Mon Sep 17 00:00:00 2001 From: Sebastien Date: Fri, 21 Feb 2025 11:47:48 +0100 Subject: [PATCH 04/19] Finishing the node --- .../preprocessing/silence_artefacts.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/src/spikeinterface/preprocessing/silence_artefacts.py b/src/spikeinterface/preprocessing/silence_artefacts.py index 132b168b09..82b8af999d 100644 --- a/src/spikeinterface/preprocessing/silence_artefacts.py +++ b/src/spikeinterface/preprocessing/silence_artefacts.py @@ -111,7 +111,7 @@ class SilencedArtefactsRecording(SilencedPeriodsRecording): seed : int | None, default: None Random seed for `get_noise_levels` and `NoiseGeneratorRecording`. If none, `get_noise_levels` uses `seed=0` and `NoiseGeneratorRecording` generates a random seed using `numpy.random.default_rng`. - mode : "zeros" | "noise, default: "zeros" + mode : "zeros" | "noise", default: "zeros" Determines what periods are replaced by. Can be one of the following: - "zeros": Artifacts are replaced by zeros. @@ -131,28 +131,29 @@ def __init__(self, recording, detect_threshold=5, freq_max=20., - mode="zeros", + mode="noise", noise_levels=None, - seed=None, + seed=None, + list_periods=None, **random_chunk_kwargs): - recording = RectifyRecording(recording) - recording = GaussianFilterRecording(recording, freq_min=None, freq_max=freq_max) + enveloppe = RectifyRecording(recording) + enveloppe = GaussianFilterRecording(enveloppe, freq_min=None, freq_max=freq_max) - periods = detect_onsets(recording, + if list_periods is None: + list_periods = detect_onsets(enveloppe, detect_threshold=detect_threshold, **random_chunk_kwargs) SilencedPeriodsRecording.__init__(self, recording, - periods, + list_periods, mode=mode, noise_levels=noise_levels, seed=seed, **random_chunk_kwargs) - #self._kwargs = dict(recording=recording, list_periods=list_periods, mode=mode, seed=seed) - #self._kwargs.update(random_chunk_kwargs) + self._kwargs.update({'detect_threshold' : detect_threshold, 'freq_max' : freq_max}) # function for API From 816e4bc519428919abb37d29cf3281c47cb1db69 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 21 Feb 2025 11:11:48 +0000 Subject: [PATCH 05/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../preprocessing/silence_artefacts.py | 81 +++++++++---------- 1 file changed, 39 insertions(+), 42 deletions(-) diff --git a/src/spikeinterface/preprocessing/silence_artefacts.py b/src/spikeinterface/preprocessing/silence_artefacts.py index 82b8af999d..170b843d4b 100644 --- a/src/spikeinterface/preprocessing/silence_artefacts.py +++ b/src/spikeinterface/preprocessing/silence_artefacts.py @@ -7,17 +7,18 @@ from .rectify import RectifyRecording from .filter_gaussian import GaussianFilterRecording from ..core.job_tools import split_job_kwargs, fix_job_kwargs -from ..core import get_noise_levels +from ..core import get_noise_levels from ..core.node_pipeline import PeakDetector, base_peak_dtype import numpy as np + class DetectThresholdCrossing(PeakDetector): - + name = "threshold_crossings" preferred_mp_context = None - + def __init__( self, recording, @@ -36,16 +37,16 @@ def get_trace_margin(self): def get_dtype(self): return self._dtype - + def compute(self, traces, start_frame, end_frame, segment_index, max_margin): - z = (traces - self.abs_thresholds).mean(1) + z = (traces - self.abs_thresholds).mean(1) threshold_mask = np.diff((z > 0) != 0, axis=0) - indices = np.flatnonzero(threshold_mask) + indices = np.flatnonzero(threshold_mask) local_peaks = np.zeros(indices.size, dtype=self._dtype) local_peaks["sample_index"] = indices local_peaks["onset"][::2] = True local_peaks["onset"][1::2] = False - return (local_peaks, ) + return (local_peaks,) def detect_onsets(recording, detect_threshold=5, **extra_kwargs): @@ -54,17 +55,17 @@ def detect_onsets(recording, detect_threshold=5, **extra_kwargs): run_node_pipeline, ) - random_chunk_kwargs, job_kwargs = split_job_kwargs(extra_kwargs) + random_chunk_kwargs, job_kwargs = split_job_kwargs(extra_kwargs) job_kwargs = fix_job_kwargs(job_kwargs) node0 = DetectThresholdCrossing(recording, detect_threshold, **random_chunk_kwargs) - + peaks = run_node_pipeline( recording, [node0], job_kwargs, job_name="detect threshold crossings", - ) + ) periods = [] num_seg = recording.get_num_segments() @@ -72,21 +73,21 @@ def detect_onsets(recording, detect_threshold=5, **extra_kwargs): sub_periods = [] mask = peaks["segment_index"] == 0 sub_peaks = peaks[mask] - onsets = sub_peaks[sub_peaks['onset']] - offsets = sub_peaks[~sub_peaks['onset']] - - if onsets['sample_index'][0] > offsets['sample_index'][0]: - sub_periods += [(0, offsets['sample_index'][0])] + onsets = sub_peaks[sub_peaks["onset"]] + offsets = sub_peaks[~sub_peaks["onset"]] + + if onsets["sample_index"][0] > offsets["sample_index"][0]: + sub_periods += [(0, offsets["sample_index"][0])] offsets = offsets[1:] - + for i in range(min(len(onsets), len(offsets))): - sub_periods += [(onsets['sample_index'][i], offsets['sample_index'][i])] + sub_periods += [(onsets["sample_index"][i], offsets["sample_index"][i])] if len(onsets) > len(offsets): - sub_periods += [(onsets['sample_index'][0], recording.get_num_samples(seg_index))] + sub_periods += [(onsets["sample_index"][0], recording.get_num_samples(seg_index))] periods.append(sub_periods) - + return periods @@ -95,7 +96,7 @@ class SilencedArtefactsRecording(SilencedPeriodsRecording): Silence user-defined periods from recording extractor traces. The code will construct an enveloppe of the recording (as a low pass filtered version of the traces) and detect threshold crossings to identify the periods to silence. The periods are then silenced either - on a per channel basis or across all channels by replacing the values by zeros or by + on a per channel basis or across all channels by replacing the values by zeros or by adding gaussian noise with the same variance as the one in the recordings Parameters @@ -127,33 +128,29 @@ class SilencedArtefactsRecording(SilencedPeriodsRecording): The recording extractor after silencing detected artefacts """ - def __init__(self, - recording, - detect_threshold=5, - freq_max=20., - mode="noise", - noise_levels=None, - seed=None, - list_periods=None, - **random_chunk_kwargs): + def __init__( + self, + recording, + detect_threshold=5, + freq_max=20.0, + mode="noise", + noise_levels=None, + seed=None, + list_periods=None, + **random_chunk_kwargs, + ): enveloppe = RectifyRecording(recording) enveloppe = GaussianFilterRecording(enveloppe, freq_min=None, freq_max=freq_max) if list_periods is None: - list_periods = detect_onsets(enveloppe, - detect_threshold=detect_threshold, - **random_chunk_kwargs) - - SilencedPeriodsRecording.__init__(self, - recording, - list_periods, - mode=mode, - noise_levels=noise_levels, - seed=seed, - **random_chunk_kwargs) - - self._kwargs.update({'detect_threshold' : detect_threshold, 'freq_max' : freq_max}) + list_periods = detect_onsets(enveloppe, detect_threshold=detect_threshold, **random_chunk_kwargs) + + SilencedPeriodsRecording.__init__( + self, recording, list_periods, mode=mode, noise_levels=noise_levels, seed=seed, **random_chunk_kwargs + ) + + self._kwargs.update({"detect_threshold": detect_threshold, "freq_max": freq_max}) # function for API From d2958510b868de9d67660faee6a52f53dc3c7018 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Fri, 21 Feb 2025 13:18:12 +0100 Subject: [PATCH 06/19] WIP --- .../preprocessing/silence_artefacts.py | 26 +++++++++++++------ 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/src/spikeinterface/preprocessing/silence_artefacts.py b/src/spikeinterface/preprocessing/silence_artefacts.py index 82b8af999d..d92373fe20 100644 --- a/src/spikeinterface/preprocessing/silence_artefacts.py +++ b/src/spikeinterface/preprocessing/silence_artefacts.py @@ -38,9 +38,9 @@ def get_dtype(self): return self._dtype def compute(self, traces, start_frame, end_frame, segment_index, max_margin): - z = (traces - self.abs_thresholds).mean(1) + z = (traces - self.abs_thresholds).mean(1) threshold_mask = np.diff((z > 0) != 0, axis=0) - indices = np.flatnonzero(threshold_mask) + indices = np.flatnonzero(threshold_mask) local_peaks = np.zeros(indices.size, dtype=self._dtype) local_peaks["sample_index"] = indices local_peaks["onset"][::2] = True @@ -75,6 +75,10 @@ def detect_onsets(recording, detect_threshold=5, **extra_kwargs): onsets = sub_peaks[sub_peaks['onset']] offsets = sub_peaks[~sub_peaks['onset']] + if len(onsets) == 0 and len(offsets) == 0: + periods.append([]) + continue + if onsets['sample_index'][0] > offsets['sample_index'][0]: sub_periods += [(0, offsets['sample_index'][0])] offsets = offsets[1:] @@ -130,20 +134,26 @@ class SilencedArtefactsRecording(SilencedPeriodsRecording): def __init__(self, recording, detect_threshold=5, - freq_max=20., - mode="noise", + verbose=False, + freq_max=5., + mode="zeros", noise_levels=None, seed=None, list_periods=None, **random_chunk_kwargs): - enveloppe = RectifyRecording(recording) - enveloppe = GaussianFilterRecording(enveloppe, freq_min=None, freq_max=freq_max) + self.enveloppe = RectifyRecording(recording) + self.enveloppe = GaussianFilterRecording(self.enveloppe, freq_min=None, freq_max=freq_max) if list_periods is None: - list_periods = detect_onsets(enveloppe, + list_periods = detect_onsets(self.enveloppe, detect_threshold=detect_threshold, **random_chunk_kwargs) + if verbose: + for i, periods in enumerate(list_periods): + total_time = np.sum([end-start for start, end in periods]) + percentage = 100 * total_time / recording.get_num_samples(i) + print(f"{percentage}% of segment {i} has been flagged as artefactual") SilencedPeriodsRecording.__init__(self, recording, @@ -153,7 +163,7 @@ def __init__(self, seed=seed, **random_chunk_kwargs) - self._kwargs.update({'detect_threshold' : detect_threshold, 'freq_max' : freq_max}) + self._kwargs.update({'detect_threshold' : detect_threshold, 'freq_max' : freq_max, "verbose" : verbose, 'enveloppe' : self.enveloppe}) # function for API From 09291d47b6da6ef0f814432999ec5f9700d71b24 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 21 Feb 2025 12:19:45 +0000 Subject: [PATCH 07/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../preprocessing/silence_artefacts.py | 47 +++++++++++-------- 1 file changed, 27 insertions(+), 20 deletions(-) diff --git a/src/spikeinterface/preprocessing/silence_artefacts.py b/src/spikeinterface/preprocessing/silence_artefacts.py index 5229ac9b9c..5c7f72f3a2 100644 --- a/src/spikeinterface/preprocessing/silence_artefacts.py +++ b/src/spikeinterface/preprocessing/silence_artefacts.py @@ -73,15 +73,15 @@ def detect_onsets(recording, detect_threshold=5, **extra_kwargs): sub_periods = [] mask = peaks["segment_index"] == 0 sub_peaks = peaks[mask] - onsets = sub_peaks[sub_peaks['onset']] - offsets = sub_peaks[~sub_peaks['onset']] - + onsets = sub_peaks[sub_peaks["onset"]] + offsets = sub_peaks[~sub_peaks["onset"]] + if len(onsets) == 0 and len(offsets) == 0: periods.append([]) continue - if onsets['sample_index'][0] > offsets['sample_index'][0]: - sub_periods += [(0, offsets['sample_index'][0])] + if onsets["sample_index"][0] > offsets["sample_index"][0]: + sub_periods += [(0, offsets["sample_index"][0])] offsets = offsets[1:] for i in range(min(len(onsets), len(offsets))): @@ -132,27 +132,27 @@ class SilencedArtefactsRecording(SilencedPeriodsRecording): The recording extractor after silencing detected artefacts """ - def __init__(self, - recording, - detect_threshold=5, - verbose=False, - freq_max=5., - mode="zeros", - noise_levels=None, - seed=None, - list_periods=None, - **random_chunk_kwargs): + def __init__( + self, + recording, + detect_threshold=5, + verbose=False, + freq_max=5.0, + mode="zeros", + noise_levels=None, + seed=None, + list_periods=None, + **random_chunk_kwargs, + ): self.enveloppe = RectifyRecording(recording) self.enveloppe = GaussianFilterRecording(self.enveloppe, freq_min=None, freq_max=freq_max) if list_periods is None: - list_periods = detect_onsets(self.enveloppe, - detect_threshold=detect_threshold, - **random_chunk_kwargs) + list_periods = detect_onsets(self.enveloppe, detect_threshold=detect_threshold, **random_chunk_kwargs) if verbose: for i, periods in enumerate(list_periods): - total_time = np.sum([end-start for start, end in periods]) + total_time = np.sum([end - start for start, end in periods]) percentage = 100 * total_time / recording.get_num_samples(i) print(f"{percentage}% of segment {i} has been flagged as artefactual") @@ -160,7 +160,14 @@ def __init__(self, self, recording, list_periods, mode=mode, noise_levels=noise_levels, seed=seed, **random_chunk_kwargs ) - self._kwargs.update({'detect_threshold' : detect_threshold, 'freq_max' : freq_max, "verbose" : verbose, 'enveloppe' : self.enveloppe}) + self._kwargs.update( + { + "detect_threshold": detect_threshold, + "freq_max": freq_max, + "verbose": verbose, + "enveloppe": self.enveloppe, + } + ) # function for API From 3cf2615afb15587815074a65624ec3fc975cdc3b Mon Sep 17 00:00:00 2001 From: Sebastien Date: Mon, 10 Mar 2025 12:45:47 +0100 Subject: [PATCH 08/19] Renaming --- .../{silence_artefacts.py => silence_artifacts.py} | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) rename src/spikeinterface/preprocessing/{silence_artefacts.py => silence_artifacts.py} (92%) diff --git a/src/spikeinterface/preprocessing/silence_artefacts.py b/src/spikeinterface/preprocessing/silence_artifacts.py similarity index 92% rename from src/spikeinterface/preprocessing/silence_artefacts.py rename to src/spikeinterface/preprocessing/silence_artifacts.py index 5c7f72f3a2..27edf2c32c 100644 --- a/src/spikeinterface/preprocessing/silence_artefacts.py +++ b/src/spikeinterface/preprocessing/silence_artifacts.py @@ -95,7 +95,7 @@ def detect_onsets(recording, detect_threshold=5, **extra_kwargs): return periods -class SilencedArtefactsRecording(SilencedPeriodsRecording): +class SilencedArtifactsRecording(SilencedPeriodsRecording): """ Silence user-defined periods from recording extractor traces. The code will construct an enveloppe of the recording (as a low pass filtered version of the traces) and detect @@ -106,9 +106,9 @@ class SilencedArtefactsRecording(SilencedPeriodsRecording): Parameters ---------- recording : RecordingExtractor - The recording extractor to silence putative artefacts + The recording extractor to silence putative artifacts detect_threshold : float, default: 5 - The threshold to detect artefacts. The threshold is computed as `detect_threshold * noise_level` + The threshold to detect artifacts. The threshold is computed as `detect_threshold * noise_level` freq_max : float, default: 20 The maximum frequency for the low pass filter used noise_levels : array @@ -128,8 +128,8 @@ class SilencedArtefactsRecording(SilencedPeriodsRecording): Returns ------- - silenced_recording : SilencedArtefactsRecording - The recording extractor after silencing detected artefacts + silenced_recording : SilencedArtifactsRecording + The recording extractor after silencing detected artifacts """ def __init__( @@ -154,7 +154,7 @@ def __init__( for i, periods in enumerate(list_periods): total_time = np.sum([end - start for start, end in periods]) percentage = 100 * total_time / recording.get_num_samples(i) - print(f"{percentage}% of segment {i} has been flagged as artefactual") + print(f"{percentage}% of segment {i} has been flagged as artifactual") SilencedPeriodsRecording.__init__( self, recording, list_periods, mode=mode, noise_levels=noise_levels, seed=seed, **random_chunk_kwargs @@ -171,4 +171,4 @@ def __init__( # function for API -silence_artefacts = define_function_from_class(source_class=SilencedArtefactsRecording, name="silence_artefacts") +silence_artifacts = define_function_from_class(source_class=SilencedArtifactsRecording, name="silence_artifacts") From ee74efca8ebb973917519b6a85982a0765ab7292 Mon Sep 17 00:00:00 2001 From: Sebastien Date: Mon, 10 Mar 2025 12:47:33 +0100 Subject: [PATCH 09/19] WIP --- src/spikeinterface/preprocessing/preprocessinglist.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/preprocessing/preprocessinglist.py b/src/spikeinterface/preprocessing/preprocessinglist.py index 4fc397a50d..f65690edf8 100644 --- a/src/spikeinterface/preprocessing/preprocessinglist.py +++ b/src/spikeinterface/preprocessing/preprocessinglist.py @@ -43,7 +43,7 @@ from .depth_order import DepthOrderRecording, depth_order from .astype import AstypeRecording, astype from .unsigned_to_signed import UnsignedToSignedRecording, unsigned_to_signed -from .silence_artefacts import SilencedArtefactsRecording, silence_artefacts +from .silence_artifacts import SilencedArtifactsRecording, silence_artifacts preprocessers_full_list = [ @@ -80,7 +80,7 @@ DirectionalDerivativeRecording, AstypeRecording, UnsignedToSignedRecording, - SilencedArtefactsRecording, + SilencedArtifactsRecording, ] preprocesser_dict = {pp_class.name: pp_class for pp_class in preprocessers_full_list} From cdd48ec03ee0f9d72f40c12ec8e3be9ae7d4d04d Mon Sep 17 00:00:00 2001 From: Sebastien Date: Mon, 10 Mar 2025 13:20:02 +0100 Subject: [PATCH 10/19] WIP --- .../preprocessing/silence_artifacts.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/src/spikeinterface/preprocessing/silence_artifacts.py b/src/spikeinterface/preprocessing/silence_artifacts.py index 27edf2c32c..f9ea66fb51 100644 --- a/src/spikeinterface/preprocessing/silence_artifacts.py +++ b/src/spikeinterface/preprocessing/silence_artifacts.py @@ -79,19 +79,14 @@ def detect_onsets(recording, detect_threshold=5, **extra_kwargs): if len(onsets) == 0 and len(offsets) == 0: periods.append([]) continue - - if onsets["sample_index"][0] > offsets["sample_index"][0]: - sub_periods += [(0, offsets["sample_index"][0])] - offsets = offsets[1:] + elif len(onsets) > 0 and len(offsets) == 0: + periods.append([(onsets["sample_index"][0], recording.get_num_samples(seg_index))]) + continue for i in range(min(len(onsets), len(offsets))): sub_periods += [(onsets["sample_index"][i], offsets["sample_index"][i])] - if len(onsets) > len(offsets): - sub_periods += [(onsets["sample_index"][0], recording.get_num_samples(seg_index))] - periods.append(sub_periods) - return periods @@ -156,6 +151,9 @@ def __init__( percentage = 100 * total_time / recording.get_num_samples(i) print(f"{percentage}% of segment {i} has been flagged as artifactual") + if 'enveloppe' in random_chunk_kwargs: + random_chunk_kwargs.pop('enveloppe') + SilencedPeriodsRecording.__init__( self, recording, list_periods, mode=mode, noise_levels=noise_levels, seed=seed, **random_chunk_kwargs ) From 0636637fcacaa881e90f73d1f796bf02c04f9181 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 10 Mar 2025 12:24:05 +0000 Subject: [PATCH 11/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/preprocessing/silence_artifacts.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/preprocessing/silence_artifacts.py b/src/spikeinterface/preprocessing/silence_artifacts.py index f9ea66fb51..ed2507430f 100644 --- a/src/spikeinterface/preprocessing/silence_artifacts.py +++ b/src/spikeinterface/preprocessing/silence_artifacts.py @@ -151,9 +151,9 @@ def __init__( percentage = 100 * total_time / recording.get_num_samples(i) print(f"{percentage}% of segment {i} has been flagged as artifactual") - if 'enveloppe' in random_chunk_kwargs: - random_chunk_kwargs.pop('enveloppe') - + if "enveloppe" in random_chunk_kwargs: + random_chunk_kwargs.pop("enveloppe") + SilencedPeriodsRecording.__init__( self, recording, list_periods, mode=mode, noise_levels=noise_levels, seed=seed, **random_chunk_kwargs ) From 39e47706f8a730b91180503858536f6a09e3a1d1 Mon Sep 17 00:00:00 2001 From: Sebastien Date: Thu, 27 Mar 2025 09:21:55 +0100 Subject: [PATCH 12/19] WIP --- .../preprocessing/silence_artifacts.py | 35 +++++++++++++++---- .../preprocessing/silence_periods.py | 5 ++- .../tests/test_silence_artifacts.py | 17 +++++++++ 3 files changed, 47 insertions(+), 10 deletions(-) create mode 100644 src/spikeinterface/preprocessing/tests/test_silence_artifacts.py diff --git a/src/spikeinterface/preprocessing/silence_artifacts.py b/src/spikeinterface/preprocessing/silence_artifacts.py index f9f484fd01..9ea11c985f 100644 --- a/src/spikeinterface/preprocessing/silence_artifacts.py +++ b/src/spikeinterface/preprocessing/silence_artifacts.py @@ -2,12 +2,12 @@ import numpy as np -from spikeinterface.core.core_tools import define_function_from_class +from spikeinterface.core.core_tools import define_function_handling_dict_from_class from spikeinterface.preprocessing.silence_periods import SilencedPeriodsRecording from spikeinterface.preprocessing.rectify import RectifyRecording from spikeinterface.preprocessing.filter_gaussian import GaussianFilterRecording from spikeinterface.core.job_tools import split_job_kwargs, fix_job_kwargs -from spikeinterface.core.core_tools import get_noise_levels +from spikeinterface.core.recording_tools import get_noise_levels from spikeinterface.core.node_pipeline import PeakDetector, base_peak_dtype import numpy as np @@ -28,6 +28,7 @@ def __init__( if noise_levels is None: noise_levels = get_noise_levels(recording, return_scaled=False, **random_chunk_kwargs) self.abs_thresholds = noise_levels * detect_threshold + print(self.abs_thresholds) self._dtype = np.dtype(base_peak_dtype + [("onset", "bool")]) def get_trace_margin(self): @@ -37,7 +38,7 @@ def get_dtype(self): return self._dtype def compute(self, traces, start_frame, end_frame, segment_index, max_margin): - z = (traces - self.abs_thresholds).mean(1) + z = np.median(traces - self.abs_thresholds, 1) threshold_mask = np.diff((z > 0) != 0, axis=0) indices = np.flatnonzero(threshold_mask) local_peaks = np.zeros(indices.size, dtype=self._dtype) @@ -47,7 +48,7 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): return (local_peaks,) -def detect_onsets(recording, detect_threshold=5, **extra_kwargs): +def detect_onsets(recording, detect_threshold=5, min_duration_ms=50, **extra_kwargs): from spikeinterface.core.node_pipeline import ( run_node_pipeline, @@ -66,6 +67,8 @@ def detect_onsets(recording, detect_threshold=5, **extra_kwargs): ) periods = [] + fs = recording.sampling_frequency + max_duration_samples = int(min_duration_ms*fs/1000) num_seg = recording.get_num_segments() for seg_index in range(num_seg): sub_periods = [] @@ -74,17 +77,28 @@ def detect_onsets(recording, detect_threshold=5, **extra_kwargs): onsets = sub_peaks[sub_peaks["onset"]] offsets = sub_peaks[~sub_peaks["onset"]] + onset_time = 0 + offset_time = recording.get_num_samples(seg_index) + + while onset_time < offset_time: + if len(onsets) == 0 and len(offsets) == 0: periods.append([]) continue elif len(onsets) > 0 and len(offsets) == 0: - periods.append([(onsets["sample_index"][0], recording.get_num_samples(seg_index))]) + offset = recording.get_num_samples(seg_index) + if (offset - onsets["sample_index"][0]) > max_duration_samples: + periods.append([( onsets["sample_index"][0], offset)]) continue + max_size = min(len(onsets), len(offsets)) + np.where(onsets[:max_size]) + for i in range(min(len(onsets), len(offsets))): sub_periods += [(onsets["sample_index"][i], offsets["sample_index"][i])] periods.append(sub_periods) + #print(periods) return periods @@ -104,6 +118,8 @@ class SilencedArtifactsRecording(SilencedPeriodsRecording): The threshold to detect artifacts. The threshold is computed as `detect_threshold * noise_level` freq_max : float, default: 20 The maximum frequency for the low pass filter used + min_duration_ms : float, default: 50 + The minimum duration for a threshold crossing to be considered as an artefact. noise_levels : array Noise levels if already computed seed : int | None, default: None @@ -131,6 +147,7 @@ def __init__( detect_threshold=5, verbose=False, freq_max=5.0, + min_duration_ms=50, mode="zeros", noise_levels=None, seed=None, @@ -142,7 +159,10 @@ def __init__( self.enveloppe = GaussianFilterRecording(self.enveloppe, freq_min=None, freq_max=freq_max) if list_periods is None: - list_periods = detect_onsets(self.enveloppe, detect_threshold=detect_threshold, **random_chunk_kwargs) + list_periods = detect_onsets(self.enveloppe, + detect_threshold=detect_threshold, + min_duration_ms=min_duration_ms, + **random_chunk_kwargs) if verbose: for i, periods in enumerate(list_periods): total_time = np.sum([end - start for start, end in periods]) @@ -161,10 +181,11 @@ def __init__( "detect_threshold": detect_threshold, "freq_max": freq_max, "verbose": verbose, + "min_duration_ms": min_duration_ms, "enveloppe": self.enveloppe, } ) # function for API -silence_artifacts = define_function_from_class(source_class=SilencedArtifactsRecording, name="silence_artifacts") +silence_artifacts = define_function_handling_dict_from_class(source_class=SilencedArtifactsRecording, name="silence_artifacts") diff --git a/src/spikeinterface/preprocessing/silence_periods.py b/src/spikeinterface/preprocessing/silence_periods.py index 3f7335937a..7badc12021 100644 --- a/src/spikeinterface/preprocessing/silence_periods.py +++ b/src/spikeinterface/preprocessing/silence_periods.py @@ -3,9 +3,8 @@ import numpy as np from spikeinterface.core.core_tools import define_function_handling_dict_from_class -from .basepreprocessor import BasePreprocessor, BasePreprocessorSegment - -from spikeinterface.core import get_noise_levels +from spikeinterface.preprocessing.basepreprocessor import BasePreprocessor, BasePreprocessorSegment +from spikeinterface.core.recording_tools import get_noise_levels from spikeinterface.core.generate import NoiseGeneratorRecording diff --git a/src/spikeinterface/preprocessing/tests/test_silence_artifacts.py b/src/spikeinterface/preprocessing/tests/test_silence_artifacts.py new file mode 100644 index 0000000000..2c7e6f951b --- /dev/null +++ b/src/spikeinterface/preprocessing/tests/test_silence_artifacts.py @@ -0,0 +1,17 @@ +import pytest + +import numpy as np + +from spikeinterface.core import generate_recording +from spikeinterface.preprocessing import silence_artifacts + + +def test_silence_artifacts(): + # one segment only + rec = generate_recording(durations=[10.0]) + + rec_rmart_mean = silence_artifacts(rec) + + +if __name__ == "__main__": + test_remove_artifacts() From e2139f015cd92fe4993b9e08a57b27ca7790cb90 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 27 Mar 2025 08:24:30 +0000 Subject: [PATCH 13/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../preprocessing/silence_artifacts.py | 22 +++++++++++-------- .../tests/test_silence_artifacts.py | 2 +- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/src/spikeinterface/preprocessing/silence_artifacts.py b/src/spikeinterface/preprocessing/silence_artifacts.py index f256aa83d0..eced5baacd 100644 --- a/src/spikeinterface/preprocessing/silence_artifacts.py +++ b/src/spikeinterface/preprocessing/silence_artifacts.py @@ -67,7 +67,7 @@ def detect_onsets(recording, detect_threshold=5, min_duration_ms=50, **extra_kwa periods = [] fs = recording.sampling_frequency - max_duration_samples = int(min_duration_ms*fs/1000) + max_duration_samples = int(min_duration_ms * fs / 1000) num_seg = recording.get_num_segments() for seg_index in range(num_seg): sub_periods = [] @@ -81,14 +81,14 @@ def detect_onsets(recording, detect_threshold=5, min_duration_ms=50, **extra_kwa while onset_time < offset_time: pass - + if len(onsets) == 0 and len(offsets) == 0: periods.append([]) continue elif len(onsets) > 0 and len(offsets) == 0: offset = recording.get_num_samples(seg_index) if (offset - onsets["sample_index"][0]) > max_duration_samples: - periods.append([( onsets["sample_index"][0], offset)]) + periods.append([(onsets["sample_index"][0], offset)]) continue max_size = min(len(onsets), len(offsets)) @@ -98,7 +98,7 @@ def detect_onsets(recording, detect_threshold=5, min_duration_ms=50, **extra_kwa sub_periods += [(onsets["sample_index"][i], offsets["sample_index"][i])] periods.append(sub_periods) - #print(periods) + # print(periods) return periods @@ -159,10 +159,12 @@ def __init__( self.enveloppe = GaussianFilterRecording(self.enveloppe, freq_min=None, freq_max=freq_max) if list_periods is None: - list_periods = detect_onsets(self.enveloppe, - detect_threshold=detect_threshold, - min_duration_ms=min_duration_ms, - **random_chunk_kwargs) + list_periods = detect_onsets( + self.enveloppe, + detect_threshold=detect_threshold, + min_duration_ms=min_duration_ms, + **random_chunk_kwargs, + ) if verbose: for i, periods in enumerate(list_periods): total_time = np.sum([end - start for start, end in periods]) @@ -188,4 +190,6 @@ def __init__( # function for API -silence_artifacts = define_function_handling_dict_from_class(source_class=SilencedArtifactsRecording, name="silence_artifacts") +silence_artifacts = define_function_handling_dict_from_class( + source_class=SilencedArtifactsRecording, name="silence_artifacts" +) diff --git a/src/spikeinterface/preprocessing/tests/test_silence_artifacts.py b/src/spikeinterface/preprocessing/tests/test_silence_artifacts.py index 2c7e6f951b..8742b3a9f8 100644 --- a/src/spikeinterface/preprocessing/tests/test_silence_artifacts.py +++ b/src/spikeinterface/preprocessing/tests/test_silence_artifacts.py @@ -11,7 +11,7 @@ def test_silence_artifacts(): rec = generate_recording(durations=[10.0]) rec_rmart_mean = silence_artifacts(rec) - + if __name__ == "__main__": test_remove_artifacts() From d1d2097d6f4c474eb18687a897e379c610529b48 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Thu, 10 Apr 2025 21:23:04 +0200 Subject: [PATCH 14/19] Making the detector reproducible --- .../preprocessing/silence_artifacts.py | 35 +++++++++---------- 1 file changed, 16 insertions(+), 19 deletions(-) diff --git a/src/spikeinterface/preprocessing/silence_artifacts.py b/src/spikeinterface/preprocessing/silence_artifacts.py index eced5baacd..3d7b414e4b 100644 --- a/src/spikeinterface/preprocessing/silence_artifacts.py +++ b/src/spikeinterface/preprocessing/silence_artifacts.py @@ -5,6 +5,7 @@ from spikeinterface.core.core_tools import define_function_handling_dict_from_class from spikeinterface.preprocessing.silence_periods import SilencedPeriodsRecording from spikeinterface.preprocessing.rectify import RectifyRecording +from spikeinterface.preprocessing.common_reference import CommonReferenceRecording from spikeinterface.preprocessing.filter_gaussian import GaussianFilterRecording from spikeinterface.core.job_tools import split_job_kwargs, fix_job_kwargs from spikeinterface.core.recording_tools import get_noise_levels @@ -22,11 +23,13 @@ def __init__( recording, detect_threshold=5, noise_levels=None, - random_chunk_kwargs={}, + seed=None, + random_slices_kwargs={}, ): PeakDetector.__init__(self, recording, return_output=True) if noise_levels is None: - noise_levels = get_noise_levels(recording, return_scaled=False, **random_chunk_kwargs) + random_slices_kwargs.update({"seed" : seed}) + noise_levels = get_noise_levels(recording, return_scaled=False, random_slices_kwargs=random_slices_kwargs) self.abs_thresholds = noise_levels * detect_threshold self._dtype = np.dtype(base_peak_dtype + [("onset", "bool")]) @@ -37,8 +40,8 @@ def get_dtype(self): return self._dtype def compute(self, traces, start_frame, end_frame, segment_index, max_margin): - z = np.median(traces - self.abs_thresholds, 1) - threshold_mask = np.diff((z > 0) != 0, axis=0) + z = np.median(traces/self.abs_thresholds, 1) + threshold_mask = np.diff((z > 1) != 0, axis=0) indices = np.flatnonzero(threshold_mask) local_peaks = np.zeros(indices.size, dtype=self._dtype) local_peaks["sample_index"] = indices @@ -69,6 +72,7 @@ def detect_onsets(recording, detect_threshold=5, min_duration_ms=50, **extra_kwa fs = recording.sampling_frequency max_duration_samples = int(min_duration_ms * fs / 1000) num_seg = recording.get_num_segments() + for seg_index in range(num_seg): sub_periods = [] mask = peaks["segment_index"] == 0 @@ -76,12 +80,6 @@ def detect_onsets(recording, detect_threshold=5, min_duration_ms=50, **extra_kwa onsets = sub_peaks[sub_peaks["onset"]] offsets = sub_peaks[~sub_peaks["onset"]] - onset_time = 0 - offset_time = recording.get_num_samples(seg_index) - - while onset_time < offset_time: - pass - if len(onsets) == 0 and len(offsets) == 0: periods.append([]) continue @@ -91,14 +89,11 @@ def detect_onsets(recording, detect_threshold=5, min_duration_ms=50, **extra_kwa periods.append([(onsets["sample_index"][0], offset)]) continue - max_size = min(len(onsets), len(offsets)) - np.where(onsets[:max_size]) - for i in range(min(len(onsets), len(offsets))): sub_periods += [(onsets["sample_index"][i], offsets["sample_index"][i])] periods.append(sub_periods) - # print(periods) + return periods @@ -152,18 +147,20 @@ def __init__( noise_levels=None, seed=None, list_periods=None, - **random_chunk_kwargs, + **random_slices_kwargs, ): self.enveloppe = RectifyRecording(recording) self.enveloppe = GaussianFilterRecording(self.enveloppe, freq_min=None, freq_max=freq_max) + self.enveloppe = CommonReferenceRecording(self.enveloppe) if list_periods is None: list_periods = detect_onsets( self.enveloppe, detect_threshold=detect_threshold, min_duration_ms=min_duration_ms, - **random_chunk_kwargs, + seed=seed, + **random_slices_kwargs, ) if verbose: for i, periods in enumerate(list_periods): @@ -171,11 +168,11 @@ def __init__( percentage = 100 * total_time / recording.get_num_samples(i) print(f"{percentage}% of segment {i} has been flagged as artifactual") - if "enveloppe" in random_chunk_kwargs: - random_chunk_kwargs.pop("enveloppe") + if "enveloppe" in random_slices_kwargs: + random_slices_kwargs.pop("enveloppe") SilencedPeriodsRecording.__init__( - self, recording, list_periods, mode=mode, noise_levels=noise_levels, seed=seed, **random_chunk_kwargs + self, recording, list_periods, mode=mode, noise_levels=noise_levels, seed=seed, **random_slices_kwargs ) self._kwargs.update( From 2f022dc27b90f39318532140c39577f5a4719f12 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 10 Apr 2025 19:24:01 +0000 Subject: [PATCH 15/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/preprocessing/silence_artifacts.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/preprocessing/silence_artifacts.py b/src/spikeinterface/preprocessing/silence_artifacts.py index 3d7b414e4b..990ce78205 100644 --- a/src/spikeinterface/preprocessing/silence_artifacts.py +++ b/src/spikeinterface/preprocessing/silence_artifacts.py @@ -28,7 +28,7 @@ def __init__( ): PeakDetector.__init__(self, recording, return_output=True) if noise_levels is None: - random_slices_kwargs.update({"seed" : seed}) + random_slices_kwargs.update({"seed": seed}) noise_levels = get_noise_levels(recording, return_scaled=False, random_slices_kwargs=random_slices_kwargs) self.abs_thresholds = noise_levels * detect_threshold self._dtype = np.dtype(base_peak_dtype + [("onset", "bool")]) @@ -40,7 +40,7 @@ def get_dtype(self): return self._dtype def compute(self, traces, start_frame, end_frame, segment_index, max_margin): - z = np.median(traces/self.abs_thresholds, 1) + z = np.median(traces / self.abs_thresholds, 1) threshold_mask = np.diff((z > 1) != 0, axis=0) indices = np.flatnonzero(threshold_mask) local_peaks = np.zeros(indices.size, dtype=self._dtype) From 269773695895490921f3734778b3add0b37ac922 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Thu, 10 Apr 2025 21:40:42 +0200 Subject: [PATCH 16/19] Fixing bug in silence_periods --- src/spikeinterface/preprocessing/silence_artifacts.py | 2 +- src/spikeinterface/preprocessing/silence_periods.py | 5 +---- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/preprocessing/silence_artifacts.py b/src/spikeinterface/preprocessing/silence_artifacts.py index 3d7b414e4b..3a27e511ff 100644 --- a/src/spikeinterface/preprocessing/silence_artifacts.py +++ b/src/spikeinterface/preprocessing/silence_artifacts.py @@ -128,7 +128,7 @@ class SilencedArtifactsRecording(SilencedPeriodsRecording): - "noise": The periods are filled with a gaussion noise that has the same variance that the one in the recordings, on a per channel basis - **random_chunk_kwargs : Keyword arguments for `spikeinterface.core.get_random_data_chunk()` function + **random_slices_kwargs : Keyword arguments for `spikeinterface.core.get_random_data_chunk()` function Returns ------- diff --git a/src/spikeinterface/preprocessing/silence_periods.py b/src/spikeinterface/preprocessing/silence_periods.py index 7badc12021..2ece69e2ae 100644 --- a/src/spikeinterface/preprocessing/silence_periods.py +++ b/src/spikeinterface/preprocessing/silence_periods.py @@ -46,12 +46,10 @@ class SilencedPeriodsRecording(BasePreprocessor): def __init__(self, recording, list_periods, mode="zeros", noise_levels=None, seed=None, **random_chunk_kwargs): available_modes = ("zeros", "noise") num_seg = recording.get_num_segments() - if num_seg == 1: - if isinstance(list_periods, (list, np.ndarray)) and np.array(list_periods).ndim == 2: + if isinstance(list_periods, (list, np.ndarray)) and np.array(list_periods).ndim == 1: # when unique segment accept list instead of of list of list/arrays list_periods = [list_periods] - # some checks assert mode in available_modes, f"mode {mode} is not an available mode: {available_modes}" @@ -111,7 +109,6 @@ def __init__(self, parent_recording_segment, periods, mode, noise_generator, seg def get_traces(self, start_frame, end_frame, channel_indices): traces = self.parent_recording_segment.get_traces(start_frame, end_frame, channel_indices) traces = traces.copy() - if len(self.periods) > 0: new_interval = np.array([start_frame, end_frame]) lower_index = np.searchsorted(self.periods[:, 1], new_interval[0]) From 83853e5ddab3a10556114a152ebc7a69a9537c69 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Fri, 11 Apr 2025 09:40:57 +0200 Subject: [PATCH 17/19] Fix --- .../preprocessing/silence_artifacts.py | 37 ++++++++++++------- 1 file changed, 23 insertions(+), 14 deletions(-) diff --git a/src/spikeinterface/preprocessing/silence_artifacts.py b/src/spikeinterface/preprocessing/silence_artifacts.py index 760a3b4f48..ce6a15449d 100644 --- a/src/spikeinterface/preprocessing/silence_artifacts.py +++ b/src/spikeinterface/preprocessing/silence_artifacts.py @@ -68,6 +68,9 @@ def detect_onsets(recording, detect_threshold=5, min_duration_ms=50, **extra_kwa job_name="detect threshold crossings", ) + order = np.lexsort((peaks["sample_index"], peaks["segment_index"])) + peaks = peaks[order] + periods = [] fs = recording.sampling_frequency max_duration_samples = int(min_duration_ms * fs / 1000) @@ -77,20 +80,25 @@ def detect_onsets(recording, detect_threshold=5, min_duration_ms=50, **extra_kwa sub_periods = [] mask = peaks["segment_index"] == 0 sub_peaks = peaks[mask] - onsets = sub_peaks[sub_peaks["onset"]] - offsets = sub_peaks[~sub_peaks["onset"]] - - if len(onsets) == 0 and len(offsets) == 0: - periods.append([]) - continue - elif len(onsets) > 0 and len(offsets) == 0: - offset = recording.get_num_samples(seg_index) - if (offset - onsets["sample_index"][0]) > max_duration_samples: - periods.append([(onsets["sample_index"][0], offset)]) - continue - - for i in range(min(len(onsets), len(offsets))): - sub_periods += [(onsets["sample_index"][i], offsets["sample_index"][i])] + if len(sub_peaks) > 0: + if not sub_peaks["onset"][0]: + local_peaks = np.zeros(1, dtype=np.dtype(base_peak_dtype + [("onset", "bool")])) + local_peaks["sample_index"] = 0 + local_peaks["onset"] = True + sub_peaks = np.hstack((local_peaks, sub_peaks)) + if sub_peaks["onset"][-1]: + local_peaks = np.zeros(1, dtype=np.dtype(base_peak_dtype + [("onset", "bool")])) + local_peaks["sample_index"] = recording.get_num_samples(seg_index) + local_peaks["onset"] = False + sub_peaks = np.hstack((sub_peaks, local_peaks)) + + indices = np.flatnonzero(np.diff(sub_peaks["onset"])) + for i, j in zip(indices[:-1], indices[1:]): + if sub_peaks["onset"][i]: + start = sub_peaks["sample_index"][i] + end = sub_peaks["sample_index"][j] + if end - start > max_duration_samples: + sub_periods.append((start, end)) periods.append(sub_periods) @@ -162,6 +170,7 @@ def __init__( seed=seed, **random_slices_kwargs, ) + if verbose: for i, periods in enumerate(list_periods): total_time = np.sum([end - start for start, end in periods]) From 7dfdd17eb0073db772244c041a6887a9ea29e1ad Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Fri, 11 Apr 2025 10:03:53 +0200 Subject: [PATCH 18/19] Patching --- src/spikeinterface/preprocessing/silence_periods.py | 6 +++--- .../preprocessing/tests/test_silence_artifacts.py | 9 ++++----- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/src/spikeinterface/preprocessing/silence_periods.py b/src/spikeinterface/preprocessing/silence_periods.py index 2ece69e2ae..e27ad821d9 100644 --- a/src/spikeinterface/preprocessing/silence_periods.py +++ b/src/spikeinterface/preprocessing/silence_periods.py @@ -47,8 +47,8 @@ def __init__(self, recording, list_periods, mode="zeros", noise_levels=None, see available_modes = ("zeros", "noise") num_seg = recording.get_num_segments() if num_seg == 1: - if isinstance(list_periods, (list, np.ndarray)) and np.array(list_periods).ndim == 1: - # when unique segment accept list instead of of list of list/arrays + if isinstance(list_periods, (list, np.ndarray)) and np.array(list_periods).ndim == 2: + # when unique segment accept list instead of list of list/arrays list_periods = [list_periods] # some checks assert mode in available_modes, f"mode {mode} is not an available mode: {available_modes}" @@ -109,7 +109,7 @@ def __init__(self, parent_recording_segment, periods, mode, noise_generator, seg def get_traces(self, start_frame, end_frame, channel_indices): traces = self.parent_recording_segment.get_traces(start_frame, end_frame, channel_indices) traces = traces.copy() - if len(self.periods) > 0: + if self.periods.size > 0: new_interval = np.array([start_frame, end_frame]) lower_index = np.searchsorted(self.periods[:, 1], new_interval[0]) upper_index = np.searchsorted(self.periods[:, 0], new_interval[1]) diff --git a/src/spikeinterface/preprocessing/tests/test_silence_artifacts.py b/src/spikeinterface/preprocessing/tests/test_silence_artifacts.py index 8742b3a9f8..dffdc72ae9 100644 --- a/src/spikeinterface/preprocessing/tests/test_silence_artifacts.py +++ b/src/spikeinterface/preprocessing/tests/test_silence_artifacts.py @@ -8,10 +8,9 @@ def test_silence_artifacts(): # one segment only - rec = generate_recording(durations=[10.0]) - - rec_rmart_mean = silence_artifacts(rec) - + rec = generate_recording(durations=[10.0, 10]) + new_rec = silence_artifacts(rec, detect_threshold=5, freq_max=5.0, min_duration_ms=50) + if __name__ == "__main__": - test_remove_artifacts() + test_silence_artifacts() From 3dc4a954a346b4ed911f867b6eee612dcb3ea0ac Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 11 Apr 2025 08:04:23 +0000 Subject: [PATCH 19/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../preprocessing/tests/test_silence_artifacts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/preprocessing/tests/test_silence_artifacts.py b/src/spikeinterface/preprocessing/tests/test_silence_artifacts.py index dffdc72ae9..2baa4bf1b3 100644 --- a/src/spikeinterface/preprocessing/tests/test_silence_artifacts.py +++ b/src/spikeinterface/preprocessing/tests/test_silence_artifacts.py @@ -10,7 +10,7 @@ def test_silence_artifacts(): # one segment only rec = generate_recording(durations=[10.0, 10]) new_rec = silence_artifacts(rec, detect_threshold=5, freq_max=5.0, min_duration_ms=50) - + if __name__ == "__main__": test_silence_artifacts()