Skip to content

Commit ad3e924

Browse files
authored
Merge pull request #3138 from chrishalcrow/fix-nn-calculations
Fix nn pca_metric computation and update tests
2 parents eba9f68 + edb8003 commit ad3e924

File tree

4 files changed

+109
-142
lines changed

4 files changed

+109
-142
lines changed

src/spikeinterface/qualitymetrics/pca_metrics.py

Lines changed: 59 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from ..core import get_random_data_chunks, compute_sparsity
1818
from ..core.template_tools import get_template_extremum_channel
1919

20-
2120
_possible_pc_metric_names = [
2221
"isolation_distance",
2322
"l_ratio",
@@ -90,7 +89,7 @@ def compute_pc_metrics(
9089
sorting = sorting_analyzer.sorting
9190

9291
if metric_names is None:
93-
metric_names = _possible_pc_metric_names
92+
metric_names = _possible_pc_metric_names.copy()
9493
if qm_params is None:
9594
qm_params = _default_params
9695

@@ -110,8 +109,13 @@ def compute_pc_metrics(
110109
if "nn_isolation" in metric_names:
111110
pc_metrics["nn_unit_id"] = {}
112111

112+
possible_nn_metrics = ["nn_isolation", "nn_noise_overlap"]
113+
114+
nn_metrics = list(set(metric_names).intersection(possible_nn_metrics))
115+
non_nn_metrics = list(set(metric_names).difference(possible_nn_metrics))
116+
113117
# Compute nspikes and firing rate outside of main loop for speed
114-
if any([n in metric_names for n in ["nn_isolation", "nn_noise_overlap"]]):
118+
if nn_metrics:
115119
n_spikes_all_units = compute_num_spikes(sorting_analyzer, unit_ids=unit_ids)
116120
fr_all_units = compute_firing_rates(sorting_analyzer, unit_ids=unit_ids)
117121
else:
@@ -120,9 +124,6 @@ def compute_pc_metrics(
120124

121125
run_in_parallel = n_jobs > 1
122126

123-
if run_in_parallel:
124-
parallel_functions = []
125-
126127
# this get dense projection for selected unit_ids
127128
dense_projections, spike_unit_indices = pca_ext.get_some_projections(channel_ids=None, unit_ids=unit_ids)
128129
all_labels = sorting.unit_ids[spike_unit_indices]
@@ -146,7 +147,7 @@ def compute_pc_metrics(
146147
func_args = (
147148
pcs_flat,
148149
labels,
149-
metric_names,
150+
non_nn_metrics,
150151
unit_id,
151152
unit_ids,
152153
qm_params,
@@ -156,16 +157,16 @@ def compute_pc_metrics(
156157
)
157158
items.append(func_args)
158159

159-
if not run_in_parallel:
160+
if not run_in_parallel and non_nn_metrics:
160161
units_loop = enumerate(unit_ids)
161162
if progress_bar:
162-
units_loop = tqdm(units_loop, desc="calculate_pc_metrics", total=len(unit_ids))
163+
units_loop = tqdm(units_loop, desc="calculate pc_metrics", total=len(unit_ids))
163164

164165
for unit_ind, unit_id in units_loop:
165166
pca_metrics_unit = pca_metrics_one_unit(items[unit_ind])
166167
for metric_name, metric in pca_metrics_unit.items():
167168
pc_metrics[metric_name][unit_id] = metric
168-
else:
169+
elif run_in_parallel and non_nn_metrics:
169170
with ProcessPoolExecutor(n_jobs) as executor:
170171
results = executor.map(pca_metrics_one_unit, items)
171172
if progress_bar:
@@ -176,6 +177,37 @@ def compute_pc_metrics(
176177
for metric_name, metric in pca_metrics_unit.items():
177178
pc_metrics[metric_name][unit_id] = metric
178179

180+
for metric_name in nn_metrics:
181+
units_loop = enumerate(unit_ids)
182+
if progress_bar:
183+
units_loop = tqdm(units_loop, desc=f"calculate {metric_name} metric", total=len(unit_ids))
184+
185+
func = _nn_metric_name_to_func[metric_name]
186+
metric_params = qm_params[metric_name] if metric_name in qm_params else {}
187+
188+
for _, unit_id in units_loop:
189+
try:
190+
res = func(
191+
sorting_analyzer,
192+
unit_id,
193+
seed=seed,
194+
n_spikes_all_units=n_spikes_all_units,
195+
fr_all_units=fr_all_units,
196+
**metric_params,
197+
)
198+
except:
199+
if metric_name == "nn_isolation":
200+
res = (np.nan, np.nan)
201+
elif metric_name == "nn_noise_overlap":
202+
res = np.nan
203+
204+
if metric_name == "nn_isolation":
205+
nn_isolation, nn_unit_id = res
206+
pc_metrics["nn_isolation"][unit_id] = nn_isolation
207+
pc_metrics["nn_unit_id"][unit_id] = nn_unit_id
208+
elif metric_name == "nn_noise_overlap":
209+
pc_metrics["nn_noise_overlap"][unit_id] = res
210+
179211
return pc_metrics
180212

181213

@@ -677,6 +709,14 @@ def nearest_neighbors_noise_overlap(
677709
templates_ext = sorting_analyzer.get_extension("templates")
678710
assert templates_ext is not None, "nearest_neighbors_isolation() need extension 'templates'"
679711

712+
try:
713+
sorting_analyzer.get_extension("templates").get_data(operator="median")
714+
except KeyError:
715+
warnings.warn(
716+
"nearest_neighbors_isolation() need extension 'templates' calculated with the 'median' operator."
717+
"You can run sorting_analyzer.compute('templates', operators=['average', 'median']) to calculate templates based on both average and median modes."
718+
)
719+
680720
if n_spikes_all_units is None:
681721
n_spikes_all_units = compute_num_spikes(sorting_analyzer)
682722
if fr_all_units is None:
@@ -955,11 +995,13 @@ def pca_metrics_one_unit(args):
955995
pc_metrics = {}
956996
# metrics
957997
if "isolation_distance" in metric_names or "l_ratio" in metric_names:
998+
958999
try:
9591000
isolation_distance, l_ratio = mahalanobis_metrics(pcs_flat, labels, unit_id)
9601001
except:
9611002
isolation_distance = np.nan
9621003
l_ratio = np.nan
1004+
9631005
if "isolation_distance" in metric_names:
9641006
pc_metrics["isolation_distance"] = isolation_distance
9651007
if "l_ratio" in metric_names:
@@ -973,6 +1015,7 @@ def pca_metrics_one_unit(args):
9731015
d_prime = lda_metrics(pcs_flat, labels, unit_id)
9741016
except:
9751017
d_prime = np.nan
1018+
9761019
pc_metrics["d_prime"] = d_prime
9771020

9781021
if "nearest_neighbor" in metric_names:
@@ -986,36 +1029,6 @@ def pca_metrics_one_unit(args):
9861029
pc_metrics["nn_hit_rate"] = nn_hit_rate
9871030
pc_metrics["nn_miss_rate"] = nn_miss_rate
9881031

989-
if "nn_isolation" in metric_names:
990-
try:
991-
nn_isolation, nn_unit_id = nearest_neighbors_isolation(
992-
we,
993-
unit_id,
994-
seed=seed,
995-
n_spikes_all_units=n_spikes_all_units,
996-
fr_all_units=fr_all_units,
997-
**qm_params["nn_isolation"],
998-
)
999-
except:
1000-
nn_isolation = np.nan
1001-
nn_unit_id = np.nan
1002-
pc_metrics["nn_isolation"] = nn_isolation
1003-
pc_metrics["nn_unit_id"] = nn_unit_id
1004-
1005-
if "nn_noise_overlap" in metric_names:
1006-
try:
1007-
nn_noise_overlap = nearest_neighbors_noise_overlap(
1008-
we,
1009-
unit_id,
1010-
n_spikes_all_units=n_spikes_all_units,
1011-
fr_all_units=fr_all_units,
1012-
seed=seed,
1013-
**qm_params["nn_noise_overlap"],
1014-
)
1015-
except:
1016-
nn_noise_overlap = np.nan
1017-
pc_metrics["nn_noise_overlap"] = nn_noise_overlap
1018-
10191032
if "silhouette" in metric_names:
10201033
silhouette_method = qm_params["silhouette"]["method"]
10211034
if "simplified" in silhouette_method:
@@ -1032,3 +1045,9 @@ def pca_metrics_one_unit(args):
10321045
pc_metrics["silhouette_full"] = unit_silhouette_score
10331046

10341047
return pc_metrics
1048+
1049+
1050+
_nn_metric_name_to_func = {
1051+
"nn_isolation": nearest_neighbors_isolation,
1052+
"nn_noise_overlap": nearest_neighbors_noise_overlap,
1053+
}
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import pytest
2+
3+
from spikeinterface.core import (
4+
generate_ground_truth_recording,
5+
create_sorting_analyzer,
6+
)
7+
8+
9+
def _small_sorting_analyzer():
10+
recording, sorting = generate_ground_truth_recording(
11+
durations=[2.0],
12+
num_units=10,
13+
seed=1205,
14+
)
15+
16+
sorting = sorting.select_units([2, 7, 0], ["#3", "#9", "#4"])
17+
18+
sorting_analyzer = create_sorting_analyzer(recording=recording, sorting=sorting, format="memory")
19+
20+
extensions_to_compute = {
21+
"random_spikes": {"seed": 1205},
22+
"noise_levels": {"seed": 1205},
23+
"waveforms": {},
24+
"templates": {"operators": ["average", "median"]},
25+
"spike_amplitudes": {},
26+
"spike_locations": {},
27+
"principal_components": {},
28+
}
29+
30+
sorting_analyzer.compute(extensions_to_compute)
31+
32+
return sorting_analyzer
33+
34+
35+
@pytest.fixture(scope="module")
36+
def small_sorting_analyzer():
37+
return _small_sorting_analyzer()

src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py

Lines changed: 4 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -47,37 +47,6 @@
4747
job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s")
4848

4949

50-
def _small_sorting_analyzer():
51-
recording, sorting = generate_ground_truth_recording(
52-
durations=[2.0],
53-
num_units=4,
54-
seed=1205,
55-
)
56-
57-
sorting = sorting.select_units([3, 2, 0], ["#3", "#9", "#4"])
58-
59-
sorting_analyzer = create_sorting_analyzer(recording=recording, sorting=sorting, format="memory")
60-
61-
extensions_to_compute = {
62-
"random_spikes": {"seed": 1205},
63-
"noise_levels": {"seed": 1205},
64-
"waveforms": {},
65-
"templates": {},
66-
"spike_amplitudes": {},
67-
"spike_locations": {},
68-
"principal_components": {},
69-
}
70-
71-
sorting_analyzer.compute(extensions_to_compute)
72-
73-
return sorting_analyzer
74-
75-
76-
@pytest.fixture(scope="module")
77-
def small_sorting_analyzer():
78-
return _small_sorting_analyzer()
79-
80-
8150
def test_unit_structure_in_output(small_sorting_analyzer):
8251

8352
qm_params = {
@@ -126,7 +95,7 @@ def test_unit_id_order_independence(small_sorting_analyzer):
12695
"""
12796

12897
recording = small_sorting_analyzer.recording
129-
sorting = small_sorting_analyzer.sorting.select_units(["#4", "#9", "#3"], [0, 2, 3])
98+
sorting = small_sorting_analyzer.sorting.select_units(["#4", "#9", "#3"], [1, 7, 2])
13099

131100
small_sorting_analyzer_2 = create_sorting_analyzer(recording=recording, sorting=sorting, format="memory")
132101

@@ -161,9 +130,9 @@ def test_unit_id_order_independence(small_sorting_analyzer):
161130
)
162131

163132
for metric, metric_1_data in quality_metrics_1.items():
164-
assert quality_metrics_2[metric][3] == metric_1_data["#3"]
165-
assert quality_metrics_2[metric][2] == metric_1_data["#9"]
166-
assert quality_metrics_2[metric][0] == metric_1_data["#4"]
133+
assert quality_metrics_2[metric][2] == metric_1_data["#3"]
134+
assert quality_metrics_2[metric][7] == metric_1_data["#9"]
135+
assert quality_metrics_2[metric][1] == metric_1_data["#4"]
167136

168137

169138
def _sorting_analyzer_simple():
Lines changed: 9 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1,82 +1,24 @@
11
import pytest
2-
from pathlib import Path
32
import numpy as np
4-
from spikeinterface.core import (
5-
generate_ground_truth_recording,
6-
create_sorting_analyzer,
7-
)
8-
93

104
from spikeinterface.qualitymetrics import (
115
compute_pc_metrics,
12-
nearest_neighbors_isolation,
13-
nearest_neighbors_noise_overlap,
146
)
157

168

17-
job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s")
18-
19-
20-
def _sorting_analyzer_simple():
21-
recording, sorting = generate_ground_truth_recording(
22-
durations=[
23-
50.0,
24-
],
25-
sampling_frequency=30_000.0,
26-
num_channels=6,
27-
num_units=10,
28-
generate_sorting_kwargs=dict(firing_rates=6.0, refractory_period_ms=4.0),
29-
noise_kwargs=dict(noise_levels=5.0, strategy="tile_pregenerated"),
30-
seed=2205,
31-
)
32-
33-
sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=True)
34-
35-
sorting_analyzer.compute("random_spikes", max_spikes_per_unit=300, seed=2205)
36-
sorting_analyzer.compute("noise_levels")
37-
sorting_analyzer.compute("waveforms", **job_kwargs)
38-
sorting_analyzer.compute("templates", operators=["average", "std", "median"])
39-
sorting_analyzer.compute("principal_components", n_components=5, mode="by_channel_local", **job_kwargs)
40-
sorting_analyzer.compute("spike_amplitudes", **job_kwargs)
41-
42-
return sorting_analyzer
43-
44-
45-
@pytest.fixture(scope="module")
46-
def sorting_analyzer_simple():
47-
return _sorting_analyzer_simple()
48-
49-
50-
def test_calculate_pc_metrics(sorting_analyzer_simple):
9+
def test_calculate_pc_metrics(small_sorting_analyzer):
5110
import pandas as pd
5211

53-
sorting_analyzer = sorting_analyzer_simple
54-
res1 = compute_pc_metrics(sorting_analyzer, n_jobs=1, progress_bar=True)
12+
sorting_analyzer = small_sorting_analyzer
13+
res1 = compute_pc_metrics(sorting_analyzer, n_jobs=1, progress_bar=True, seed=1205)
5514
res1 = pd.DataFrame(res1)
5615

57-
res2 = compute_pc_metrics(sorting_analyzer, n_jobs=2, progress_bar=True)
16+
res2 = compute_pc_metrics(sorting_analyzer, n_jobs=2, progress_bar=True, seed=1205)
5817
res2 = pd.DataFrame(res2)
5918

60-
for k in res1.columns:
61-
mask = ~np.isnan(res1[k].values)
62-
if np.any(mask):
63-
assert np.array_equal(res1[k].values[mask], res2[k].values[mask])
64-
65-
66-
def test_nearest_neighbors_isolation(sorting_analyzer_simple):
67-
sorting_analyzer = sorting_analyzer_simple
68-
this_unit_id = sorting_analyzer.unit_ids[0]
69-
nearest_neighbors_isolation(sorting_analyzer, this_unit_id)
70-
71-
72-
def test_nearest_neighbors_noise_overlap(sorting_analyzer_simple):
73-
sorting_analyzer = sorting_analyzer_simple
74-
this_unit_id = sorting_analyzer.unit_ids[0]
75-
nearest_neighbors_noise_overlap(sorting_analyzer, this_unit_id)
76-
19+
for metric_name in res1.columns:
20+
if metric_name != "nn_unit_id":
21+
assert not np.all(np.isnan(res1[metric_name].values))
22+
assert not np.all(np.isnan(res2[metric_name].values))
7723

78-
if __name__ == "__main__":
79-
sorting_analyzer = _sorting_analyzer_simple()
80-
test_calculate_pc_metrics(sorting_analyzer)
81-
test_nearest_neighbors_isolation(sorting_analyzer)
82-
test_nearest_neighbors_noise_overlap(sorting_analyzer)
24+
assert np.array_equal(res1[metric_name].values, res2[metric_name].values)

0 commit comments

Comments
 (0)