Skip to content

Commit c8c5ddb

Browse files
committed
add utils docstring
1 parent 63b878e commit c8c5ddb

File tree

3 files changed

+39
-19
lines changed

3 files changed

+39
-19
lines changed

examples/getting_started/plot_skore_getting_started.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@
201201

202202
# %%
203203
# We can highlight the performance metric gain against timings
204-
comparator.plot_perf_against_time(perf_metric="Brier score")
204+
comparator.pairwise_plot(perf_metric_x="brier_score", perf_metric_y="fit_time")
205205

206206
# %%
207207
scatter_data = comparator.metrics.report_metrics().T.reset_index()

skore/src/skore/sklearn/_comparison/report.py

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from skore.externals._pandas_accessors import DirNamesMixin
1313
from skore.sklearn._base import _BaseReport
1414
from skore.sklearn._estimator.report import EstimatorReport
15+
from skore.sklearn.utils import _SCORE_OR_LOSS_INFO
1516
from skore.utils._progress_bar import progress_decorator
1617

1718
if TYPE_CHECKING:
@@ -351,18 +352,22 @@ def get_predictions(
351352
for report in self.estimator_reports_
352353
]
353354

354-
def plot_perf_against_time(
355+
def pairwise_plot(
355356
self,
356-
perf_metric: str,
357+
perf_metric_x: str,
358+
perf_metric_y: str,
357359
data_source: Literal["test", "train", "X_y"] = "test",
358-
time_metric: Literal["fit", "predict"] = "predict",
360+
pos_label: Optional[Any] = None,
359361
):
360-
"""
361-
Plot a given performance metric against a time metric.
362+
"""Plot a given performance metric against another.
362363
363364
Parameters
364365
----------
365-
perf_metric : str
366+
perf_metric_x : str
367+
The performance metric to plot on the abscissa axis.
368+
369+
perf_metric_y : str
370+
The performance metrics to plot on the ordinates axis.
366371
367372
data_source : {"test", "train", "X_y"}, default="test"
368373
The data source to use.
@@ -371,11 +376,12 @@ def plot_perf_against_time(
371376
- "train" : use the train set provided when creating the report.
372377
- "X_y" : use the provided `X` and `y` to compute the metric.
373378
374-
perf_metric : str
375-
376-
time_metric: {"fit", "predict"}, default = "predict"
377-
The time metric to use in the plot.
378-
379+
pos_label : int, float, bool or str, default=None
380+
The positive class when it comes to binary classification. When
381+
`response_method="predict_proba"`, it will select the column corresponding
382+
to the positive class. When `response_method="decision_function"`, it will
383+
negate the decision function if `pos_label` is different from
384+
`estimator.classes_[1]`.
379385
380386
Returns
381387
-------
@@ -403,23 +409,21 @@ def plot_perf_against_time(
403409
# - turn into display
404410
# - change name to sth like `pairwise_plot`
405411

406-
if time_metric == "fit":
407-
x_label = "Fit time"
408-
elif time_metric == "predict":
409-
x_label = "Predict time"
412+
x_label = _SCORE_OR_LOSS_INFO[perf_metric_x].get("name", perf_metric_x)
413+
y_label = _SCORE_OR_LOSS_INFO[perf_metric_y].get("name", perf_metric_y)
410414

411-
scatter_data = self.metrics.report_metrics().T.reset_index()
415+
scatter_data = self.metrics.report_metrics(pos_label=pos_label).T.reset_index()
412416
scatter_data.plot(
413417
kind="scatter",
414418
x=x_label,
415-
y="Brier score",
419+
y=y_label,
416420
title="Performance vs Time (s)",
417421
)
418422

419423
# Add labels to the points with a small offset
420424
text = scatter_data["Estimator"]
421425
x = scatter_data[x_label]
422-
y = scatter_data["Brier score"]
426+
y = scatter_data[y_label]
423427
for label, x_coord, y_coord in zip(text, x, y):
424428
plt.annotate(
425429
label,

skore/src/skore/sklearn/utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
"""Utility functions for Skore and Scikit-learn integration."""
2+
3+
_SCORE_OR_LOSS_INFO: dict[str, dict[str, str]] = {
4+
"fit_time": {"name": "Fit time (s)", "icon": "(↘︎)"},
5+
"predict_time": {"name": "Predict time (s)", "icon": "(↘︎)"},
6+
"accuracy": {"name": "Accuracy", "icon": "(↗︎)"},
7+
"precision": {"name": "Precision", "icon": "(↗︎)"},
8+
"recall": {"name": "Recall", "icon": "(↗︎)"},
9+
"brier_score": {"name": "Brier score", "icon": "(↘︎)"},
10+
"roc_auc": {"name": "ROC AUC", "icon": "(↗︎)"},
11+
"log_loss": {"name": "Log loss", "icon": "(↘︎)"},
12+
"r2": {"name": "R²", "icon": "(↗︎)"},
13+
"rmse": {"name": "RMSE", "icon": "(↘︎)"},
14+
"custom_metric": {"name": "Custom metric", "icon": ""},
15+
"report_metrics": {"name": "Report metrics", "icon": ""},
16+
}

0 commit comments

Comments
 (0)