Skip to content

Commit 8de09d7

Browse files
authored
Merge branch 'master' into merge_ap_lfp_neuropix
2 parents c163acc + ff91f9a commit 8de09d7

33 files changed

+899
-196
lines changed

.github/workflows/full-test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ jobs:
154154
python ./.github/build_job_summary.py report_core.txt >> $GITHUB_STEP_SUMMARY
155155
rm report_core.txt
156156
- name: Test extractors
157-
if: ${{ steps.modules-changed.outputs.EXTRACTORS_CHANGED == 'true' }}
157+
if: ${{ steps.modules-changed.outputs.EXTRACTORS_CHANGED == 'true' || steps.modules-changed.outputs.CORE_CHANGED == 'true' }}
158158
run: |
159159
source ~/test_env/bin/activate
160160
pytest -m extractors -vv -ra --durations=0 --durations-min=0.001 | tee report_extractors.txt; test ${PIPESTATUS[0]} -eq 0 || exit 1

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ dependencies = [
2525
"threadpoolctl",
2626
"tqdm",
2727
"probeinterface>=0.2.16",
28-
"psutil",
2928
]
3029

3130
[build-system]

spikeinterface/core/baserecording.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,9 @@ def __repr__(self):
4343
nchan = self.get_num_channels()
4444
sf_khz = self.get_sampling_frequency() / 1000.
4545
duration = self.get_total_duration()
46+
dtype = self.get_dtype()
4647
memory_size = self.get_memory_size()
47-
txt = f"{clsname}: {nchan} channels - {nseg} segments - {sf_khz:0.1f}kHz - {duration:0.3f}s - {memory_size}"
48+
txt = f"{clsname}: {nchan} channels - {nseg} segments - {sf_khz:0.1f}kHz - {duration:0.3f}s - {dtype} type - {memory_size}"
4849
if 'file_paths' in self._kwargs:
4950
txt += '\n file_paths: {}'.format(self._kwargs['file_paths'])
5051
if 'file_path' in self._kwargs:

spikeinterface/core/core_tools.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -793,8 +793,10 @@ def recursive_path_modifier(d, func, target='path', copy=True):
793793
else:
794794
for k, v in d.items():
795795
if target in k:
796-
# paths can be str or list of str
797-
if isinstance(v, str):
796+
# paths can be str or list of str or None
797+
if v is None:
798+
continue
799+
if isinstance(v, (str, Path)):
798800
dc[k] =func(v)
799801
elif isinstance(v, list):
800802
dc[k] = [func(e) for e in v]
@@ -815,13 +817,32 @@ def recursive_key_finder(d, key):
815817

816818
def convert_bytes_to_str(byte_value:int ) -> str:
817819
"""
818-
Converts a number of bytes to a value in either KiB, MiB, GiB, or TiB.
820+
Convert a number of bytes to a human-readable string with an appropriate unit.
819821
820-
Args:
821-
byte_value (int): The number of bytes to convert.
822+
This function converts a given number of bytes into a human-readable string
823+
representing the value in either bytes (B), kibibytes (KiB), mebibytes (MiB),
824+
gibibytes (GiB), or tebibytes (TiB). The function uses the IEC binary prefixes
825+
(1 KiB = 1024 B, 1 MiB = 1024 KiB, etc.) to determine the appropriate unit.
822826
823-
Returns:
824-
str: The converted value with the appropriate unit (KiB, MiB, GiB, or TiB).
827+
Parameters
828+
----------
829+
byte_value : int
830+
The number of bytes to convert.
831+
832+
Returns
833+
-------
834+
str
835+
The converted value as a formatted string with two decimal places,
836+
followed by a space and the appropriate unit (B, KiB, MiB, GiB, or TiB).
837+
838+
Examples
839+
--------
840+
>>> convert_bytes_to_str(1024)
841+
'1.00 KiB'
842+
>>> convert_bytes_to_str(1048576)
843+
'1.00 MiB'
844+
>>> convert_bytes_to_str(45056)
845+
'43.99 KiB'
825846
"""
826847
suffixes = ['B', 'KiB', 'MiB', 'GiB', 'TiB']
827848
i = 0

spikeinterface/core/recording_tools.py

Lines changed: 138 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,15 @@
11
import numpy as np
22

33

