Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/src/_static/custom.css
Original file line number Diff line number Diff line change
Expand Up @@ -133,3 +133,9 @@ div.sphx-glr-download a:hover {
font-weight: bold;
font-style: italic;
}

/* Bold small caps */
.bsc {
font-weight: bold;
font-variant: small-caps;
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

# These should be defined by your use-case
import torch
from datasets import load_dataset # type: ignore[import-untyped]
from datasets import load_dataset # type: ignore[import-untyped, unused-ignore]

calib_set = load_dataset("ego-thales/cifar10", name="calibration")["unique_split"]
calib_data, calib_labels, _ = calib_set.with_format("torch")[:].values()
Expand Down
10 changes: 9 additions & 1 deletion scio/eval/classification/discriminative_power.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,15 @@ def __call__(self, labels: ArrayLike, scores: ArrayLike) -> float:


class AUC(BaseDiscriminativePower):
"""AUC for ROC, potentially partial — in which case normalized.
r"""AUC for ROC, potentially partial — in which case normalized.

With the default arguments, one has

.. math::

AUC = \mathbb{P}(\text{score}_{\text{OoD}}<\text{score}_{\text{InD}}),

when sampling from the reference population.

Arguments
---------
Expand Down
93 changes: 61 additions & 32 deletions scio/eval/classification/roc.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,16 @@


class ROC:
"""ROC utility for Discriminative Power and visualization.
r"""ROC utility for Discriminative Power and visualization.

We recall that a :ref:`Discriminative Power <discriminative_power>`
only depends on the Pareto front of all the :math:`(FP, TP)` tuples
when thresholding with every possible threshold. Per convention:

#. The thresholding test is ``score <= tau``;
#. The thresholding test is ``score <= threshold``.
#. **Positive** (*i.e.* OoD) samples should verify this and thus
have a **low score**;
#. Scores must not be ``nan`` or ``-inf`` (ensuring validity of the
first note in :attr:`pareto`);
have a **low score**.
#. Scores must not be ``nan``.

Arguments
---------
Expand All @@ -33,6 +32,25 @@ class ROC:
scores: ``ArrayLike``
The score of samples. Shape ``(n_samples,)``.

Raises
------
:exc:`AssertionError`
If there is no positive (*resp.* negative) labels.
:exc:`AssertionError`
If there is at least one ``nan`` score.

Note
----
.. role:: bsc
:class: bsc

If a negative (*i.e.* InD) sample has a score of :math:`-\infty`,
then the ROC curve would theoretically start with a *nonzero*
:attr:`~ROC.FPR`. In this case, for consistency in
:ref:`discriminative_power` definitions, we artificially add the
point :math:`(0, 0)`, corresponding to the trivial :bsc:`False`
classifier.

"""

def __init__(self, labels: ArrayLike, scores: ArrayLike) -> None:
Expand All @@ -49,7 +67,6 @@ def _preprocess(self, labels: ArrayLike, scores: ArrayLike) -> None:
check(labels_np.any())
check(not labels_np.all())
check(not np.isnan(scores).any())
check(-np.inf < scores_np.min())

sorter = np.argsort(scores)
self._scores = scores_np[sorter]
Expand All @@ -67,18 +84,20 @@ def _compute_front(self) -> None:
# ``inf`` (considered self equal). Rests on ``scores`` being
# sorted. Faster than ``np.unique`` which keeps first occurrence
unique_mask = scores != np.r_[scores[1:], np.nan]
PP = np.where(unique_mask)[0]
TP = self._labels.cumsum()[PP]
FP = PP - TP + 1
attainable_fptp = np.insert(np.c_[FP, TP], 0, 0, 0)
pareto_mask = (np.diff(attainable_fptp[:, 0], append=inf) > 0) & (
np.diff(attainable_fptp[:, 1], prepend=-inf) > 0
)
self._pareto = attainable_fptp[pareto_mask]
unique_thresholds = np.insert(scores[unique_mask], unique_mask.sum(), inf)
PP = np.where(unique_mask)[0] + 1 # Predicted Positive
TP = self._labels.cumsum()[PP - 1]
FP = PP - TP

# Add ``(0, 0)``
unique_thresholds = np.insert(unique_thresholds, 0, -inf)
FP = np.insert(FP, 0, 0)
TP = np.insert(TP, 0, 0)

pareto_mask = (np.diff(FP, append=inf) > 0) & (np.diff(TP, prepend=-inf) > 0)
pareto_idx = np.where(pareto_mask)[0]
self._thresholds = np.insert(scores[unique_mask], [0, len(PP)], [-inf, inf])[
[pareto_idx, pareto_idx + 1]
].T
self._pareto = np.c_[FP, TP][pareto_mask]
self._thresholds = unique_thresholds[[pareto_idx, pareto_idx + 1]].T
self._N, self._P = int(FP[-1]), int(TP[-1])

def _compute_convex_hull(self) -> None:
Expand Down Expand Up @@ -107,33 +126,43 @@ def P(self) -> int:

@property
def pareto(self) -> NDArray[np.integer]:
"""Ordered :math:`(FP, TP)` tuples defining the Pareto front.
r"""Ordered :math:`(FP, TP)` tuples defining the Pareto front.

Returns
-------
pareto: ``NDArray[np.integer]``
Shape ``(n_points_pareto, 2)``.
Shape ``(n_pareto_points, 2)``.

Note
----
The following are always true:

- ``self.pareto[0, 0] == 0`` since ``-inf`` scores are
prohibited;
- ``self.pareto[0, 0] == 0`` (see :class:`ROC` note);
- ``self.pareto[-1, 1] == self.P``.

"""
return self._pareto

