Skip to content

Commit ba47792

Browse files
authored
Merge pull request #1495 from alejoe91/sync-973-to-main
Sync #973 to main
2 parents 52c0657 + ad58775 commit ba47792

File tree

4 files changed

+146
-0
lines changed

4 files changed

+146
-0
lines changed

doc/api.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ spikeinterface.preprocessing
149149
.. autofunction:: correct_lsb
150150
.. autofunction:: detect_bad_channels
151151
.. autofunction:: filter
152+
.. autofunction:: gaussian_bandpass_filter
152153
.. autofunction:: highpass_filter
153154
.. autofunction:: highpass_spatial_filter
154155
.. autofunction:: interpolate_bad_channels
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
from typing import Iterable, Union
2+
import numpy as np
3+
from scipy.stats import norm
4+
from spikeinterface.core import BaseRecording, BaseRecordingSegment, get_chunk_with_margin
5+
from spikeinterface.core.core_tools import define_function_from_class
6+
from .basepreprocessor import BasePreprocessor, BasePreprocessorSegment
7+
8+
9+
class GaussianBandpassFilterRecording(BasePreprocessor):
10+
"""
11+
Class for performing a bandpass gaussian filtering/smoothing on a recording.
12+
This is done by a convolution with a Gaussian kernel, which acts as a lowpass-filter.
13+
The highpass-filter can be computed by subtracting the result.
14+
15+
Here, the bandpass is computed in the Fourier domain to accelerate the computation.
16+
17+
Parameters
18+
----------
19+
recording: BaseRecording
20+
The recording extractor to be filtered.
21+
freq_min: float
22+
The lower frequency cutoff for the bandpass filter.
23+
freq_max: float
24+
The higher frequency cutoff for the bandpass filter.
25+
26+
Returns
27+
-------
28+
gaussian_bandpass_filtered_recording: GaussianBandpassFilterRecording
29+
The filtered recording extractor object.
30+
"""
31+
name = 'gaussian_bandpass_filter'
32+
33+
def __init__(self, recording: BaseRecording, freq_min: float = 300., freq_max: float = 5000.):
34+
sf = recording.sampling_frequency
35+
BasePreprocessor.__init__(self, recording)
36+
self.annotate(is_filtered=True)
37+
38+
for parent_segment in recording._recording_segments:
39+
self.add_recording_segment(GaussianFilterRecordingSegment(parent_segment, freq_min, freq_max))
40+
41+
self._kwargs = {'recording': recording.to_dict(), 'freq_min': freq_min, 'freq_max': freq_max}
42+
43+
44+
class GaussianFilterRecordingSegment(BasePreprocessorSegment):
45+
46+
def __init__(self, parent_recording_segment: BaseRecordingSegment, freq_min: float, freq_max: float):
47+
BasePreprocessorSegment.__init__(self, parent_recording_segment)
48+
49+
self.freq_min = freq_min
50+
self.freq_max = freq_max
51+
self.cached_gaussian = dict()
52+
53+
sf = parent_recording_segment.sampling_frequency
54+
low_sigma = sf / (2*np.pi * freq_min)
55+
high_sigma = sf / (2*np.pi * freq_max)
56+
self.margin = int(max(low_sigma, high_sigma) * 6. + 1)
57+
58+
def get_traces(self, start_frame: Union[int, None] = None, end_frame: Union[int, None] = None,
59+
channel_indices: Union[Iterable, None] = None):
60+
traces, left_margin, right_margin = get_chunk_with_margin(self.parent_recording_segment, start_frame,
61+
end_frame, channel_indices, self.margin)
62+
dtype = traces.dtype
63+
64+
traces_fft = np.fft.fft(traces, axis=0)
65+
gauss_low = self._create_gaussian(traces.shape[0], self.freq_min)
66+
gauss_high = self._create_gaussian(traces.shape[0], self.freq_max)
67+
68+
filtered_fft = traces_fft * (gauss_high - gauss_low)[:, None]
69+
filtered_traces = np.real(np.fft.ifft(filtered_fft, axis=0))
70+
71+
if right_margin > 0:
72+
return filtered_traces[left_margin : -right_margin, :].astype(dtype)
73+
else:
74+
return filtered_traces[left_margin:, :].astype(dtype)
75+
76+
def _create_gaussian(self, N: int, cutoff_f: float):
77+
if cutoff_f in self.cached_gaussian and N in self.cached_gaussian[cutoff_f]:
78+
return self.cached_gaussian[cutoff_f][N]
79+
80+
sf = self.parent_recording_segment.sampling_frequency
81+
faxis = np.fft.fftfreq(N, d=1/sf)
82+
83+
if cutoff_f > sf / 8: # The Fourier transform of a Gaussian with a very low sigma isn't a Gaussian.
84+
sigma = sf / (2*np.pi * cutoff_f)
85+
limit = int(round(6*sigma)) + 1
86+
xaxis = np.arange(-limit, limit+1) / sigma
87+
gaussian = norm.pdf(xaxis) / sigma
88+
gaussian = np.abs(np.fft.fft(gaussian, n=N))
89+
else:
90+
gaussian = norm.pdf(faxis / cutoff_f) * np.sqrt(2*np.pi)
91+
92+
if cutoff_f not in self.cached_gaussian:
93+
self.cached_gaussian[cutoff_f] = dict()
94+
self.cached_gaussian[cutoff_f][N] = gaussian
95+
96+
return gaussian
97+
98+
gaussian_bandpass_filter = define_function_from_class(source_class=GaussianBandpassFilterRecording, name="gaussian_filter")