4-
def get_random_data_chunks(recording, return_scaled=False, num_chunks_per_segment=20,
5-
chunk_size=10000, concatenated=True, seed=0, margin_frames=0):
4+
def get_random_data_chunks(
5+
recording,
6+
return_scaled=False,
7+
num_chunks_per_segment=20,
8+
chunk_size=10000,
9+
concatenated=True,
10+
seed=0,
11+
margin_frames=0,
12+
):
613
"""
714
Exctract random chunks across segments
815
@@ -31,22 +38,30 @@ def get_random_data_chunks(recording, return_scaled=False, num_chunks_per_segmen
3138
# Should be done by changing kwargs with total_num_chunks=XXX and total_duration=YYYY
3239
# And randomize the number of chunk per segment weighted by segment duration
3340

34-
# check chunk size
41+
# check chunk size
3542
for segment_index in range(recording.get_num_segments()):
36-
assert chunk_size < recording.get_num_samples(segment_index), (f"chunk_size is greater than the number "
37-
f"of samples for segment index {segment_index}. "
38-
f"Use a smaller chunk_size!")
43+
assert chunk_size < recording.get_num_samples(segment_index), (
44+
f"chunk_size is greater than the number "
45+
f"of samples for segment index {segment_index}. "
46+
f"Use a smaller chunk_size!"
47+
)
3948

4049
chunk_list = []
4150
for segment_index in range(recording.get_num_segments()):
4251
length = recording.get_num_frames(segment_index)
43-
44-
random_starts = np.random.RandomState(seed=seed).randint(margin_frames, length - chunk_size - margin_frames, size=num_chunks_per_segment)
52+
53+
random_starts = np.random.RandomState(seed=seed).randint(
54+
margin_frames,
55+
length - chunk_size - margin_frames,
56+
size=num_chunks_per_segment,
57+
)
4558
for start_frame in random_starts:
46-
chunk = recording.get_traces(start_frame=start_frame,
47-
end_frame=start_frame + chunk_size,
48-
segment_index=segment_index,
49-
return_scaled=return_scaled)
59+
chunk = recording.get_traces(
60+
start_frame=start_frame,
61+
end_frame=start_frame + chunk_size,
62+
segment_index=segment_index,
63+
return_scaled=return_scaled,
64+
)
5065
chunk_list.append(chunk)
5166
if concatenated:
5267
return np.concatenate(chunk_list, axis=0)
@@ -59,7 +74,9 @@ def get_channel_distances(recording):
5974
Distance between channel pairs
6075
"""
6176
locations = recording.get_channel_locations()
62-
channel_distances = np.linalg.norm(locations[:, np.newaxis] - locations[np.newaxis, :], axis=2)
77+
channel_distances = np.linalg.norm(
78+
locations[:, np.newaxis] - locations[np.newaxis, :], axis=2
79+
)
6380

6481
return channel_distances
6582

@@ -95,40 +112,59 @@ def get_closest_channels(recording, channel_ids=None, num_channels=None):
95112
for i in range(locations.shape[0]):
96113
distances = np.linalg.norm(locations[i, :] - locations, axis=1)
97114
order = np.argsort(distances)
98-
closest_channels_inds.append(order[1:num_channels + 1])
99-
dists.append(distances[order][1:num_channels + 1])
115+
closest_channels_inds.append(order[1 : num_channels + 1])
116+
dists.append(distances[order][1 : num_channels + 1])
100117

101118
return np.array(closest_channels_inds), np.array(dists)
102119

103120

104121
def get_noise_levels(recording, return_scaled=True, **random_chunk_kwargs):
105122
"""
106123
Estimate noise for each channel using MAD methods.
107-
124+
108125
Internally it sample some chunk across segment.
109126
And then, it use MAD estimator (more robust than STD)
110-
127+
111128
"""
112-
random_chunks = get_random_data_chunks(recording, return_scaled=return_scaled, **random_chunk_kwargs)
129+
random_chunks = get_random_data_chunks(
130+
recording, return_scaled=return_scaled, **random_chunk_kwargs
131+
)
113132
med = np.median(random_chunks, axis=0, keepdims=True)
114133
# hard-coded so that core doesn't depend on scipy
115-
noise_levels = np.median(np.abs(random_chunks - med), axis=0) / 0.6744897501960817
134+
noise_levels = (
135+
np.median(np.abs(random_chunks - med), axis=0) / 0.6744897501960817
136+
)
116137
return noise_levels
117138

118139

