Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 48 additions & 13 deletions skore/src/skore/sklearn/_cross_validation/metrics_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1106,6 +1106,7 @@ def _get_display(
*,
X: Optional[ArrayLike] = None,
y: Optional[ArrayLike] = None,
average: Optional[Literal["threshold"]] = None,
data_source: DataSource,
response_method: str,
display_class: type[
Expand Down Expand Up @@ -1155,7 +1156,11 @@ def _get_display(
if "seed" in display_kwargs and display_kwargs["seed"] is None:
cache_key = None
else:
cache_key_parts: list[Any] = [self._parent._hash, display_class.__name__]
cache_key_parts: list[Any] = [
self._parent._hash,
display_class.__name__,
average,
]
cache_key_parts.extend(display_kwargs.values())
if data_source_hash is not None:
cache_key_parts.append(data_source_hash)
Expand Down Expand Up @@ -1211,18 +1216,34 @@ def _get_display(
)
progress.update(main_task, advance=1, refresh=True)

display = display_class._compute_data_for_display(
y_true=y_true,
y_pred=y_pred,
report_type="cross-validation",
estimators=[
report.estimator_ for report in self._parent.estimator_reports_
],
estimator_names=[self._parent.estimator_name_],
ml_task=self._parent._ml_task,
data_source=data_source,
**display_kwargs,
)
if display_class == RocCurveDisplay:
display_class = cast(type[RocCurveDisplay], display_class)
display = display_class._compute_data_for_display(
y_true=y_true,
y_pred=y_pred,
average=average,
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mypy is still complaining about this line, even after casting:
Unexpected keyword argument "average" for "_compute_data_for_display" of "PrecisionRecallCurveDisplay

Without casting it shows the same error twice. Any suggestions on why?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, the cast looks correct to me... maybe reveal_type could help?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Side-note: This is a sign that _get_display should stop existing, although that should be the subject of another PR.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, the cast looks correct to me... maybe reveal_type could help?

Thanks, though it looks how I would expect after casting:
Revealed type is "type[skore.sklearn._plot.metrics.roc_curve.RocCurveDisplay]"

report_type="cross-validation",
estimators=[
report.estimator_ for report in self._parent.estimator_reports_
],
estimator_names=[self._parent.estimator_name_],
ml_task=self._parent._ml_task,
data_source=data_source,
**display_kwargs,
)
else:
display = display_class._compute_data_for_display(
y_true=y_true,
y_pred=y_pred,
report_type="cross-validation",
estimators=[
report.estimator_ for report in self._parent.estimator_reports_
],
estimator_names=[self._parent.estimator_name_],
ml_task=self._parent._ml_task,
data_source=data_source,
**display_kwargs,
)

if cache_key is not None:
# Unless seed is an int (i.e. the call is deterministic),
Expand All @@ -1238,6 +1259,7 @@ def roc(
data_source: DataSource = "test",
X: Optional[ArrayLike] = None,
y: Optional[ArrayLike] = None,
average: Optional[Literal["threshold"]] = None,
pos_label: Optional[PositiveLabel] = None,
) -> RocCurveDisplay:
"""Plot the ROC curve.
Expand All @@ -1259,6 +1281,12 @@ def roc(
New target on which to compute the metric. By default, we use the target
provided when creating the report.

average: {"threshold"}, default=None
Method to use for averaging cross-validation ROC curves.
Possible values are:
- `None`: No averaging.
- `"threshold"`: Threshold averaging [1]_.

pos_label : int, float, bool or str, default=None
The positive class.

Expand All @@ -1267,6 +1295,12 @@ def roc(
RocCurveDisplay
The ROC curve display.

References
----------

.. [1] T. Fawcett, "An introduction to ROC analysis", Pattern Recognition
Letters, 27(8), 861–874, 2006.

Examples
--------
>>> from sklearn.datasets import load_breast_cancer
Expand All @@ -1286,6 +1320,7 @@ def roc(
data_source=data_source,
X=X,
y=y,
average=average,
response_method=response_method,
display_class=RocCurveDisplay,
display_kwargs=display_kwargs,
Expand Down
207 changes: 194 additions & 13 deletions skore/src/skore/sklearn/_plot/metrics/roc_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Any, Literal, Optional, Union, cast

import matplotlib.pyplot as plt
import numpy as np
from matplotlib import colormaps
from matplotlib.axes import Axes
from matplotlib.lines import Line2D
Expand Down Expand Up @@ -387,6 +388,115 @@ def _plot_cross_validated_estimator(

return self.ax_, lines, info_pos_label

def _plot_average_cross_validated_binary_estimator(
self,
*,
estimator_name: str,
roc_curve_kwargs: list[dict[str, Any]],
plot_chance_level: bool = True,
chance_level_kwargs: Optional[dict[str, Any]],
) -> tuple[Axes, list[Line2D], Union[str, None]]:
"""Plot average ROC curve for a cross-validated binary estimator.

Includes the underlying ROC curves from cross-validation.

Parameters
----------
estimator_name : str
The name of the estimator.

roc_curve_kwargs : list of dict
List of dictionaries containing keyword arguments to customize the ROC
curves. The length of the list should match the number of curves to plot.

plot_chance_level : bool, default=True
Whether to plot the chance level.

chance_level_kwargs : dict, default=None
Keyword arguments to be passed to matplotlib's `plot` for rendering
the chance level line.

Returns
-------
ax : matplotlib.axes.Axes
The axes with the ROC curves plotted.

lines : list of matplotlib.lines.Line2D
The plotted ROC curve lines.

info_pos_label : str or None
String containing positive label information for binary classification,
None for multiclass.
"""
lines: list[Line2D] = []
average_type = self.roc_curve["average"].cat.categories.item()
n_folds: int = 0

for split_idx in self.roc_curve["split_index"].cat.categories:
if split_idx is None:
continue
split_idx = int(split_idx)
query = f"label == {self.pos_label!r} & split_index == {split_idx}"
roc_curve = self.roc_curve.query(query)

line_kwargs_validated = _validate_style_kwargs(
{"color": "grey", "alpha": 0.3, "lw": 0.75}, roc_curve_kwargs[split_idx]
)

(line,) = self.ax_.plot(
roc_curve["fpr"],
roc_curve["tpr"],
**line_kwargs_validated,
)
lines.append(line)
n_folds += 1

info_pos_label = (
f"\n(Positive label: {self.pos_label})"
if self.pos_label is not None
else ""
)

query = f"label == {self.pos_label!r} & average == '{average_type}'"
average_roc_curve = self.roc_curve.query(query)
average_roc_auc = self.roc_auc.query(query)["roc_auc"].item()

line_kwargs_validated = _validate_style_kwargs({}, {})
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure how best to take style kwargs for the average line. Other lines slice using split_idx, but this is None for the average line

Copy link
Contributor

@auguste-probabl auguste-probabl May 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it might be time to make roc_curve_kwargs more specific in this case. Something like

Kwargs = dict[str, Any]
oneOrMore[T] = Union[T, list[T]]

class RocCurveKwargs(TypedDict):
	splits: oneOrMore[Kwargs]
	average: Optional[Kwargs]

line_kwargs_validated["label"] = (
f"{average_type.capitalize()} average of {n_folds} folds"
f"(AUC = {average_roc_auc:0.2f})"
)

(line,) = self.ax_.plot(
average_roc_curve["fpr"],
average_roc_curve["tpr"],
**line_kwargs_validated,
)
lines.append(line)

info_pos_label = (
f"\n(Positive label: {self.pos_label})"
if self.pos_label is not None
else ""
)

if plot_chance_level:
self.chance_level_ = _add_chance_level(
self.ax_,
chance_level_kwargs,
self._default_chance_level_kwargs,
)
else:
self.chance_level_ = None

if self.data_source in ("train", "test"):
title = f"{estimator_name} on $\\bf{{{self.data_source}}}$ set"
else:
title = f"{estimator_name} on $\\bf{{external}}$ set"
self.ax_.legend(bbox_to_anchor=(1.02, 1), title=title)

return self.ax_, lines, info_pos_label

def _plot_comparison_estimator(
self,
*,
Expand Down Expand Up @@ -760,17 +870,30 @@ def plot(
chance_level_kwargs=chance_level_kwargs,
)
elif self.report_type == "cross-validation":
self.ax_, self.lines_, info_pos_label = (
self._plot_cross_validated_estimator(
estimator_name=(
estimator_name
or self.roc_auc["estimator_name"].cat.categories.item()
),
roc_curve_kwargs=roc_curve_kwargs,
plot_chance_level=plot_chance_level,
chance_level_kwargs=chance_level_kwargs,
if "average" in self.roc_auc.columns:
self.ax_, self.lines_, info_pos_label = (
self._plot_average_cross_validated_binary_estimator(
estimator_name=(
estimator_name
or self.roc_auc["estimator_name"].cat.categories.item()
),
roc_curve_kwargs=roc_curve_kwargs,
plot_chance_level=plot_chance_level,
chance_level_kwargs=chance_level_kwargs,
)
)
else:
self.ax_, self.lines_, info_pos_label = (
self._plot_cross_validated_estimator(
estimator_name=(
estimator_name
or self.roc_auc["estimator_name"].cat.categories.item()
),
roc_curve_kwargs=roc_curve_kwargs,
plot_chance_level=plot_chance_level,
chance_level_kwargs=chance_level_kwargs,
)
)
)
elif self.report_type == "comparison-estimator":
self.ax_, self.lines_, info_pos_label = self._plot_comparison_estimator(
estimator_names=self.roc_auc["estimator_name"].cat.categories,
Expand Down Expand Up @@ -812,6 +935,7 @@ def _compute_data_for_display(
cls,
y_true: Sequence[YPlotData],
y_pred: Sequence[YPlotData],
average: Optional[Literal["threshold"]] = None,
*,
report_type: ReportType,
estimators: Sequence[BaseEstimator],
Expand All @@ -833,6 +957,9 @@ def _compute_data_for_display(
confidence values, or non-thresholded measure of decisions (as returned by
"decision_function" on some classifiers).

average: {"threshold"}, default=None
Method to use for averaging cross-validation ROC curves.

report_type : {"comparison-cross-validation", "comparison-estimator", \
"cross-validation", "estimator"}
The type of report.
Expand Down Expand Up @@ -869,6 +996,7 @@ def _compute_data_for_display(
roc_auc_records = []

if ml_task == "binary-classification":
pos_label_validated = cast(PositiveLabel, pos_label_validated)
for y_true_i, y_pred_i in zip(y_true, y_pred):
fpr_i, tpr_i, thresholds_i = roc_curve(
y_true_i.y,
Expand All @@ -878,8 +1006,6 @@ def _compute_data_for_display(
)
roc_auc_i = auc(fpr_i, tpr_i)

pos_label_validated = cast(PositiveLabel, pos_label_validated)

for fpr, tpr, threshold in zip(fpr_i, tpr_i, thresholds_i):
roc_curve_records.append(
{
Expand All @@ -900,8 +1026,63 @@ def _compute_data_for_display(
"roc_auc": roc_auc_i,
}
)
if average is not None:
if average == "threshold":
all_thresholds = []
all_fprs = []
all_tprs = []

roc_curves_df = DataFrame.from_records(roc_curve_records)
for _, group in roc_curves_df.groupby("split_index"):
sorted_group = group.sort_values("threshold", ascending=False)
all_thresholds.append(
np.array(sorted_group["threshold"].values)
)
all_fprs.append(np.array(sorted_group["fpr"].values))
all_tprs.append(np.array(sorted_group["tpr"].values))

average_fpr, average_tpr, average_threshold = (
cls._threshold_average(
xs=all_fprs,
ys=all_tprs,
thresholds=all_thresholds,
)
)
else:
raise TypeError(
'average must be "threshold" or None, '
f"got {average}"
)
average_roc_auc = auc(average_fpr, average_tpr)
for fpr, tpr, threshold in zip(
average_fpr, average_tpr, average_threshold
):
roc_curve_records.append(
{
"estimator_name": y_true_i.estimator_name,
"split_index": None,
"label": pos_label_validated,
"threshold": threshold,
"fpr": fpr,
"tpr": tpr,
"average": "threshold",
}
)
roc_auc_records.append(
{
"estimator_name": y_true_i.estimator_name,
"split_index": None,
"label": pos_label_validated,
"roc_auc": average_roc_auc,
"average": "threshold",
}
)

else: # multiclass-classification
if average is not None:
raise ValueError(
"Averaging is not implemented for multi class classification"
)
# OvR fashion to collect fpr, tpr, and roc_auc
for y_true_i, y_pred_i, est in zip(y_true, y_pred, estimators):
label_binarizer = LabelBinarizer().fit(est.classes_)
Expand Down Expand Up @@ -942,7 +1123,7 @@ def _compute_data_for_display(
"estimator_name": "category",
"split_index": "category",
"label": "category",
}
} | ({"average": "category"} if average is not None else {})

return cls(
roc_curve=DataFrame.from_records(roc_curve_records).astype(dtypes),
Expand Down
Loading