1
1
import numpy as np
2
2
3
3
4
- def get_random_data_chunks (recording , return_scaled = False , num_chunks_per_segment = 20 ,
5
- chunk_size = 10000 , concatenated = True , seed = 0 , margin_frames = 0 ):
4
+ def get_random_data_chunks (
5
+ recording ,
6
+ return_scaled = False ,
7
+ num_chunks_per_segment = 20 ,
8
+ chunk_size = 10000 ,
9
+ concatenated = True ,
10
+ seed = 0 ,
11
+ margin_frames = 0 ,
12
+ ):
6
13
"""
7
14
Exctract random chunks across segments
8
15
@@ -31,22 +38,30 @@ def get_random_data_chunks(recording, return_scaled=False, num_chunks_per_segmen
31
38
# Should be done by changing kwargs with total_num_chunks=XXX and total_duration=YYYY
32
39
# And randomize the number of chunk per segment weighted by segment duration
33
40
34
- # check chunk size
41
+ # check chunk size
35
42
for segment_index in range (recording .get_num_segments ()):
36
- assert chunk_size < recording .get_num_samples (segment_index ), (f"chunk_size is greater than the number "
37
- f"of samples for segment index { segment_index } . "
38
- f"Use a smaller chunk_size!" )
43
+ assert chunk_size < recording .get_num_samples (segment_index ), (
44
+ f"chunk_size is greater than the number "
45
+ f"of samples for segment index { segment_index } . "
46
+ f"Use a smaller chunk_size!"
47
+ )
39
48
40
49
chunk_list = []
41
50
for segment_index in range (recording .get_num_segments ()):
42
51
length = recording .get_num_frames (segment_index )
43
-
44
- random_starts = np .random .RandomState (seed = seed ).randint (margin_frames , length - chunk_size - margin_frames , size = num_chunks_per_segment )
52
+
53
+ random_starts = np .random .RandomState (seed = seed ).randint (
54
+ margin_frames ,
55
+ length - chunk_size - margin_frames ,
56
+ size = num_chunks_per_segment ,
57
+ )
45
58
for start_frame in random_starts :
46
- chunk = recording .get_traces (start_frame = start_frame ,
47
- end_frame = start_frame + chunk_size ,
48
- segment_index = segment_index ,
49
- return_scaled = return_scaled )
59
+ chunk = recording .get_traces (
60
+ start_frame = start_frame ,
61
+ end_frame = start_frame + chunk_size ,
62
+ segment_index = segment_index ,
63
+ return_scaled = return_scaled ,
64
+ )
50
65
chunk_list .append (chunk )
51
66
if concatenated :
52
67
return np .concatenate (chunk_list , axis = 0 )
@@ -59,7 +74,9 @@ def get_channel_distances(recording):
59
74
Distance between channel pairs
60
75
"""
61
76
locations = recording .get_channel_locations ()
62
- channel_distances = np .linalg .norm (locations [:, np .newaxis ] - locations [np .newaxis , :], axis = 2 )
77
+ channel_distances = np .linalg .norm (
78
+ locations [:, np .newaxis ] - locations [np .newaxis , :], axis = 2
79
+ )
63
80
64
81
return channel_distances
65
82
@@ -95,40 +112,59 @@ def get_closest_channels(recording, channel_ids=None, num_channels=None):
95
112
for i in range (locations .shape [0 ]):
96
113
distances = np .linalg .norm (locations [i , :] - locations , axis = 1 )
97
114
order = np .argsort (distances )
98
- closest_channels_inds .append (order [1 : num_channels + 1 ])
99
- dists .append (distances [order ][1 : num_channels + 1 ])
115
+ closest_channels_inds .append (order [1 : num_channels + 1 ])
116
+ dists .append (distances [order ][1 : num_channels + 1 ])
100
117
101
118
return np .array (closest_channels_inds ), np .array (dists )
102
119
103
120
104
121
def get_noise_levels (recording , return_scaled = True , ** random_chunk_kwargs ):
105
122
"""
106
123
Estimate noise for each channel using MAD methods.
107
-
124
+
108
125
Internally it sample some chunk across segment.
109
126
And then, it use MAD estimator (more robust than STD)
110
-
127
+
111
128
"""
112
- random_chunks = get_random_data_chunks (recording , return_scaled = return_scaled , ** random_chunk_kwargs )
129
+ random_chunks = get_random_data_chunks (
130
+ recording , return_scaled = return_scaled , ** random_chunk_kwargs
131
+ )
113
132
med = np .median (random_chunks , axis = 0 , keepdims = True )
114
133
# hard-coded so that core doesn't depend on scipy
115
- noise_levels = np .median (np .abs (random_chunks - med ), axis = 0 ) / 0.6744897501960817
134
+ noise_levels = (
135
+ np .median (np .abs (random_chunks - med ), axis = 0 ) / 0.6744897501960817
136
+ )
116
137
return noise_levels
117
138
118
139
119
- def get_chunk_with_margin (rec_segment , start_frame , end_frame ,
120
- channel_indices , margin , add_zeros = False ,
121
- window_on_margin = False , dtype = None ):
140
+ def get_chunk_with_margin (
141
+ rec_segment ,
142
+ start_frame ,
143
+ end_frame ,
144
+ channel_indices ,
145
+ margin ,
146
+ add_zeros = False ,
147
+ add_reflect_padding = False ,
148
+ window_on_margin = False ,
149
+ dtype = None ,
150
+ ):
122
151
"""
123
152
Helper to get chunk with margin
153
+
154
+ The margin is extracted from the recording when possible. If
155
+ at the edge of the recording, no margin is used unless one
156
+ of `add_zeros` or `add_reflect_padding` is True. In the first
157
+ case zero padding is used, in the second case np.pad is called
158
+ with mod="reflect".
124
159
"""
125
160
length = rec_segment .get_num_samples ()
126
161
127
162
if channel_indices is None :
128
163
channel_indices = slice (None )
129
164
130
- if not add_zeros :
131
- assert not window_on_margin , 'window_on_margin can be used only for add_zeros=True'
165
+ if not (add_zeros or add_reflect_padding ):
166
+ if window_on_margin and not add_zeros :
167
+ raise ValueError ("window_on_margin requires add_zeros=True" )
132
168
if start_frame is None :
133
169
left_margin = 0
134
170
start_frame = 0
@@ -144,10 +180,14 @@ def get_chunk_with_margin(rec_segment, start_frame, end_frame,
144
180
right_margin = length - end_frame
145
181
else :
146
182
right_margin = margin
147
- traces_chunk = rec_segment .get_traces (start_frame - left_margin , end_frame + right_margin , channel_indices )
183
+ traces_chunk = rec_segment .get_traces (
184
+ start_frame - left_margin ,
185
+ end_frame + right_margin ,
186
+ channel_indices ,
187
+ )
148
188
149
189
else :
150
- # add_zeros=True
190
+ # either add_zeros or reflect_padding
151
191
assert start_frame is not None
152
192
assert end_frame is not None
153
193
chunk_size = end_frame - start_frame
@@ -167,41 +207,66 @@ def get_chunk_with_margin(rec_segment, start_frame, end_frame,
167
207
end_frame2 = end_frame + margin
168
208
right_pad = 0
169
209
170
- traces_chunk = rec_segment .get_traces (start_frame2 , end_frame2 , channel_indices )
171
-
172
-
173
- if dtype is not None or window_on_margin or left_pad > 0 or right_pad > 0 :
210
+ traces_chunk = rec_segment .get_traces (
211
+ start_frame2 , end_frame2 , channel_indices
212
+ )
213
+
214
+ if (
215
+ dtype is not None
216
+ or window_on_margin
217
+ or left_pad > 0
218
+ or right_pad > 0
219
+ ):
174
220
need_copy = True
175
221
else :
176
222
need_copy = False
177
223
224
+ left_margin = margin
225
+ right_margin = margin
226
+
178
227
if need_copy :
179
228
if dtype is None :
180
229
dtype = traces_chunk .dtype
181
- traces_chunk2 = np .zeros ((full_size , traces_chunk .shape [1 ]), dtype = dtype )
182
- i0 = left_pad
183
- i1 = left_pad + traces_chunk .shape [0 ]
184
- traces_chunk2 [i0 : i1 , :] = traces_chunk
230
+
185
231
left_margin = margin
186
232
if end_frame < (length + margin ):
187
233
right_margin = margin
188
234
else :
189
235
right_margin = end_frame + margin - length
190
- if window_on_margin :
191
- # apply inplace taper on border
192
- taper = (1 - np .cos (np .arange (margin ) / margin * np .pi )) / 2
193
- taper = taper [:, np .newaxis ]
194
- traces_chunk2 [:margin ] *= taper
195
- traces_chunk2 [- margin :] *= taper [::- 1 ]
196
- traces_chunk = traces_chunk2
197
- else :
198
- left_margin = margin
199
- right_margin = margin
236
+
237
+ if add_zeros :
238
+ traces_chunk2 = np .zeros (
239
+ (full_size , traces_chunk .shape [1 ]), dtype = dtype
240
+ )
241
+ i0 = left_pad
242
+ i1 = left_pad + traces_chunk .shape [0 ]
243
+ traces_chunk2 [i0 :i1 , :] = traces_chunk
244
+ if window_on_margin :
245
+ # apply inplace taper on border
246
+ taper = (
247
+ 1 - np .cos (np .arange (margin ) / margin * np .pi )
248
+ ) / 2
249
+ taper = taper [:, np .newaxis ]
250
+ traces_chunk2 [:margin ] *= taper
251
+ traces_chunk2 [- margin :] *= taper [::- 1 ]
252
+ traces_chunk = traces_chunk2
253
+ elif add_reflect_padding :
254
+ # in this case, we don't want to taper
255
+ traces_chunk = np .pad (
256
+ traces_chunk .astype (dtype ),
257
+ [(left_pad , right_pad ), (0 , 0 )],
258
+ mode = "reflect" ,
259
+ )
260
+ else :
261
+ # we need a copy to change the dtype
262
+ traces_chunk = np .asarray (traces_chunk , dtype = dtype )
200
263
201
264
return traces_chunk , left_margin , right_margin
202
265
203
266
204
- def order_channels_by_depth (recording , channel_ids = None , dimensions = ('x' , 'y' )):
267
+ def order_channels_by_depth (
268
+ recording , channel_ids = None , dimensions = ("x" , "y" )
269
+ ):
205
270
"""
206
271
Order channels by depth, by first ordering the x-axis, and then the y-axis.
207
272
@@ -213,7 +278,7 @@ def order_channels_by_depth(recording, channel_ids=None, dimensions=('x', 'y')):
213
278
If given, a subset of channels to order locations for
214
279
dimensions : str or tuple
215
280
If str, it needs to be 'x', 'y', 'z'.
216
- If tuple, it sorts the locations in two dimensions using lexsort.
281
+ If tuple, it sorts the locations in two dimensions using lexsort.
217
282
This approach is recommended since there is less ambiguity, by default ('x', 'y')
218
283
219
284
Returns
@@ -229,18 +294,20 @@ def order_channels_by_depth(recording, channel_ids=None, dimensions=('x', 'y')):
229
294
locations = locations [channel_inds , :]
230
295
231
296
if isinstance (dimensions , str ):
232
- dim = ['x' , 'y' , 'z' ].index (dimensions )
297
+ dim = ["x" , "y" , "z" ].index (dimensions )
233
298
assert dim < ndim , "Invalid dimensions!"
234
- order_f = np .argsort (locations [:, dim ])
299
+ order_f = np .argsort (locations [:, dim ], kind = "stable" )
235
300
else :
236
- assert isinstance (dimensions , tuple ), "dimensions can be a str or a tuple"
301
+ assert isinstance (
302
+ dimensions , tuple
303
+ ), "dimensions can be a str or a tuple"
237
304
locations_to_sort = ()
238
305
for dim in dimensions :
239
- dim = ['x' , 'y' , 'z' ].index (dim )
306
+ dim = ["x" , "y" , "z" ].index (dim )
240
307
assert dim < ndim , "Invalid dimensions!"
241
- locations_to_sort += (locations [:, dim ], )
308
+ locations_to_sort += (locations [:, dim ],)
242
309
order_f = np .lexsort (locations_to_sort )
243
- order_r = np .argsort (order_f )
310
+ order_r = np .argsort (order_f , kind = "stable" )
244
311
245
312
return order_f , order_r
246
313
@@ -253,21 +320,27 @@ def check_probe_do_not_overlap(probes):
253
320
for i in range (len (probes )):
254
321
probe_i = probes [i ]
255
322
# check that all positions in probe_j are outside probe_i boundaries
256
- x_bounds_i = [np .min (probe_i .contact_positions [:, 0 ]),
257
- np .max (probe_i .contact_positions [:, 0 ])]
258
- y_bounds_i = [np .min (probe_i .contact_positions [:, 1 ]),
259
- np .max (probe_i .contact_positions [:, 1 ])]
323
+ x_bounds_i = [
324
+ np .min (probe_i .contact_positions [:, 0 ]),
325
+ np .max (probe_i .contact_positions [:, 0 ]),
326
+ ]
327
+ y_bounds_i = [
328
+ np .min (probe_i .contact_positions [:, 1 ]),
329
+ np .max (probe_i .contact_positions [:, 1 ]),
330
+ ]
260
331
261
332
for j in range (i + 1 , len (probes )):
262
333
probe_j = probes [j ]
263
334
264
- if np .any (np .array ([x_bounds_i [0 ] < cp [0 ] < x_bounds_i [1 ] and
265
- y_bounds_i [0 ] < cp [1 ] < y_bounds_i [1 ]
266
- for cp in probe_j .contact_positions ])):
335
+ if np .any (
336
+ np .array (
337
+ [
338
+ x_bounds_i [0 ] < cp [0 ] < x_bounds_i [1 ]
339
+ and y_bounds_i [0 ] < cp [1 ] < y_bounds_i [1 ]
340
+ for cp in probe_j .contact_positions
341
+ ]
342
+ )
343
+ ):
267
344
raise Exception (
268
- "Probes are overlapping! Retrieve locations of single probes separately" )
269
-
270
-
271
-
272
-
273
-
345
+ "Probes are overlapping! Retrieve locations of single probes separately"
346
+ )
0 commit comments