1
+ from __future__ import annotations
2
+
1
3
import numpy as np
2
4
from pathlib import Path
3
- from typing import Union , Optional
5
+ import warnings
6
+ import datetime
4
7
5
8
from ..core import BaseSorting , BaseSortingSegment
6
9
from ..core .core_tools import define_function_from_class
7
10
8
11
9
- try :
10
- import scipy .io
11
- import hdf5storage
12
-
13
- HAVE_SCIPY_AND_HDF5STORAGE = True
14
- except ImportError :
15
- HAVE_SCIPY_AND_HDF5STORAGE = False
16
-
17
-
18
- PathType = Union [str , Path ]
19
- OptionalPathType = Optional [PathType ]
20
-
21
-
22
12
class CellExplorerSortingExtractor (BaseSorting ):
23
13
"""
24
- Extracts spiking information from .mat files stored in the CellExplorer format.
25
- Spike times are stored in units of seconds.
14
+ Extracts spiking information from `.mat` file stored in the CellExplorer format.
15
+ Spike times are stored in units of seconds so we transform them to units of samples.
16
+
17
+ The newer version of the format is described here:
18
+ https://cellexplorer.org/data-structure/
19
+
20
+ Whereas the old format is described here:
21
+ https://github.com/buzsakilab/buzcode/wiki/Data-Formatting-Standards
26
22
27
23
Parameters
28
24
----------
29
- spikes_matfile_path : PathType
30
- Path to the sorting_id.spikes.cellinfo.mat file.
25
+ file_path: str | Path
26
+ Path to `.mat` file containing spikes. Usually named `session_id.spikes.cellinfo.mat`
27
+ sampling_frequency: float | None, optional
28
+ The sampling frequency of the data. If None, it will be extracted from the files.
29
+ session_info_file_path: str | Path | None, optional
30
+ Path to the `sessionInfo.mat` file. If None, it will be inferred from the file_path.
31
31
"""
32
32
33
33
extractor_name = "CellExplorerSortingExtractor"
34
- installed = HAVE_SCIPY_AND_HDF5STORAGE
35
34
is_writable = True
36
35
mode = "file"
37
- installation_mesg = (
38
- "To use the CellExplorerSortingExtractor install scipy and hdf5storage: \n \n pip install scipy hdf5storage"
39
- )
36
+ installation_mesg = "To use the CellExplorerSortingExtractor install scipy and h5py"
40
37
41
38
def __init__ (
42
39
self ,
43
- spikes_matfile_path : PathType ,
44
- session_info_matfile_path : OptionalPathType = None ,
45
- sampling_frequency : Optional [float ] = None ,
40
+ file_path : str | Path | None = None ,
41
+ sampling_frequency : float | None = None ,
42
+ session_info_file_path : str | Path | None = None ,
43
+ spikes_matfile_path : str | Path | None = None ,
44
+ session_info_matfile_path : str | Path | None = None ,
46
45
):
47
- assert self .installed , self .installation_mesg
46
+ try :
47
+ import h5py
48
+ import scipy .io
49
+ except ImportError :
50
+ raise ImportError (self .installation_mesg )
51
+
52
+ assert (
53
+ file_path is not None or spikes_matfile_path is not None
54
+ ), "Either file_path or spikes_matfile_path must be provided!"
55
+
56
+ if spikes_matfile_path is not None :
57
+ # Raise an error if the warning period has expired
58
+ deprecation_issued = datetime .datetime (2023 , 4 , 1 )
59
+ deprecation_deadline = deprecation_issued + datetime .timedelta (days = 180 )
60
+ if datetime .datetime .now () > deprecation_deadline :
61
+ raise ValueError ("The spikes_matfile_path argument is no longer supported in. Use file_path instead." )
62
+
63
+ # Otherwise, issue a DeprecationWarning
64
+ else :
65
+ warnings .warn (
66
+ "The spikes_matfile_path argument is deprecated and will be removed in six months. "
67
+ "Use file_path instead." ,
68
+ DeprecationWarning ,
69
+ )
70
+ file_path = spikes_matfile_path if file_path is None else file_path
71
+
72
+ if session_info_matfile_path is not None :
73
+ # Raise an error if the warning period has expired
74
+ deprecation_issued = datetime .datetime (2023 , 4 , 1 )
75
+ deprecation_deadline = deprecation_issued + datetime .timedelta (days = 180 )
76
+ if datetime .datetime .now () > deprecation_deadline :
77
+ raise ValueError (
78
+ "The session_info_matfile_path argument is no longer supported in. Use session_info_file_path instead."
79
+ )
80
+
81
+ # Otherwise, issue a DeprecationWarning
82
+ else :
83
+ warnings .warn (
84
+ "The session_info_matfile_path argument is deprecated and will be removed in six months. "
85
+ "Use session_info_file_path instead." ,
86
+ DeprecationWarning ,
87
+ )
88
+ session_info_file_path = (
89
+ session_info_matfile_path if session_info_file_path is None else session_info_file_path
90
+ )
48
91
49
- spikes_matfile_path = Path (spikes_matfile_path )
50
- assert spikes_matfile_path . is_file (), f"The spikes_matfile_path ( { spikes_matfile_path } ) must exist!"
92
+ self . spikes_cellinfo_path = Path (file_path ). absolute ( )
93
+ assert self . spikes_cellinfo_path . is_file (), f"The spikes.cellinfo.mat file must exist in { self . folder_path } !"
51
94
52
- if sampling_frequency is None :
53
- folder_path = spikes_matfile_path .parent
54
- sorting_id = spikes_matfile_path .name .split ("." )[0 ]
55
- if session_info_matfile_path is None :
56
- session_info_matfile_path = folder_path / f"{ sorting_id } .sessionInfo.mat"
57
- session_info_matfile_path = Path (session_info_matfile_path )
58
- assert (session_info_matfile_path ).is_file (), f"No { sorting_id } .sessionInfo.mat file found in the folder!"
59
-
60
- try :
61
- session_info_mat = scipy .io .loadmat (file_name = str (session_info_matfile_path ))
62
- self .read_session_info_with_scipy = True
63
- except NotImplementedError :
64
- session_info_mat = hdf5storage .loadmat (file_name = str (session_info_matfile_path ))
65
- self .read_session_info_with_scipy = False
66
-
67
- assert session_info_mat ["sessionInfo" ]["rates" ][0 ][0 ]["wideband" ], (
68
- "The sesssionInfo.mat file must contain "
69
- "a 'sessionInfo' struct with field 'rates' containing field 'wideband' to extract the sampling frequency!"
70
- )
71
- if self .read_session_info_with_scipy :
72
- sampling_frequency = float (
73
- session_info_mat ["sessionInfo" ]["rates" ][0 ][0 ]["wideband" ][0 ][0 ][0 ][0 ]
74
- ) # careful not to confuse it with the lfpsamplingrate; reported in units Hz
75
- else :
76
- sampling_frequency = float (
77
- session_info_mat ["sessionInfo" ]["rates" ][0 ][0 ]["wideband" ][0 ][0 ]
78
- ) # careful not to confuse it with the lfpsamplingrate; reported in units Hz
95
+ self .folder_path = self .spikes_cellinfo_path .parent
96
+ self .session_info_file_path = session_info_file_path
79
97
98
+ self .session_id = self .spikes_cellinfo_path .stem .split ("." )[0 ]
99
+
100
+ read_as_hdf5 = False
80
101
try :
81
- spikes_mat = scipy .io .loadmat (file_name = str (spikes_matfile_path ))
82
- self .read_spikes_info_with_scipy = True
102
+ matlab_file = scipy .io .loadmat (file_name = str (self .spikes_cellinfo_path ), simplify_cells = True )
103
+ spikes_mat = matlab_file ["spikes" ]
104
+ assert isinstance (spikes_mat , dict ), f"field `spikes` must be a dict, not { type (spikes_mat )} !"
105
+
83
106
except NotImplementedError :
84
- spikes_mat = hdf5storage .loadmat (file_name = str (spikes_matfile_path ))
85
- self .read_spikes_info_with_scipy = False
107
+ matlab_file = h5py .File (name = self .spikes_cellinfo_path , mode = "r" )
108
+ spikes_mat = matlab_file ["spikes" ]
109
+ assert isinstance (spikes_mat , h5py .Group ), f"field `spikes` must be a Group, not { type (spikes_mat )} !"
110
+ read_as_hdf5 = True
86
111
87
- assert np .all (
88
- np .isin (["UID" , "times" ], spikes_mat ["spikes" ].dtype .names )
89
- ), "The spikes.cellinfo.mat file must contain a 'spikes' struct with fields 'UID' and 'times'!"
112
+ if sampling_frequency is None :
113
+ # First try the new format of spikes.cellinfo.mat files where sampling rate is included in the file
114
+ sr_data = spikes_mat .get ("sr" , None )
115
+ sampling_frequency = sr_data [()] if isinstance (sr_data , h5py .Dataset ) else None
116
+
117
+ if sampling_frequency is None :
118
+ sampling_frequency = self ._retrieve_sampling_frequency_from_session_info ()
119
+
120
+ sampling_frequency = float (sampling_frequency )
121
+
122
+ unit_ids_available = "UID" in spikes_mat .keys ()
123
+ assert unit_ids_available , f"The `spikes struct` must contain field 'UID'! fields: { spikes_mat .keys ()} "
124
+
125
+ spike_times_available = "times" in spikes_mat .keys ()
126
+ assert spike_times_available , f"The `spike struct` must contain field 'times'! fields: { spikes_mat .keys ()} "
127
+
128
+ unit_ids = spikes_mat ["UID" ]
129
+ spike_times = spikes_mat ["times" ]
130
+
131
+ if read_as_hdf5 :
132
+ assert isinstance (unit_ids , h5py .Dataset ), f"`unit_ids` must be a Dataset, not { type (unit_ids )} !"
133
+ assert isinstance (spike_times , h5py .Dataset ), f"`spike_times` must be a Dataset, not { type (spike_times )} !"
134
+
135
+ unit_ids = unit_ids [:].squeeze ().astype ("int" )
136
+ references = (ref [0 ] for ref in spike_times [:]) # These are HDF5 references
137
+ spike_times_data = (matlab_file [ref ] for ref in references if isinstance (matlab_file [ref ], h5py .Dataset ))
138
+ # Format as a list of numpy arrays
139
+ spike_times = [data [()].squeeze () for data in spike_times_data ]
90
140
91
141
# CellExplorer reports spike times in units seconds; SpikeExtractors uses time units of sampling frames
92
- # Rounding is necessary to prevent data loss from int-casting floating point errors
93
- if self .read_spikes_info_with_scipy :
94
- unit_ids = np .asarray (spikes_mat ["spikes" ]["UID" ][0 ][0 ][0 ])
95
- spiketrains = [
96
- (np .array ([y [0 ] for y in x ]) * sampling_frequency ).round ().astype (np .int64 )
97
- for x in spikes_mat ["spikes" ]["times" ][0 ][0 ][0 ]
98
- ]
99
- else :
100
- unit_ids = np .asarray (spikes_mat ["spikes" ]["UID" ][0 ][0 ])
101
- spiketrains = [
102
- (np .array ([y [0 ] for y in x ]) * sampling_frequency ).round ().astype (np .int64 )
103
- for x in spikes_mat ["spikes" ]["times" ][0 ][0 ]
104
- ]
142
+ unit_ids = unit_ids [:].tolist ()
143
+ spiketrains_dict = {unit_id : spike_times [index ] for index , unit_id in enumerate (unit_ids )}
144
+ for unit_id in unit_ids :
145
+ spiketrains_dict [unit_id ] = (sampling_frequency * spiketrains_dict [unit_id ]).round ().astype (np .int64 )
146
+ # Rounding is necessary to prevent data loss from int-casting floating point errors
105
147
106
148
BaseSorting .__init__ (self , unit_ids = unit_ids , sampling_frequency = sampling_frequency )
107
- sorting_segment = CellExplorerSortingSegment (spiketrains , unit_ids )
149
+ sorting_segment = CellExplorerSortingSegment (spiketrains_dict , unit_ids )
108
150
self .add_sorting_segment (sorting_segment )
109
151
110
152
self .extra_requirements .append ("scipy" )
111
- self .extra_requirements .append ("hdf5storage" )
112
153
113
- self ._kwargs = dict (spikes_matfile_path = str (spikes_matfile_path .absolute ()))
154
+ self ._kwargs = dict (
155
+ file_path = str (self .spikes_cellinfo_path ),
156
+ sampling_frequency = sampling_frequency ,
157
+ session_info_file_path = str (session_info_file_path ),
158
+ )
159
+
160
+ def _retrieve_sampling_frequency_from_session_info (self ) -> float :
161
+ """
162
+ Retrieve the sampling frequency from the `sessionInfo.mat` file when available.
163
+
164
+ This function tries to locate a .sessionInfo.mat file corresponding to the current session. It then loads this
165
+ file (either as a standard .mat file or as an HDF5 file if the former is not possible) and extracts the wideband
166
+ sampling frequency from the 'rates' field of the 'sessionInfo' struct.
167
+
168
+ Returns
169
+ -------
170
+ float
171
+ The wideband sampling frequency for the current session.
172
+ """
173
+ import h5py
174
+ import scipy .io
175
+
176
+ if self .session_info_file_path is None :
177
+ self .session_info_file_path = self .folder_path / f"{ self .session_id } .sessionInfo.mat"
178
+
179
+ self .session_info_file_path = Path (self .session_info_file_path ).absolute ()
180
+ assert (
181
+ self .session_info_file_path .is_file ()
182
+ ), f"No { self .session_id } .sessionInfo.mat file found in the { self .folder_path } !, can't inferr sampling rate"
183
+
184
+ read_as_hdf5 = False
185
+ try :
186
+ session_info_mat = scipy .io .loadmat (file_name = str (self .session_info_file_path ), simplify_cells = True )
187
+ except NotImplementedError :
188
+ session_info_mat = h5py .File (name = str (self .session_info_file_path ), mode = "r" )
189
+ read_as_hdf5 = True
190
+
191
+ rates = session_info_mat ["sessionInfo" ]["rates" ]
192
+ wideband_in_rates = "wideband" in rates .keys ()
193
+ assert wideband_in_rates , "a 'sessionInfo' should contain a 'wideband' to extract the sampling frequency!"
194
+
195
+ # Not to be connfused with the lfpsamplingrate; reported in units Hz also present in rates
196
+ sampling_frequency = rates ["wideband" ]
197
+
198
+ if read_as_hdf5 :
199
+ sampling_frequency = sampling_frequency [()]
200
+
201
+ return sampling_frequency
114
202
115
203
116
204
class CellExplorerSortingSegment (BaseSortingSegment ):
117
- def __init__ (self , spiketrains , unit_ids ):
118
- self ._spiketrains = spiketrains
205
+ def __init__ (self , spiketrains_dict , unit_ids ):
206
+ self .spiketrains_dict = spiketrains_dict
119
207
self ._unit_ids = list (unit_ids )
120
208
BaseSortingSegment .__init__ (self )
121
209
@@ -125,14 +213,15 @@ def get_unit_spike_train(
125
213
start_frame ,
126
214
end_frame ,
127
215
) -> np .ndarray :
128
- # must be implemented in subclass
129
- if start_frame is None :
130
- start_frame = 0
131
- if end_frame is None :
132
- end_frame = np .inf
133
- spike_frames = self ._spiketrains [self ._unit_ids .index (unit_id )]
134
- inds = np .where ((start_frame <= spike_frames ) & (spike_frames < end_frame ))
135
- return spike_frames [inds ]
216
+ spike_frames = self .spiketrains_dict [unit_id ]
217
+ # clip
218
+ if start_frame is not None :
219
+ spike_frames = spike_frames [spike_frames >= start_frame ]
220
+
221
+ if end_frame is not None :
222
+ spike_frames = spike_frames [spike_frames <= end_frame ]
223
+
224
+ return spike_frames
136
225
137
226
138
227
read_cellexplorer = define_function_from_class (source_class = CellExplorerSortingExtractor , name = "read_cellexplorer" )
0 commit comments