@@ -57,7 +57,7 @@ def compute_result(self, **result_params):
57
57
spikes , self .recording .sampling_frequency , unit_ids = self .recording .channel_ids
58
58
)
59
59
60
- self .result ["gt_comparison " ] = GroundTruthComparison (
60
+ self .result ["gt_comparison_by_channels " ] = GroundTruthComparison (
61
61
self .result ["gt_on_channels" ], self .result ["peak_on_channels" ], exhaustive_gt = self .exhaustive_gt
62
62
)
63
63
@@ -82,35 +82,34 @@ def compute_result(self, **result_params):
82
82
sorting ["segment_index" ] = peaks [detected_matches ]["segment_index" ]
83
83
order = np .lexsort ((sorting ["sample_index" ], sorting ["segment_index" ]))
84
84
sorting = sorting [order ]
85
- self .result ["sliced_gt_sorting " ] = NumpySorting (
85
+ self .result ["matched_sorting " ] = NumpySorting (
86
86
sorting , self .recording .sampling_frequency , self .gt_sorting .unit_ids
87
87
)
88
- self .result ["sliced_gt_comparison " ] = GroundTruthComparison (
89
- self .gt_sorting , self .result ["sliced_gt_sorting " ], exhaustive_gt = self .exhaustive_gt
88
+ self .result ["gt_comparison " ] = GroundTruthComparison (
89
+ self .gt_sorting , self .result ["matched_sorting " ], exhaustive_gt = self .exhaustive_gt
90
90
)
91
91
92
92
ratio = 100 * len (gt_matches ) / len (times2 )
93
93
print ("Only {0:.2f}% of gt peaks are matched to detected peaks" .format (ratio ))
94
94
95
95
sorting_analyzer = create_sorting_analyzer (
96
- self .result ["sliced_gt_sorting " ], self .recording , format = "memory" , sparse = False , ** job_kwargs
96
+ self .result ["matched_sorting " ], self .recording , format = "memory" , sparse = False , ** job_kwargs
97
97
)
98
98
sorting_analyzer .compute ("random_spikes" )
99
99
sorting_analyzer .compute ("templates" , ** job_kwargs )
100
100
101
- self .result ["templates " ] = sorting_analyzer .get_extension ("templates" ).get_data ()
101
+ self .result ["matched_templates " ] = sorting_analyzer .get_extension ("templates" ).get_data ()
102
102
103
103
_run_key_saved = [("peaks" , "npy" )]
104
104
105
105
_result_key_saved = [
106
+ ("gt_comparison_by_channels" , "pickle" ),
107
+ ("matched_sorting" , "sorting" ),
106
108
("gt_comparison" , "pickle" ),
107
- ("sliced_gt_sorting" , "sorting" ),
108
- ("sliced_gt_comparison" , "pickle" ),
109
- ("sliced_gt_sorting" , "sorting" ),
110
109
("peak_on_channels" , "sorting" ),
111
110
("gt_on_channels" , "sorting" ),
112
111
("matches" , "pickle" ),
113
- ("templates " , "npy" ),
112
+ ("matched_templates " , "npy" ),
114
113
("gt_amplitudes" , "npy" ),
115
114
("gt_templates" , "npy" ),
116
115
]
@@ -128,6 +127,11 @@ def create_benchmark(self, key):
128
127
benchmark = PeakDetectionBenchmark (recording , gt_sorting , params , ** init_kwargs )
129
128
return benchmark
130
129
130
+ def plot_performances_vs_snr (self , ** kwargs ):
131
+ from .benchmark_plot_tools import plot_performances_vs_snr
132
+
133
+ return plot_performances_vs_snr (self , ** kwargs )
134
+
131
135
def plot_agreements_by_channels (self , case_keys = None , figsize = (15 , 15 )):
132
136
if case_keys is None :
133
137
case_keys = list (self .cases .keys ())
@@ -138,7 +142,7 @@ def plot_agreements_by_channels(self, case_keys=None, figsize=(15, 15)):
138
142
for count , key in enumerate (case_keys ):
139
143
ax = axs [0 , count ]
140
144
ax .set_title (self .cases [key ]["label" ])
141
- plot_agreement_matrix (self .get_result (key )["gt_comparison " ], ax = ax )
145
+ plot_agreement_matrix (self .get_result (key )["gt_comparison_by_channels " ], ax = ax )
142
146
143
147
def plot_agreements_by_units (self , case_keys = None , figsize = (15 , 15 )):
144
148
if case_keys is None :
@@ -150,37 +154,49 @@ def plot_agreements_by_units(self, case_keys=None, figsize=(15, 15)):
150
154
for count , key in enumerate (case_keys ):
151
155
ax = axs [0 , count ]
152
156
ax .set_title (self .cases [key ]["label" ])
153
- plot_agreement_matrix (self .get_result (key )["sliced_gt_comparison " ], ax = ax )
157
+ plot_agreement_matrix (self .get_result (key )["gt_comparison " ], ax = ax )
154
158
155
- def plot_detected_amplitudes (self , case_keys = None , figsize = (15 , 5 ), detect_threshold = None , axs = None ):
159
+ def plot_detected_amplitude_distributions (
160
+ self , case_keys = None , show_legend = True , detect_threshold = None , figsize = (15 , 5 ), ax = None
161
+ ):
156
162
157
163
if case_keys is None :
158
164
case_keys = list (self .cases .keys ())
159
165
import matplotlib .pyplot as plt
160
166
161
- if axs is None :
162
- fig , axs = plt .subplots (ncols = len ( case_keys ), figsize = figsize , squeeze = False )
167
+ if ax is None :
168
+ fig , ax = plt .subplots (figsize = figsize , squeeze = False )
163
169
else :
164
- fig = axs [0 ].get_figure ()
165
- assert len (axs ) == len (case_keys ), "axs should be the same length as case_keys"
170
+ fig = ax .get_figure ()
171
+
172
+ # plot only the first key for gt amplitude
173
+ # TODO make a loop for all of then
174
+ key0 = case_keys [0 ]
175
+ data2 = self .get_result (key0 )["gt_amplitudes" ]
176
+ bins = np .linspace (data2 .min (), data2 .max (), 100 )
177
+ ax .hist (data2 , bins = bins , alpha = 0.1 , label = "gt" , color = "k" )
166
178
167
179
for count , key in enumerate (case_keys ):
168
- ax = axs [count ]
169
180
despine (ax )
170
181
data1 = self .get_result (key )["peaks" ]["amplitude" ]
171
- data2 = self . get_result ( key )[ "gt_amplitudes" ]
182
+
172
183
color = self .get_colors ()[key ]
173
- bins = np . linspace ( data2 . min (), data2 . max (), 100 )
174
- ax . hist ( data1 , bins = bins , label = "detected" , histtype = "step" , color = color , linewidth = 2 )
175
- ax .hist (data2 , bins = bins , alpha = 0.1 , label = "gt " , color = "k" )
176
- ax . set_yscale ( "log" )
184
+
185
+ label = self . cases [ key ][ " label" ]
186
+ ax .hist (data1 , bins = bins , label = label , histtype = "step " , color = color , linewidth = 2 )
187
+
177
188
# ax.set_title(self.cases[key]["label"])
189
+
190
+ ax .set_yscale ("log" )
191
+
192
+ if detect_threshold is not None :
193
+ noise_levels = get_noise_levels (self .benchmarks [key ].recording , return_in_uV = False ).mean ()
194
+ ymin , ymax = ax .get_ylim ()
195
+ abs_threshold = - detect_threshold * noise_levels
196
+ ax .plot ([abs_threshold , abs_threshold ], [ymin , ymax ], "k--" )
197
+
198
+ if show_legend :
178
199
ax .legend ()
179
- if detect_threshold is not None :
180
- noise_levels = get_noise_levels (self .benchmarks [key ].recording , return_in_uV = False ).mean ()
181
- ymin , ymax = ax .get_ylim ()
182
- abs_threshold = - detect_threshold * noise_levels
183
- ax .plot ([abs_threshold , abs_threshold ], [ymin , ymax ], "k--" )
184
200
185
201
return fig
186
202
@@ -266,7 +282,7 @@ def plot_template_similarities(self, case_keys=None, metric="l2", figsize=(15, 5
266
282
import sklearn .metrics
267
283
268
284
gt_templates = self .get_result (key )["gt_templates" ]
269
- found_templates = self .get_result (key )["templates " ]
285
+ found_templates = self .get_result (key )["matched_templates " ]
270
286
num_templates = len (gt_templates )
271
287
distances = np .zeros (num_templates )
272
288
0 commit comments