Skip to content

Commit 772fd4b

Browse files
authored
Merge pull request #4098 from samuelgarcia/benchmark_plots
More benchmark improvements
2 parents b3cf122 + adb8b41 commit 772fd4b

File tree

1 file changed

+45
-29
lines changed

1 file changed

+45
-29
lines changed

src/spikeinterface/benchmark/benchmark_peak_detection.py

Lines changed: 45 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def compute_result(self, **result_params):
5757
spikes, self.recording.sampling_frequency, unit_ids=self.recording.channel_ids
5858
)
5959

60-
self.result["gt_comparison"] = GroundTruthComparison(
60+
self.result["gt_comparison_by_channels"] = GroundTruthComparison(
6161
self.result["gt_on_channels"], self.result["peak_on_channels"], exhaustive_gt=self.exhaustive_gt
6262
)
6363

@@ -82,35 +82,34 @@ def compute_result(self, **result_params):
8282
sorting["segment_index"] = peaks[detected_matches]["segment_index"]
8383
order = np.lexsort((sorting["sample_index"], sorting["segment_index"]))
8484
sorting = sorting[order]
85-
self.result["sliced_gt_sorting"] = NumpySorting(
85+
self.result["matched_sorting"] = NumpySorting(
8686
sorting, self.recording.sampling_frequency, self.gt_sorting.unit_ids
8787
)
88-
self.result["sliced_gt_comparison"] = GroundTruthComparison(
89-
self.gt_sorting, self.result["sliced_gt_sorting"], exhaustive_gt=self.exhaustive_gt
88+
self.result["gt_comparison"] = GroundTruthComparison(
89+
self.gt_sorting, self.result["matched_sorting"], exhaustive_gt=self.exhaustive_gt
9090
)
9191

9292
ratio = 100 * len(gt_matches) / len(times2)
9393
print("Only {0:.2f}% of gt peaks are matched to detected peaks".format(ratio))
9494

9595
sorting_analyzer = create_sorting_analyzer(
96-
self.result["sliced_gt_sorting"], self.recording, format="memory", sparse=False, **job_kwargs
96+
self.result["matched_sorting"], self.recording, format="memory", sparse=False, **job_kwargs
9797
)
9898
sorting_analyzer.compute("random_spikes")
9999
sorting_analyzer.compute("templates", **job_kwargs)
100100

101-
self.result["templates"] = sorting_analyzer.get_extension("templates").get_data()
101+
self.result["matched_templates"] = sorting_analyzer.get_extension("templates").get_data()
102102

103103
_run_key_saved = [("peaks", "npy")]
104104

105105
_result_key_saved = [
106+
("gt_comparison_by_channels", "pickle"),
107+
("matched_sorting", "sorting"),
106108
("gt_comparison", "pickle"),
107-
("sliced_gt_sorting", "sorting"),
108-
("sliced_gt_comparison", "pickle"),
109-
("sliced_gt_sorting", "sorting"),
110109
("peak_on_channels", "sorting"),
111110
("gt_on_channels", "sorting"),
112111
("matches", "pickle"),
113-
("templates", "npy"),
112+
("matched_templates", "npy"),
114113
("gt_amplitudes", "npy"),
115114
("gt_templates", "npy"),
116115
]
@@ -128,6 +127,11 @@ def create_benchmark(self, key):
128127
benchmark = PeakDetectionBenchmark(recording, gt_sorting, params, **init_kwargs)
129128
return benchmark
130129

130+
def plot_performances_vs_snr(self, **kwargs):
131+
from .benchmark_plot_tools import plot_performances_vs_snr
132+
133+
return plot_performances_vs_snr(self, **kwargs)
134+
131135
def plot_agreements_by_channels(self, case_keys=None, figsize=(15, 15)):
132136
if case_keys is None:
133137
case_keys = list(self.cases.keys())
@@ -138,7 +142,7 @@ def plot_agreements_by_channels(self, case_keys=None, figsize=(15, 15)):
138142
for count, key in enumerate(case_keys):
139143
ax = axs[0, count]
140144
ax.set_title(self.cases[key]["label"])
141-
plot_agreement_matrix(self.get_result(key)["gt_comparison"], ax=ax)
145+
plot_agreement_matrix(self.get_result(key)["gt_comparison_by_channels"], ax=ax)
142146

143147
def plot_agreements_by_units(self, case_keys=None, figsize=(15, 15)):
144148
if case_keys is None:
@@ -150,37 +154,49 @@ def plot_agreements_by_units(self, case_keys=None, figsize=(15, 15)):
150154
for count, key in enumerate(case_keys):
151155
ax = axs[0, count]
152156
ax.set_title(self.cases[key]["label"])
153-
plot_agreement_matrix(self.get_result(key)["sliced_gt_comparison"], ax=ax)
157+
plot_agreement_matrix(self.get_result(key)["gt_comparison"], ax=ax)
154158

155-
def plot_detected_amplitudes(self, case_keys=None, figsize=(15, 5), detect_threshold=None, axs=None):
159+
def plot_detected_amplitude_distributions(
160+
self, case_keys=None, show_legend=True, detect_threshold=None, figsize=(15, 5), ax=None
161+
):
156162

157163
if case_keys is None:
158164
case_keys = list(self.cases.keys())
159165
import matplotlib.pyplot as plt
160166

161-
if axs is None:
162-
fig, axs = plt.subplots(ncols=len(case_keys), figsize=figsize, squeeze=False)
167+
if ax is None:
168+
fig, ax = plt.subplots(figsize=figsize, squeeze=False)
163169
else:
164-
fig = axs[0].get_figure()
165-
assert len(axs) == len(case_keys), "axs should be the same length as case_keys"
170+
fig = ax.get_figure()
171+
172+
# plot only the first key for gt amplitude
173+
# TODO make a loop for all of then
174+
key0 = case_keys[0]
175+
data2 = self.get_result(key0)["gt_amplitudes"]
176+
bins = np.linspace(data2.min(), data2.max(), 100)
177+
ax.hist(data2, bins=bins, alpha=0.1, label="gt", color="k")
166178

167179
for count, key in enumerate(case_keys):
168-
ax = axs[count]
169180
despine(ax)
170181
data1 = self.get_result(key)["peaks"]["amplitude"]
171-
data2 = self.get_result(key)["gt_amplitudes"]
182+
172183
color = self.get_colors()[key]
173-
bins = np.linspace(data2.min(), data2.max(), 100)
174-
ax.hist(data1, bins=bins, label="detected", histtype="step", color=color, linewidth=2)
175-
ax.hist(data2, bins=bins, alpha=0.1, label="gt", color="k")
176-
ax.set_yscale("log")
184+
185+
label = self.cases[key]["label"]
186+
ax.hist(data1, bins=bins, label=label, histtype="step", color=color, linewidth=2)
187+
177188
# ax.set_title(self.cases[key]["label"])
189+
190+
ax.set_yscale("log")
191+
192+
if detect_threshold is not None:
193+
noise_levels = get_noise_levels(self.benchmarks[key].recording, return_in_uV=False).mean()
194+
ymin, ymax = ax.get_ylim()
195+
abs_threshold = -detect_threshold * noise_levels
196+
ax.plot([abs_threshold, abs_threshold], [ymin, ymax], "k--")
197+
198+
if show_legend:
178199
ax.legend()
179-
if detect_threshold is not None:
180-
noise_levels = get_noise_levels(self.benchmarks[key].recording, return_in_uV=False).mean()
181-
ymin, ymax = ax.get_ylim()
182-
abs_threshold = -detect_threshold * noise_levels
183-
ax.plot([abs_threshold, abs_threshold], [ymin, ymax], "k--")
184200

185201
return fig
186202

@@ -266,7 +282,7 @@ def plot_template_similarities(self, case_keys=None, metric="l2", figsize=(15, 5
266282
import sklearn.metrics
267283

268284
gt_templates = self.get_result(key)["gt_templates"]
269-
found_templates = self.get_result(key)["templates"]
285+
found_templates = self.get_result(key)["matched_templates"]
270286
num_templates = len(gt_templates)
271287
distances = np.zeros(num_templates)
272288

0 commit comments

Comments
 (0)