Skip to content

Commit 2aa83e1

Browse files
committed
Unify params and extend tests
1 parent 50390ed commit 2aa83e1

File tree

2 files changed

+33
-26
lines changed

2 files changed

+33
-26
lines changed

spikeinterface/preprocessing/silence_periods.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,7 @@ class SilencedPeriodsRecording(BasePreprocessor):
2929
- 'noise': The periods are filled with a gaussion noise that has the
3030
same variance that the one in the recordings, on a per channel
3131
basis
32-
**random_chunk_kwargs
33-
Random seed for random chunk, by default None
32+
**random_chunk_kwargs: Keyword arguments for `spikeinterface.core.get_random_data_chunk()` function
3433
3534
Returns
3635
-------
@@ -52,6 +51,8 @@ def __init__(self, recording, list_periods, mode='zeros',
5251
list_periods = [list_periods]
5352

5453
# some checks
54+
assert mode in available_modes, f"mode {mode} is not an available mode: {available_modes}"
55+
5556
assert isinstance(list_periods, list), "'list_periods' must be a list (one per segment)"
5657
assert len(list_periods) == num_seg, "'list_periods' must have the same length as the number of segments"
5758
assert all(isinstance(list_periods[i], (list, np.ndarray)) for i in range(num_seg)), \
@@ -60,9 +61,8 @@ def __init__(self, recording, list_periods, mode='zeros',
6061
for periods in list_periods:
6162
if len(periods) > 0:
6263
assert np.all(np.diff(np.array(periods), axis=1) > 0), "t_stops should be larger than t_starts"
63-
assert all(e < s for (_, e), (s, _) in zip(periods, periods[1:])), "Intervals should not overlap"
64-
65-
assert mode in available_modes, f"mode {mode} is not an available mode: {available_modes}"
64+
assert np.all(periods[i][1] < periods[i + 1][0] for i in np.arange(len(periods) - 1)), \
65+
"Intervals should not overlap"
6666

6767
if mode in ['noise']:
6868
noise_levels = get_noise_levels(recording, return_scaled=False, concatenated=True, **random_chunk_kwargs)
@@ -73,7 +73,7 @@ def __init__(self, recording, list_periods, mode='zeros',
7373
for seg_index, parent_segment in enumerate(recording._recording_segments):
7474
periods = list_periods[seg_index]
7575
periods = np.asarray(periods, dtype='int64')
76-
periods = np.sort(self.periods, axis=0)
76+
periods = np.sort(periods, axis=0)
7777
rec_segment = SilencedPeriodsRecordingSegment(parent_segment, periods, mode, noise_levels)
7878
self.add_recording_segment(rec_segment)
7979

@@ -84,7 +84,6 @@ def __init__(self, recording, list_periods, mode='zeros',
8484
class SilencedPeriodsRecordingSegment(BasePreprocessorSegment):
8585

8686
def __init__(self, parent_recording_segment, periods, mode, noise_levels):
87-
8887
BasePreprocessorSegment.__init__(self, parent_recording_segment)
8988
self.periods = periods
9089
self.mode = mode
@@ -103,22 +102,23 @@ def get_traces(self, start_frame, end_frame, channel_indices):
103102

104103
if len(self.periods) > 0:
105104
new_interval = np.array([start_frame, end_frame])
106-
lower_index = np.searchsorted(self.periods[:,1], new_interval[0])
107-
upper_index = np.searchsorted(self.periods[:,0], new_interval[1])
105+
lower_index = np.searchsorted(self.periods[:, 1], new_interval[0])
106+
upper_index = np.searchsorted(self.periods[:, 0], new_interval[1])
108107

109108
if upper_index > lower_index:
110109

111-
intersection = self.periods[lower_index:upper_index]
110+
periods_in_interval = self.periods[lower_index:upper_index]
112111

113-
for i in intersection:
112+
for period in periods_in_interval:
114113

115-
onset = max(0, i[0] - start_frame)
116-
offset = min(i[1] - start_frame, end_frame)
114+
onset = max(0, period[0] - start_frame)
115+
offset = min(period[1] - start_frame, end_frame)
117116

118117
if self.mode == 'zeros':
119118
traces[onset:offset, :] = 0
120119
elif self.mode == 'noise':
121-
traces[onset:offset, :] = self.noise_levels[channel_indices] * np.random.randn(offset-onset, num_channels)
120+
traces[onset:offset, :] = self.noise_levels[channel_indices] * \
121+
np.random.randn(offset - onset, num_channels)
122122

123123
return traces
124124

spikeinterface/preprocessing/tests/test_silence.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,21 +25,28 @@ def test_silence():
2525

2626
rec0 = silence_periods(rec, list_periods=[[[0, 1000], [5000, 6000]], []], mode="zeros")
2727
rec0.save(verbose=False)
28-
traces0 = rec0.get_traces(segment_index=0, start_frame=0, end_frame=1000)
29-
traces1 = rec0.get_traces(segment_index=0, start_frame=5000, end_frame=6000)
30-
assert np.all(traces0 == 0) and np.all(traces1 == 0)
28+
traces_in0 = rec0.get_traces(segment_index=0, start_frame=0, end_frame=1000)
29+
traces_in1 = rec0.get_traces(segment_index=0, start_frame=5000, end_frame=6000)
30+
traces_out0 = rec0.get_traces(segment_index=0, start_frame=2000, end_frame=3000)
31+
assert np.all(traces_in0 == 0)
32+
assert np.all(traces_in1 == 0)
33+
assert not np.all(traces_out0 == 0)
3134

3235
rec1 = silence_periods(rec, list_periods=[[[0, 1000], [5000, 6000]], []], mode="noise")
33-
rec1.save(verbose=False)
36+
rec1 = rec1.save(folder=cache_folder / "rec_w_noise", verbose=False)
3437
noise_levels = get_noise_levels(rec, return_scaled=False)
35-
traces0 = rec1.get_traces(segment_index=0, start_frame=0, end_frame=1000)
36-
traces1 = rec1.get_traces(segment_index=0, start_frame=5000, end_frame=6000)
37-
assert np.abs((np.std(traces0, axis=0) - noise_levels) < 0.1).sum() and np.abs((np.std(traces1, axis=0) - noise_levels)).sum() < 0.1
38-
39-
traces0 = rec0.get_traces(segment_index=0, start_frame=900, end_frame=5100)
40-
traces = rec.get_traces(segment_index=0, start_frame=900, end_frame=5100)
41-
assert np.all(traces[100:-100] == traces0[100:-100]) and np.all(traces0[:100] == 0) and np.all(traces0[-100:] == 0)
42-
38+
traces_in0 = rec1.get_traces(segment_index=0, start_frame=0, end_frame=1000)
39+
traces_in1 = rec1.get_traces(segment_index=0, start_frame=5000, end_frame=6000)
40+
assert np.abs((np.std(traces_in0, axis=0) - noise_levels) < 0.1).sum()
41+
assert np.abs((np.std(traces_in1, axis=0) - noise_levels)).sum() < 0.1
42+
43+
traces_mix = rec0.get_traces(segment_index=0, start_frame=900, end_frame=5100)
44+
traces_original = rec.get_traces(segment_index=0, start_frame=900, end_frame=5100)
45+
assert np.all(traces_original[100:-100] == traces_mix[100:-100])
46+
assert np.all(traces_mix[:100] == 0)
47+
assert np.all(traces_mix[-100:] == 0)
48+
assert not np.all(traces_mix[:200] == 0)
49+
assert not np.all(traces_mix[:-200] == 0)
4350

4451

4552
if __name__ == '__main__':

0 commit comments

Comments
 (0)