Skip to content

Commit a52c128

Browse files
authored
Merge pull request #1310 from yger/silence_interval
Silence interval
2 parents 1e41c25 + 2aa83e1 commit a52c128

File tree

3 files changed

+182
-0
lines changed

3 files changed

+182
-0
lines changed

spikeinterface/preprocessing/preprocessinglist.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
ClipRecording, clip)
1818
from .common_reference import CommonReferenceRecording, common_reference
1919
from .remove_artifacts import RemoveArtifactsRecording, remove_artifacts
20+
from .silence_periods import SilencedPeriodsRecording, silence_periods
2021
from .phase_shift import PhaseShiftRecording, phase_shift
2122
from .zero_channel_pad import ZeroChannelPaddedRecording, zero_channel_pad
2223
from .deepinterpolation import DeepInterpolatedRecording, deepinterpolate
@@ -51,6 +52,7 @@
5152
RectifyRecording,
5253
ClipRecording,
5354
BlankSaturationRecording,
55+
SilencedPeriodsRecording,
5456
RemoveArtifactsRecording,
5557
ZeroChannelPaddedRecording,
5658
DeepInterpolatedRecording,
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
import numpy as np
2+
import scipy.interpolate
3+
4+
from spikeinterface.core.core_tools import define_function_from_class
5+
from .basepreprocessor import BasePreprocessor, BasePreprocessorSegment
6+
7+
from ..core import get_random_data_chunks, get_noise_levels
8+
9+
class SilencedPeriodsRecording(BasePreprocessor):
10+
"""
11+
Silence user-defined periods from recording extractor traces. By default,
12+
periods are zeroed-out (mode = 'zeros'). You can also fill the periods with noise.
13+
Note that both methods assume that traces that are centered around zero.
14+
If this is not the case, make sure you apply a filter or center function prior to
15+
silencing periods.
16+
17+
Parameters
18+
----------
19+
recording: RecordingExtractor
20+
The recording extractor to silance periods
21+
list_periods: list of lists/arrays
22+
One list per segment of tuples (start_frame, end_frame) to silence
23+
24+
mode: str
25+
Determines what periods are replaced by. Can be one of the following:
26+
27+
- 'zeros' (default): Artifacts are replaced by zeros.
28+
29+
- 'noise': The periods are filled with a gaussion noise that has the
30+
same variance that the one in the recordings, on a per channel
31+
basis
32+
**random_chunk_kwargs: Keyword arguments for `spikeinterface.core.get_random_data_chunk()` function
33+
34+
Returns
35+
-------
36+
silence_recording: SilencedPeriodsRecording
37+
The recording extractor after silencing some periods
38+
"""
39+
name = 'silence_periods'
40+
41+
def __init__(self, recording, list_periods, mode='zeros',
42+
**random_chunk_kwargs):
43+
44+
available_modes = ('zeros', 'noise')
45+
num_seg = recording.get_num_segments()
46+
47+
48+
if num_seg == 1:
49+
if isinstance(list_periods, (list, np.ndarray)) and not np.isscalar(list_periods[0]):
50+
# when unique segment accept list instead of of list of list/arrays
51+
list_periods = [list_periods]
52+
53+
# some checks
54+
assert mode in available_modes, f"mode {mode} is not an available mode: {available_modes}"
55+
56+
assert isinstance(list_periods, list), "'list_periods' must be a list (one per segment)"
57+
assert len(list_periods) == num_seg, "'list_periods' must have the same length as the number of segments"
58+
assert all(isinstance(list_periods[i], (list, np.ndarray)) for i in range(num_seg)), \
59+
"Each element of 'list_periods' must be array-like"
60+
61+
for periods in list_periods:
62+
if len(periods) > 0:
63+
assert np.all(np.diff(np.array(periods), axis=1) > 0), "t_stops should be larger than t_starts"
64+
assert np.all(periods[i][1] < periods[i + 1][0] for i in np.arange(len(periods) - 1)), \
65+
"Intervals should not overlap"
66+
67+
if mode in ['noise']:
68+
noise_levels = get_noise_levels(recording, return_scaled=False, concatenated=True, **random_chunk_kwargs)
69+
else:
70+
noise_levels = None
71+
72+
BasePreprocessor.__init__(self, recording)
73+
for seg_index, parent_segment in enumerate(recording._recording_segments):
74+
periods = list_periods[seg_index]
75+
periods = np.asarray(periods, dtype='int64')
76+
periods = np.sort(periods, axis=0)
77+
rec_segment = SilencedPeriodsRecordingSegment(parent_segment, periods, mode, noise_levels)
78+
self.add_recording_segment(rec_segment)
79+
80+
self._kwargs = dict(recording=recording.to_dict(), list_periods=list_periods,
81+
mode=mode, noise_levels=noise_levels)
82+
83+
84+
class SilencedPeriodsRecordingSegment(BasePreprocessorSegment):
85+
86+
def __init__(self, parent_recording_segment, periods, mode, noise_levels):
87+
BasePreprocessorSegment.__init__(self, parent_recording_segment)
88+
self.periods = periods
89+
self.mode = mode
90+
self.noise_levels = noise_levels
91+
92+
def get_traces(self, start_frame, end_frame, channel_indices):
93+
94+
traces = self.parent_recording_segment.get_traces(start_frame, end_frame, channel_indices)
95+
traces = traces.copy()
96+
num_channels = traces.shape[1]
97+
98+
if start_frame is None:
99+
start_frame = 0
100+
if end_frame is None:
101+
end_frame = self.get_num_samples()
102+
103+
if len(self.periods) > 0:
104+
new_interval = np.array([start_frame, end_frame])
105+
lower_index = np.searchsorted(self.periods[:, 1], new_interval[0])
106+
upper_index = np.searchsorted(self.periods[:, 0], new_interval[1])
107+
108+
if upper_index > lower_index:
109+
110+
periods_in_interval = self.periods[lower_index:upper_index]
111+
112+
for period in periods_in_interval:
113+
114+
onset = max(0, period[0] - start_frame)
115+
offset = min(period[1] - start_frame, end_frame)
116+
117+
if self.mode == 'zeros':
118+
traces[onset:offset, :] = 0
119+
elif self.mode == 'noise':
120+
traces[onset:offset, :] = self.noise_levels[channel_indices] * \
121+
np.random.randn(offset - onset, num_channels)
122+
123+
return traces
124+
125+
126+
# function for API
127+
silence_periods = define_function_from_class(source_class=SilencedPeriodsRecording, name="silence_periods")
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import pytest
2+
from pathlib import Path
3+
import shutil
4+
5+
from spikeinterface import set_global_tmp_folder
6+
from spikeinterface.core import generate_recording
7+
8+
from spikeinterface.preprocessing import silence_periods
9+
10+
11+
from spikeinterface.core import get_noise_levels
12+
13+
import numpy as np
14+
15+
16+
if hasattr(pytest, "global_test_folder"):
17+
cache_folder = pytest.global_test_folder / "preprocessing"
18+
else:
19+
cache_folder = Path("cache_folder") / "preprocessing"
20+
21+
set_global_tmp_folder(cache_folder)
22+
23+
def test_silence():
24+
rec = generate_recording()
25+
26+
rec0 = silence_periods(rec, list_periods=[[[0, 1000], [5000, 6000]], []], mode="zeros")
27+
rec0.save(verbose=False)
28+
traces_in0 = rec0.get_traces(segment_index=0, start_frame=0, end_frame=1000)
29+
traces_in1 = rec0.get_traces(segment_index=0, start_frame=5000, end_frame=6000)
30+
traces_out0 = rec0.get_traces(segment_index=0, start_frame=2000, end_frame=3000)
31+
assert np.all(traces_in0 == 0)
32+
assert np.all(traces_in1 == 0)
33+
assert not np.all(traces_out0 == 0)
34+
35+
rec1 = silence_periods(rec, list_periods=[[[0, 1000], [5000, 6000]], []], mode="noise")
36+
rec1 = rec1.save(folder=cache_folder / "rec_w_noise", verbose=False)
37+
noise_levels = get_noise_levels(rec, return_scaled=False)
38+
traces_in0 = rec1.get_traces(segment_index=0, start_frame=0, end_frame=1000)
39+
traces_in1 = rec1.get_traces(segment_index=0, start_frame=5000, end_frame=6000)
40+
assert np.abs((np.std(traces_in0, axis=0) - noise_levels) < 0.1).sum()
41+
assert np.abs((np.std(traces_in1, axis=0) - noise_levels)).sum() < 0.1
42+
43+
traces_mix = rec0.get_traces(segment_index=0, start_frame=900, end_frame=5100)
44+
traces_original = rec.get_traces(segment_index=0, start_frame=900, end_frame=5100)
45+
assert np.all(traces_original[100:-100] == traces_mix[100:-100])
46+
assert np.all(traces_mix[:100] == 0)
47+
assert np.all(traces_mix[-100:] == 0)
48+
assert not np.all(traces_mix[:200] == 0)
49+
assert not np.all(traces_mix[:-200] == 0)
50+
51+
52+
if __name__ == '__main__':
53+
test_silence()

0 commit comments

Comments
 (0)