Skip to content

Commit 4e4af76

Browse files
committed
feat(CrossValidationReport): Add threshold averaging for roc plot
1 parent 36ef673 commit 4e4af76

File tree

7 files changed

+257
-19
lines changed

7 files changed

+257
-19
lines changed

skore/src/skore/sklearn/_cross_validation/metrics_accessor.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1100,6 +1100,7 @@ def _get_display(
11001100
*,
11011101
X: Optional[ArrayLike] = None,
11021102
y: Optional[ArrayLike] = None,
1103+
average: Optional[Literal["threshold"]] = None,
11031104
data_source: DataSource,
11041105
response_method: str,
11051106
display_class: type[
@@ -1208,6 +1209,7 @@ def _get_display(
12081209
display = display_class._compute_data_for_display(
12091210
y_true=y_true,
12101211
y_pred=y_pred,
1212+
average=average,
12111213
report_type="cross-validation",
12121214
estimators=[
12131215
report.estimator_ for report in self._parent.estimator_reports_
@@ -1232,6 +1234,7 @@ def roc(
12321234
data_source: DataSource = "test",
12331235
X: Optional[ArrayLike] = None,
12341236
y: Optional[ArrayLike] = None,
1237+
average: Optional[Literal["threshold"]] = None,
12351238
pos_label: Optional[PositiveLabel] = None,
12361239
) -> RocCurveDisplay:
12371240
"""Plot the ROC curve.
@@ -1280,6 +1283,7 @@ def roc(
12801283
data_source=data_source,
12811284
X=X,
12821285
y=y,
1286+
average=average,
12831287
response_method=response_method,
12841288
display_class=RocCurveDisplay,
12851289
display_kwargs=display_kwargs,
@@ -1294,6 +1298,7 @@ def precision_recall(
12941298
data_source: DataSource = "test",
12951299
X: Optional[ArrayLike] = None,
12961300
y: Optional[ArrayLike] = None,
1301+
average: Optional[Literal["threshold"]] = None,
12971302
pos_label: Optional[PositiveLabel] = None,
12981303
) -> PrecisionRecallCurveDisplay:
12991304
"""Plot the precision-recall curve.

skore/src/skore/sklearn/_cross_validation/report.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -206,10 +206,7 @@ def _fit_estimator_reports(self) -> list[EstimatorReport]:
206206
progress.update(task, advance=1, refresh=True)
207207

208208
warn_msg = None
209-
if not any (
210-
isinstance(report, EstimatorReport)
211-
for report in estimator_reports
212-
):
209+
if not any(isinstance(report, EstimatorReport) for report in estimator_reports):
213210
traceback_msg = "\n".join(str(exc) for exc in estimator_reports)
214211
raise RuntimeError(
215212
"Cross-validation failed: no estimators were successfully fitted. "

skore/src/skore/sklearn/_plot/metrics/precision_recall_curve.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -549,6 +549,7 @@ def _compute_data_for_display(
549549
cls,
550550
y_true: Sequence[YPlotData],
551551
y_pred: Sequence[YPlotData],
552+
average: Optional[Literal["threshold"]] = None,
552553
*,
553554
report_type: ReportType,
554555
estimators: Sequence[BaseEstimator],

skore/src/skore/sklearn/_plot/metrics/roc_curve.py

Lines changed: 189 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import Any, Literal, Optional, Union, cast
33

44
import matplotlib.pyplot as plt
5+
import numpy as np
56
from matplotlib import colormaps
67
from matplotlib.axes import Axes
78
from matplotlib.lines import Line2D
@@ -387,6 +388,113 @@ def _plot_cross_validated_estimator(
387388

388389
return self.ax_, lines, info_pos_label
389390

391+
def _plot_average_cross_validated_binary_estimator(
392+
self,
393+
*,
394+
estimator_name: str,
395+
roc_curve_kwargs: list[dict[str, Any]],
396+
plot_chance_level: bool = True,
397+
chance_level_kwargs: Optional[dict[str, Any]],
398+
) -> tuple[Axes, list[Line2D], Union[str, None]]:
399+
"""Plot ROC curve for a cross-validated estimator.
400+
401+
Parameters
402+
----------
403+
estimator_name : str
404+
The name of the estimator.
405+
406+
roc_curve_kwargs : list of dict
407+
List of dictionaries containing keyword arguments to customize the ROC
408+
curves. The length of the list should match the number of curves to plot.
409+
410+
plot_chance_level : bool, default=True
411+
Whether to plot the chance level.
412+
413+
chance_level_kwargs : dict, default=None
414+
Keyword arguments to be passed to matplotlib's `plot` for rendering
415+
the chance level line.
416+
417+
Returns
418+
-------
419+
ax : matplotlib.axes.Axes
420+
The axes with the ROC curves plotted.
421+
422+
lines : list of matplotlib.lines.Line2D
423+
The plotted ROC curve lines.
424+
425+
info_pos_label : str or None
426+
String containing positive label information for binary classification,
427+
None for multiclass.
428+
"""
429+
lines: list[Line2D] = []
430+
average_type = self.roc_curve["average"].cat.categories.item()
431+
n_folds: int = 0
432+
433+
for split_idx in self.roc_curve["split_index"].cat.categories:
434+
if split_idx is None:
435+
continue
436+
split_idx = int(split_idx)
437+
query = f"label == {self.pos_label!r} & split_index == {split_idx}"
438+
roc_curve = self.roc_curve.query(query)
439+
440+
line_kwargs_validated = _validate_style_kwargs(
441+
{"color": "grey", "alpha": 0.3, "lw": 0.75}, roc_curve_kwargs[split_idx]
442+
)
443+
444+
(line,) = self.ax_.plot(
445+
roc_curve["fpr"],
446+
roc_curve["tpr"],
447+
**line_kwargs_validated,
448+
)
449+
lines.append(line)
450+
n_folds += 1
451+
452+
info_pos_label = (
453+
f"\n(Positive label: {self.pos_label})"
454+
if self.pos_label is not None
455+
else ""
456+
)
457+
458+
query = f"label == {self.pos_label!r} & average == '{average_type}'"
459+
average_roc_curve = self.roc_curve.query(query)
460+
average_roc_auc = self.roc_auc.query(query)["roc_auc"].item()
461+
462+
line_kwargs_validated = _validate_style_kwargs({}, {})
463+
line_kwargs_validated["label"] = (
464+
f"{average_type.capitalize()} average of {n_folds} folds"
465+
f"(AUC = {average_roc_auc:0.2f})"
466+
)
467+
468+
(line,) = self.ax_.plot(
469+
average_roc_curve["fpr"],
470+
average_roc_curve["tpr"],
471+
**line_kwargs_validated,
472+
)
473+
lines.append(line)
474+
475+
info_pos_label = (
476+
f"\n(Positive label: {self.pos_label})"
477+
if self.pos_label is not None
478+
else ""
479+
)
480+
481+
if plot_chance_level:
482+
self.chance_level_ = _add_chance_level(
483+
self.ax_,
484+
chance_level_kwargs,
485+
self._default_chance_level_kwargs,
486+
)
487+
else:
488+
self.chance_level_ = None
489+
490+
if self.data_source in ("train", "test"):
491+
title = f"{estimator_name} on $\\bf{{{self.data_source}}}$ set"
492+
else:
493+
title = f"{estimator_name} on $\\bf{{external}}$ set"
494+
self.ax_.legend(bbox_to_anchor=(1.02, 1), title=title)
495+
496+
return self.ax_, lines, info_pos_label
497+
390498
def _plot_comparison_estimator(
391499
self,
392500
*,
@@ -760,17 +868,30 @@ def plot(
760868
chance_level_kwargs=chance_level_kwargs,
761869
)
762870
elif self.report_type == "cross-validation":
763-
self.ax_, self.lines_, info_pos_label = (
764-
self._plot_cross_validated_estimator(
765-
estimator_name=(
766-
estimator_name
767-
or self.roc_auc["estimator_name"].cat.categories.item()
768-
),
769-
roc_curve_kwargs=roc_curve_kwargs,
770-
plot_chance_level=plot_chance_level,
771-
chance_level_kwargs=chance_level_kwargs,
871+
if "average" in self.roc_auc.columns:
872+
self.ax_, self.lines_, info_pos_label = (
873+
self._plot_average_cross_validated_binary_estimator(
874+
estimator_name=(
875+
estimator_name
876+
or self.roc_auc["estimator_name"].cat.categories.item()
877+
),
878+
roc_curve_kwargs=roc_curve_kwargs,
879+
plot_chance_level=plot_chance_level,
880+
chance_level_kwargs=chance_level_kwargs,
881+
)
882+
)
883+
else:
884+
self.ax_, self.lines_, info_pos_label = (
885+
self._plot_cross_validated_estimator(
886+
estimator_name=(
887+
estimator_name
888+
or self.roc_auc["estimator_name"].cat.categories.item()
889+
),
890+
roc_curve_kwargs=roc_curve_kwargs,
891+
plot_chance_level=plot_chance_level,
892+
chance_level_kwargs=chance_level_kwargs,
893+
)
772894
)
773-
)
774895
elif self.report_type == "comparison-estimator":
775896
self.ax_, self.lines_, info_pos_label = self._plot_comparison_estimator(
776897
estimator_names=self.roc_auc["estimator_name"].cat.categories,
@@ -812,6 +933,7 @@ def _compute_data_for_display(
812933
cls,
813934
y_true: Sequence[YPlotData],
814935
y_pred: Sequence[YPlotData],
936+
average: Optional[Literal["threshold"]] = None,
815937
*,
816938
report_type: ReportType,
817939
estimators: Sequence[BaseEstimator],
@@ -869,6 +991,7 @@ def _compute_data_for_display(
869991
roc_auc_records = []
870992

871993
if ml_task == "binary-classification":
994+
pos_label_validated = cast(PositiveLabel, pos_label_validated)
872995
for y_true_i, y_pred_i in zip(y_true, y_pred):
873996
fpr_i, tpr_i, thresholds_i = roc_curve(
874997
y_true_i.y,
@@ -878,8 +1001,6 @@ def _compute_data_for_display(
8781001
)
8791002
roc_auc_i = auc(fpr_i, tpr_i)
8801003

881-
pos_label_validated = cast(PositiveLabel, pos_label_validated)
882-
8831004
for fpr, tpr, threshold in zip(fpr_i, tpr_i, thresholds_i):
8841005
roc_curve_records.append(
8851006
{
@@ -900,8 +1021,63 @@ def _compute_data_for_display(
9001021
"roc_auc": roc_auc_i,
9011022
}
9021023
)
1024+
if average is not None:
1025+
if average == "threshold":
1026+
all_thresholds = []
1027+
all_fprs = []
1028+
all_tprs = []
1029+
1030+
roc_curves_df = DataFrame.from_records(roc_curve_records)
1031+
for _, group in roc_curves_df.groupby("split_index"):
1032+
sorted_group = group.sort_values("threshold", ascending=False)
1033+
all_thresholds.append(
1034+
np.array(sorted_group["threshold"].values)
1035+
)
1036+
all_fprs.append(np.array(sorted_group["fpr"].values))
1037+
all_tprs.append(np.array(sorted_group["tpr"].values))
1038+
1039+
average_fpr, average_tpr, average_threshold = (
1040+
cls._threshold_average(
1041+
xs=all_fprs,
1042+
ys=all_tprs,
1043+
thresholds=all_thresholds,
1044+
)
1045+
)
1046+
else:
1047+
raise TypeError(
1048+
"'threshold' is the only supported option for `average`,"
1049+
f"but got {average} instead"
1050+
)
1051+
average_roc_auc = auc(average_fpr, average_tpr)
1052+
for fpr, tpr, threshold in zip(
1053+
average_fpr, average_tpr, average_threshold
1054+
):
1055+
roc_curve_records.append(
1056+
{
1057+
"estimator_name": y_true_i.estimator_name,
1058+
"split_index": None,
1059+
"label": pos_label_validated,
1060+
"threshold": threshold,
1061+
"fpr": fpr,
1062+
"tpr": tpr,
1063+
"average": "threshold",
1064+
}
1065+
)
1066+
roc_auc_records.append(
1067+
{
1068+
"estimator_name": y_true_i.estimator_name,
1069+
"split_index": None,
1070+
"label": pos_label_validated,
1071+
"roc_auc": average_roc_auc,
1072+
"average": "threshold",
1073+
}
1074+
)
9031075

9041076
else: # multiclass-classification
1077+
if average is not None:
1078+
raise ValueError(
1079+
"Averaging is not implemented for multi class classification"
1080+
)
9051081
# OvR fashion to collect fpr, tpr, and roc_auc
9061082
for y_true_i, y_pred_i, est in zip(y_true, y_pred, estimators):
9071083
label_binarizer = LabelBinarizer().fit(est.classes_)
@@ -942,7 +1118,7 @@ def _compute_data_for_display(
9421118
"estimator_name": "category",
9431119
"split_index": "category",
9441120
"label": "category",
945-
}
1121+
} | ({"average": "category"} if average is not None else {})
9461122

9471123
return cls(
9481124
roc_curve=DataFrame.from_records(roc_curve_records).astype(dtypes),

skore/src/skore/sklearn/_plot/utils.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import numpy as np
77
from matplotlib.axes import Axes
88
from matplotlib.colors import Colormap
9+
from numpy.typing import ArrayLike
910
from rich.console import Console
1011
from rich.panel import Panel
1112
from rich.tree import Tree
@@ -241,6 +242,47 @@ def _validate_from_predictions_params(
241242

242243
return pos_label
243244

245+
@staticmethod
246+
def _threshold_average(
247+
xs: list[ArrayLike], ys: list[ArrayLike], thresholds: list[ArrayLike]
248+
) -> tuple[list[float], list[float], list[float]]:
249+
"""
250+
Private method to calculate threshold average roc or precision_recall_curve.
251+
252+
Parameters
253+
----------
254+
x : list of array-like of shape (n_samples,)
255+
False positive rates or precision
256+
y : list of array-like of shape (n_samples,)
257+
True positive rates or recall
258+
thresholds : list of array-like of shape (n_samples,)
259+
Thresholds
260+
"""
261+
unique_thresholds = sorted(np.unique(np.concatenate(thresholds)), reverse=True)
262+
263+
average_x = []
264+
average_y = []
265+
average_threshold = []
266+
for target_threshold in unique_thresholds:
267+
threshold_x, threshold_y = [], []
268+
for x, y, threshold in zip(
269+
xs,
270+
ys,
271+
thresholds,
272+
):
273+
closest_idx = max(
274+
np.searchsorted(threshold[::-1], target_threshold, side="right")
275+
- 1,
276+
0,
277+
)
278+
closest_idx_inverted = (closest_idx + 1) * -1
279+
threshold_x.append(x[closest_idx_inverted])
280+
threshold_y.append(y[closest_idx_inverted])
281+
average_x.append(np.mean(threshold_x))
282+
average_y.append(np.mean(threshold_y))
283+
average_threshold.append(target_threshold)
284+
return average_x, average_y, average_threshold
285+
244286

245287
def _despine_matplotlib_axis(
246288
ax: Axes,

skore/tests/unit/sklearn/cross_validation/test_cross_validation.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -320,11 +320,12 @@ def test_cross_validation_report_metrics_data_source_external(
320320
########################################################################################
321321

322322

323-
def test_cross_validation_report_plot_roc(binary_classification_data):
323+
@pytest.mark.parametrize("average", [None, "threshold"])
324+
def test_cross_validation_report_plot_roc(binary_classification_data, average):
324325
"""Check that the ROC plot method works."""
325326
estimator, X, y = binary_classification_data
326327
report = CrossValidationReport(estimator, X, y, cv_splitter=2)
327-
assert isinstance(report.metrics.roc(), RocCurveDisplay)
328+
assert isinstance(report.metrics.roc(average=average), RocCurveDisplay)
328329

329330

330331
@pytest.mark.parametrize("display", ["roc", "precision_recall"])

0 commit comments

Comments
 (0)