17
17
from ..core import get_random_data_chunks , compute_sparsity
18
18
from ..core .template_tools import get_template_extremum_channel
19
19
20
-
21
20
_possible_pc_metric_names = [
22
21
"isolation_distance" ,
23
22
"l_ratio" ,
@@ -90,7 +89,7 @@ def compute_pc_metrics(
90
89
sorting = sorting_analyzer .sorting
91
90
92
91
if metric_names is None :
93
- metric_names = _possible_pc_metric_names
92
+ metric_names = _possible_pc_metric_names . copy ()
94
93
if qm_params is None :
95
94
qm_params = _default_params
96
95
@@ -110,8 +109,13 @@ def compute_pc_metrics(
110
109
if "nn_isolation" in metric_names :
111
110
pc_metrics ["nn_unit_id" ] = {}
112
111
112
+ possible_nn_metrics = ["nn_isolation" , "nn_noise_overlap" ]
113
+
114
+ nn_metrics = list (set (metric_names ).intersection (possible_nn_metrics ))
115
+ non_nn_metrics = list (set (metric_names ).difference (possible_nn_metrics ))
116
+
113
117
# Compute nspikes and firing rate outside of main loop for speed
114
- if any ([ n in metric_names for n in [ "nn_isolation" , "nn_noise_overlap" ]]) :
118
+ if nn_metrics :
115
119
n_spikes_all_units = compute_num_spikes (sorting_analyzer , unit_ids = unit_ids )
116
120
fr_all_units = compute_firing_rates (sorting_analyzer , unit_ids = unit_ids )
117
121
else :
@@ -120,9 +124,6 @@ def compute_pc_metrics(
120
124
121
125
run_in_parallel = n_jobs > 1
122
126
123
- if run_in_parallel :
124
- parallel_functions = []
125
-
126
127
# this get dense projection for selected unit_ids
127
128
dense_projections , spike_unit_indices = pca_ext .get_some_projections (channel_ids = None , unit_ids = unit_ids )
128
129
all_labels = sorting .unit_ids [spike_unit_indices ]
@@ -146,7 +147,7 @@ def compute_pc_metrics(
146
147
func_args = (
147
148
pcs_flat ,
148
149
labels ,
149
- metric_names ,
150
+ non_nn_metrics ,
150
151
unit_id ,
151
152
unit_ids ,
152
153
qm_params ,
@@ -156,16 +157,16 @@ def compute_pc_metrics(
156
157
)
157
158
items .append (func_args )
158
159
159
- if not run_in_parallel :
160
+ if not run_in_parallel and non_nn_metrics :
160
161
units_loop = enumerate (unit_ids )
161
162
if progress_bar :
162
- units_loop = tqdm (units_loop , desc = "calculate_pc_metrics " , total = len (unit_ids ))
163
+ units_loop = tqdm (units_loop , desc = "calculate pc_metrics " , total = len (unit_ids ))
163
164
164
165
for unit_ind , unit_id in units_loop :
165
166
pca_metrics_unit = pca_metrics_one_unit (items [unit_ind ])
166
167
for metric_name , metric in pca_metrics_unit .items ():
167
168
pc_metrics [metric_name ][unit_id ] = metric
168
- else :
169
+ elif run_in_parallel and non_nn_metrics :
169
170
with ProcessPoolExecutor (n_jobs ) as executor :
170
171
results = executor .map (pca_metrics_one_unit , items )
171
172
if progress_bar :
@@ -176,6 +177,37 @@ def compute_pc_metrics(
176
177
for metric_name , metric in pca_metrics_unit .items ():
177
178
pc_metrics [metric_name ][unit_id ] = metric
178
179
180
+ for metric_name in nn_metrics :
181
+ units_loop = enumerate (unit_ids )
182
+ if progress_bar :
183
+ units_loop = tqdm (units_loop , desc = f"calculate { metric_name } metric" , total = len (unit_ids ))
184
+
185
+ func = _nn_metric_name_to_func [metric_name ]
186
+ metric_params = qm_params [metric_name ] if metric_name in qm_params else {}
187
+
188
+ for _ , unit_id in units_loop :
189
+ try :
190
+ res = func (
191
+ sorting_analyzer ,
192
+ unit_id ,
193
+ seed = seed ,
194
+ n_spikes_all_units = n_spikes_all_units ,
195
+ fr_all_units = fr_all_units ,
196
+ ** metric_params ,
197
+ )
198
+ except :
199
+ if metric_name == "nn_isolation" :
200
+ res = (np .nan , np .nan )
201
+ elif metric_name == "nn_noise_overlap" :
202
+ res = np .nan
203
+
204
+ if metric_name == "nn_isolation" :
205
+ nn_isolation , nn_unit_id = res
206
+ pc_metrics ["nn_isolation" ][unit_id ] = nn_isolation
207
+ pc_metrics ["nn_unit_id" ][unit_id ] = nn_unit_id
208
+ elif metric_name == "nn_noise_overlap" :
209
+ pc_metrics ["nn_noise_overlap" ][unit_id ] = res
210
+
179
211
return pc_metrics
180
212
181
213
@@ -677,6 +709,14 @@ def nearest_neighbors_noise_overlap(
677
709
templates_ext = sorting_analyzer .get_extension ("templates" )
678
710
assert templates_ext is not None , "nearest_neighbors_isolation() need extension 'templates'"
679
711
712
+ try :
713
+ sorting_analyzer .get_extension ("templates" ).get_data (operator = "median" )
714
+ except KeyError :
715
+ warnings .warn (
716
+ "nearest_neighbors_isolation() need extension 'templates' calculated with the 'median' operator."
717
+ "You can run sorting_analyzer.compute('templates', operators=['average', 'median']) to calculate templates based on both average and median modes."
718
+ )
719
+
680
720
if n_spikes_all_units is None :
681
721
n_spikes_all_units = compute_num_spikes (sorting_analyzer )
682
722
if fr_all_units is None :
@@ -955,11 +995,13 @@ def pca_metrics_one_unit(args):
955
995
pc_metrics = {}
956
996
# metrics
957
997
if "isolation_distance" in metric_names or "l_ratio" in metric_names :
998
+
958
999
try :
959
1000
isolation_distance , l_ratio = mahalanobis_metrics (pcs_flat , labels , unit_id )
960
1001
except :
961
1002
isolation_distance = np .nan
962
1003
l_ratio = np .nan
1004
+
963
1005
if "isolation_distance" in metric_names :
964
1006
pc_metrics ["isolation_distance" ] = isolation_distance
965
1007
if "l_ratio" in metric_names :
@@ -973,6 +1015,7 @@ def pca_metrics_one_unit(args):
973
1015
d_prime = lda_metrics (pcs_flat , labels , unit_id )
974
1016
except :
975
1017
d_prime = np .nan
1018
+
976
1019
pc_metrics ["d_prime" ] = d_prime
977
1020
978
1021
if "nearest_neighbor" in metric_names :
@@ -986,36 +1029,6 @@ def pca_metrics_one_unit(args):
986
1029
pc_metrics ["nn_hit_rate" ] = nn_hit_rate
987
1030
pc_metrics ["nn_miss_rate" ] = nn_miss_rate
988
1031
989
- if "nn_isolation" in metric_names :
990
- try :
991
- nn_isolation , nn_unit_id = nearest_neighbors_isolation (
992
- we ,
993
- unit_id ,
994
- seed = seed ,
995
- n_spikes_all_units = n_spikes_all_units ,
996
- fr_all_units = fr_all_units ,
997
- ** qm_params ["nn_isolation" ],
998
- )
999
- except :
1000
- nn_isolation = np .nan
1001
- nn_unit_id = np .nan
1002
- pc_metrics ["nn_isolation" ] = nn_isolation
1003
- pc_metrics ["nn_unit_id" ] = nn_unit_id
1004
-
1005
- if "nn_noise_overlap" in metric_names :
1006
- try :
1007
- nn_noise_overlap = nearest_neighbors_noise_overlap (
1008
- we ,
1009
- unit_id ,
1010
- n_spikes_all_units = n_spikes_all_units ,
1011
- fr_all_units = fr_all_units ,
1012
- seed = seed ,
1013
- ** qm_params ["nn_noise_overlap" ],
1014
- )
1015
- except :
1016
- nn_noise_overlap = np .nan
1017
- pc_metrics ["nn_noise_overlap" ] = nn_noise_overlap
1018
-
1019
1032
if "silhouette" in metric_names :
1020
1033
silhouette_method = qm_params ["silhouette" ]["method" ]
1021
1034
if "simplified" in silhouette_method :
@@ -1032,3 +1045,9 @@ def pca_metrics_one_unit(args):
1032
1045
pc_metrics ["silhouette_full" ] = unit_silhouette_score
1033
1046
1034
1047
return pc_metrics
1048
+
1049
+
1050
+ _nn_metric_name_to_func = {
1051
+ "nn_isolation" : nearest_neighbors_isolation ,
1052
+ "nn_noise_overlap" : nearest_neighbors_noise_overlap ,
1053
+ }
0 commit comments