diff --git a/skore/src/skore/sklearn/_cross_validation/metrics_accessor.py b/skore/src/skore/sklearn/_cross_validation/metrics_accessor.py index cdc994e03e..293647902b 100644 --- a/skore/src/skore/sklearn/_cross_validation/metrics_accessor.py +++ b/skore/src/skore/sklearn/_cross_validation/metrics_accessor.py @@ -1106,6 +1106,7 @@ def _get_display( *, X: Optional[ArrayLike] = None, y: Optional[ArrayLike] = None, + average: Optional[Literal["threshold"]] = None, data_source: DataSource, response_method: str, display_class: type[ @@ -1155,7 +1156,11 @@ def _get_display( if "seed" in display_kwargs and display_kwargs["seed"] is None: cache_key = None else: - cache_key_parts: list[Any] = [self._parent._hash, display_class.__name__] + cache_key_parts: list[Any] = [ + self._parent._hash, + display_class.__name__, + average, + ] cache_key_parts.extend(display_kwargs.values()) if data_source_hash is not None: cache_key_parts.append(data_source_hash) @@ -1211,18 +1216,34 @@ def _get_display( ) progress.update(main_task, advance=1, refresh=True) - display = display_class._compute_data_for_display( - y_true=y_true, - y_pred=y_pred, - report_type="cross-validation", - estimators=[ - report.estimator_ for report in self._parent.estimator_reports_ - ], - estimator_names=[self._parent.estimator_name_], - ml_task=self._parent._ml_task, - data_source=data_source, - **display_kwargs, - ) + if display_class == RocCurveDisplay: + display_class = cast(type[RocCurveDisplay], display_class) + display = display_class._compute_data_for_display( + y_true=y_true, + y_pred=y_pred, + average=average, + report_type="cross-validation", + estimators=[ + report.estimator_ for report in self._parent.estimator_reports_ + ], + estimator_names=[self._parent.estimator_name_], + ml_task=self._parent._ml_task, + data_source=data_source, + **display_kwargs, + ) + else: + display = display_class._compute_data_for_display( + y_true=y_true, + y_pred=y_pred, + report_type="cross-validation", + estimators=[ + report.estimator_ for report in self._parent.estimator_reports_ + ], + estimator_names=[self._parent.estimator_name_], + ml_task=self._parent._ml_task, + data_source=data_source, + **display_kwargs, + ) if cache_key is not None: # Unless seed is an int (i.e. the call is deterministic), @@ -1238,6 +1259,7 @@ def roc( data_source: DataSource = "test", X: Optional[ArrayLike] = None, y: Optional[ArrayLike] = None, + average: Optional[Literal["threshold"]] = None, pos_label: Optional[PositiveLabel] = None, ) -> RocCurveDisplay: """Plot the ROC curve. @@ -1259,6 +1281,12 @@ def roc( New target on which to compute the metric. By default, we use the target provided when creating the report. + average: {"threshold"}, default=None + Method to use for averaging cross-validation ROC curves. + Possible values are: + - `None`: No averaging. + - `"threshold"`: Threshold averaging [1]_. + pos_label : int, float, bool or str, default=None The positive class. @@ -1267,6 +1295,12 @@ def roc( RocCurveDisplay The ROC curve display. + References + ---------- + + .. [1] T. Fawcett, "An introduction to ROC analysis", Pattern Recognition + Letters, 27(8), 861–874, 2006. + Examples -------- >>> from sklearn.datasets import load_breast_cancer @@ -1286,6 +1320,7 @@ def roc( data_source=data_source, X=X, y=y, + average=average, response_method=response_method, display_class=RocCurveDisplay, display_kwargs=display_kwargs, diff --git a/skore/src/skore/sklearn/_plot/metrics/roc_curve.py b/skore/src/skore/sklearn/_plot/metrics/roc_curve.py index 4f4e67b0d2..f531927f7c 100644 --- a/skore/src/skore/sklearn/_plot/metrics/roc_curve.py +++ b/skore/src/skore/sklearn/_plot/metrics/roc_curve.py @@ -2,6 +2,7 @@ from typing import Any, Literal, Optional, Union, cast import matplotlib.pyplot as plt +import numpy as np from matplotlib import colormaps from matplotlib.axes import Axes from matplotlib.lines import Line2D @@ -387,6 +388,115 @@ def _plot_cross_validated_estimator( return self.ax_, lines, info_pos_label + def _plot_average_cross_validated_binary_estimator( + self, + *, + estimator_name: str, + roc_curve_kwargs: list[dict[str, Any]], + plot_chance_level: bool = True, + chance_level_kwargs: Optional[dict[str, Any]], + ) -> tuple[Axes, list[Line2D], Union[str, None]]: + """Plot average ROC curve for a cross-validated binary estimator. + + Includes the underlying ROC curves from cross-validation. + + Parameters + ---------- + estimator_name : str + The name of the estimator. + + roc_curve_kwargs : list of dict + List of dictionaries containing keyword arguments to customize the ROC + curves. The length of the list should match the number of curves to plot. + + plot_chance_level : bool, default=True + Whether to plot the chance level. + + chance_level_kwargs : dict, default=None + Keyword arguments to be passed to matplotlib's `plot` for rendering + the chance level line. + + Returns + ------- + ax : matplotlib.axes.Axes + The axes with the ROC curves plotted. + + lines : list of matplotlib.lines.Line2D + The plotted ROC curve lines. + + info_pos_label : str or None + String containing positive label information for binary classification, + None for multiclass. + """ + lines: list[Line2D] = [] + average_type = self.roc_curve["average"].cat.categories.item() + n_folds: int = 0 + + for split_idx in self.roc_curve["split_index"].cat.categories: + if split_idx is None: + continue + split_idx = int(split_idx) + query = f"label == {self.pos_label!r} & split_index == {split_idx}" + roc_curve = self.roc_curve.query(query) + + line_kwargs_validated = _validate_style_kwargs( + {"color": "grey", "alpha": 0.3, "lw": 0.75}, roc_curve_kwargs[split_idx] + ) + + (line,) = self.ax_.plot( + roc_curve["fpr"], + roc_curve["tpr"], + **line_kwargs_validated, + ) + lines.append(line) + n_folds += 1 + + info_pos_label = ( + f"\n(Positive label: {self.pos_label})" + if self.pos_label is not None + else "" + ) + + query = f"label == {self.pos_label!r} & average == '{average_type}'" + average_roc_curve = self.roc_curve.query(query) + average_roc_auc = self.roc_auc.query(query)["roc_auc"].item() + + line_kwargs_validated = _validate_style_kwargs({}, {}) + line_kwargs_validated["label"] = ( + f"{average_type.capitalize()} average of {n_folds} folds" + f"(AUC = {average_roc_auc:0.2f})" + ) + + (line,) = self.ax_.plot( + average_roc_curve["fpr"], + average_roc_curve["tpr"], + **line_kwargs_validated, + ) + lines.append(line) + + info_pos_label = ( + f"\n(Positive label: {self.pos_label})" + if self.pos_label is not None + else "" + ) + + if plot_chance_level: + self.chance_level_ = _add_chance_level( + self.ax_, + chance_level_kwargs, + self._default_chance_level_kwargs, + ) + else: + self.chance_level_ = None + + if self.data_source in ("train", "test"): + title = f"{estimator_name} on $\\bf{{{self.data_source}}}$ set" + else: + title = f"{estimator_name} on $\\bf{{external}}$ set" + self.ax_.legend(bbox_to_anchor=(1.02, 1), title=title) + + return self.ax_, lines, info_pos_label + def _plot_comparison_estimator( self, *, @@ -760,17 +870,30 @@ def plot( chance_level_kwargs=chance_level_kwargs, ) elif self.report_type == "cross-validation": - self.ax_, self.lines_, info_pos_label = ( - self._plot_cross_validated_estimator( - estimator_name=( - estimator_name - or self.roc_auc["estimator_name"].cat.categories.item() - ), - roc_curve_kwargs=roc_curve_kwargs, - plot_chance_level=plot_chance_level, - chance_level_kwargs=chance_level_kwargs, + if "average" in self.roc_auc.columns: + self.ax_, self.lines_, info_pos_label = ( + self._plot_average_cross_validated_binary_estimator( + estimator_name=( + estimator_name + or self.roc_auc["estimator_name"].cat.categories.item() + ), + roc_curve_kwargs=roc_curve_kwargs, + plot_chance_level=plot_chance_level, + chance_level_kwargs=chance_level_kwargs, + ) + ) + else: + self.ax_, self.lines_, info_pos_label = ( + self._plot_cross_validated_estimator( + estimator_name=( + estimator_name + or self.roc_auc["estimator_name"].cat.categories.item() + ), + roc_curve_kwargs=roc_curve_kwargs, + plot_chance_level=plot_chance_level, + chance_level_kwargs=chance_level_kwargs, + ) ) - ) elif self.report_type == "comparison-estimator": self.ax_, self.lines_, info_pos_label = self._plot_comparison_estimator( estimator_names=self.roc_auc["estimator_name"].cat.categories, @@ -812,6 +935,7 @@ def _compute_data_for_display( cls, y_true: Sequence[YPlotData], y_pred: Sequence[YPlotData], + average: Optional[Literal["threshold"]] = None, *, report_type: ReportType, estimators: Sequence[BaseEstimator], @@ -833,6 +957,9 @@ def _compute_data_for_display( confidence values, or non-thresholded measure of decisions (as returned by "decision_function" on some classifiers). + average: {"threshold"}, default=None + Method to use for averaging cross-validation ROC curves. + report_type : {"comparison-cross-validation", "comparison-estimator", \ "cross-validation", "estimator"} The type of report. @@ -869,6 +996,7 @@ def _compute_data_for_display( roc_auc_records = [] if ml_task == "binary-classification": + pos_label_validated = cast(PositiveLabel, pos_label_validated) for y_true_i, y_pred_i in zip(y_true, y_pred): fpr_i, tpr_i, thresholds_i = roc_curve( y_true_i.y, @@ -878,8 +1006,6 @@ def _compute_data_for_display( ) roc_auc_i = auc(fpr_i, tpr_i) - pos_label_validated = cast(PositiveLabel, pos_label_validated) - for fpr, tpr, threshold in zip(fpr_i, tpr_i, thresholds_i): roc_curve_records.append( { @@ -900,8 +1026,63 @@ def _compute_data_for_display( "roc_auc": roc_auc_i, } ) + if average is not None: + if average == "threshold": + all_thresholds = [] + all_fprs = [] + all_tprs = [] + + roc_curves_df = DataFrame.from_records(roc_curve_records) + for _, group in roc_curves_df.groupby("split_index"): + sorted_group = group.sort_values("threshold", ascending=False) + all_thresholds.append( + np.array(sorted_group["threshold"].values) + ) + all_fprs.append(np.array(sorted_group["fpr"].values)) + all_tprs.append(np.array(sorted_group["tpr"].values)) + + average_fpr, average_tpr, average_threshold = ( + cls._threshold_average( + xs=all_fprs, + ys=all_tprs, + thresholds=all_thresholds, + ) + ) + else: + raise TypeError( + 'average must be "threshold" or None, ' + f"got {average}" + ) + average_roc_auc = auc(average_fpr, average_tpr) + for fpr, tpr, threshold in zip( + average_fpr, average_tpr, average_threshold + ): + roc_curve_records.append( + { + "estimator_name": y_true_i.estimator_name, + "split_index": None, + "label": pos_label_validated, + "threshold": threshold, + "fpr": fpr, + "tpr": tpr, + "average": "threshold", + } + ) + roc_auc_records.append( + { + "estimator_name": y_true_i.estimator_name, + "split_index": None, + "label": pos_label_validated, + "roc_auc": average_roc_auc, + "average": "threshold", + } + ) else: # multiclass-classification + if average is not None: + raise ValueError( + "Averaging is not implemented for multi class classification" + ) # OvR fashion to collect fpr, tpr, and roc_auc for y_true_i, y_pred_i, est in zip(y_true, y_pred, estimators): label_binarizer = LabelBinarizer().fit(est.classes_) @@ -942,7 +1123,7 @@ def _compute_data_for_display( "estimator_name": "category", "split_index": "category", "label": "category", - } + } | ({"average": "category"} if average is not None else {}) return cls( roc_curve=DataFrame.from_records(roc_curve_records).astype(dtypes), diff --git a/skore/src/skore/sklearn/_plot/utils.py b/skore/src/skore/sklearn/_plot/utils.py index 61bee3558f..120fdddae9 100644 --- a/skore/src/skore/sklearn/_plot/utils.py +++ b/skore/src/skore/sklearn/_plot/utils.py @@ -6,6 +6,7 @@ import numpy as np from matplotlib.axes import Axes from matplotlib.colors import Colormap +from numpy.typing import ArrayLike from rich.console import Console from rich.panel import Panel from rich.tree import Tree @@ -241,6 +242,47 @@ def _validate_from_predictions_params( return pos_label + @staticmethod + def _threshold_average( + xs: list[ArrayLike], ys: list[ArrayLike], thresholds: list[ArrayLike] + ) -> tuple[list[float], list[float], list[float]]: + """ + Private method to calculate threshold average roc or precision_recall_curve. + + Parameters + ---------- + x : list of array-like of shape (n_samples,) + False positive rates or precision + y : list of array-like of shape (n_samples,) + True positive rates or recall + thresholds : list of array-like of shape (n_samples,) + Thresholds + """ + unique_thresholds = sorted(np.unique(np.concatenate(thresholds)), reverse=True) + + average_x = [] + average_y = [] + average_threshold = [] + for target_threshold in unique_thresholds: + threshold_x, threshold_y = [], [] + for x, y, threshold in zip( + xs, + ys, + thresholds, + ): + closest_idx = max( + np.searchsorted(threshold[::-1], target_threshold, side="right") + - 1, + 0, + ) + closest_idx_inverted = (closest_idx + 1) * -1 + threshold_x.append(x[closest_idx_inverted]) + threshold_y.append(y[closest_idx_inverted]) + average_x.append(np.mean(threshold_x)) + average_y.append(np.mean(threshold_y)) + average_threshold.append(target_threshold) + return average_x, average_y, average_threshold + def _despine_matplotlib_axis( ax: Axes, diff --git a/skore/tests/unit/sklearn/cross_validation/test_cross_validation.py b/skore/tests/unit/sklearn/cross_validation/test_cross_validation.py index 7f102fb0be..35df59fc0a 100644 --- a/skore/tests/unit/sklearn/cross_validation/test_cross_validation.py +++ b/skore/tests/unit/sklearn/cross_validation/test_cross_validation.py @@ -321,11 +321,12 @@ def test_cross_validation_report_metrics_data_source_external( ######################################################################################## -def test_cross_validation_report_plot_roc(binary_classification_data): +@pytest.mark.parametrize("average", [None, "threshold"]) +def test_cross_validation_report_plot_roc(binary_classification_data, average): """Check that the ROC plot method works.""" estimator, X, y = binary_classification_data report = CrossValidationReport(estimator, X, y, cv_splitter=2) - assert isinstance(report.metrics.roc(), RocCurveDisplay) + assert isinstance(report.metrics.roc(average=average), RocCurveDisplay) @pytest.mark.parametrize("display", ["roc", "precision_recall"]) diff --git a/skore/tests/unit/sklearn/test_utils.py b/skore/tests/unit/sklearn/test_utils.py index 208f077712..6764e41100 100644 --- a/skore/tests/unit/sklearn/test_utils.py +++ b/skore/tests/unit/sklearn/test_utils.py @@ -1,6 +1,7 @@ import numpy import pandas import pytest +from numpy.testing import assert_array_equal from sklearn.cluster import KMeans from sklearn.datasets import ( make_classification, @@ -10,6 +11,7 @@ from sklearn.dummy import DummyClassifier, DummyRegressor from sklearn.linear_model import LinearRegression, LogisticRegression from sklearn.multioutput import MultiOutputClassifier +from skore.sklearn._plot.utils import _ClassifierCurveDisplayMixin from skore.sklearn.find_ml_task import _find_ml_task @@ -118,3 +120,17 @@ def test_find_ml_task_pandas(): def test_find_ml_task_string(): assert _find_ml_task(["0", "1", "2"], None) == "multiclass-classification" + + +class Test_ClassifierCurveDisplayMixin: + def test__threshold_average(self): + xs = [numpy.array([3, 2, 1]), numpy.array([3, 2, 1])] + ys = [numpy.array([3, 2, 1]), numpy.array([3, 2, 1])] + thresholds = [numpy.array([4, 3, 1]), numpy.array([5, 3, 2])] + x, y, threshold = _ClassifierCurveDisplayMixin._threshold_average( + xs, ys, thresholds + ) + expected = numpy.array([3, 2.5, 2, 1, 1]) + assert_array_equal(x, expected) + assert_array_equal(y, expected) + assert_array_equal(threshold, numpy.array([5, 4, 3, 2, 1]))