Skip to content

[Proposal] Generic preprocessing function for arbitrary preprocessing steps #2813

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions src/spikeinterface/preprocessing/generic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from functools import partial

from spikeinterface.preprocessing.basepreprocessor import BasePreprocessor, BasePreprocessorSegment
from spikeinterface.core.core_tools import define_function_from_class


class GenericPreprocessor(BasePreprocessor):
def __init__(self, recording, function, **function_kwargs):
super().__init__(recording)
self._serializability["json"] = False

# Heavy computation can be done at the __init__ if needed
self.function_to_apply = partial(function, **function_kwargs)

# Initialize segments
for segment in recording._recording_segments:
processed_segment = GenericPreprocessorSegment(segment, self.function_to_apply)
self.add_recording_segment(processed_segment)

self._kwargs = {"recording": recording, "func": function}
self._kwargs.update(**function_kwargs)


class GenericPreprocessorSegment(BasePreprocessorSegment):
def __init__(self, parent_segment, function_to_apply):
super().__init__(parent_segment)
self.function_to_apply = function_to_apply # Function to apply to the traces

def get_traces(self, start_frame, end_frame, channel_indices):
# Fetch the traces from the parent segment
traces = self.parent_recording_segment.get_traces(start_frame, end_frame, channel_indices)
# Apply the function to the traces
return self.function_to_apply(traces)


generic_preprocessor = define_function_from_class(GenericPreprocessor, name="generic_preprocessor")
4 changes: 3 additions & 1 deletion src/spikeinterface/preprocessing/preprocessinglist.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

### PREPROCESSORS ###


from .resample import ResampleRecording, resample
from .decimate import DecimateRecording, decimate
from .filter import (
Expand Down Expand Up @@ -43,7 +45,7 @@
from .depth_order import DepthOrderRecording, depth_order
from .astype import AstypeRecording, astype
from .unsigned_to_signed import UnsignedToSignedRecording, unsigned_to_signed

from .generic import GenericPreprocessor, generic_preprocessor

preprocessers_full_list = [
# filter stuff
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import numpy as np
import pytest

from spikeinterface.core.generate import generate_recording
from spikeinterface.preprocessing import GenericPreprocessor


def test_basic_use():

recording = generate_recording(num_channels=4, durations=[1.0])
recording = recording.rename_channels(["a", "b", "c", "d"])
function = np.mean # function to apply to the traces

# Initialize the preprocessor
preprocessor = GenericPreprocessor(recording, function)

traces = preprocessor.get_traces(channel_ids=["a", "d"])
expected_traces = np.mean(recording.get_traces(channel_ids=["a", "d"]))

np.testing.assert_allclose(traces, expected_traces)
Loading