Skip to content

Commit 50390ed

Browse files
committed
merge with master
2 parents aa3142e + 5d3a509 commit 50390ed

File tree

9 files changed

+251
-68
lines changed

9 files changed

+251
-68
lines changed

doc/api.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ spikeinterface.preprocessing
149149
.. autofunction:: correct_lsb
150150
.. autofunction:: detect_bad_channels
151151
.. autofunction:: filter
152+
.. autofunction:: highpass_filter
152153
.. autofunction:: highpass_spatial_filter
153154
.. autofunction:: interpolate_bad_channels
154155
.. autofunction:: normalize_by_quantile
@@ -159,6 +160,7 @@ spikeinterface.preprocessing
159160
.. autofunction:: scale
160161
.. autofunction:: whiten
161162
.. autofunction:: zero_channel_pad
163+
.. autofunction:: zscore
162164

163165

164166
spikeinterface.postprocessing

doc/modules/core.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -591,8 +591,8 @@ same sampling frequency, number of segments, and number of samples:
591591

592592
.. code-block:: python
593593
594-
recA_4_chans = read_binray('fileA.raw')
595-
recB_4_chans = read_binray('fileB.raw')
594+
recA_4_chans = read_binary('fileA.raw')
595+
recB_4_chans = read_binary('fileB.raw')
596596
rec_8_chans = aggregate_channels([recA_4_chans, recB_4_chans])
597597
598598
We can also aggregate (or stack) multiple sortings on the unit axis using the

spikeinterface/preprocessing/common_reference.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from .basepreprocessor import BasePreprocessor, BasePreprocessorSegment
66
from ..core import get_closest_channels
77

8+
from .filter import fix_dtype
9+
810

