Skip to content

Commit 6a0c010

Browse files
authored
Merge pull request #1714 from NeuralEnsemble/black-formatting
Black formatting
2 parents 1e09dd7 + 5e29166 commit 6a0c010

File tree

2 files changed

+31
-21
lines changed

2 files changed

+31
-21
lines changed

neo/rawio/spikeglxrawio.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ class SpikeGLXRawIO(BaseRawWithBufferApiIO):
7878
The spikeglx folder containing meta/bin files
7979
load_sync_channel: bool, default: False
8080
Can be used to load the synch stream as the last channel of the neural data.
81-
This option is deprecated and will be removed in version 0.15.
81+
This option is deprecated and will be removed in version 0.15.
8282
From versions higher than 0.14.1 the sync channel is always loaded as a separate stream.
8383
load_channel_location: bool, default: False
8484
If True probeinterface is used to load the channel locations from the directory
@@ -116,7 +116,8 @@ def __init__(self, dirname="", load_sync_channel=False, load_channel_location=Fa
116116
warn(
117117
"The load_sync_channel=True option is deprecated and will be removed in version 0.15 \n"
118118
"The sync channel is now loaded as a separate stream by default and should be accessed as such. ",
119-
DeprecationWarning, stacklevel=2
119+
DeprecationWarning,
120+
stacklevel=2,
120121
)
121122
self.load_channel_location = load_channel_location
122123

@@ -162,7 +163,7 @@ def _parse_header(self):
162163
signal_streams = []
163164
signal_channels = []
164165
sync_stream_id_to_buffer_id = {}
165-
166+
166167
for stream_name in stream_names:
167168
# take first segment
168169
info = self.signals_info_dict[0, stream_name]
@@ -179,16 +180,21 @@ def _parse_header(self):
179180
for local_chan in range(info["num_chan"]):
180181
chan_name = info["channel_names"][local_chan]
181182
chan_id = f"{stream_name}#{chan_name}"
182-
183+
183184
# Sync channel
184-
if "nidq" not in stream_name and "SY0" in chan_name and not self.load_sync_channel and local_chan == info["num_chan"] - 1:
185+
if (
186+
"nidq" not in stream_name
187+
and "SY0" in chan_name
188+
and not self.load_sync_channel
189+
and local_chan == info["num_chan"] - 1
190+
):
185191
# This is a sync channel and should be added as its own stream
186192
sync_stream_id = f"{stream_name}-SYNC"
187193
sync_stream_id_to_buffer_id[sync_stream_id] = buffer_id
188194
stream_id_for_chan = sync_stream_id
189195
else:
190196
stream_id_for_chan = stream_id
191-
197+
192198
signal_channels.append(
193199
(
194200
chan_name,
@@ -205,26 +211,26 @@ def _parse_header(self):
205211

206212
# all channel by default unless load_sync_channel=False
207213
self._stream_buffer_slice[stream_id] = None
208-
214+
209215
# check sync channel validity
210216
if "nidq" not in stream_name:
211217
if not self.load_sync_channel and info["has_sync_trace"]:
212218
# the last channel is removed from the stream but not from the buffer
213219
self._stream_buffer_slice[stream_id] = slice(0, -1)
214-
220+
215221
# Add a buffer slice for the sync channel
216222
sync_stream_id = f"{stream_name}-SYNC"
217223
self._stream_buffer_slice[sync_stream_id] = slice(-1, None)
218-
224+
219225
if self.load_sync_channel and not info["has_sync_trace"]:
220226
raise ValueError("SYNC channel is not present in the recording. " "Set load_sync_channel to False")
221227

222228
signal_buffers = np.array(signal_buffers, dtype=_signal_buffer_dtype)
223-
229+
224230
# Add sync channels as their own streams
225231
for sync_stream_id, buffer_id in sync_stream_id_to_buffer_id.items():
226232
signal_streams.append((sync_stream_id, sync_stream_id, buffer_id))
227-
233+
228234
signal_streams = np.array(signal_streams, dtype=_signal_stream_dtype)
229235
signal_channels = np.array(signal_channels, dtype=_signal_channel_dtype)
230236

@@ -266,14 +272,14 @@ def _parse_header(self):
266272
t_start = frame_start / sampling_frequency
267273

268274
self._t_starts[stream_name][seg_index] = t_start
269-
275+
270276
# This need special logic because sync not present in stream_names
271277
if f"{stream_name}-SYNC" in signal_streams["name"]:
272278
sync_stream_name = f"{stream_name}-SYNC"
273279
if sync_stream_name not in self._t_starts:
274280
self._t_starts[sync_stream_name] = {}
275281
self._t_starts[sync_stream_name][seg_index] = t_start
276-
282+
277283
t_stop = info["sample_length"] / info["sampling_rate"]
278284
self._t_stops[seg_index] = max(self._t_stops[seg_index], t_stop)
279285

@@ -302,11 +308,11 @@ def _parse_header(self):
302308
if self.load_channel_location:
303309
# need probeinterface to be installed
304310
import probeinterface
305-
311+
306312
# Skip for sync streams
307313
if "SYNC" in stream_name:
308314
continue
309-
315+
310316
info = self.signals_info_dict[seg_index, stream_name]
311317
if "imroTbl" in info["meta"] and info["stream_kind"] == "ap":
312318
# only for ap channel

neo/test/rawiotest/test_spikeglxrawio.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -137,22 +137,26 @@ def test_sync_channel_as_separate_stream(self):
137137
# Test with load_sync_channel=False (default)
138138
rawio_no_sync = SpikeGLXRawIO(self.get_local_path("spikeglx/NP2_with_sync"), load_sync_channel=False)
139139
rawio_no_sync.parse_header()
140-
140+
141141
# Get stream names
142142
stream_names = rawio_no_sync.header["signal_streams"]["name"].tolist()
143-
143+
144144
# Check if there's a sync channel stream (should contain "SY0" or "SYNC" in the name)
145145
sync_streams = [name for name in stream_names if "SY0" in name or "SYNC" in name]
146146
assert len(sync_streams) > 0, "No sync channel stream found when load_sync_channel=False"
147-
147+
148148
# Test deprecation warning when load_sync_channel=True
149149
with warnings.catch_warnings(record=True) as w:
150150
warnings.simplefilter("always")
151151
rawio_with_sync = SpikeGLXRawIO(self.get_local_path("spikeglx/NP2_with_sync"), load_sync_channel=True)
152-
152+
153153
# Check if deprecation warning was raised
154-
assert any(issubclass(warning.category, DeprecationWarning) for warning in w), "No deprecation warning raised"
155-
assert any("will be removed in version 0.15" in str(warning.message) for warning in w), "Deprecation warning message is incorrect"
154+
assert any(
155+
issubclass(warning.category, DeprecationWarning) for warning in w
156+
), "No deprecation warning raised"
157+
assert any(
158+
"will be removed in version 0.15" in str(warning.message) for warning in w
159+
), "Deprecation warning message is incorrect"
156160

157161
def test_t_start_reading(self):
158162
"""Test that t_start values are correctly read for all streams and segments."""

0 commit comments

Comments
 (0)