Skip to content

Commit 8d08dac

Browse files
feat(ComparisonReport[CVReport]): Add ROC curve (#1669)
1 parent d32b7d5 commit 8d08dac

File tree

16 files changed

+1479
-891
lines changed

16 files changed

+1479
-891
lines changed

skore/src/skore/sklearn/_comparison/metrics_accessor.py

Lines changed: 110 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1263,9 +1263,6 @@ def _get_display(
12631263
display : display_class
12641264
The display.
12651265
"""
1266-
if self._parent._reports_type == "CrossValidationReport":
1267-
raise NotImplementedError()
1268-
12691266
if "seed" in display_kwargs and display_kwargs["seed"] is None:
12701267
cache_key = None
12711268
else:
@@ -1288,55 +1285,121 @@ def _get_display(
12881285
y_true: list[YPlotData] = []
12891286
y_pred: list[YPlotData] = []
12901287

1291-
for report, report_name in zip(
1292-
self._parent.reports_, self._parent.report_names_
1293-
):
1294-
report_X, report_y, _ = report.metrics._get_X_y_and_data_source_hash(
1295-
data_source=data_source,
1296-
X=X,
1297-
y=y,
1298-
)
1288+
if self._parent._reports_type == "EstimatorReport":
1289+
for report, report_name in zip(
1290+
self._parent.reports_, self._parent.report_names_
1291+
):
1292+
report_X, report_y, _ = (
1293+
report.metrics._get_X_y_and_data_source_hash(
1294+
data_source=data_source,
1295+
X=X,
1296+
y=y,
1297+
)
1298+
)
12991299

1300-
y_true.append(
1301-
YPlotData(
1302-
estimator_name=report_name,
1303-
split_index=None,
1304-
y=report_y,
1300+
y_true.append(
1301+
YPlotData(
1302+
estimator_name=report_name,
1303+
split_index=None,
1304+
y=report_y,
1305+
)
13051306
)
1306-
)
1307-
results = _get_cached_response_values(
1308-
cache=report._cache,
1309-
estimator_hash=report._hash,
1310-
estimator=report._estimator,
1311-
X=report_X,
1312-
response_method=response_method,
1307+
results = _get_cached_response_values(
1308+
cache=report._cache,
1309+
estimator_hash=report._hash,
1310+
estimator=report._estimator,
1311+
X=report_X,
1312+
response_method=response_method,
1313+
data_source=data_source,
1314+
data_source_hash=None,
1315+
pos_label=display_kwargs.get("pos_label"),
1316+
)
1317+
for key, value, is_cached in results:
1318+
if not is_cached:
1319+
report._cache[key] = value
1320+
if key[-1] != "predict_time":
1321+
y_pred.append(
1322+
YPlotData(
1323+
estimator_name=report_name,
1324+
split_index=None,
1325+
y=value,
1326+
)
1327+
)
1328+
1329+
progress.update(main_task, advance=1, refresh=True)
1330+
1331+
display = display_class._compute_data_for_display(
1332+
y_true=y_true,
1333+
y_pred=y_pred,
1334+
report_type="comparison-estimator",
1335+
estimators=[report.estimator_ for report in self._parent.reports_],
1336+
estimator_names=self._parent.report_names_,
1337+
ml_task=self._parent._ml_task,
13131338
data_source=data_source,
1314-
data_source_hash=None,
1315-
pos_label=display_kwargs.get("pos_label"),
1339+
**display_kwargs,
13161340
)
1317-
for key, value, is_cached in results:
1318-
if not is_cached:
1319-
report._cache[key] = value
1320-
if key[-1] != "predict_time":
1321-
y_pred.append(
1341+
1342+
else:
1343+
for report, report_name in zip(
1344+
self._parent.reports_, self._parent.report_names_
1345+
):
1346+
for split_index, estimator_report in enumerate(
1347+
report.estimator_reports_
1348+
):
1349+
report_X, report_y, _ = (
1350+
estimator_report.metrics._get_X_y_and_data_source_hash(
1351+
data_source=data_source,
1352+
X=X,
1353+
y=y,
1354+
)
1355+
)
1356+
1357+
y_true.append(
13221358
YPlotData(
13231359
estimator_name=report_name,
1324-
split_index=None,
1325-
y=value,
1360+
split_index=split_index,
1361+
y=report_y,
13261362
)
13271363
)
1328-
progress.update(main_task, advance=1, refresh=True)
13291364

1330-
display = display_class._compute_data_for_display(
1331-
y_true=y_true,
1332-
y_pred=y_pred,
1333-
report_type="comparison-estimator",
1334-
estimators=[report.estimator_ for report in self._parent.reports_],
1335-
estimator_names=self._parent.report_names_,
1336-
ml_task=self._parent._ml_task,
1337-
data_source=data_source,
1338-
**display_kwargs,
1339-
)
1365+
results = _get_cached_response_values(
1366+
cache=estimator_report._cache,
1367+
estimator_hash=estimator_report._hash,
1368+
estimator=estimator_report.estimator_,
1369+
X=report_X,
1370+
response_method=response_method,
1371+
data_source=data_source,
1372+
data_source_hash=None,
1373+
pos_label=display_kwargs.get("pos_label"),
1374+
)
1375+
for key, value, is_cached in results:
1376+
if not is_cached:
1377+
report._cache[key] = value
1378+
if key[-1] != "predict_time":
1379+
y_pred.append(
1380+
YPlotData(
1381+
estimator_name=report_name,
1382+
split_index=split_index,
1383+
y=value,
1384+
)
1385+
)
1386+
1387+
progress.update(main_task, advance=1, refresh=True)
1388+
1389+
display = display_class._compute_data_for_display(
1390+
y_true=y_true,
1391+
y_pred=y_pred,
1392+
report_type="comparison-cross-validation",
1393+
estimators=[
1394+
estimator_report.estimator_
1395+
for report in self._parent.reports_
1396+
for estimator_report in report.estimator_reports_
1397+
],
1398+
estimator_names=self._parent.report_names_,
1399+
ml_task=self._parent._ml_task,
1400+
data_source=data_source,
1401+
**display_kwargs,
1402+
)
13401403

13411404
if cache_key is not None:
13421405
# Unless seed is an int (i.e. the call is deterministic),
@@ -1476,6 +1539,8 @@ def precision_recall(
14761539
>>> display = comparison_report.metrics.precision_recall()
14771540
>>> display.plot()
14781541
"""
1542+
if self._parent._reports_type == "CrossValidationReport":
1543+
raise NotImplementedError()
14791544
response_method = ("predict_proba", "decision_function")
14801545
display_kwargs = {"pos_label": pos_label}
14811546
display = cast(
@@ -1560,6 +1625,8 @@ def prediction_error(
15601625
>>> display = comparison_report.metrics.prediction_error()
15611626
>>> display.plot(kind="actual_vs_predicted")
15621627
"""
1628+
if self._parent._reports_type == "CrossValidationReport":
1629+
raise NotImplementedError()
15631630
display_kwargs = {"subsample": subsample, "seed": seed}
15641631
display = cast(
15651632
PredictionErrorDisplay,

skore/src/skore/sklearn/_plot/metrics/precision_recall_curve.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
_validate_style_kwargs,
2222
sample_mpl_colormap,
2323
)
24-
from skore.sklearn.types import MLTask, PositiveLabel, YPlotData
24+
from skore.sklearn.types import MLTask, PositiveLabel, ReportType, YPlotData
2525

2626

2727
class PrecisionRecallCurveDisplay(
@@ -81,7 +81,8 @@ class PrecisionRecallCurveDisplay(
8181
ml_task : {"binary-classification", "multiclass-classification"}
8282
The machine learning task.
8383
84-
report_type : {"comparison-estimator", "cross-validation", "estimator"}
84+
report_type : {"comparison-cross-validation", "comparison-estimator", \
85+
"cross-validation", "estimator"}
8586
The type of report.
8687
8788
Attributes
@@ -121,7 +122,7 @@ def __init__(
121122
pos_label: Optional[PositiveLabel],
122123
data_source: Literal["train", "test", "X_y"],
123124
ml_task: MLTask,
124-
report_type: Literal["comparison-estimator", "cross-validation", "estimator"],
125+
report_type: ReportType,
125126
) -> None:
126127
self.precision = precision
127128
self.recall = recall
@@ -480,10 +481,15 @@ def plot(
480481
if pr_curve_kwargs is None:
481482
pr_curve_kwargs = self._default_pr_curve_kwargs
482483

484+
if self.ml_task == "binary-classification":
485+
n_curves = len(self.average_precision[self.pos_label])
486+
else:
487+
n_curves = len(self.average_precision)
488+
483489
pr_curve_kwargs = self._validate_curve_kwargs(
484490
curve_param_name="pr_curve_kwargs",
485491
curve_kwargs=pr_curve_kwargs,
486-
metric=self.average_precision,
492+
n_curves=n_curves,
487493
report_type=self.report_type,
488494
)
489495

@@ -512,10 +518,13 @@ def plot(
512518
estimator_names=self.estimator_names,
513519
pr_curve_kwargs=pr_curve_kwargs,
514520
)
521+
elif self.report_type == "comparison-cross-validation":
522+
raise NotImplementedError()
515523
else:
516524
raise ValueError(
517-
f"`report_type` should be one of 'estimator', 'cross-validation', "
518-
f"or 'comparison-estimator'. Got '{self.report_type}' instead."
525+
"`report_type` should be one of 'estimator', 'cross-validation', "
526+
"'comparison-cross-validation' or 'comparison-estimator'. "
527+
f"Got '{self.report_type}' instead."
519528
)
520529

521530
xlabel = "Recall"
@@ -541,7 +550,7 @@ def _compute_data_for_display(
541550
y_true: Sequence[YPlotData],
542551
y_pred: Sequence[YPlotData],
543552
*,
544-
report_type: Literal["comparison-estimator", "cross-validation", "estimator"],
553+
report_type: ReportType,
545554
estimators: Sequence[BaseEstimator],
546555
estimator_names: list[str],
547556
ml_task: MLTask,
@@ -561,7 +570,8 @@ def _compute_data_for_display(
561570
confidence values, or non-thresholded measure of decisions (as returned by
562571
"decision_function" on some classifiers).
563572
564-
report_type : {"comparison-estimator", "cross-validation", "estimator"}
573+
report_type : {"comparison-cross-validation", "comparison-estimator", \
574+
"cross-validation", "estimator"}
565575
The type of report.
566576
567577
estimators : list of estimator instances

skore/src/skore/sklearn/_plot/metrics/prediction_error.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
_validate_style_kwargs,
1818
sample_mpl_colormap,
1919
)
20-
from skore.sklearn.types import MLTask, YPlotData
20+
from skore.sklearn.types import MLTask, ReportType, YPlotData
2121

2222
RangeData = namedtuple("RangeData", ["min", "max"])
2323

@@ -62,7 +62,8 @@ class PredictionErrorDisplay(StyleDisplayMixin, HelpDisplayMixin):
6262
ml_task : {"regression", "multioutput-regression"}
6363
The machine learning task.
6464
65-
report_type : {"cross-validation", "estimator", "comparison-estimator"}
65+
report_type : {"comparison-cross-validation", "comparison-estimator", \
66+
"cross-validation", "estimator"}
6667
The type of report.
6768
6869
Attributes
@@ -113,7 +114,7 @@ def __init__(
113114
estimator_names: list[str],
114115
data_source: Literal["train", "test", "X_y"],
115116
ml_task: MLTask,
116-
report_type: Literal["cross-validation", "estimator", "comparison-estimator"],
117+
report_type: ReportType,
117118
) -> None:
118119
self.y_true = y_true
119120
self.y_pred = y_pred
@@ -557,7 +558,7 @@ def _compute_data_for_display(
557558
y_true: list[YPlotData],
558559
y_pred: list[YPlotData],
559560
*,
560-
report_type: Literal["cross-validation", "estimator", "comparison-estimator"],
561+
report_type: ReportType,
561562
estimator_names: list[str],
562563
ml_task: MLTask,
563564
data_source: Literal["train", "test", "X_y"],
@@ -575,6 +576,10 @@ def _compute_data_for_display(
575576
y_pred : list of array-like of shape (n_samples,)
576577
Predicted target values.
577578
579+
report_type : {"comparison-cross-validation", "comparison-estimator", \
580+
"cross-validation", "estimator"}
581+
The type of report.
582+
578583
estimators : list of estimator instances
579584
The estimators from which `y_pred` is obtained.
580585

0 commit comments

Comments
 (0)