@property
def thresholds(self) -> NDArray[np.floating]:
"""The threshold intervals associated with Pareto points.
r"""The threshold intervals associated with Pareto points.

Returns
-------
thresholds: ``NDArray[np.floating]``
Convention: lower bound is included, higher bound is
excluded (unless ``inf``). Shape ``(n_points_pareto, 2)``.
Intervals for thresholds, to achieve the corresponding
:math:`(FP, TP)` point from :attr:`~ROC.pareto`. The lower
bound is included and the upper bound is excluded, with the
two following exceptions.

1. A :math:`+\infty` upper bound is included if and only
if ``self.pareto[-1, 0] == self.N``.
2. A :math:`-\infty` upper bound is a special case for the
point :math:`(0, 0)`, when it is not attainable via
thresholding because a negative sample has a score of
:math:`-\infty`.

Shape ``(n_pareto_points, 2)``.

"""
return self._thresholds
Expand All @@ -160,7 +189,7 @@ def FP(self) -> NDArray[np.integer]:
Returns
-------
FP: ``NDArray[np.integer]``
Shape ``(n_points_pareto,)``.
Shape ``(n_pareto_points,)``.

"""
return self.pareto[:, 0]
Expand All @@ -177,7 +206,7 @@ def TP(self) -> NDArray[np.integer]:
Returns
-------
TP: ``NDArray[np.integer]``
Shape ``(n_points_pareto,)``.
Shape ``(n_pareto_points,)``.

"""
return self.pareto[:, 1]
Expand All @@ -194,7 +223,7 @@ def FN(self) -> NDArray[np.integer]:
Returns
-------
FN: ``NDArray[np.integer]``
Shape ``(n_points_pareto,)``.
Shape ``(n_pareto_points,)``.

"""
return self.P - self.TP
Expand All @@ -211,7 +240,7 @@ def TN(self) -> NDArray[np.integer]:
Returns
-------
TN: ``NDArray[np.integer]``
Shape ``(n_points_pareto,)``.
Shape ``(n_pareto_points,)``.

"""
return self.N - self.FP
Expand All @@ -228,7 +257,7 @@ def FPR(self) -> NDArray[np.floating]:
Returns
-------
FPR: ``NDArray[np.floating]``
Shape ``(n_points_pareto,)``.
Shape ``(n_pareto_points,)``.

"""
return self.FP / self.N
Expand All @@ -245,7 +274,7 @@ def TPR(self) -> NDArray[np.floating]:
Returns
-------
TPR: ``NDArray[np.floating]``
Shape ``(n_points_pareto,)``.
Shape ``(n_pareto_points,)``.

"""
return self.TP / self.P
Expand All @@ -262,7 +291,7 @@ def FNR(self) -> NDArray[np.floating]:
Returns
-------
FNR: ``NDArray[np.floating]``
Shape ``(n_points_pareto,)``.
Shape ``(n_pareto_points,)``.

"""
return self.FN / self.P
Expand All @@ -279,7 +308,7 @@ def TNR(self) -> NDArray[np.floating]:
Returns
-------
TNR: ``NDArray[np.floating]``
Shape ``(n_points_pareto,)``.
Shape ``(n_pareto_points,)``.

"""
return self.TN / self.N
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified test/expected/test_roc_plot/plot_legend=True-with_ax=False.fig0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified test/expected/test_roc_plot/plot_legend=True-with_ax=True.fig0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading