Skip to content

Commit 604c049

Browse files
authored
Remove classes from extractor and preprocessing __init__ (#3898)
* wip--remove class from init * fix __all__ * add __all__ to extractors * different strategy for displaying dicts * different strategy again for __all__ * fix sorters * fix ironclust * WIP --fix testing * finish fixing testing * additional fixes * Heberto discussion for nwb_timeseries * Alessio for how the full_dict should work * preprocessing return function too * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * list -> classes hopefully I fixed all the files, but the CI will tell me. * more list-> classes * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * another classes * fix two docs * add dep warning to classes * fix * import for si.full * heberto suggestion * Heberto suggestion again * more doc fixes * add dep warning to neo classes thank you Chris * oops * heberto feedback * better comment + typo fix * typo from fixing merge conflict * Update src/spikeinterface/extractors/extractor_classes.py * final heberto comments ---------
1 parent 92a3d45 commit 604c049

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+463
-286
lines changed

examples/tutorials/core/plot_4_sorting_analyzer.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,8 @@
4343
##############################################################################
4444
# Let's now instantiate the recording and sorting objects:
4545

46-
recording = se.MEArecRecordingExtractor(local_path)
46+
recording, sorting = se.read_mearec(local_path)
4747
print(recording)
48-
sorting = se.MEArecSortingExtractor(local_path)
4948
print(sorting)
5049

5150
###############################################################################

examples/tutorials/extractors/plot_1_read_various_formats.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@
6161
# :py:class:`~spikeinterface.extractors.Spike2RecordingExtractor` object:
6262
#
6363

64-
recording = se.Spike2RecordingExtractor(spike2_file_path, stream_id="0")
64+
recording = se.read_spike2(spike2_file_path, stream_id="0")
6565
print(recording)
6666

6767
##############################################################################
@@ -75,11 +75,6 @@
7575
print(sorting)
7676
print(type(sorting))
7777

78-
##############################################################################
79-
# The :py:func:`~spikeinterface.extractors.read_mearec` function is equivalent to:
80-
81-
recording = se.MEArecRecordingExtractor(mearec_folder_path)
82-
sorting = se.MEArecSortingExtractor(mearec_folder_path)
8378

8479
##############################################################################
8580
# SI objects (:py:class:`~spikeinterface.core.BaseRecording` and :py:class:`~spikeinterface.core.BaseSorting`)
Lines changed: 58 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,64 @@
1-
from .extractorlist import *
1+
from .extractor_classes import *
22

3-
from .toy_example import toy_example
4-
from .bids import read_bids
3+
from .toy_example import toy_example as toy_example
4+
from .bids import read_bids as read_bids
55

66

77
from .neuropixels_utils import get_neuropixels_channel_groups, get_neuropixels_sample_shifts
88

99
from .neoextractors import get_neo_num_blocks, get_neo_streams
10+
11+
from warnings import warn
12+
13+
14+
# deprecation of class import idea from neuroconv
15+
# this __getattr__ is only triggered if the normal lookup fails so import
16+
# any of our functions is fine but if someone tries to import a class this raises
17+
# the warning and then returns the "function" version which will look the same
18+
# to the end-user
19+
# to be removed after version 0.105.0
20+
def __getattr__(extractor_name):
21+
# we need this trick to allow us to use import * for spikeinterface.full
22+
if extractor_name == "__all__":
23+
__all__ = []
24+
for imp in globals():
25+
# need to remove a bunch of builtins etc that shouldn't be part of all
26+
if imp[0] != "_" and imp != "warn" and imp != "extractor_name":
27+
__all__.append(imp)
28+
return __all__
29+
all_extractors = list(recording_extractor_full_dict.values())
30+
all_extractors += list(sorting_extractor_full_dict.values())
31+
all_extractors += list(event_extractor_full_dict.values())
32+
all_extractors += list(snippets_extractor_full_dict.values())
33+
# special cases because they don't have simple wrappers
34+
# instead a single wrapper maps to multiple classes so we return
35+
# each class to check it
36+
from .neoextractors import (
37+
MEArecRecordingExtractor,
38+
MEArecSortingExtractor,
39+
OpenEphysBinaryEventExtractor,
40+
OpenEphysBinaryRecordingExtractor,
41+
OpenEphysLegacyRecordingExtractor,
42+
SpikeGLXEventExtractor,
43+
)
44+
45+
all_extractors += [
46+
MEArecRecordingExtractor,
47+
MEArecSortingExtractor,
48+
OpenEphysBinaryEventExtractor,
49+
OpenEphysBinaryRecordingExtractor,
50+
OpenEphysLegacyRecordingExtractor,
51+
SpikeGLXEventExtractor,
52+
]
53+
for reading_function in all_extractors:
54+
if extractor_name == reading_function.__name__:
55+
dep_msg = (
56+
"Importing classes at __init__ has been deprecated in favor of only importing function-size wrappers "
57+
"and will be removed in 0.105.0. For developers that prefer working with the class versions of extractors "
58+
"they can be imported from spikeinterface.extractors.extractor_classes"
59+
)
60+
warn(dep_msg)
61+
return reading_function
62+
# this is necessary for objects that we don't support
63+
# normally this is an ImportError but since this is in the __getattr__ pytest needs an AttributeError
64+
raise AttributeError(f"cannot import name '{extractor_name}' from '{__name__}'")
Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
from __future__ import annotations
2+
3+
4+
# most important extractor are in spikeinterface.core
5+
from spikeinterface.core import (
6+
BinaryFolderRecording,
7+
BinaryRecordingExtractor,
8+
NumpyRecording,
9+
NpzSortingExtractor,
10+
NumpySorting,
11+
NpySnippetsExtractor,
12+
ZarrRecordingExtractor,
13+
ZarrSortingExtractor,
14+
read_binary,
15+
read_zarr,
16+
read_npz_sorting,
17+
read_npy_snippets,
18+
)
19+
20+
# sorting/recording/event from neo
21+
from .neoextractors import *
22+
23+
# non-NEO objects implemented in neo folder
24+
# keep for reference Currently pulling from neoextractor __init__
25+
# from .neoextractors import NeuroScopeSortingExtractor, MaxwellEventExtractor
26+
27+
# NWB sorting/recording/event
28+
from .nwbextractors import (
29+
NwbRecordingExtractor,
30+
NwbSortingExtractor,
31+
NwbTimeSeriesExtractor,
32+
read_nwb,
33+
read_nwb_recording,
34+
read_nwb_sorting,
35+
read_nwb_timeseries,
36+
)
37+
38+
from .cbin_ibl import CompressedBinaryIblExtractor, read_cbin_ibl
39+
from .iblextractors import IblRecordingExtractor, IblSortingExtractor, read_ibl_recording, read_ibl_sorting
40+
from .mcsh5extractors import MCSH5RecordingExtractor, read_mcsh5
41+
from .whitematterrecordingextractor import WhiteMatterRecordingExtractor, read_whitematter
42+
43+
# sorting extractors in relation with a sorter
44+
from .cellexplorersortingextractor import CellExplorerSortingExtractor, read_cellexplorer
45+
from .klustaextractors import KlustaSortingExtractor, read_klusta
46+
from .hdsortextractors import HDSortSortingExtractor, read_hdsort
47+
from .mclustextractors import MClustSortingExtractor, read_mclust
48+
from .waveclustextractors import WaveClusSortingExtractor, read_waveclus
49+
from .yassextractors import YassSortingExtractor, read_yass
50+
from .combinatoextractors import CombinatoSortingExtractor, read_combinato
51+
from .tridesclousextractors import TridesclousSortingExtractor, read_tridesclous
52+
from .spykingcircusextractors import SpykingCircusSortingExtractor, read_spykingcircus
53+
from .herdingspikesextractors import HerdingspikesSortingExtractor, read_herdingspikes
54+
from .mdaextractors import MdaRecordingExtractor, MdaSortingExtractor, read_mda_recording, read_mda_sorting
55+
from .phykilosortextractors import PhySortingExtractor, KiloSortSortingExtractor, read_phy, read_kilosort
56+
from .sinapsrecordingextractors import (
57+
SinapsResearchPlatformRecordingExtractor,
58+
SinapsResearchPlatformH5RecordingExtractor,
59+
read_sinaps_research_platform,
60+
read_sinaps_research_platform_h5,
61+
)
62+
63+
# sorting in relation with simulator
64+
from .shybridextractors import (
65+
SHYBRIDRecordingExtractor,
66+
SHYBRIDSortingExtractor,
67+
read_shybrid_recording,
68+
read_shybrid_sorting,
69+
)
70+
71+
# snippers
72+
from .waveclussnippetstextractors import WaveClusSnippetsExtractor, read_waveclus_snippets
73+
74+
75+
# misc
76+
from .alfsortingextractor import ALFSortingExtractor, read_alf_sorting
77+
78+
79+
###############################################################################################
80+
# the following code is necessary for controlling what the end user imports from spikeinterface.
81+
# The strategy has three goals:
82+
#
83+
# * A mapping from the original class to its wrapper (because that's what we want to expose)
84+
# * A mapping from the original class to its wrapper string (because of __all__)
85+
# * A mapping from format to the class wrapper for convenience (exposed to users for ease of use)
86+
#
87+
# To achieve these there goals we do the following:
88+
#
89+
# 1) we line up each class with its wrapper that returns a snakecase version of the class (in some docs called
90+
# the "function" version, although this is just a wrapper of the underlying class)
91+
# 2) we do (1) by creating nested dicts where the key is the original class and the values are a nested dict with
92+
# 3) a "wrapper_class" key which returns the wrapper to be exposed to the end user and
93+
# 4) a "wrapper_string" which is added to the __all__ attribute of the __init__. This is necessary because __all__
94+
# can only accept a list of strings
95+
# 5) Finally we create dictionaries exposed to the user where we return a formatted file format as a key along
96+
# with the value being the wrapper (see the comment below for examples for this dict)
97+
#
98+
# Note that some formats (e.g. binary and numpy) still use the class format as they aren't read-only (i.e. they
99+
# have no wrapper)
100+
101+
_recording_extractor_full_dict = {
102+
# core extractors that are returned as classes
103+
BinaryFolderRecording: dict(wrapper_string="BinaryFolderRecording", wrapper_class=BinaryFolderRecording),
104+
BinaryRecordingExtractor: dict(wrapper_string="BinaryRecordingExtractor", wrapper_class=BinaryRecordingExtractor),
105+
ZarrRecordingExtractor: dict(wrapper_string="ZarrRecordingExtractor", wrapper_class=ZarrRecordingExtractor),
106+
# natively implemented in spikeinterface.extractors
107+
NumpyRecording: dict(wrapper_string="NumpyRecording", wrapper_class=NumpyRecording),
108+
SHYBRIDRecordingExtractor: dict(wrapper_string="read_shybrid_recording", wrapper_class=read_shybrid_recording),
109+
MdaRecordingExtractor: dict(wrapper_string="read_mda_recording", wrapper_class=read_mda_recording),
110+
NwbRecordingExtractor: dict(wrapper_string="read_nwb_recording", wrapper_class=read_nwb_recording),
111+
NwbTimeSeriesExtractor: dict(wrapper_string="read_nwb_timeseries", wrapper_class=read_nwb_timeseries),
112+
# others
113+
CompressedBinaryIblExtractor: dict(wrapper_string="read_cbin_ibl", wrapper_class=read_cbin_ibl),
114+
IblRecordingExtractor: dict(wrapper_string="read_ibl_recording", wrapper_class=read_ibl_recording),
115+
MCSH5RecordingExtractor: dict(wrapper_string="read_mcsh5", wrapper_class=read_mcsh5),
116+
SinapsResearchPlatformRecordingExtractor: dict(
117+
wrapper_string="read_sinaps_research_platform", wrapper_class=read_sinaps_research_platform
118+
),
119+
SinapsResearchPlatformH5RecordingExtractor: dict(
120+
wrapper_string="read_sinaps_research_platform_h5", wrapper_class=read_sinaps_research_platform_h5
121+
),
122+
WhiteMatterRecordingExtractor: dict(wrapper_string="read_whitematter", wrapper_class=read_whitematter),
123+
}
124+
_recording_extractor_full_dict.update(neo_recording_extractors_dict)
125+
126+
_sorting_extractor_full_dict = {
127+
NpzSortingExtractor: dict(wrapper_string="read_npz_sorting", wrapper_class=read_npz_sorting),
128+
ZarrSortingExtractor: dict(wrapper_string="ZarrSortingExtractor", wrapper_class=ZarrSortingExtractor),
129+
NumpySorting: dict(wrapper_string="NumpySorting", wrapper_class=NumpySorting),
130+
# natively implemented in spikeinterface.extractors
131+
MdaSortingExtractor: dict(wrapper_string="read_mda_sorting", wrapper_class=read_mda_sorting),
132+
SHYBRIDSortingExtractor: dict(wrapper_string="read_shybrid_sorting", wrapper_class=read_shybrid_sorting),
133+
ALFSortingExtractor: dict(wrapper_string="read_alf_sorting", wrapper_class=read_alf_sorting),
134+
KlustaSortingExtractor: dict(wrapper_string="read_klusta", wrapper_class=read_klusta),
135+
HDSortSortingExtractor: dict(wrapper_string="read_hdsort", wrapper_class=read_hdsort),
136+
MClustSortingExtractor: dict(wrapper_string="read_mclust", wrapper_class=read_mclust),
137+
WaveClusSortingExtractor: dict(wrapper_string="read_waveclus", wrapper_class=read_waveclus),
138+
YassSortingExtractor: dict(wrapper_string="read_yass", wrapper_class=read_yass),
139+
CombinatoSortingExtractor: dict(wrapper_string="read_combinato", wrapper_class=read_combinato),
140+
TridesclousSortingExtractor: dict(wrapper_string="read_tridesclous", wrapper_class=read_tridesclous),
141+
SpykingCircusSortingExtractor: dict(wrapper_string="read_spykingcircus", wrapper_class=read_spykingcircus),
142+
HerdingspikesSortingExtractor: dict(wrapper_string="read_herdingspikes", wrapper_class=read_herdingspikes),
143+
KiloSortSortingExtractor: dict(wrapper_string="read_kilosort", wrapper_class=read_kilosort),
144+
PhySortingExtractor: dict(wrapper_string="read_phy", wrapper_class=read_phy),
145+
NwbSortingExtractor: dict(wrapper_string="read_nwb_sorting", wrapper_class=read_nwb_sorting),
146+
IblSortingExtractor: dict(wrapper_string="read_ibl_sorting", wrapper_class=read_ibl_sorting),
147+
CellExplorerSortingExtractor: dict(wrapper_string="read_cellexplorer", wrapper_class=read_cellexplorer),
148+
}
149+
_sorting_extractor_full_dict.update(neo_sorting_extractors_dict)
150+
151+
# events only from neo
152+
_event_extractor_full_dict = neo_event_extractors_dict
153+
154+
_snippets_extractor_full_dict = {
155+
NpySnippetsExtractor: dict(wrapper_string="read_npy_snippets", wrapper_class=read_npy_snippets),
156+
WaveClusSnippetsExtractor: dict(wrapper_string="read_waveclus_snippets", wrapper_class=read_waveclus_snippets),
157+
}
158+
159+
############################################################################################################
160+
# Organize the possible extractors into a user facing format with keys being extractor names
161+
# (e.g. 'intan' , 'kilosort') and values being the appropriate Extractor class returned as its wrapper
162+
# (e.g. IntanRecordingExtractor, KiloSortSortingExtractor)
163+
# An important note is the the formats are returned after performing `.lower()` so a format like
164+
# SpikeGLX will be a key of 'spikeglx'
165+
# for example if we wanted to create a recording from an intan file we could do the following:
166+
# >>> recording = se.recording_extractor_full_dict['intan'](file_path='path/to/data.rhd')
167+
168+
169+
recording_extractor_full_dict = {
170+
rec_class.__name__.replace("Recording", "").replace("Extractor", "").lower(): rec_func["wrapper_class"]
171+
for rec_class, rec_func in _recording_extractor_full_dict.items()
172+
}
173+
sorting_extractor_full_dict = {
174+
sort_class.__name__.replace("Sorting", "").replace("Extractor", "").lower(): sort_func["wrapper_class"]
175+
for sort_class, sort_func in _sorting_extractor_full_dict.items()
176+
}
177+
event_extractor_full_dict = {
178+
event_class.__name__.replace("Event", "").replace("Extractor", "").lower(): event_func["wrapper_class"]
179+
for event_class, event_func in _event_extractor_full_dict.items()
180+
}
181+
snippets_extractor_full_dict = {
182+
snippets_class.__name__.replace("Snippets", "").replace("Extractor", "").lower(): snippets_func["wrapper_class"]
183+
for snippets_class, snippets_func in _snippets_extractor_full_dict.items()
184+
}
185+
186+
187+
# we only do the functions in the init rather than pull in the classes
188+
__all__ = [func["wrapper_string"] for func in _recording_extractor_full_dict.values()]
189+
__all__ += [func["wrapper_string"] for func in _sorting_extractor_full_dict.values()]
190+
__all__ += [func["wrapper_string"] for func in _event_extractor_full_dict.values()]
191+
__all__ += [func["wrapper_string"] for func in _snippets_extractor_full_dict.values()]
192+
__all__.extend(
193+
[
194+
"read_nwb", # convenience function for multiple nwb formats
195+
"recording_extractor_full_dict",
196+
"sorting_extractor_full_dict",
197+
"event_extractor_full_dict",
198+
"snippets_extractor_full_dict",
199+
"read_binary", # convenience function for binary formats
200+
"read_zarr",
201+
]
202+
)

0 commit comments

Comments
 (0)