Skip to content

Commit c163acc

Browse files
authored
Merge branch 'master' into merge_ap_lfp_neuropix
2 parents e62cc32 + 8a50476 commit c163acc

File tree

12 files changed

+274
-54
lines changed

12 files changed

+274
-54
lines changed

spikeinterface/core/baserecording.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from .baserecordingsnippets import BaseRecordingSnippets
1111
from .core_tools import write_binary_recording, write_memory_recording, write_traces_to_zarr, check_json
1212
from .job_tools import split_job_kwargs, fix_job_kwargs
13+
from .core_tools import convert_bytes_to_str
1314

1415
from warnings import warn
1516

@@ -42,12 +43,17 @@ def __repr__(self):
4243
nchan = self.get_num_channels()
4344
sf_khz = self.get_sampling_frequency() / 1000.
4445
duration = self.get_total_duration()
45-
txt = f'{clsname}: {nchan} channels - {nseg} segments - {sf_khz:0.1f}kHz - {duration:0.3f}s'
46+
memory_size = self.get_memory_size()
47+
txt = f"{clsname}: {nchan} channels - {nseg} segments - {sf_khz:0.1f}kHz - {duration:0.3f}s - {memory_size}"
4648
if 'file_paths' in self._kwargs:
4749
txt += '\n file_paths: {}'.format(self._kwargs['file_paths'])
4850
if 'file_path' in self._kwargs:
4951
txt += '\n file_path: {}'.format(self._kwargs['file_path'])
5052
return txt
53+
54+
def get_memory_size(self):
55+
bytes = self.get_total_samples() * self.get_num_channels() * self.get_dtype().itemsize
56+
return convert_bytes_to_str(bytes)
5157

5258
def get_num_segments(self):
5359
"""Returns the number of segments.

spikeinterface/core/core_tools.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -811,3 +811,21 @@ def recursive_key_finder(d, key):
811811
else:
812812
if k == key:
813813
yield v
814+
815+
816+
def convert_bytes_to_str(byte_value:int ) -> str:
817+
"""
818+
Converts a number of bytes to a value in either KiB, MiB, GiB, or TiB.
819+
820+
Args:
821+
byte_value (int): The number of bytes to convert.
822+
823+
Returns:
824+
str: The converted value with the appropriate unit (KiB, MiB, GiB, or TiB).
825+
"""
826+
suffixes = ['B', 'KiB', 'MiB', 'GiB', 'TiB']
827+
i = 0
828+
while byte_value >= 1024 and i < len(suffixes) - 1:
829+
byte_value /= 1024
830+
i += 1
831+
return f"{byte_value:.2f} {suffixes[i]}"

spikeinterface/core/frameslicerecording.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,18 @@ class FrameSliceRecording(BaseRecording):
1010
1111
Do not use this class directly but use `recording.frame_slice(...)`
1212
13+
Parameters
14+
----------
15+
parent_recording: BaseRecording
16+
start_frame: None or int
17+
Earliest included frame in the parent recording.
18+
Times are re-referenced to start_frame in the
19+
sliced object. Set to 0 by default.
20+
end_frame: None or int
21+
Latest frame in the parent recording. As for usual
22+
python slicing, the end frame is excluded.
23+
Set to the recording's total number of samples by
24+
default
1325
"""
1426

1527
def __init__(self, parent_recording, start_frame=None, end_frame=None):

spikeinterface/core/frameslicesorting.py

Lines changed: 50 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import warnings
33

44
from .basesorting import BaseSorting, BaseSortingSegment
5+
from .waveform_tools import has_exceeding_spikes
56

67