911
class CommonReferenceRecording(BasePreprocessor):
1012
"""
@@ -35,6 +37,8 @@ class CommonReferenceRecording(BasePreprocessor):
3537
Use in the local CAR implementation as the selecting annulus (exclude radius, include radius)
3638
verbose: bool
3739
If True, output is verbose
40+
dtype: None or dtype
41+
If None the parent dtype is kept.
3842
3943
Returns
4044
-------
@@ -45,7 +49,7 @@ class CommonReferenceRecording(BasePreprocessor):
4549
name = 'common_reference'
4650

4751
def __init__(self, recording, reference='global', operator='median', groups=None, ref_channel_ids=None,
48-
local_radius=(30, 55), verbose=False):
52+
local_radius=(30, 55), verbose=False, dtype=None):
4953

5054
num_chans = recording.get_num_channels()
5155
neighbors = None
@@ -79,7 +83,8 @@ def __init__(self, recording, reference='global', operator='median', groups=None
7983
neighbors[i] = closest_inds[i, mask]
8084
assert len(neighbors[i]) > 0, "No reference channels available in the local annulus for selection."
8185

82-
BasePreprocessor.__init__(self, recording)
86+
dtype_ = fix_dtype(recording, dtype)
87+
BasePreprocessor.__init__(self, recording, dtype=dtype_)
8388

8489
# tranforms groups (ids) to groups (indices)
8590
if groups is not None:
@@ -92,15 +97,16 @@ def __init__(self, recording, reference='global', operator='median', groups=None
9297
for parent_segment in recording._recording_segments:
9398
rec_segment = CommonReferenceRecordingSegment(parent_segment,
9499
reference, operator, groups, ref_channel_inds, local_radius,
95-
neighbors)
100+
neighbors, dtype_)
96101
self.add_recording_segment(rec_segment)
97102

98103
self._kwargs = dict(recording=recording, reference=reference, groups=groups, operator=operator,
99-
ref_channel_ids=ref_channel_ids, local_radius=local_radius)
104+
ref_channel_ids=ref_channel_ids, local_radius=local_radius, dtype=dtype_.str)
100105

101106

102107
class CommonReferenceRecordingSegment(BasePreprocessorSegment):
103-
def __init__(self, parent_recording_segment, reference, operator, groups, ref_channel_inds, local_radius, neighbors):
108+
def __init__(self, parent_recording_segment, reference, operator, groups, ref_channel_inds, local_radius,
109+
neighbors, dtype):
104110
BasePreprocessorSegment.__init__(self, parent_recording_segment)
105111

106112
self.reference = reference
@@ -110,6 +116,7 @@ def __init__(self, parent_recording_segment, reference, operator, groups, ref_ch
110116
self.local_radius = local_radius
111117
self.neighbors = neighbors
112118
self.temp = None
119+
self.dtype = dtype
113120

114121
if self.operator == 'median':
115122
self.operator_func = lambda x: np.median(x, axis=1, out=self.temp)[:, None]
@@ -119,31 +126,32 @@ def __init__(self, parent_recording_segment, reference, operator, groups, ref_ch
119126
def get_traces(self, start_frame, end_frame, channel_indices):
120127
# need input trace
121128
all_traces = self.parent_recording_segment.get_traces(start_frame, end_frame, slice(None))
129+
all_traces = all_traces.astype(self.dtype)
122130
self.temp = np.zeros((all_traces.shape[0],),dtype=all_traces.dtype)
123131
_channel_indices = np.arange(all_traces.shape[1])
124132
if channel_indices is not None:
125133
_channel_indices = _channel_indices[channel_indices]
126134

127135

128136
if self.reference == 'global':
129-
out_traces = np.zeros((all_traces.shape[0], _channel_indices.size), dtype=all_traces.dtype)
137+
out_traces = np.zeros((all_traces.shape[0], _channel_indices.size), dtype=self.dtype)
130138
for chan_inds, chan_group_inds in self._groups(_channel_indices):
131139
out_inds = np.array([np.where(_channel_indices == i)[0][0] for i in chan_inds])
132140
out_traces[:, out_inds] = all_traces[:, chan_inds] \
133141
- self.operator_func(all_traces[:, chan_group_inds])
134142

135143
elif self.reference == 'single':
136-
out_traces = np.zeros((all_traces.shape[0], _channel_indices.size), dtype=all_traces.dtype)
144+
out_traces = np.zeros((all_traces.shape[0], _channel_indices.size), dtype=self.dtype)
137145
for i, (chan_inds, _) in enumerate(self._groups(_channel_indices)):
138146
out_inds = np.array([np.where(_channel_indices == i)[0][0] for i in chan_inds])
139147
out_traces[:, out_inds] = all_traces[:, chan_inds] \
140148
- self.operator_func(all_traces[:, [self.ref_channel_inds[i]]])
141149

142150
elif self.reference == 'local':
143-
out_traces = np.hstack([
144-
all_traces[:, [chan_ind]] - self.operator_func(all_traces[:, self.neighbors[chan_ind]])
145-
for chan_ind in _channel_indices])
146-
151+
out_traces = np.zeros((all_traces.shape[0], _channel_indices.size), dtype=self.dtype)
152+
for i, chan_ind in enumerate(_channel_indices):
153+
out_traces[:, [i]] = all_traces[:, [chan_ind]] - \
154+
self.operator_func(all_traces[:, self.neighbors[chan_ind]])
147155
return out_traces
148156

149157
def _groups(self, channel_indices):

spikeinterface/preprocessing/normalize_scale.py

Lines changed: 49 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
from .basepreprocessor import BasePreprocessor, BasePreprocessorSegment
66

7+
from .filter import fix_dtype
8+
79
from ..core import get_random_data_chunks
810

911

@@ -50,7 +52,7 @@ class NormalizeByQuantileRecording(BasePreprocessor):
5052
Random seed for reproducibility
5153
dtype: str or np.dtype
5254
The dtype of the output traces. Default "float32"
53-
**random_chunk_kwargs: keyword arguments for `get_random_data_chunks()` function
55+
**random_chunk_kwargs: Keyword arguments for `spikeinterface.core.get_random_data_chunk()` function
5456
5557
Returns
5658
-------
@@ -196,7 +198,7 @@ class CenterRecording(BasePreprocessor):
196198
'median' (default) | 'mean'
197199
dtype: str or np.dtype
198200
The dtype of the output traces. Default "float32"
199-
**random_chunk_kwargs: keyword arguments for `get_random_data_chunks()` function
201+
**random_chunk_kwargs: Keyword arguments for `spikeinterface.core.get_random_data_chunk()` function
200202
201203
Returns
202204
-------
@@ -247,35 +249,63 @@ class ZScoreRecording(BasePreprocessor):
247249
The recording extractor to be centered
248250
mode: str
249251
"median+mad" (default) or "mean+std"
250-
dtype: str or np.dtype
251-
The dtype of the output traces. Default "float32"
252-
**random_chunk_kwargs: keyword arguments for `get_random_data_chunks()` function
252+
dtype: None or dtype
253+
If None the the parent dtype is kept.
254+
For integer dtype a int_scale must be also given.
255+
gain : None or np.array
256+
Pre-computed gain.
257+
offset : None or np.array
258+
Pre-computed offset
259+
int_scale : None or float
260+
Apply a scaling factor to fit the integer range.
261+
This is used when the dtype is an integer, so that the output is scaled.
262+
For example, a value of `int_scale=200` will scale the zscore value to a standard deviation of 200.
263+
**random_chunk_kwargs: Keyword arguments for `spikeinterface.core.get_random_data_chunk()` function
253264
254265
Returns
255266
-------
256267
centered_traces: ScaleRecording
257268
The centered traces recording extractor object
258269
"""
259-
260270
name = "zscore"
261271

