|
| 1 | +from typing import Iterable, Union |
| 2 | +import numpy as np |
| 3 | +from scipy.stats import norm |
| 4 | +from spikeinterface.core import BaseRecording, BaseRecordingSegment, get_chunk_with_margin |
| 5 | +from spikeinterface.core.core_tools import define_function_from_class |
| 6 | +from .basepreprocessor import BasePreprocessor, BasePreprocessorSegment |
| 7 | + |
| 8 | + |
| 9 | +class GaussianBandpassFilterRecording(BasePreprocessor): |
| 10 | + """ |
| 11 | + Class for performing a bandpass gaussian filtering/smoothing on a recording. |
| 12 | + This is done by a convolution with a Gaussian kernel, which acts as a lowpass-filter. |
| 13 | + The highpass-filter can be computed by subtracting the result. |
| 14 | +
|
| 15 | + Here, the bandpass is computed in the Fourier domain to accelerate the computation. |
| 16 | +
|
| 17 | + Parameters |
| 18 | + ---------- |
| 19 | + recording: BaseRecording |
| 20 | + The recording extractor to be filtered. |
| 21 | + freq_min: float |
| 22 | + The lower frequency cutoff for the bandpass filter. |
| 23 | + freq_max: float |
| 24 | + The higher frequency cutoff for the bandpass filter. |
| 25 | +
|
| 26 | + Returns |
| 27 | + ------- |
| 28 | + gaussian_bandpass_filtered_recording: GaussianBandpassFilterRecording |
| 29 | + The filtered recording extractor object. |
| 30 | + """ |
| 31 | + name = 'gaussian_bandpass_filter' |
| 32 | + |
| 33 | + def __init__(self, recording: BaseRecording, freq_min: float = 300., freq_max: float = 5000.): |
| 34 | + sf = recording.sampling_frequency |
| 35 | + BasePreprocessor.__init__(self, recording) |
| 36 | + self.annotate(is_filtered=True) |
| 37 | + |
| 38 | + for parent_segment in recording._recording_segments: |
| 39 | + self.add_recording_segment(GaussianFilterRecordingSegment(parent_segment, freq_min, freq_max)) |
| 40 | + |
| 41 | + self._kwargs = {'recording': recording.to_dict(), 'freq_min': freq_min, 'freq_max': freq_max} |
| 42 | + |
| 43 | + |
| 44 | +class GaussianFilterRecordingSegment(BasePreprocessorSegment): |
| 45 | + |
| 46 | + def __init__(self, parent_recording_segment: BaseRecordingSegment, freq_min: float, freq_max: float): |
| 47 | + BasePreprocessorSegment.__init__(self, parent_recording_segment) |
| 48 | + |
| 49 | + self.freq_min = freq_min |
| 50 | + self.freq_max = freq_max |
| 51 | + self.cached_gaussian = dict() |
| 52 | + |
| 53 | + sf = parent_recording_segment.sampling_frequency |
| 54 | + low_sigma = sf / (2*np.pi * freq_min) |
| 55 | + high_sigma = sf / (2*np.pi * freq_max) |
| 56 | + self.margin = int(max(low_sigma, high_sigma) * 6. + 1) |
| 57 | + |
| 58 | + def get_traces(self, start_frame: Union[int, None] = None, end_frame: Union[int, None] = None, |
| 59 | + channel_indices: Union[Iterable, None] = None): |
| 60 | + traces, left_margin, right_margin = get_chunk_with_margin(self.parent_recording_segment, start_frame, |
| 61 | + end_frame, channel_indices, self.margin) |
| 62 | + dtype = traces.dtype |
| 63 | + |
| 64 | + traces_fft = np.fft.fft(traces, axis=0) |
| 65 | + gauss_low = self._create_gaussian(traces.shape[0], self.freq_min) |
| 66 | + gauss_high = self._create_gaussian(traces.shape[0], self.freq_max) |
| 67 | + |
| 68 | + filtered_fft = traces_fft * (gauss_high - gauss_low)[:, None] |
| 69 | + filtered_traces = np.real(np.fft.ifft(filtered_fft, axis=0)) |
| 70 | + |
| 71 | + if right_margin > 0: |
| 72 | + return filtered_traces[left_margin : -right_margin, :].astype(dtype) |
| 73 | + else: |
| 74 | + return filtered_traces[left_margin:, :].astype(dtype) |
| 75 | + |
| 76 | + def _create_gaussian(self, N: int, cutoff_f: float): |
| 77 | + if cutoff_f in self.cached_gaussian and N in self.cached_gaussian[cutoff_f]: |
| 78 | + return self.cached_gaussian[cutoff_f][N] |
| 79 | + |
| 80 | + sf = self.parent_recording_segment.sampling_frequency |
| 81 | + faxis = np.fft.fftfreq(N, d=1/sf) |
| 82 | + |
| 83 | + if cutoff_f > sf / 8: # The Fourier transform of a Gaussian with a very low sigma isn't a Gaussian. |
| 84 | + sigma = sf / (2*np.pi * cutoff_f) |
| 85 | + limit = int(round(6*sigma)) + 1 |
| 86 | + xaxis = np.arange(-limit, limit+1) / sigma |
| 87 | + gaussian = norm.pdf(xaxis) / sigma |
| 88 | + gaussian = np.abs(np.fft.fft(gaussian, n=N)) |
| 89 | + else: |
| 90 | + gaussian = norm.pdf(faxis / cutoff_f) * np.sqrt(2*np.pi) |
| 91 | + |
| 92 | + if cutoff_f not in self.cached_gaussian: |
| 93 | + self.cached_gaussian[cutoff_f] = dict() |
| 94 | + self.cached_gaussian[cutoff_f][N] = gaussian |
| 95 | + |
| 96 | + return gaussian |
| 97 | + |
| 98 | +gaussian_bandpass_filter = define_function_from_class(source_class=GaussianBandpassFilterRecording, name="gaussian_filter") |
0 commit comments