Skip to content

Commit b31435f

Browse files
committed
fix info in annotation
1 parent a3334e9 commit b31435f

File tree

1 file changed

+11
-28
lines changed

1 file changed

+11
-28
lines changed

neo/rawio/blackrockrawio.py

Lines changed: 11 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@
7373
_signal_channel_dtype,
7474
_signal_stream_dtype,
7575
_signal_buffer_dtype,
76+
_spike_channel_dtype,
7677
_event_channel_dtype,
7778
)
7879

@@ -132,19 +133,6 @@ class BlackrockRawIO(BaseRawIO):
132133
# We need to document the origin of this value
133134
main_sampling_rate = 30000.0
134135

135-
# Override spike channel dtype to include unit_class field specific to Blackrock
136-
_spike_channel_dtype = [
137-
("name", "U64"),
138-
("id", "U64"),
139-
# for waveform
140-
("wf_units", "U64"),
141-
("wf_gain", "float64"),
142-
("wf_offset", "float64"),
143-
("wf_left_sweep", "int64"),
144-
("wf_sampling_rate", "float64"),
145-
("unit_class", "U64"),
146-
]
147-
148136
def __init__(
149137
self, filename=None, nsx_override=None, nev_override=None, nsx_to_load=None, load_nev=True, verbose=False
150138
):
@@ -312,18 +300,7 @@ def _parse_header(self):
312300
# default value: threshold crossing after 10 samples of waveform
313301
wf_left_sweep = 10
314302
wf_sampling_rate = self.main_sampling_rate
315-
316-
# Map unit_class_nb to unit classification string
317-
if unit_id == 0:
318-
unit_class = "unclassified"
319-
elif 1 <= unit_id <= 16:
320-
unit_class = "sorted"
321-
elif unit_id == 255:
322-
unit_class = "noise"
323-
else: # 17-254 are reserved but treated as "non-spike-events"
324-
unit_class = "non-spike-events"
325-
326-
spike_channels.append((name, _id, wf_units, wf_gain, wf_offset, wf_left_sweep, wf_sampling_rate, unit_class))
303+
spike_channels.append((name, _id, wf_units, wf_gain, wf_offset, wf_left_sweep, wf_sampling_rate))
327304

328305
# scan events
329306
# NonNeural: serial and digital input
@@ -543,7 +520,7 @@ def _parse_header(self):
543520
self._sigs_t_starts = [None] * self._nb_segment
544521

545522
# finalize header
546-
spike_channels = np.array(spike_channels, dtype=self._spike_channel_dtype)
523+
spike_channels = np.array(spike_channels, dtype=_spike_channel_dtype)
547524
event_channels = np.array(event_channels, dtype=_event_channel_dtype)
548525
signal_channels = np.array(signal_channels, dtype=_signal_channel_dtype)
549526
signal_streams = np.array(signal_streams, dtype=_signal_stream_dtype)
@@ -620,10 +597,16 @@ def _parse_header(self):
620597
for c in range(spike_channels.size):
621598
st_ann = seg_ann["spikes"][c]
622599
channel_id, unit_id = self.internal_unit_ids[c]
623-
unit_tag = {0: "unclassified", 255: "noise"}.get(unit_id, str(unit_id))
624600
st_ann["channel_id"] = channel_id
625601
st_ann["unit_id"] = unit_id
626-
st_ann["unit_tag"] = unit_tag
602+
if unit_id == 0:
603+
st_ann["unit_classification"] = "unclassified"
604+
elif 1 <= unit_id <= 16:
605+
st_ann["unit_classification"] = "sorted"
606+
elif unit_id == 255:
607+
st_ann["unit_classification"] = "noise"
608+
else: # 17-254 are reserved
609+
st_ann["unit_classification"] = "reserved"
627610
st_ann["description"] = f"SpikeTrain channel_id: {channel_id}, unit_id: {unit_id}"
628611
st_ann["file_origin"] = self._filenames["nev"] + ".nev"
629612

0 commit comments

Comments
 (0)