-
Notifications
You must be signed in to change notification settings - Fork 96
feat(CrossValidationReport): Add threshold averaging for roc plot #1750
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
4e4af76
2745dae
ed457c7
3b8e8eb
efc6608
06b5895
30e027c
5eb22fb
dab17ae
ff902d3
07d31c6
10be285
058d22d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,6 +2,7 @@ | |
| from typing import Any, Literal, Optional, Union, cast | ||
|
|
||
| import matplotlib.pyplot as plt | ||
| import numpy as np | ||
| from matplotlib import colormaps | ||
| from matplotlib.axes import Axes | ||
| from matplotlib.lines import Line2D | ||
|
|
@@ -387,6 +388,115 @@ def _plot_cross_validated_estimator( | |
|
|
||
| return self.ax_, lines, info_pos_label | ||
|
|
||
| def _plot_average_cross_validated_binary_estimator( | ||
| self, | ||
| *, | ||
| estimator_name: str, | ||
| roc_curve_kwargs: list[dict[str, Any]], | ||
| plot_chance_level: bool = True, | ||
| chance_level_kwargs: Optional[dict[str, Any]], | ||
| ) -> tuple[Axes, list[Line2D], Union[str, None]]: | ||
| """Plot average ROC curve for a cross-validated binary estimator. | ||
|
|
||
| Includes the underlying ROC curves from cross-validation. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| estimator_name : str | ||
| The name of the estimator. | ||
|
|
||
| roc_curve_kwargs : list of dict | ||
| List of dictionaries containing keyword arguments to customize the ROC | ||
| curves. The length of the list should match the number of curves to plot. | ||
|
|
||
| plot_chance_level : bool, default=True | ||
| Whether to plot the chance level. | ||
|
|
||
| chance_level_kwargs : dict, default=None | ||
| Keyword arguments to be passed to matplotlib's `plot` for rendering | ||
| the chance level line. | ||
|
|
||
| Returns | ||
| ------- | ||
| ax : matplotlib.axes.Axes | ||
| The axes with the ROC curves plotted. | ||
|
|
||
| lines : list of matplotlib.lines.Line2D | ||
| The plotted ROC curve lines. | ||
|
|
||
| info_pos_label : str or None | ||
| String containing positive label information for binary classification, | ||
| None for multiclass. | ||
| """ | ||
| lines: list[Line2D] = [] | ||
| average_type = self.roc_curve["average"].cat.categories.item() | ||
| n_folds: int = 0 | ||
|
|
||
| for split_idx in self.roc_curve["split_index"].cat.categories: | ||
| if split_idx is None: | ||
| continue | ||
| split_idx = int(split_idx) | ||
| query = f"label == {self.pos_label!r} & split_index == {split_idx}" | ||
| roc_curve = self.roc_curve.query(query) | ||
|
|
||
| line_kwargs_validated = _validate_style_kwargs( | ||
| {"color": "grey", "alpha": 0.3, "lw": 0.75}, roc_curve_kwargs[split_idx] | ||
| ) | ||
|
|
||
| (line,) = self.ax_.plot( | ||
| roc_curve["fpr"], | ||
| roc_curve["tpr"], | ||
| **line_kwargs_validated, | ||
| ) | ||
| lines.append(line) | ||
| n_folds += 1 | ||
|
|
||
| info_pos_label = ( | ||
| f"\n(Positive label: {self.pos_label})" | ||
| if self.pos_label is not None | ||
| else "" | ||
| ) | ||
|
|
||
| query = f"label == {self.pos_label!r} & average == '{average_type}'" | ||
| average_roc_curve = self.roc_curve.query(query) | ||
| average_roc_auc = self.roc_auc.query(query)["roc_auc"].item() | ||
|
|
||
| line_kwargs_validated = _validate_style_kwargs({}, {}) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure how best to take style kwargs for the average line. Other lines slice using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, it might be time to make Kwargs = dict[str, Any]
oneOrMore[T] = Union[T, list[T]]
class RocCurveKwargs(TypedDict):
splits: oneOrMore[Kwargs]
average: Optional[Kwargs] |
||
| line_kwargs_validated["label"] = ( | ||
| f"{average_type.capitalize()} average of {n_folds} folds" | ||
| f"(AUC = {average_roc_auc:0.2f})" | ||
| ) | ||
|
|
||
| (line,) = self.ax_.plot( | ||
| average_roc_curve["fpr"], | ||
| average_roc_curve["tpr"], | ||
| **line_kwargs_validated, | ||
| ) | ||
| lines.append(line) | ||
|
|
||
| info_pos_label = ( | ||
| f"\n(Positive label: {self.pos_label})" | ||
| if self.pos_label is not None | ||
| else "" | ||
| ) | ||
|
|
||
| if plot_chance_level: | ||
| self.chance_level_ = _add_chance_level( | ||
| self.ax_, | ||
| chance_level_kwargs, | ||
| self._default_chance_level_kwargs, | ||
| ) | ||
| else: | ||
| self.chance_level_ = None | ||
|
|
||
| if self.data_source in ("train", "test"): | ||
| title = f"{estimator_name} on $\\bf{{{self.data_source}}}$ set" | ||
| else: | ||
| title = f"{estimator_name} on $\\bf{{external}}$ set" | ||
| self.ax_.legend(bbox_to_anchor=(1.02, 1), title=title) | ||
|
|
||
| return self.ax_, lines, info_pos_label | ||
|
|
||
| def _plot_comparison_estimator( | ||
| self, | ||
| *, | ||
|
|
@@ -760,17 +870,30 @@ def plot( | |
| chance_level_kwargs=chance_level_kwargs, | ||
| ) | ||
| elif self.report_type == "cross-validation": | ||
| self.ax_, self.lines_, info_pos_label = ( | ||
| self._plot_cross_validated_estimator( | ||
| estimator_name=( | ||
| estimator_name | ||
| or self.roc_auc["estimator_name"].cat.categories.item() | ||
| ), | ||
| roc_curve_kwargs=roc_curve_kwargs, | ||
| plot_chance_level=plot_chance_level, | ||
| chance_level_kwargs=chance_level_kwargs, | ||
| if "average" in self.roc_auc.columns: | ||
| self.ax_, self.lines_, info_pos_label = ( | ||
| self._plot_average_cross_validated_binary_estimator( | ||
| estimator_name=( | ||
| estimator_name | ||
| or self.roc_auc["estimator_name"].cat.categories.item() | ||
| ), | ||
| roc_curve_kwargs=roc_curve_kwargs, | ||
| plot_chance_level=plot_chance_level, | ||
| chance_level_kwargs=chance_level_kwargs, | ||
| ) | ||
| ) | ||
| else: | ||
| self.ax_, self.lines_, info_pos_label = ( | ||
| self._plot_cross_validated_estimator( | ||
| estimator_name=( | ||
| estimator_name | ||
| or self.roc_auc["estimator_name"].cat.categories.item() | ||
| ), | ||
| roc_curve_kwargs=roc_curve_kwargs, | ||
| plot_chance_level=plot_chance_level, | ||
| chance_level_kwargs=chance_level_kwargs, | ||
| ) | ||
| ) | ||
| ) | ||
| elif self.report_type == "comparison-estimator": | ||
| self.ax_, self.lines_, info_pos_label = self._plot_comparison_estimator( | ||
| estimator_names=self.roc_auc["estimator_name"].cat.categories, | ||
|
|
@@ -812,6 +935,7 @@ def _compute_data_for_display( | |
| cls, | ||
| y_true: Sequence[YPlotData], | ||
| y_pred: Sequence[YPlotData], | ||
| average: Optional[Literal["threshold"]] = None, | ||
| *, | ||
| report_type: ReportType, | ||
| estimators: Sequence[BaseEstimator], | ||
|
|
@@ -833,6 +957,9 @@ def _compute_data_for_display( | |
| confidence values, or non-thresholded measure of decisions (as returned by | ||
| "decision_function" on some classifiers). | ||
|
|
||
| average: {"threshold"}, default=None | ||
| Method to use for averaging cross-validation ROC curves. | ||
|
|
||
| report_type : {"comparison-cross-validation", "comparison-estimator", \ | ||
| "cross-validation", "estimator"} | ||
| The type of report. | ||
|
|
@@ -869,6 +996,7 @@ def _compute_data_for_display( | |
| roc_auc_records = [] | ||
|
|
||
| if ml_task == "binary-classification": | ||
| pos_label_validated = cast(PositiveLabel, pos_label_validated) | ||
| for y_true_i, y_pred_i in zip(y_true, y_pred): | ||
| fpr_i, tpr_i, thresholds_i = roc_curve( | ||
| y_true_i.y, | ||
|
|
@@ -878,8 +1006,6 @@ def _compute_data_for_display( | |
| ) | ||
| roc_auc_i = auc(fpr_i, tpr_i) | ||
|
|
||
| pos_label_validated = cast(PositiveLabel, pos_label_validated) | ||
|
|
||
| for fpr, tpr, threshold in zip(fpr_i, tpr_i, thresholds_i): | ||
| roc_curve_records.append( | ||
| { | ||
|
|
@@ -900,8 +1026,63 @@ def _compute_data_for_display( | |
| "roc_auc": roc_auc_i, | ||
| } | ||
| ) | ||
| if average is not None: | ||
| if average == "threshold": | ||
| all_thresholds = [] | ||
| all_fprs = [] | ||
| all_tprs = [] | ||
|
|
||
| roc_curves_df = DataFrame.from_records(roc_curve_records) | ||
| for _, group in roc_curves_df.groupby("split_index"): | ||
| sorted_group = group.sort_values("threshold", ascending=False) | ||
| all_thresholds.append( | ||
| np.array(sorted_group["threshold"].values) | ||
| ) | ||
| all_fprs.append(np.array(sorted_group["fpr"].values)) | ||
| all_tprs.append(np.array(sorted_group["tpr"].values)) | ||
|
|
||
| average_fpr, average_tpr, average_threshold = ( | ||
| cls._threshold_average( | ||
| xs=all_fprs, | ||
| ys=all_tprs, | ||
| thresholds=all_thresholds, | ||
| ) | ||
| ) | ||
| else: | ||
| raise TypeError( | ||
| 'average must be "threshold" or None, ' | ||
| f"got {average}" | ||
| ) | ||
| average_roc_auc = auc(average_fpr, average_tpr) | ||
| for fpr, tpr, threshold in zip( | ||
| average_fpr, average_tpr, average_threshold | ||
| ): | ||
| roc_curve_records.append( | ||
| { | ||
| "estimator_name": y_true_i.estimator_name, | ||
| "split_index": None, | ||
| "label": pos_label_validated, | ||
| "threshold": threshold, | ||
| "fpr": fpr, | ||
| "tpr": tpr, | ||
| "average": "threshold", | ||
auguste-probabl marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| } | ||
| ) | ||
| roc_auc_records.append( | ||
| { | ||
| "estimator_name": y_true_i.estimator_name, | ||
| "split_index": None, | ||
| "label": pos_label_validated, | ||
| "roc_auc": average_roc_auc, | ||
| "average": "threshold", | ||
| } | ||
| ) | ||
|
|
||
| else: # multiclass-classification | ||
| if average is not None: | ||
| raise ValueError( | ||
| "Averaging is not implemented for multi class classification" | ||
| ) | ||
| # OvR fashion to collect fpr, tpr, and roc_auc | ||
| for y_true_i, y_pred_i, est in zip(y_true, y_pred, estimators): | ||
| label_binarizer = LabelBinarizer().fit(est.classes_) | ||
|
|
@@ -942,7 +1123,7 @@ def _compute_data_for_display( | |
| "estimator_name": "category", | ||
| "split_index": "category", | ||
| "label": "category", | ||
| } | ||
| } | ({"average": "category"} if average is not None else {}) | ||
|
|
||
| return cls( | ||
| roc_curve=DataFrame.from_records(roc_curve_records).astype(dtypes), | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mypy is still complaining about this line, even after casting:
Unexpected keyword argument "average" for "_compute_data_for_display" of "PrecisionRecallCurveDisplayWithout casting it shows the same error twice. Any suggestions on why?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hm, the cast looks correct to me... maybe reveal_type could help?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Side-note: This is a sign that
_get_displayshould stop existing, although that should be the subject of another PR.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, though it looks how I would expect after casting:
Revealed type is "type[skore.sklearn._plot.metrics.roc_curve.RocCurveDisplay]"