262272
def __init__(
263273
self,
264274
recording,
265275
mode="median+mad",
276+
gain=None,
277+
offset=None,
278+
int_scale=None,
266279
dtype="float32",
267280
**random_chunk_kwargs
268281
):
269282

270283
assert mode in ("median+mad", "mean+std")
271284

285+
# fix dtype
286+
dtype_ = fix_dtype(recording, dtype)
287+
288+
if dtype_.kind == 'i':
289+
assert int_scale is not None, 'For recording with dtype=int you must set dtype=float32 OR set a scale'
290+
272291
random_data = get_random_data_chunks(recording, **random_chunk_kwargs)
273292

274-
if mode == "median+mad":
293+
if gain is not None:
294+
assert offset is not None
295+
gain = np.asarray(gain)
296+
offset = np.asarray(offset)
297+
n = recording.get_num_channels()
298+
if gain.ndim == 1:
299+
gain = gain[None, :]
300+
assert gain.shape[1] == n
301+
if offset.ndim == 1:
302+
offset = offset[None, :]
303+
assert offset.shape[1] == n
304+
elif mode == "median+mad":
275305
medians = np.median(random_data, axis=0)
276306
medians = medians[None, :]
277-
mads = np.median(np.abs(random_data - medians), axis=0) / 0.6745
278-
mads = mads[None, :]
307+
mads = np.median(np.abs(random_data - medians), axis=0) / 0.6744897501960817
308+
mads = mads[None, :]
279309
gain = 1 / mads
280310
offset = -medians / mads
281311
else:
@@ -285,6 +315,14 @@ def __init__(
285315
stds = stds[None, :]
286316
gain = 1 / stds
287317
offset = -means / stds
318+
319+
if int_scale is not None:
320+
gain *= int_scale
321+
offset *= int_scale
322+
323+
# convenient to have them here
324+
self.gain = gain
325+
self.offset = offset
288326

289327
BasePreprocessor.__init__(self, recording, dtype=dtype)
290328

@@ -298,6 +336,8 @@ def __init__(
298336
recording=recording,
299337
dtype=np.dtype(self._dtype).str,
300338
mode=mode,
339+
gain=gain.tolist(),
340+
offset=offset.tolist()
301341
)
302342
self._kwargs.update(random_chunk_kwargs)
303343

spikeinterface/preprocessing/preprocessinglist.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
ScaleRecording, scale,
1111
ZScoreRecording, zscore,
1212
CenterRecording, center)
13-
from .whiten import WhitenRecording, whiten
13+
from .whiten import WhitenRecording, whiten, compute_whitening_matrix
1414
from .rectify import RectifyRecording, rectify
1515
from .clip import (
1616
BlankSaturationRecording, blank_staturation,

spikeinterface/preprocessing/tests/test_normalize_scale.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -62,22 +62,30 @@ def test_center():
6262

6363
def test_zscore():
6464
rec = generate_recording()
65-
# print("original")
6665
tr = rec.get_traces(segment_index=0)
67-
# print("medians", np.median(tr, axis=0))
68-
# print("stds", np.std(tr, axis=0))
6966

70-
# print("median+mad")
7167
rec2 = zscore(rec)
7268
tr = rec2.get_traces(segment_index=0)
73-
# print("medians", np.median(tr, axis=0))
74-
# print("stds", np.std(tr, axis=0))
69+
meds = np.median(tr, axis=0)
70+
mads = np.median(np.abs(tr - meds), axis=0) / 0.6744897501960817
71+
assert np.all(np.abs(meds) < 0.01)
72+
assert np.all(np.abs(mads - 1) < 0.01)
73+
assert 'gain' in rec2._kwargs
7574

76-
# print("mean+std")
7775
rec3 = zscore(rec, mode="mean+std")
7876
tr = rec3.get_traces(segment_index=0)
79-
# print("medians", np.median(tr, axis=0))
80-
# print("stds", np.std(tr, axis=0))
77+
assert np.all(np.abs(np.mean(tr, axis=0)) < 0.01)
78+
assert np.all(np.abs(np.std(tr, axis=0) - 1) < 0.01)
79+
80+
rec_int = scale(rec, dtype="int16", gain=100)
81+
with pytest.raises(AssertionError):
82+
rec4 = zscore(rec_int, dtype=None)
83+
rec4 = zscore(rec_int, dtype='float32', mode="mean+std")
84+
rec4 = zscore(rec_int, dtype='int16', int_scale=256, mode="mean+std")
85+
tr = rec4.get_traces(segment_index=0)
86+
assert np.all(np.abs(np.mean(tr, axis=0)) < 1)
87+
assert np.all(np.abs(np.std(tr, axis=0) - 256) < 1)
88+
8189

8290

8391
if __name__ == '__main__':

spikeinterface/preprocessing/tests/test_whiten.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from spikeinterface import set_global_tmp_folder
66
from spikeinterface.core import generate_recording
77

8-
from spikeinterface.preprocessing import whiten, scale
8+
from spikeinterface.preprocessing import whiten, scale, compute_whitening_matrix
99

1010
if hasattr(pytest, "global_test_folder"):
1111
cache_folder = pytest.global_test_folder / "preprocessing"
@@ -16,7 +16,25 @@
1616

1717

1818
def test_whiten():
19-
rec = generate_recording()
19+
rec = generate_recording(num_channels=4)
20+
21+
print(rec.get_channel_locations())
22+
random_chunk_kwargs={}
23+
W, M = compute_whitening_matrix(rec, 'global', random_chunk_kwargs, apply_mean=False,
24+
radius_um=None, eps=1e-8)
25+
print(W)
26+
print(M)
27+
28+
with pytest.raises(AssertionError):
29+
W, M = compute_whitening_matrix(rec, 'local', random_chunk_kwargs, apply_mean=False,
30+
radius_um=None, eps=1e-8)
31+
W, M = compute_whitening_matrix(rec, 'local', random_chunk_kwargs, apply_mean=False,
32+
radius_um=25, eps=1e-8)
33+
# W must be sparse
34+
np.sum(W==0) == 6
35+
36+
37+
2038

2139
rec2 = whiten(rec)
2240
rec2.save(verbose=False)
@@ -32,6 +50,13 @@ def test_whiten():
3250
np.testing.assert_array_equal(rec3.get_traces(segment_index=0),
3351
rec_par.get_traces(segment_index=0))
3452

53+
with pytest.raises(AssertionError):
54+
rec4 = whiten(rec_int, dtype=None)
55+
rec4 = whiten(rec_int, dtype=None, int_scale=256)
56+
assert rec4.get_dtype() == "int16"
57+
assert rec4._kwargs['M'] is None
58+
59+
3560

3661
if __name__ == '__main__':
3762
test_whiten()

0 commit comments

Comments
 (0)