78
class FrameSliceSorting(BaseSorting):
@@ -11,27 +12,67 @@ class FrameSliceSorting(BaseSorting):
1112
1213
Do not use this class directly but use `sorting.frame_slice(...)`
1314
15+
When a recording is registered for the parent sorting,
16+
a corresponding sliced recording is registered to the sliced sorting.
17+
18+
Note that the returned sliced sorting may be empty.
19+
20+
Parameters
21+
----------
22+
parent_sorting: BaseSorting
23+
start_frame: None or int
24+
Earliest included frame in the parent sorting(/recording).
25+
Spike times(/traces) are re-referenced to start_frame in the
26+
sliced objects. Set to 0 by default.
27+
end_frame: None or int
28+
Latest frame in the parent sorting(/recording). As for usual
29+
python slicing, the end frame is excluded (such that the max
30+
spike frame in the sliced sorting is `end_frame - start_frame - 1`)
31+
If None (default), the end_frame is either:
32+
- The total number of samples, if a recording is assigned
33+
- The maximum spike frame + 1, if no recording is assigned
1434
"""
1535

1636
def __init__(self, parent_sorting, start_frame=None, end_frame=None):
1737
unit_ids = parent_sorting.get_unit_ids()
1838

1939
assert parent_sorting.get_num_segments() == 1, 'FrameSliceSorting work only with one segment'
2040

21-
if start_frame is not None or end_frame is None:
22-
parent_size = 0
23-
for u in parent_sorting.get_unit_ids():
24-
parent_size = np.max([parent_size, np.max(parent_sorting.get_unit_spike_train(u))])
2541

2642
if start_frame is None:
2743
start_frame = 0
28-
else:
29-
assert 0 <= start_frame < parent_size
44+
assert 0 <= start_frame, "Invalid value for start_frame: expected positive integer."
3045

31-
if end_frame is None:
32-
end_frame = parent_size + 1
46+
if parent_sorting.has_recording():
47+
# Pull df end_frame from recording
48+
parent_n_samples = parent_sorting._recording.get_total_samples()
49+
if end_frame is None:
50+
end_frame = parent_n_samples
51+
assert end_frame <= parent_n_samples, (
52+
"`end_frame` should be smaller than the sortings total number of samples."
53+
)
54+
assert start_frame <= parent_n_samples, (
55+
"`start_frame` should be smaller than the sortings total number of samples."
56+
)
57+
if has_exceeding_spikes(parent_sorting._recording, parent_sorting):
58+
raise ValueError(
59+
"The sorting object has spikes exceeding the recording duration. You have to remove those spikes "
60+
"with the `spikeinterface.curation.remove_excess_spikes()` function"
61+
)
3362
else:
34-
assert end_frame > start_frame, "'start_frame' must be smaller than 'end_frame'!"
63+
# Pull df end_frame from spikes
64+
if end_frame is None:
65+
max_spike_time = 0
66+
for u in parent_sorting.get_unit_ids():
67+
max_spike_time = np.max([max_spike_time, np.max(parent_sorting.get_unit_spike_train(u))])
68+
end_frame = max_spike_time + 1
69+
70+
assert start_frame < end_frame, (
71+
"`start_frame` should be greater than `end_frame`. "
72+
"This may be due to start_frame >= max_spike_time, if the end frame "
73+
"was not specified explicitly."
74+
)
75+
3576

3677
BaseSorting.__init__(self,
3778
sampling_frequency=parent_sorting.get_sampling_frequency(),

spikeinterface/core/generate.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -516,7 +516,7 @@ def generate_lazy_recording(full_traces_size_GiB: float, seed=None) -> Generator
516516
num_channels = 1024
517517

518518
GiB_to_bytes = 1024** 3
519-
full_traces_size_bytes = full_traces_size_GiB * GiB_to_bytes
519+
full_traces_size_bytes = int(full_traces_size_GiB * GiB_to_bytes)
520520
num_samples = int(full_traces_size_bytes / (num_channels * dtype.itemsize))
521521
durations = [num_samples / sampling_frequency]
522522

@@ -525,6 +525,7 @@ def generate_lazy_recording(full_traces_size_GiB: float, seed=None) -> Generator
525525

526526
return recording
527527

528+
528529
if __name__ == '__main__':
529530
print(generate_recording())
530531
print(generate_sorting())

spikeinterface/core/job_tools.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -397,3 +397,5 @@ def function_wrapper(args):
397397
else:
398398
with threadpool_limits(limits=max_threads_per_process):
399399
return _func(segment_index, start_frame, end_frame, _worker_ctx)
400+
401+
Lines changed: 99 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,111 @@
1-
from spikeinterface.core import generate_sorting
1+
import warnings
2+
3+
import numpy as np
4+
from numpy.testing import assert_raises
5+
6+
from spikeinterface.core import NumpyRecording, NumpySorting
27

38

49
def test_FrameSliceSorting():
5-
fs = 30000
6-
duration = 10
7-
sort = generate_sorting(num_units=10, durations=[
8-
duration], sampling_frequency=fs)
910

10-
mid_frame = (duration * fs) // 2
11+
# Single segment sorting, with and without attached recording
12+
# Since the default end_frame can be set either from the last spike
13+
# or from the registered recording
14+
sf = 10
15+
nsamp = 1000
16+
max_spike_time = 900
17+
min_spike_time = 100
18+
unit_0_train = np.arange(min_spike_time + 10, max_spike_time - 10)
19+
spike_times = {
20+
"0": unit_0_train,
21+
"1": np.arange(min_spike_time, max_spike_time),
22+
}
23+
# Sorting with attached rec
24+
sorting = NumpySorting.from_dict( [spike_times], sf)
25+
rec = NumpyRecording([np.zeros((nsamp, 5))], sampling_frequency=sf)
26+
sorting.register_recording(rec)
27+
# Sorting without attached rec
28+
sorting_norec = NumpySorting.from_dict( [spike_times], sf)
29+
# Sorting with attached rec and exceeding spikes
30+
sorting_exceeding = NumpySorting.from_dict( [spike_times], sf)
31+
rec_exceeding = NumpyRecording([np.zeros((max_spike_time-1, 5))], sampling_frequency=sf)
32+
with warnings.catch_warnings():
33+
warnings.filterwarnings("ignore")
34+
sorting_exceeding.register_recording(rec_exceeding)
35+
36+
mid_frame = nsamp // 2
37+
1138
# duration of all slices is mid_frame. Spike trains are re-referenced to the start_time
12-
sub_sort = sort.frame_slice(None, None)
13-
for u in sort.get_unit_ids():
14-
assert len(sort.get_unit_spike_train(u)) == len(
15-
sub_sort.get_unit_spike_train(u))
39+
# Vary start_frame/end_frame combination
40+
start_frame, end_frame = None, None
41+
sub_sorting = sorting.frame_slice(start_frame, end_frame)
42+
assert np.array_equal(sub_sorting.get_unit_spike_train("0"), unit_0_train)
43+
assert sub_sorting._recording.get_total_samples() == nsamp
44+
sub_sorting_norec = sorting.frame_slice(start_frame, end_frame)
45+
assert np.array_equal(sub_sorting_norec.get_unit_spike_train("0"), unit_0_train)
46+
47+
start_frame, end_frame = None, mid_frame
48+
sub_sorting = sorting.frame_slice(start_frame, end_frame)
49+
assert np.array_equal(
50+
sub_sorting.get_unit_spike_train("0"),
51+
[t for t in unit_0_train if t < mid_frame]
52+
)
53+
assert sub_sorting._recording.get_total_samples() == mid_frame
54+
sub_sorting_norec = sorting.frame_slice(start_frame, end_frame)
55+
assert np.array_equal(
56+
sub_sorting_norec.get_unit_spike_train("0"),
57+
sub_sorting.get_unit_spike_train("0")
58+
)
59+
60+
start_frame, end_frame = mid_frame, None
61+
sub_sorting = sorting.frame_slice(start_frame, end_frame)
62+
assert np.array_equal(
63+
sub_sorting.get_unit_spike_train("0"),
64+
[t - mid_frame for t in unit_0_train if t >= mid_frame]
65+
)
66+
assert sub_sorting._recording.get_total_samples() == nsamp - mid_frame
67+
sub_sorting_norec = sorting.frame_slice(start_frame, end_frame)
68+
assert np.array_equal(
69+
sub_sorting_norec.get_unit_spike_train("0"),
70+
sub_sorting.get_unit_spike_train("0")
71+
)
72+
73+
start_frame, end_frame = mid_frame - 10, mid_frame + 10
74+
sub_sorting = sorting.frame_slice(start_frame, end_frame)
75+
assert np.array_equal(
76+
sub_sorting.get_unit_spike_train("0"),
77+
[t - start_frame for t in unit_0_train if start_frame <= t < end_frame]
78+
)
79+
assert sub_sorting._recording.get_total_samples() == 20
80+
sub_sorting_norec = sorting.frame_slice(start_frame, end_frame)
81+
assert np.array_equal(
82+
sub_sorting_norec.get_unit_spike_train("0"),
83+
sub_sorting.get_unit_spike_train("0")
84+
)
85+
86+
# Edge cases: start_frame > end_frame
87+
assert_raises(Exception, sorting.frame_slice, 100, 90)
1688

17-
sub_sort = sort.frame_slice(None, mid_frame)
18-
for u in sort.get_unit_ids():
19-
assert max(sub_sort.get_unit_spike_train(u)) <= mid_frame
89+
# Edge case: start_frame > max_spike_time
90+
# Fails without rec (since end_frame is last spike)
91+
assert_raises(Exception, sorting_norec.frame_slice, max_spike_time + 1, None)
92+
# Empty sorting with rec
93+
sub_sorting = sorting.frame_slice(max_spike_time + 1, None)
94+
assert np.array_equal(
95+
sub_sorting.get_unit_spike_train("1"),
96+
[]
97+
)
2098

21-
sub_sort = sort.frame_slice(mid_frame, None)
22-
for u in sort.get_unit_ids():
23-
assert max(sub_sort.get_unit_spike_train(u)) <= mid_frame
99+
# Edge case: end_frame <= min_spike_time
100+
# Empty sorting
101+
sub_sorting = sorting.frame_slice(None, min_spike_time)
102+
assert np.array_equal( sub_sorting.get_unit_spike_train("1"), [])
24103

25-
sub_sort = sort.frame_slice(
26-
mid_frame - mid_frame // 2, mid_frame + mid_frame // 2)
27-
for u in sort.get_unit_ids():
28-
assert max(sub_sort.get_unit_spike_train(u)) <= mid_frame
104+
# Edge case: start_frame = end_frame
105+
assert_raises(Exception, sorting.frame_slice, max_spike_time, max_spike_time)
29106

107+
# Sorting with exceeding spikes
108+
assert_raises(Exception, sorting_exceeding.frame_slice, None, None)
30109

31110
if __name__ == '__main__':
32111
test_FrameSliceSorting()

spikeinterface/core/tests/test_generate.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,3 +98,9 @@ def test_generate_lazy_recording():
9898
print(f"Difference between the last two {(memory_after_traces_MiB - traces_size_MiB)} MiB")
9999

100100
(memory_after_instanciation_MiB + traces_size_MiB) == pytest.approx(memory_after_traces_MiB, rel=relative_tolerance)
101+
102+
103+
def test_generate_lazy_recording_under_giga():
104+
105+
recording = generate_lazy_recording(full_traces_size_GiB=0.5)
106+
assert recording.get_memory_size() == "512.00 MiB"

spikeinterface/extractors/neoextractors/openephys.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,15 +81,15 @@ class OpenEphysBinaryRecordingExtractor(NeoBaseRecordingExtractor):
8181
Parameters
8282
----------
8383
folder_path: str
84-
The folder path to load the recordings from.
84+
The folder path to the root folder (containing the record node folders).
8585
load_sync_channel : bool
8686
If False (default) and a SYNC channel is present (e.g. Neuropixels), this is not loaded.
8787
If True, the SYNC channel is loaded and can be accessed in the analog signals.
88-
load_sync_channel : bool
88+
load_sync_timestamps : bool
8989
If True, the synchronized_timestamps are loaded and set as times to the recording.
9090
If False (default), only the t_start and sampling rate are set, and timestamps are assumed
9191
to be uniform and linearly increasing.
92-
experiment_name: str, list, or None
92+
experiment_names: str, list, or None
9393
If multiple experiments are available, this argument allows users to select one
9494
or more experiments. If None, all experiements are loaded as blocks.
9595
E.g. 'experiment_names="experiment2"', 'experiment_names=["experiment1", "experiment2"]'

spikeinterface/extractors/phykilosortextractors.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,16 +35,19 @@ def __init__(self, folder_path, exclude_cluster_groups=None, keep_good_only=Fals
3535

3636
phy_folder = Path(folder_path)
3737

38-
spike_times = np.load(phy_folder / 'spike_times.npy')
38+
spike_times = np.load(phy_folder / 'spike_times.npy').astype(int)
3939

4040
if (phy_folder / 'spike_clusters.npy').is_file():
4141
spike_clusters = np.load(phy_folder / 'spike_clusters.npy')
4242
else:
4343
spike_clusters = np.load(phy_folder / 'spike_templates.npy')
4444

45+
# spike_times and spike_clusters can be 2d sometimes --> convert to 1d.
46+
spike_times = np.atleast_1d(spike_times.squeeze())
47+
spike_clusters = np.atleast_1d(spike_clusters.squeeze())
48+
4549
clust_id = np.unique(spike_clusters)
4650
unit_ids = list(clust_id)
47-
spike_times = spike_times.astype(int)
4851
params = read_python(str(phy_folder / 'params.py'))
4952
sampling_frequency = params['sample_rate']
5053

0 commit comments

Comments
 (0)