119-
def get_chunk_with_margin(rec_segment, start_frame, end_frame,
120-
channel_indices, margin, add_zeros=False,
121-
window_on_margin=False, dtype=None):
140+
def get_chunk_with_margin(
141+
rec_segment,
142+
start_frame,
143+
end_frame,
144+
channel_indices,
145+
margin,
146+
add_zeros=False,
147+
add_reflect_padding=False,
148+
window_on_margin=False,
149+
dtype=None,
150+
):
122151
"""
123152
Helper to get chunk with margin
153+
154+
The margin is extracted from the recording when possible. If
155+
at the edge of the recording, no margin is used unless one
156+
of `add_zeros` or `add_reflect_padding` is True. In the first
157+
case zero padding is used, in the second case np.pad is called
158+
with mod="reflect".
124159
"""
125160
length = rec_segment.get_num_samples()
126161

127162
if channel_indices is None:
128163
channel_indices = slice(None)
129164

130-
if not add_zeros:
131-
assert not window_on_margin, 'window_on_margin can be used only for add_zeros=True'
165+
if not (add_zeros or add_reflect_padding):
166+
if window_on_margin and not add_zeros:
167+
raise ValueError("window_on_margin requires add_zeros=True")
132168
if start_frame is None:
133169
left_margin = 0
134170
start_frame = 0
@@ -144,10 +180,14 @@ def get_chunk_with_margin(rec_segment, start_frame, end_frame,
144180
right_margin = length - end_frame
145181
else:
146182
right_margin = margin
147-
traces_chunk = rec_segment.get_traces(start_frame - left_margin, end_frame + right_margin, channel_indices)
183+
traces_chunk = rec_segment.get_traces(
184+
start_frame - left_margin,
185+
end_frame + right_margin,
186+
channel_indices,
187+
)
148188

149189
else:
150-
# add_zeros=True
190+
# either add_zeros or reflect_padding
151191
assert start_frame is not None
152192
assert end_frame is not None
153193
chunk_size = end_frame - start_frame
@@ -167,41 +207,66 @@ def get_chunk_with_margin(rec_segment, start_frame, end_frame,
167207
end_frame2 = end_frame + margin
168208
right_pad = 0
169209

170-
traces_chunk = rec_segment.get_traces(start_frame2, end_frame2, channel_indices)
171-
172-
173-
if dtype is not None or window_on_margin or left_pad > 0 or right_pad > 0:
210+
traces_chunk = rec_segment.get_traces(
211+
start_frame2, end_frame2, channel_indices
212+
)
213+
214+
if (
215+
dtype is not None
216+
or window_on_margin
217+
or left_pad > 0
218+
or right_pad > 0
219+
):
174220
need_copy = True
175221
else:
176222
need_copy = False
177223

224+
left_margin = margin
225+
right_margin = margin
226+
178227
if need_copy:
179228
if dtype is None:
180229
dtype = traces_chunk.dtype
181-
traces_chunk2 = np.zeros((full_size, traces_chunk.shape[1]), dtype=dtype)
182-
i0 = left_pad
183-
i1 = left_pad + traces_chunk.shape[0]
184-
traces_chunk2[i0: i1, :] = traces_chunk
230+
185231
left_margin = margin
186232
if end_frame < (length + margin):
187233
right_margin = margin
188234
else:
189235
right_margin = end_frame + margin - length
190-
if window_on_margin:
191-
# apply inplace taper on border
192-
taper = (1 - np.cos(np.arange(margin) / margin * np.pi)) / 2
193-
taper = taper[:, np.newaxis]
194-
traces_chunk2[:margin] *= taper
195-
traces_chunk2[-margin:] *= taper[::-1]
196-
traces_chunk = traces_chunk2
197-
else:
198-
left_margin = margin
199-
right_margin = margin
236+
237+
if add_zeros:
238+
traces_chunk2 = np.zeros(
239+
(full_size, traces_chunk.shape[1]), dtype=dtype
240+
)
241+
i0 = left_pad
242+
i1 = left_pad + traces_chunk.shape[0]
243+
traces_chunk2[i0:i1, :] = traces_chunk
244+
if window_on_margin:
245+
# apply inplace taper on border
246+
taper = (
247+
1 - np.cos(np.arange(margin) / margin * np.pi)
248+
) / 2
249+
taper = taper[:, np.newaxis]
250+
traces_chunk2[:margin] *= taper
251+
traces_chunk2[-margin:] *= taper[::-1]
252+
traces_chunk = traces_chunk2
253+
elif add_reflect_padding:
254+
# in this case, we don't want to taper
255+
traces_chunk = np.pad(
256+
traces_chunk.astype(dtype),
257+
[(left_pad, right_pad), (0, 0)],
258+
mode="reflect",
259+
)
260+
else:
261+
# we need a copy to change the dtype
262+
traces_chunk = np.asarray(traces_chunk, dtype=dtype)
200263

201264
return traces_chunk, left_margin, right_margin
202265

203266

204-
def order_channels_by_depth(recording, channel_ids=None, dimensions=('x', 'y')):
267+
def order_channels_by_depth(
268+
recording, channel_ids=None, dimensions=("x", "y")
269+
):
205270
"""
206271
Order channels by depth, by first ordering the x-axis, and then the y-axis.
207272
@@ -213,7 +278,7 @@ def order_channels_by_depth(recording, channel_ids=None, dimensions=('x', 'y')):
213278
If given, a subset of channels to order locations for
214279
dimensions : str or tuple
215280
If str, it needs to be 'x', 'y', 'z'.
216-
If tuple, it sorts the locations in two dimensions using lexsort.
281+
If tuple, it sorts the locations in two dimensions using lexsort.
217282
This approach is recommended since there is less ambiguity, by default ('x', 'y')
218283
219284
Returns
@@ -229,18 +294,20 @@ def order_channels_by_depth(recording, channel_ids=None, dimensions=('x', 'y')):
229294
locations = locations[channel_inds, :]
230295

231296
if isinstance(dimensions, str):
232-
dim = ['x', 'y', 'z'].index(dimensions)
297+
dim = ["x", "y", "z"].index(dimensions)
233298
assert dim < ndim, "Invalid dimensions!"
234-
order_f = np.argsort(locations[:, dim])
299+
order_f = np.argsort(locations[:, dim], kind="stable")
235300
else:
236-
assert isinstance(dimensions, tuple), "dimensions can be a str or a tuple"
301+
assert isinstance(
302+
dimensions, tuple
303+
), "dimensions can be a str or a tuple"
237304
locations_to_sort = ()
238305
for dim in dimensions:
239-
dim = ['x', 'y', 'z'].index(dim)
306+
dim = ["x", "y", "z"].index(dim)
240307
assert dim < ndim, "Invalid dimensions!"
241-
locations_to_sort += (locations[:, dim], )
308+
locations_to_sort += (locations[:, dim],)
242309
order_f = np.lexsort(locations_to_sort)
243-
order_r = np.argsort(order_f)
310+
order_r = np.argsort(order_f, kind="stable")
244311

245312
return order_f, order_r
246313

@@ -253,21 +320,27 @@ def check_probe_do_not_overlap(probes):
253320
for i in range(len(probes)):
254321
probe_i = probes[i]
255322
# check that all positions in probe_j are outside probe_i boundaries
256-
x_bounds_i = [np.min(probe_i.contact_positions[:, 0]),
257-
np.max(probe_i.contact_positions[:, 0])]
258-
y_bounds_i = [np.min(probe_i.contact_positions[:, 1]),
259-
np.max(probe_i.contact_positions[:, 1])]
323+
x_bounds_i = [
324+
np.min(probe_i.contact_positions[:, 0]),
325+
np.max(probe_i.contact_positions[:, 0]),
326+
]
327+
y_bounds_i = [
328+
np.min(probe_i.contact_positions[:, 1]),
329+
np.max(probe_i.contact_positions[:, 1]),
330+
]
260331

261332
for j in range(i + 1, len(probes)):
262333
probe_j = probes[j]
263334

264-
if np.any(np.array([x_bounds_i[0] < cp[0] < x_bounds_i[1] and
265-
y_bounds_i[0] < cp[1] < y_bounds_i[1]
266-
for cp in probe_j.contact_positions])):
335+
if np.any(
336+
np.array(
337+
[
338+
x_bounds_i[0] < cp[0] < x_bounds_i[1]
339+
and y_bounds_i[0] < cp[1] < y_bounds_i[1]
340+
for cp in probe_j.contact_positions
341+
]
342+
)
343+
):
267344
raise Exception(
268-
"Probes are overlapping! Retrieve locations of single probes separately")
269-
270-
271-
272-
273-
345+
"Probes are overlapping! Retrieve locations of single probes separately"
346+
)

0 commit comments

Comments
 (0)