Skip to content

Commit f75d3e5

Browse files
Adapt CellExplorerSortingExtractor to new format (#1628)
* refactor * refactor * passing all tests * kick yaml out * revert to file_path convention * allow the deprecated value to actually be used * correct deprecation warning * Update src/spikeinterface/extractors/cellexplorersortingextractor.py Co-authored-by: Cody Baker <[email protected]> * change file paths * correct typing * hook * Revert "hook" This reverts commit d8ee970. * test hook 2 * Revert "test hook 2" This reverts commit 0a861f4. * hook 3 * update pre-commit --------- Co-authored-by: Cody Baker <[email protected]>
1 parent a5cab08 commit f75d3e5

File tree

5 files changed

+247
-118
lines changed

5 files changed

+247
-118
lines changed

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@ extractors = [
6262
"pyedflib>=0.1.30",
6363
"sonpy;python_version<'3.10'",
6464
"lxml", # lxml for neuroscope
65-
"hdf5storage", # hdf5storage and scipy for cellexplorer
6665
"scipy",
6766
# ONE-api and ibllib for streaming IBL
6867
"ONE-api>=1.19.1",
Lines changed: 177 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -1,121 +1,209 @@
1+
from __future__ import annotations
2+
13
import numpy as np
24
from pathlib import Path
3-
from typing import Union, Optional
5+
import warnings
6+
import datetime
47

58
from ..core import BaseSorting, BaseSortingSegment
69
from ..core.core_tools import define_function_from_class
710

811

9-
try:
10-
import scipy.io
11-
import hdf5storage
12-
13-
HAVE_SCIPY_AND_HDF5STORAGE = True
14-
except ImportError:
15-
HAVE_SCIPY_AND_HDF5STORAGE = False
16-
17-
18-
PathType = Union[str, Path]
19-
OptionalPathType = Optional[PathType]
20-
21-
2212
class CellExplorerSortingExtractor(BaseSorting):
2313
"""
24-
Extracts spiking information from .mat files stored in the CellExplorer format.
25-
Spike times are stored in units of seconds.
14+
Extracts spiking information from `.mat` file stored in the CellExplorer format.
15+
Spike times are stored in units of seconds so we transform them to units of samples.
16+
17+
The newer version of the format is described here:
18+
https://cellexplorer.org/data-structure/
19+
20+
Whereas the old format is described here:
21+
https://github.com/buzsakilab/buzcode/wiki/Data-Formatting-Standards
2622
2723
Parameters
2824
----------
29-
spikes_matfile_path : PathType
30-
Path to the sorting_id.spikes.cellinfo.mat file.
25+
file_path: str | Path
26+
Path to `.mat` file containing spikes. Usually named `session_id.spikes.cellinfo.mat`
27+
sampling_frequency: float | None, optional
28+
The sampling frequency of the data. If None, it will be extracted from the files.
29+
session_info_file_path: str | Path | None, optional
30+
Path to the `sessionInfo.mat` file. If None, it will be inferred from the file_path.
3131
"""
3232

3333
extractor_name = "CellExplorerSortingExtractor"
34-
installed = HAVE_SCIPY_AND_HDF5STORAGE
3534
is_writable = True
3635
mode = "file"
37-
installation_mesg = (
38-
"To use the CellExplorerSortingExtractor install scipy and hdf5storage: \n\n pip install scipy hdf5storage"
39-
)
36+
installation_mesg = "To use the CellExplorerSortingExtractor install scipy and h5py"
4037

4138
def __init__(
4239
self,
43-
spikes_matfile_path: PathType,
44-
session_info_matfile_path: OptionalPathType = None,
45-
sampling_frequency: Optional[float] = None,
40+
file_path: str | Path | None = None,
41+
sampling_frequency: float | None = None,
42+
session_info_file_path: str | Path | None = None,
43+
spikes_matfile_path: str | Path | None = None,
44+
session_info_matfile_path: str | Path | None = None,
4645
):
47-
assert self.installed, self.installation_mesg
46+
try:
47+
import h5py
48+
import scipy.io
49+
except ImportError:
50+
raise ImportError(self.installation_mesg)
51+
52+
assert (
53+
file_path is not None or spikes_matfile_path is not None
54+
), "Either file_path or spikes_matfile_path must be provided!"
55+
56+
if spikes_matfile_path is not None:
57+
# Raise an error if the warning period has expired
58+
deprecation_issued = datetime.datetime(2023, 4, 1)
59+
deprecation_deadline = deprecation_issued + datetime.timedelta(days=180)
60+
if datetime.datetime.now() > deprecation_deadline:
61+
raise ValueError("The spikes_matfile_path argument is no longer supported in. Use file_path instead.")
62+
63+
# Otherwise, issue a DeprecationWarning
64+
else:
65+
warnings.warn(
66+
"The spikes_matfile_path argument is deprecated and will be removed in six months. "
67+
"Use file_path instead.",
68+
DeprecationWarning,
69+
)
70+
file_path = spikes_matfile_path if file_path is None else file_path
71+
72+
if session_info_matfile_path is not None:
73+
# Raise an error if the warning period has expired
74+
deprecation_issued = datetime.datetime(2023, 4, 1)
75+
deprecation_deadline = deprecation_issued + datetime.timedelta(days=180)
76+
if datetime.datetime.now() > deprecation_deadline:
77+
raise ValueError(
78+
"The session_info_matfile_path argument is no longer supported in. Use session_info_file_path instead."
79+
)
80+
81+
# Otherwise, issue a DeprecationWarning
82+
else:
83+
warnings.warn(
84+
"The session_info_matfile_path argument is deprecated and will be removed in six months. "
85+
"Use session_info_file_path instead.",
86+
DeprecationWarning,
87+
)
88+
session_info_file_path = (
89+
session_info_matfile_path if session_info_file_path is None else session_info_file_path
90+
)
4891

49-
spikes_matfile_path = Path(spikes_matfile_path)
50-
assert spikes_matfile_path.is_file(), f"The spikes_matfile_path ({spikes_matfile_path}) must exist!"
92+
self.spikes_cellinfo_path = Path(file_path).absolute()
93+
assert self.spikes_cellinfo_path.is_file(), f"The spikes.cellinfo.mat file must exist in {self.folder_path}!"
5194

52-
if sampling_frequency is None:
53-
folder_path = spikes_matfile_path.parent
54-
sorting_id = spikes_matfile_path.name.split(".")[0]
55-
if session_info_matfile_path is None:
56-
session_info_matfile_path = folder_path / f"{sorting_id}.sessionInfo.mat"
57-
session_info_matfile_path = Path(session_info_matfile_path)
58-
assert (session_info_matfile_path).is_file(), f"No {sorting_id}.sessionInfo.mat file found in the folder!"
59-
60-
try:
61-
session_info_mat = scipy.io.loadmat(file_name=str(session_info_matfile_path))
62-
self.read_session_info_with_scipy = True
63-
except NotImplementedError:
64-
session_info_mat = hdf5storage.loadmat(file_name=str(session_info_matfile_path))
65-
self.read_session_info_with_scipy = False
66-
67-
assert session_info_mat["sessionInfo"]["rates"][0][0]["wideband"], (
68-
"The sesssionInfo.mat file must contain "
69-
"a 'sessionInfo' struct with field 'rates' containing field 'wideband' to extract the sampling frequency!"
70-
)
71-
if self.read_session_info_with_scipy:
72-
sampling_frequency = float(
73-
session_info_mat["sessionInfo"]["rates"][0][0]["wideband"][0][0][0][0]
74-
) # careful not to confuse it with the lfpsamplingrate; reported in units Hz
75-
else:
76-
sampling_frequency = float(
77-
session_info_mat["sessionInfo"]["rates"][0][0]["wideband"][0][0]
78-
) # careful not to confuse it with the lfpsamplingrate; reported in units Hz
95+
self.folder_path = self.spikes_cellinfo_path.parent
96+
self.session_info_file_path = session_info_file_path
7997

98+
self.session_id = self.spikes_cellinfo_path.stem.split(".")[0]
99+
100+
read_as_hdf5 = False
80101
try:
81-
spikes_mat = scipy.io.loadmat(file_name=str(spikes_matfile_path))
82-
self.read_spikes_info_with_scipy = True
102+
matlab_file = scipy.io.loadmat(file_name=str(self.spikes_cellinfo_path), simplify_cells=True)
103+
spikes_mat = matlab_file["spikes"]
104+
assert isinstance(spikes_mat, dict), f"field `spikes` must be a dict, not {type(spikes_mat)}!"
105+
83106
except NotImplementedError:
84-
spikes_mat = hdf5storage.loadmat(file_name=str(spikes_matfile_path))
85-
self.read_spikes_info_with_scipy = False
107+
matlab_file = h5py.File(name=self.spikes_cellinfo_path, mode="r")
108+
spikes_mat = matlab_file["spikes"]
109+
assert isinstance(spikes_mat, h5py.Group), f"field `spikes` must be a Group, not {type(spikes_mat)}!"
110+
read_as_hdf5 = True
86111

87-
assert np.all(
88-
np.isin(["UID", "times"], spikes_mat["spikes"].dtype.names)
89-
), "The spikes.cellinfo.mat file must contain a 'spikes' struct with fields 'UID' and 'times'!"
112+
if sampling_frequency is None:
113+
# First try the new format of spikes.cellinfo.mat files where sampling rate is included in the file
114+
sr_data = spikes_mat.get("sr", None)
115+
sampling_frequency = sr_data[()] if isinstance(sr_data, h5py.Dataset) else None
116+
117+
if sampling_frequency is None:
118+
sampling_frequency = self._retrieve_sampling_frequency_from_session_info()
119+
120+
sampling_frequency = float(sampling_frequency)
121+
122+
unit_ids_available = "UID" in spikes_mat.keys()
123+
assert unit_ids_available, f"The `spikes struct` must contain field 'UID'! fields: {spikes_mat.keys()}"
124+
125+
spike_times_available = "times" in spikes_mat.keys()
126+
assert spike_times_available, f"The `spike struct` must contain field 'times'! fields: {spikes_mat.keys()}"
127+
128+
unit_ids = spikes_mat["UID"]
129+
spike_times = spikes_mat["times"]
130+
131+
if read_as_hdf5:
132+
assert isinstance(unit_ids, h5py.Dataset), f"`unit_ids` must be a Dataset, not {type(unit_ids)}!"
133+
assert isinstance(spike_times, h5py.Dataset), f"`spike_times` must be a Dataset, not {type(spike_times)}!"
134+
135+
unit_ids = unit_ids[:].squeeze().astype("int")
136+
references = (ref[0] for ref in spike_times[:]) # These are HDF5 references
137+
spike_times_data = (matlab_file[ref] for ref in references if isinstance(matlab_file[ref], h5py.Dataset))
138+
# Format as a list of numpy arrays
139+
spike_times = [data[()].squeeze() for data in spike_times_data]
90140

91141
# CellExplorer reports spike times in units seconds; SpikeExtractors uses time units of sampling frames
92-
# Rounding is necessary to prevent data loss from int-casting floating point errors
93-
if self.read_spikes_info_with_scipy:
94-
unit_ids = np.asarray(spikes_mat["spikes"]["UID"][0][0][0])
95-
spiketrains = [
96-
(np.array([y[0] for y in x]) * sampling_frequency).round().astype(np.int64)
97-
for x in spikes_mat["spikes"]["times"][0][0][0]
98-
]
99-
else:
100-
unit_ids = np.asarray(spikes_mat["spikes"]["UID"][0][0])
101-
spiketrains = [
102-
(np.array([y[0] for y in x]) * sampling_frequency).round().astype(np.int64)
103-
for x in spikes_mat["spikes"]["times"][0][0]
104-
]
142+
unit_ids = unit_ids[:].tolist()
143+
spiketrains_dict = {unit_id: spike_times[index] for index, unit_id in enumerate(unit_ids)}
144+
for unit_id in unit_ids:
145+
spiketrains_dict[unit_id] = (sampling_frequency * spiketrains_dict[unit_id]).round().astype(np.int64)
146+
# Rounding is necessary to prevent data loss from int-casting floating point errors
105147

106148
BaseSorting.__init__(self, unit_ids=unit_ids, sampling_frequency=sampling_frequency)
107-
sorting_segment = CellExplorerSortingSegment(spiketrains, unit_ids)
149+
sorting_segment = CellExplorerSortingSegment(spiketrains_dict, unit_ids)
108150
self.add_sorting_segment(sorting_segment)
109151

110152
self.extra_requirements.append("scipy")
111-
self.extra_requirements.append("hdf5storage")
112153

113-
self._kwargs = dict(spikes_matfile_path=str(spikes_matfile_path.absolute()))
154+
self._kwargs = dict(
155+
file_path=str(self.spikes_cellinfo_path),
156+
sampling_frequency=sampling_frequency,
157+
session_info_file_path=str(session_info_file_path),
158+
)
159+
160+
def _retrieve_sampling_frequency_from_session_info(self) -> float:
161+
"""
162+
Retrieve the sampling frequency from the `sessionInfo.mat` file when available.
163+
164+
This function tries to locate a .sessionInfo.mat file corresponding to the current session. It then loads this
165+
file (either as a standard .mat file or as an HDF5 file if the former is not possible) and extracts the wideband
166+
sampling frequency from the 'rates' field of the 'sessionInfo' struct.
167+
168+
Returns
169+
-------
170+
float
171+
The wideband sampling frequency for the current session.
172+
"""
173+
import h5py
174+
import scipy.io
175+
176+
if self.session_info_file_path is None:
177+
self.session_info_file_path = self.folder_path / f"{self.session_id}.sessionInfo.mat"
178+
179+
self.session_info_file_path = Path(self.session_info_file_path).absolute()
180+
assert (
181+
self.session_info_file_path.is_file()
182+
), f"No {self.session_id}.sessionInfo.mat file found in the {self.folder_path}!, can't inferr sampling rate"
183+
184+
read_as_hdf5 = False
185+
try:
186+
session_info_mat = scipy.io.loadmat(file_name=str(self.session_info_file_path), simplify_cells=True)
187+
except NotImplementedError:
188+
session_info_mat = h5py.File(name=str(self.session_info_file_path), mode="r")
189+
read_as_hdf5 = True
190+
191+
rates = session_info_mat["sessionInfo"]["rates"]
192+
wideband_in_rates = "wideband" in rates.keys()
193+
assert wideband_in_rates, "a 'sessionInfo' should contain a 'wideband' to extract the sampling frequency!"
194+
195+
# Not to be connfused with the lfpsamplingrate; reported in units Hz also present in rates
196+
sampling_frequency = rates["wideband"]
197+
198+
if read_as_hdf5:
199+
sampling_frequency = sampling_frequency[()]
200+
201+
return sampling_frequency
114202

115203

116204
class CellExplorerSortingSegment(BaseSortingSegment):
117-
def __init__(self, spiketrains, unit_ids):
118-
self._spiketrains = spiketrains
205+
def __init__(self, spiketrains_dict, unit_ids):
206+
self.spiketrains_dict = spiketrains_dict
119207
self._unit_ids = list(unit_ids)
120208
BaseSortingSegment.__init__(self)
121209

@@ -125,14 +213,15 @@ def get_unit_spike_train(
125213
start_frame,
126214
end_frame,
127215
) -> np.ndarray:
128-
# must be implemented in subclass
129-
if start_frame is None:
130-
start_frame = 0
131-
if end_frame is None:
132-
end_frame = np.inf
133-
spike_frames = self._spiketrains[self._unit_ids.index(unit_id)]
134-
inds = np.where((start_frame <= spike_frames) & (spike_frames < end_frame))
135-
return spike_frames[inds]
216+
spike_frames = self.spiketrains_dict[unit_id]
217+
# clip
218+
if start_frame is not None:
219+
spike_frames = spike_frames[spike_frames >= start_frame]
220+
221+
if end_frame is not None:
222+
spike_frames = spike_frames[spike_frames <= end_frame]
223+
224+
return spike_frames
136225

137226

138227
read_cellexplorer = define_function_from_class(source_class=CellExplorerSortingExtractor, name="read_cellexplorer")

0 commit comments

Comments
 (0)