Skip to content

Commit 4290c9e

Browse files
authored
Tests: clean metrics (#4152)
* namme inputs * sk rename * imports
1 parent dec31b3 commit 4290c9e

File tree

9 files changed

+88
-94
lines changed

9 files changed

+88
-94
lines changed

tests/metrics/__init__.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,2 @@
1-
import os
2-
3-
from tests.metrics.utils import NUM_BATCHES, NUM_PROCESSES, BATCH_SIZE, MetricTester
41
from tests.metrics.test_metric import Dummy
2+
from tests.metrics.utils import NUM_BATCHES, NUM_PROCESSES, BATCH_SIZE, MetricTester

tests/metrics/classification/utils.py renamed to tests/metrics/classification/inputs.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,12 @@
1-
import os
2-
import pytest
3-
import numpy as np
1+
from collections import namedtuple
2+
43
import torch
54

6-
from collections import namedtuple
75
from tests.metrics.utils import (
86
NUM_BATCHES,
9-
NUM_PROCESSES,
107
BATCH_SIZE,
118
NUM_CLASSES,
12-
EXTRA_DIM,
13-
THRESHOLD
9+
EXTRA_DIM
1410
)
1511

1612
Input = namedtuple('Input', ["preds", "target"])

tests/metrics/classification/test_accuracy.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from sklearn.metrics import accuracy_score
55

66
from pytorch_lightning.metrics.classification.accuracy import Accuracy
7-
from tests.metrics.classification.utils import (
7+
from tests.metrics.classification.inputs import (
88
_binary_inputs,
99
_binary_prob_inputs,
1010
_multiclass_inputs,
@@ -19,56 +19,56 @@
1919
torch.manual_seed(42)
2020

2121

22-
def _binary_prob_sk_metric(preds, target):
22+
def _sk_accuracy_binary_prob(preds, target):
2323
sk_preds = (preds.view(-1).numpy() >= THRESHOLD).astype(np.uint8)
2424
sk_target = target.view(-1).numpy()
2525

2626
return accuracy_score(y_true=sk_target, y_pred=sk_preds)
2727

2828

29-
def _binary_sk_metric(preds, target):
29+
def _sk_accuracy_binary(preds, target):
3030
sk_preds = preds.view(-1).numpy()
3131
sk_target = target.view(-1).numpy()
3232

3333
return accuracy_score(y_true=sk_target, y_pred=sk_preds)
3434

3535

36-
def _multilabel_prob_sk_metric(preds, target):
36+
def _sk_accuracy_multilabel_prob(preds, target):
3737
sk_preds = (preds.view(-1).numpy() >= THRESHOLD).astype(np.uint8)
3838
sk_target = target.view(-1).numpy()
3939

4040
return accuracy_score(y_true=sk_target, y_pred=sk_preds)
4141

4242

43-
def _multilabel_sk_metric(preds, target):
43+
def _sk_accuracy_multilabel(preds, target):
4444
sk_preds = preds.view(-1).numpy()
4545
sk_target = target.view(-1).numpy()
4646

4747
return accuracy_score(y_true=sk_target, y_pred=sk_preds)
4848

4949

50-
def _multiclass_prob_sk_metric(preds, target):
50+
def _sk_accuracy_multiclass_prob(preds, target):
5151
sk_preds = torch.argmax(preds, dim=len(preds.shape) - 1).view(-1).numpy()
5252
sk_target = target.view(-1).numpy()
5353

5454
return accuracy_score(y_true=sk_target, y_pred=sk_preds)
5555

5656

57-
def _multiclass_sk_metric(preds, target):
57+
def _sk_accuracy_multiclass(preds, target):
5858
sk_preds = preds.view(-1).numpy()
5959
sk_target = target.view(-1).numpy()
6060

6161
return accuracy_score(y_true=sk_target, y_pred=sk_preds)
6262

6363

64-
def _multidim_multiclass_prob_sk_metric(preds, target):
64+
def _sk_accuracy_multidim_multiclass_prob(preds, target):
6565
sk_preds = torch.argmax(preds, dim=len(preds.shape) - 2).view(-1).numpy()
6666
sk_target = target.view(-1).numpy()
6767

6868
return accuracy_score(y_true=sk_target, y_pred=sk_preds)
6969

7070

71-
def _multidim_multiclass_sk_metric(preds, target):
71+
def _sk_accuracy_multidim_multiclass(preds, target):
7272
sk_preds = preds.view(-1).numpy()
7373
sk_target = target.view(-1).numpy()
7474

@@ -86,18 +86,18 @@ def test_accuracy_invalid_shape():
8686
@pytest.mark.parametrize(
8787
"preds, target, sk_metric",
8888
[
89-
(_binary_prob_inputs.preds, _binary_prob_inputs.target, _binary_prob_sk_metric),
90-
(_binary_inputs.preds, _binary_inputs.target, _binary_sk_metric),
91-
(_multilabel_prob_inputs.preds, _multilabel_prob_inputs.target, _multilabel_prob_sk_metric),
92-
(_multilabel_inputs.preds, _multilabel_inputs.target, _multilabel_sk_metric),
93-
(_multiclass_prob_inputs.preds, _multiclass_prob_inputs.target, _multiclass_prob_sk_metric),
94-
(_multiclass_inputs.preds, _multiclass_inputs.target, _multiclass_sk_metric),
89+
(_binary_prob_inputs.preds, _binary_prob_inputs.target, _sk_accuracy_binary_prob),
90+
(_binary_inputs.preds, _binary_inputs.target, _sk_accuracy_binary),
91+
(_multilabel_prob_inputs.preds, _multilabel_prob_inputs.target, _sk_accuracy_multilabel_prob),
92+
(_multilabel_inputs.preds, _multilabel_inputs.target, _sk_accuracy_multilabel),
93+
(_multiclass_prob_inputs.preds, _multiclass_prob_inputs.target, _sk_accuracy_multiclass_prob),
94+
(_multiclass_inputs.preds, _multiclass_inputs.target, _sk_accuracy_multiclass),
9595
(
96-
_multidim_multiclass_prob_inputs.preds,
97-
_multidim_multiclass_prob_inputs.target,
98-
_multidim_multiclass_prob_sk_metric,
96+
_multidim_multiclass_prob_inputs.preds,
97+
_multidim_multiclass_prob_inputs.target,
98+
_sk_accuracy_multidim_multiclass_prob,
9999
),
100-
(_multidim_multiclass_inputs.preds, _multidim_multiclass_inputs.target, _multidim_multiclass_sk_metric),
100+
(_multidim_multiclass_inputs.preds, _multidim_multiclass_inputs.target, _sk_accuracy_multidim_multiclass),
101101
],
102102
)
103103
class TestAccuracy(MetricTester):

tests/metrics/classification/test_f_beta.py

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from sklearn.metrics import fbeta_score
77

88
from pytorch_lightning.metrics import Fbeta
9-
from tests.metrics.classification.utils import (
9+
from tests.metrics.classification.inputs import (
1010
_binary_inputs,
1111
_binary_prob_inputs,
1212
_multiclass_inputs,
@@ -16,61 +16,61 @@
1616
_multilabel_inputs,
1717
_multilabel_prob_inputs,
1818
)
19-
from tests.metrics.utils import BATCH_SIZE, NUM_BATCHES, NUM_CLASSES, NUM_PROCESSES, THRESHOLD, MetricTester
19+
from tests.metrics.utils import NUM_CLASSES, THRESHOLD, MetricTester
2020

2121
torch.manual_seed(42)
2222

2323

24-
def _binary_prob_sk_metric(preds, target, average='micro', beta=1.0):
24+
def _sk_fbeta_binary_prob(preds, target, average='micro', beta=1.0):
2525
sk_preds = (preds.view(-1).numpy() >= THRESHOLD).astype(np.uint8)
2626
sk_target = target.view(-1).numpy()
2727

2828
return fbeta_score(y_true=sk_target, y_pred=sk_preds, average='binary', beta=beta)
2929

3030

31-
def _binary_sk_metric(preds, target, average='micro', beta=1.0):
31+
def _sk_fbeta_binary(preds, target, average='micro', beta=1.0):
3232
sk_preds = preds.view(-1).numpy()
3333
sk_target = target.view(-1).numpy()
3434

3535
return fbeta_score(y_true=sk_target, y_pred=sk_preds, average='binary', beta=beta)
3636

3737

38-
def _multilabel_prob_sk_metric(preds, target, average='micro', beta=1.0):
38+
def _sk_fbeta_multilabel_prob(preds, target, average='micro', beta=1.0):
3939
sk_preds = (preds.view(-1, NUM_CLASSES).numpy() >= THRESHOLD).astype(np.uint8)
4040
sk_target = target.view(-1, NUM_CLASSES).numpy()
4141

4242
return fbeta_score(y_true=sk_target, y_pred=sk_preds, average=average, beta=beta)
4343

4444

45-
def _multilabel_sk_metric(preds, target, average='micro', beta=1.0):
45+
def _sk_fbeta_multilabel(preds, target, average='micro', beta=1.0):
4646
sk_preds = preds.view(-1, NUM_CLASSES).numpy()
4747
sk_target = target.view(-1, NUM_CLASSES).numpy()
4848

4949
return fbeta_score(y_true=sk_target, y_pred=sk_preds, average=average, beta=beta)
5050

5151

52-
def _multiclass_prob_sk_metric(preds, target, average='micro', beta=1.0):
52+
def _sk_fbeta_multiclass_prob(preds, target, average='micro', beta=1.0):
5353
sk_preds = torch.argmax(preds, dim=len(preds.shape) - 1).view(-1).numpy()
5454
sk_target = target.view(-1).numpy()
5555

5656
return fbeta_score(y_true=sk_target, y_pred=sk_preds, average=average, beta=beta)
5757

5858

59-
def _multiclass_sk_metric(preds, target, average='micro', beta=1.0):
59+
def _sk_fbeta_multiclass(preds, target, average='micro', beta=1.0):
6060
sk_preds = preds.view(-1).numpy()
6161
sk_target = target.view(-1).numpy()
6262

6363
return fbeta_score(y_true=sk_target, y_pred=sk_preds, average=average, beta=beta)
6464

6565

66-
def _multidim_multiclass_prob_sk_metric(preds, target, average='micro', beta=1.0):
66+
def _sk_fbeta_multidim_multiclass_prob(preds, target, average='micro', beta=1.0):
6767
sk_preds = torch.argmax(preds, dim=len(preds.shape) - 2).view(-1).numpy()
6868
sk_target = target.view(-1).numpy()
6969

7070
return fbeta_score(y_true=sk_target, y_pred=sk_preds, average=average, beta=beta)
7171

7272

73-
def _multidim_multiclass_sk_metric(preds, target, average='micro', beta=1.0):
73+
def _sk_fbeta_multidim_multiclass(preds, target, average='micro', beta=1.0):
7474
sk_preds = preds.view(-1).numpy()
7575
sk_target = target.view(-1).numpy()
7676

@@ -83,25 +83,25 @@ def _multidim_multiclass_sk_metric(preds, target, average='micro', beta=1.0):
8383
@pytest.mark.parametrize(
8484
"preds, target, sk_metric, num_classes, multilabel",
8585
[
86-
(_binary_prob_inputs.preds, _binary_prob_inputs.target, _binary_prob_sk_metric, 1, False),
87-
(_binary_inputs.preds, _binary_inputs.target, _binary_sk_metric, 1, False),
88-
(_multilabel_prob_inputs.preds, _multilabel_prob_inputs.target, _multilabel_prob_sk_metric, NUM_CLASSES, True),
89-
(_multilabel_inputs.preds, _multilabel_inputs.target, _multilabel_sk_metric, NUM_CLASSES, True),
90-
(_multiclass_prob_inputs.preds, _multiclass_prob_inputs.target, _multiclass_prob_sk_metric, NUM_CLASSES, False),
91-
(_multiclass_inputs.preds, _multiclass_inputs.target, _multiclass_sk_metric, NUM_CLASSES, False),
86+
(_binary_prob_inputs.preds, _binary_prob_inputs.target, _sk_fbeta_binary_prob, 1, False),
87+
(_binary_inputs.preds, _binary_inputs.target, _sk_fbeta_binary, 1, False),
88+
(_multilabel_prob_inputs.preds, _multilabel_prob_inputs.target, _sk_fbeta_multilabel_prob, NUM_CLASSES, True),
89+
(_multilabel_inputs.preds, _multilabel_inputs.target, _sk_fbeta_multilabel, NUM_CLASSES, True),
90+
(_multiclass_prob_inputs.preds, _multiclass_prob_inputs.target, _sk_fbeta_multiclass_prob, NUM_CLASSES, False),
91+
(_multiclass_inputs.preds, _multiclass_inputs.target, _sk_fbeta_multiclass, NUM_CLASSES, False),
9292
(
93-
_multidim_multiclass_prob_inputs.preds,
94-
_multidim_multiclass_prob_inputs.target,
95-
_multidim_multiclass_prob_sk_metric,
96-
NUM_CLASSES,
97-
False,
93+
_multidim_multiclass_prob_inputs.preds,
94+
_multidim_multiclass_prob_inputs.target,
95+
_sk_fbeta_multidim_multiclass_prob,
96+
NUM_CLASSES,
97+
False,
9898
),
9999
(
100-
_multidim_multiclass_inputs.preds,
101-
_multidim_multiclass_inputs.target,
102-
_multidim_multiclass_sk_metric,
103-
NUM_CLASSES,
104-
False,
100+
_multidim_multiclass_inputs.preds,
101+
_multidim_multiclass_inputs.target,
102+
_sk_fbeta_multidim_multiclass,
103+
NUM_CLASSES,
104+
False,
105105
),
106106
],
107107
)

tests/metrics/classification/test_precision_recall.py

Lines changed: 26 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import os
2-
from collections import namedtuple
31
from functools import partial
42

53
import numpy as np
@@ -8,7 +6,7 @@
86
from sklearn.metrics import precision_score, recall_score
97

108
from pytorch_lightning.metrics import Precision, Recall
11-
from tests.metrics.classification.utils import (
9+
from tests.metrics.classification.inputs import (
1210
_binary_inputs,
1311
_binary_prob_inputs,
1412
_multiclass_inputs,
@@ -18,61 +16,61 @@
1816
_multilabel_inputs,
1917
_multilabel_prob_inputs,
2018
)
21-
from tests.metrics.utils import BATCH_SIZE, NUM_BATCHES, NUM_CLASSES, NUM_PROCESSES, THRESHOLD, MetricTester
19+
from tests.metrics.utils import NUM_CLASSES, THRESHOLD, MetricTester
2220

2321
torch.manual_seed(42)
2422

2523

26-
def _binary_prob_sk_metric(preds, target, sk_fn=precision_score, average='micro'):
24+
def _sk_prec_recall_binary_prob(preds, target, sk_fn=precision_score, average='micro'):
2725
sk_preds = (preds.view(-1).numpy() >= THRESHOLD).astype(np.uint8)
2826
sk_target = target.view(-1).numpy()
2927

3028
return sk_fn(y_true=sk_target, y_pred=sk_preds, average='binary')
3129

3230

33-
def _binary_sk_metric(preds, target, sk_fn=precision_score, average='micro'):
31+
def _sk_prec_recall_binary(preds, target, sk_fn=precision_score, average='micro'):
3432
sk_preds = preds.view(-1).numpy()
3533
sk_target = target.view(-1).numpy()
3634

3735
return sk_fn(y_true=sk_target, y_pred=sk_preds, average='binary')
3836

3937

40-
def _multilabel_prob_sk_metric(preds, target, sk_fn=precision_score, average='micro'):
38+
def _sk_prec_recall_multilabel_prob(preds, target, sk_fn=precision_score, average='micro'):
4139
sk_preds = (preds.view(-1, NUM_CLASSES).numpy() >= THRESHOLD).astype(np.uint8)
4240
sk_target = target.view(-1, NUM_CLASSES).numpy()
4341

4442
return sk_fn(y_true=sk_target, y_pred=sk_preds, average=average)
4543

4644

47-
def _multilabel_sk_metric(preds, target, sk_fn=precision_score, average='micro'):
45+
def _sk_prec_recall_multilabel(preds, target, sk_fn=precision_score, average='micro'):
4846
sk_preds = preds.view(-1, NUM_CLASSES).numpy()
4947
sk_target = target.view(-1, NUM_CLASSES).numpy()
5048

5149
return sk_fn(y_true=sk_target, y_pred=sk_preds, average=average)
5250

5351

54-
def _multiclass_prob_sk_metric(preds, target, sk_fn=precision_score, average='micro'):
52+
def _sk_prec_recall_multiclass_prob(preds, target, sk_fn=precision_score, average='micro'):
5553
sk_preds = torch.argmax(preds, dim=len(preds.shape) - 1).view(-1).numpy()
5654
sk_target = target.view(-1).numpy()
5755

5856
return sk_fn(y_true=sk_target, y_pred=sk_preds, average=average)
5957

6058

61-
def _multiclass_sk_metric(preds, target, sk_fn=precision_score, average='micro'):
59+
def _sk_prec_recall_multiclass(preds, target, sk_fn=precision_score, average='micro'):
6260
sk_preds = preds.view(-1).numpy()
6361
sk_target = target.view(-1).numpy()
6462

6563
return sk_fn(y_true=sk_target, y_pred=sk_preds, average=average)
6664

6765

68-
def _multidim_multiclass_prob_sk_metric(preds, target, sk_fn=precision_score, average='micro'):
66+
def _sk_prec_recall_multidim_multiclass_prob(preds, target, sk_fn=precision_score, average='micro'):
6967
sk_preds = torch.argmax(preds, dim=len(preds.shape) - 2).view(-1).numpy()
7068
sk_target = target.view(-1).numpy()
7169

7270
return sk_fn(y_true=sk_target, y_pred=sk_preds, average=average)
7371

7472

75-
def _multidim_multiclass_sk_metric(preds, target, sk_fn=precision_score, average='micro'):
73+
def _sk_prec_recall_multidim_multiclass(preds, target, sk_fn=precision_score, average='micro'):
7674
sk_preds = preds.view(-1).numpy()
7775
sk_target = target.view(-1).numpy()
7876

@@ -85,25 +83,25 @@ def _multidim_multiclass_sk_metric(preds, target, sk_fn=precision_score, average
8583
@pytest.mark.parametrize(
8684
"preds, target, sk_metric, num_classes, multilabel",
8785
[
88-
(_binary_prob_inputs.preds, _binary_prob_inputs.target, _binary_prob_sk_metric, 1, False),
89-
(_binary_inputs.preds, _binary_inputs.target, _binary_sk_metric, 1, False),
90-
(_multilabel_prob_inputs.preds, _multilabel_prob_inputs.target, _multilabel_prob_sk_metric, NUM_CLASSES, True),
91-
(_multilabel_inputs.preds, _multilabel_inputs.target, _multilabel_sk_metric, NUM_CLASSES, True),
92-
(_multiclass_prob_inputs.preds, _multiclass_prob_inputs.target, _multiclass_prob_sk_metric, NUM_CLASSES, False),
93-
(_multiclass_inputs.preds, _multiclass_inputs.target, _multiclass_sk_metric, NUM_CLASSES, False),
86+
(_binary_prob_inputs.preds, _binary_prob_inputs.target, _sk_prec_recall_binary_prob, 1, False),
87+
(_binary_inputs.preds, _binary_inputs.target, _sk_prec_recall_binary, 1, False),
88+
(_multilabel_prob_inputs.preds, _multilabel_prob_inputs.target, _sk_prec_recall_multilabel_prob, NUM_CLASSES, True),
89+
(_multilabel_inputs.preds, _multilabel_inputs.target, _sk_prec_recall_multilabel, NUM_CLASSES, True),
90+
(_multiclass_prob_inputs.preds, _multiclass_prob_inputs.target, _sk_prec_recall_multiclass_prob, NUM_CLASSES, False),
91+
(_multiclass_inputs.preds, _multiclass_inputs.target, _sk_prec_recall_multiclass, NUM_CLASSES, False),
9492
(
95-
_multidim_multiclass_prob_inputs.preds,
96-
_multidim_multiclass_prob_inputs.target,
97-
_multidim_multiclass_prob_sk_metric,
98-
NUM_CLASSES,
99-
False,
93+
_multidim_multiclass_prob_inputs.preds,
94+
_multidim_multiclass_prob_inputs.target,
95+
_sk_prec_recall_multidim_multiclass_prob,
96+
NUM_CLASSES,
97+
False,
10098
),
10199
(
102-
_multidim_multiclass_inputs.preds,
103-
_multidim_multiclass_inputs.target,
104-
_multidim_multiclass_sk_metric,
105-
NUM_CLASSES,
106-
False,
100+
_multidim_multiclass_inputs.preds,
101+
_multidim_multiclass_inputs.target,
102+
_sk_prec_recall_multidim_multiclass,
103+
NUM_CLASSES,
104+
False,
107105
),
108106
],
109107
)

tests/metrics/functional/test_regression.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1+
from functools import partial
2+
from math import sqrt
3+
14
import numpy as np
25
import pytest
36
import torch
4-
from functools import partial
5-
from math import sqrt
67
from skimage.metrics import (
78
peak_signal_noise_ratio as ski_psnr,
89
structural_similarity as ski_ssim

0 commit comments

Comments
 (0)