1- import os
2- from collections import namedtuple
31from functools import partial
42
53import numpy as np
86from sklearn .metrics import precision_score , recall_score
97
108from pytorch_lightning .metrics import Precision , Recall
11- from tests .metrics .classification .utils import (
9+ from tests .metrics .classification .inputs import (
1210 _binary_inputs ,
1311 _binary_prob_inputs ,
1412 _multiclass_inputs ,
1816 _multilabel_inputs ,
1917 _multilabel_prob_inputs ,
2018)
21- from tests .metrics .utils import BATCH_SIZE , NUM_BATCHES , NUM_CLASSES , NUM_PROCESSES , THRESHOLD , MetricTester
19+ from tests .metrics .utils import NUM_CLASSES , THRESHOLD , MetricTester
2220
2321torch .manual_seed (42 )
2422
2523
26- def _binary_prob_sk_metric (preds , target , sk_fn = precision_score , average = 'micro' ):
24+ def _sk_prec_recall_binary_prob (preds , target , sk_fn = precision_score , average = 'micro' ):
2725 sk_preds = (preds .view (- 1 ).numpy () >= THRESHOLD ).astype (np .uint8 )
2826 sk_target = target .view (- 1 ).numpy ()
2927
3028 return sk_fn (y_true = sk_target , y_pred = sk_preds , average = 'binary' )
3129
3230
33- def _binary_sk_metric (preds , target , sk_fn = precision_score , average = 'micro' ):
31+ def _sk_prec_recall_binary (preds , target , sk_fn = precision_score , average = 'micro' ):
3432 sk_preds = preds .view (- 1 ).numpy ()
3533 sk_target = target .view (- 1 ).numpy ()
3634
3735 return sk_fn (y_true = sk_target , y_pred = sk_preds , average = 'binary' )
3836
3937
40- def _multilabel_prob_sk_metric (preds , target , sk_fn = precision_score , average = 'micro' ):
38+ def _sk_prec_recall_multilabel_prob (preds , target , sk_fn = precision_score , average = 'micro' ):
4139 sk_preds = (preds .view (- 1 , NUM_CLASSES ).numpy () >= THRESHOLD ).astype (np .uint8 )
4240 sk_target = target .view (- 1 , NUM_CLASSES ).numpy ()
4341
4442 return sk_fn (y_true = sk_target , y_pred = sk_preds , average = average )
4543
4644
47- def _multilabel_sk_metric (preds , target , sk_fn = precision_score , average = 'micro' ):
45+ def _sk_prec_recall_multilabel (preds , target , sk_fn = precision_score , average = 'micro' ):
4846 sk_preds = preds .view (- 1 , NUM_CLASSES ).numpy ()
4947 sk_target = target .view (- 1 , NUM_CLASSES ).numpy ()
5048
5149 return sk_fn (y_true = sk_target , y_pred = sk_preds , average = average )
5250
5351
54- def _multiclass_prob_sk_metric (preds , target , sk_fn = precision_score , average = 'micro' ):
52+ def _sk_prec_recall_multiclass_prob (preds , target , sk_fn = precision_score , average = 'micro' ):
5553 sk_preds = torch .argmax (preds , dim = len (preds .shape ) - 1 ).view (- 1 ).numpy ()
5654 sk_target = target .view (- 1 ).numpy ()
5755
5856 return sk_fn (y_true = sk_target , y_pred = sk_preds , average = average )
5957
6058
61- def _multiclass_sk_metric (preds , target , sk_fn = precision_score , average = 'micro' ):
59+ def _sk_prec_recall_multiclass (preds , target , sk_fn = precision_score , average = 'micro' ):
6260 sk_preds = preds .view (- 1 ).numpy ()
6361 sk_target = target .view (- 1 ).numpy ()
6462
6563 return sk_fn (y_true = sk_target , y_pred = sk_preds , average = average )
6664
6765
68- def _multidim_multiclass_prob_sk_metric (preds , target , sk_fn = precision_score , average = 'micro' ):
66+ def _sk_prec_recall_multidim_multiclass_prob (preds , target , sk_fn = precision_score , average = 'micro' ):
6967 sk_preds = torch .argmax (preds , dim = len (preds .shape ) - 2 ).view (- 1 ).numpy ()
7068 sk_target = target .view (- 1 ).numpy ()
7169
7270 return sk_fn (y_true = sk_target , y_pred = sk_preds , average = average )
7371
7472
75- def _multidim_multiclass_sk_metric (preds , target , sk_fn = precision_score , average = 'micro' ):
73+ def _sk_prec_recall_multidim_multiclass (preds , target , sk_fn = precision_score , average = 'micro' ):
7674 sk_preds = preds .view (- 1 ).numpy ()
7775 sk_target = target .view (- 1 ).numpy ()
7876
@@ -85,25 +83,25 @@ def _multidim_multiclass_sk_metric(preds, target, sk_fn=precision_score, average
8583@pytest .mark .parametrize (
8684 "preds, target, sk_metric, num_classes, multilabel" ,
8785 [
88- (_binary_prob_inputs .preds , _binary_prob_inputs .target , _binary_prob_sk_metric , 1 , False ),
89- (_binary_inputs .preds , _binary_inputs .target , _binary_sk_metric , 1 , False ),
90- (_multilabel_prob_inputs .preds , _multilabel_prob_inputs .target , _multilabel_prob_sk_metric , NUM_CLASSES , True ),
91- (_multilabel_inputs .preds , _multilabel_inputs .target , _multilabel_sk_metric , NUM_CLASSES , True ),
92- (_multiclass_prob_inputs .preds , _multiclass_prob_inputs .target , _multiclass_prob_sk_metric , NUM_CLASSES , False ),
93- (_multiclass_inputs .preds , _multiclass_inputs .target , _multiclass_sk_metric , NUM_CLASSES , False ),
86+ (_binary_prob_inputs .preds , _binary_prob_inputs .target , _sk_prec_recall_binary_prob , 1 , False ),
87+ (_binary_inputs .preds , _binary_inputs .target , _sk_prec_recall_binary , 1 , False ),
88+ (_multilabel_prob_inputs .preds , _multilabel_prob_inputs .target , _sk_prec_recall_multilabel_prob , NUM_CLASSES , True ),
89+ (_multilabel_inputs .preds , _multilabel_inputs .target , _sk_prec_recall_multilabel , NUM_CLASSES , True ),
90+ (_multiclass_prob_inputs .preds , _multiclass_prob_inputs .target , _sk_prec_recall_multiclass_prob , NUM_CLASSES , False ),
91+ (_multiclass_inputs .preds , _multiclass_inputs .target , _sk_prec_recall_multiclass , NUM_CLASSES , False ),
9492 (
95- _multidim_multiclass_prob_inputs .preds ,
96- _multidim_multiclass_prob_inputs .target ,
97- _multidim_multiclass_prob_sk_metric ,
98- NUM_CLASSES ,
99- False ,
93+ _multidim_multiclass_prob_inputs .preds ,
94+ _multidim_multiclass_prob_inputs .target ,
95+ _sk_prec_recall_multidim_multiclass_prob ,
96+ NUM_CLASSES ,
97+ False ,
10098 ),
10199 (
102- _multidim_multiclass_inputs .preds ,
103- _multidim_multiclass_inputs .target ,
104- _multidim_multiclass_sk_metric ,
105- NUM_CLASSES ,
106- False ,
100+ _multidim_multiclass_inputs .preds ,
101+ _multidim_multiclass_inputs .target ,
102+ _sk_prec_recall_multidim_multiclass ,
103+ NUM_CLASSES ,
104+ False ,
107105 ),
108106 ],
109107)
0 commit comments