src/spikeinterface/preprocessing/preprocessinglist.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
NotchFilterRecording, notch_filter,
66
HighpassFilterRecording, highpass_filter,
77
)
8+
from .filter_gaussian import (GaussianBandpassFilterRecording, gaussian_bandpass_filter)
89
from .normalize_scale import (
910
NormalizeByQuantileRecording, normalize_by_quantile,
1011
ScaleRecording, scale,
@@ -34,6 +35,7 @@
3435
BandpassFilterRecording,
3536
HighpassFilterRecording,
3637
NotchFilterRecording,
38+
GaussianBandpassFilterRecording,
3739

3840
# gain offset stuff
3941
NormalizeByQuantileRecording,
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import numpy as np
2+
import pytest
3+
from pathlib import Path
4+
from spikeinterface.core import load_extractor, set_global_tmp_folder
5+
from spikeinterface.core.testing import check_recordings_equal
6+
from spikeinterface.core.generate import generate_recording
7+
from spikeinterface.preprocessing import gaussian_bandpass_filter
8+
9+
10+
if hasattr(pytest, "global_test_folder"):
11+
cache_folder = pytest.global_test_folder / "preprocessing" / "gaussian_bandpass_filter"
12+
else:
13+
cache_folder = Path("cache_folder") / "preprocessing" / "gaussian_bandpass_filter"
14+
15+
set_global_tmp_folder(cache_folder)
16+
cache_folder.mkdir(parents=True, exist_ok=True)
17+
18+
19+
def test_filter_gaussian():
20+
recording = generate_recording(num_channels=3)
21+
recording.annotate(is_filtered=True)
22+
recording = recording.save(folder=cache_folder / "recording")
23+
24+
rec_filtered = gaussian_bandpass_filter(recording)
25+
26+
assert rec_filtered.dtype == recording.dtype
27+
assert rec_filtered.get_traces(segment_index=0, end_frame=100).dtype == rec_filtered.dtype
28+
assert rec_filtered.get_traces(segment_index=0, end_frame=600).shape == (600, 3)
29+
assert rec_filtered.get_traces(segment_index=0, start_frame=100, end_frame=600).shape == (500, 3)
30+
assert rec_filtered.get_traces(segment_index=1, start_frame=rec_filtered.get_num_frames(1) - 200).shape == (200, 3)
31+
32+
# Check dumpability
33+
saved_loaded = load_extractor(rec_filtered.to_dict())
34+
check_recordings_equal(rec_filtered, saved_loaded, return_scaled=False)
35+
36+
saved_1job = rec_filtered.save(folder=cache_folder / "1job")
37+
saved_2job = rec_filtered.save(folder=cache_folder / "2job", n_jobs=2, chunk_duration='1s')
38+
39+
for seg_idx in range(rec_filtered.get_num_segments()):
40+
original_trace = rec_filtered.get_traces(seg_idx)
41+
saved1_trace = saved_1job.get_traces(seg_idx)
42+
saved2_trace = saved_2job.get_traces(seg_idx)
43+
44+
assert np.allclose(original_trace[60:-60], saved1_trace[60:-60], rtol=1e-3, atol=1e-3)
45+
assert np.allclose(original_trace[60:-60], saved2_trace[60:-60], rtol=1e-3, atol=1e-3)

0 commit comments

Comments
 (0)