diff --git a/src/spikeinterface/extractors/neoextractors/neobaseextractor.py b/src/spikeinterface/extractors/neoextractors/neobaseextractor.py index d66ce79aa3..71b7b736e3 100644 --- a/src/spikeinterface/extractors/neoextractors/neobaseextractor.py +++ b/src/spikeinterface/extractors/neoextractors/neobaseextractor.py @@ -665,14 +665,20 @@ def get_unit_spike_train(self, unit_id, start_frame, end_frame): class NeoBaseEventExtractor(_NeoBaseExtractor, BaseEvent): handle_event_frame_directly = False - def __init__(self, block_index=None, **neo_kwargs): + def __init__(self, block_index=None, use_names_as_ids: bool = False, **neo_kwargs): _NeoBaseExtractor.__init__(self, block_index, **neo_kwargs) # TODO load feature from neo array_annotations event_channels = self.neo_reader.header["event_channels"] - channel_ids = event_channels["id"] + if use_names_as_ids: + channel_ids = event_channels["name"] + assert ( + event_channels.size == np.unique(channel_ids).size + ), "use_name_as_ids=True is not possible, channel names are not unique" + else: + channel_ids = event_channels["id"] BaseEvent.__init__(self, channel_ids, structured_dtype=_neo_event_dtype) @@ -684,21 +690,23 @@ def __init__(self, block_index=None, **neo_kwargs): else: t_start = self.neo_reader.get_signal_t_start(self.block_index, segment_index, stream_index=0) - event_segment = NeoEventSegment(self.neo_reader, self.block_index, segment_index, t_start) + event_segment = NeoEventSegment(self.neo_reader, self.block_index, segment_index, t_start, use_names_as_ids) self.add_event_segment(event_segment) class NeoEventSegment(BaseEventSegment): - def __init__(self, neo_reader, block_index, segment_index, t_start): + def __init__(self, neo_reader, block_index, segment_index, t_start, use_names_as_ids): BaseEventSegment.__init__(self) self.neo_reader = neo_reader self.segment_index = segment_index self.block_index = block_index self._t_start = t_start self._natural_ids = None + self.use_names_as_ids = use_names_as_ids def get_events(self, channel_id, start_time, end_time): - channel_index = list(self.neo_reader.header["event_channels"]["id"]).index(channel_id) + id_or_name = "name" if self.use_names_as_ids else "id" + channel_index = list(self.neo_reader.header["event_channels"][id_or_name]).index(channel_id) event_timestamps, event_duration, event_labels = self.neo_reader.get_event_timestamps( block_index=self.block_index, seg_index=self.segment_index, event_channel_index=channel_index diff --git a/src/spikeinterface/extractors/neoextractors/plexon2.py b/src/spikeinterface/extractors/neoextractors/plexon2.py index e0604f7496..ad0afb7890 100644 --- a/src/spikeinterface/extractors/neoextractors/plexon2.py +++ b/src/spikeinterface/extractors/neoextractors/plexon2.py @@ -118,14 +118,20 @@ class Plexon2EventExtractor(NeoBaseEventExtractor): Parameters ---------- folder_path : str + Path to the .pl2 file. + block_index : int, default: None + Block index to read from, by default None. + use_names_as_ids : bool, default: False + If True, use channel names as identifiers instead of channel IDs. + Channel names must be unique when this option is enabled. """ NeoRawIOClass = "Plexon2RawIO" - def __init__(self, folder_path, block_index=None): + def __init__(self, folder_path, block_index=None, use_names_as_ids=False): neo_kwargs = self.map_to_neo_kwargs(folder_path) - NeoBaseEventExtractor.__init__(self, block_index=block_index, **neo_kwargs) + NeoBaseEventExtractor.__init__(self, block_index=block_index, use_names_as_ids=use_names_as_ids, **neo_kwargs) @classmethod def map_to_neo_kwargs(cls, folder_path):