Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
3f1e97d
chore: Change to more explicit class name and add docs
MarieSacksick Jun 11, 2025
1e68910
Update sphinx/user_guide/reporters.rst
MarieSacksick Jun 11, 2025
39a66cb
alphabetic sorting
MarieSacksick Jun 11, 2025
3158ace
fix init for sphinx
MarieSacksick Jun 11, 2025
4665e9e
add function to be consistant with a display
MarieSacksick Jun 11, 2025
d778605
fix init for sphinx
MarieSacksick Jun 11, 2025
6a0b34d
docs: explain a bit more about display and their functions
MarieSacksick Jun 11, 2025
963bae0
merge
MarieSacksick Jun 9, 2025
e9e8d5d
value error based on literal
MarieSacksick Jun 9, 2025
20f7685
plot for comparison report for estimator
MarieSacksick Jun 9, 2025
b293389
linting
MarieSacksick Jun 11, 2025
22c0e49
introduce temporarily _SCORE_OR_LOSS_INFO in class
MarieSacksick Jun 11, 2025
00ae0d7
first version of plot for comp report ready
MarieSacksick Jun 12, 2025
c557d1e
linting
MarieSacksick Jun 12, 2025
866f367
adapt to name change from report_metrics to summarize
MarieSacksick Jun 12, 2025
ddb7344
add some tests
MarieSacksick Jun 13, 2025
b2f2250
add test not implemented error
MarieSacksick Jun 13, 2025
c6d063a
add tests
MarieSacksick Jun 13, 2025
42fb3a3
add data_source at display creation
MarieSacksick Jun 13, 2025
38a91bf
add new tests
MarieSacksick Jun 13, 2025
9274eda
add description to tests
MarieSacksick Jun 13, 2025
d188828
add example
MarieSacksick Jun 13, 2025
a21ca0e
Update skore/src/skore/sklearn/_plot/metrics/metrics_summary_display.py
MarieSacksick Jun 16, 2025
7fa3615
change not implemented error for estimator
MarieSacksick Jun 16, 2025
77c84f1
bugfix matplotlib ax set
MarieSacksick Jun 16, 2025
d348ca1
remove useless line
MarieSacksick Jun 16, 2025
bf74554
change annotation to legend
MarieSacksick Jun 16, 2025
7631798
Merge branch 'main' into plot_comp_report_metrics
MarieSacksick Jul 28, 2025
d3abdff
linting
MarieSacksick Jul 28, 2025
c7fb645
docs(skore): Change the name of features after preprocessing (#1901)
mrastgoo Jul 29, 2025
bc221c1
finish merge
MarieSacksick Jul 31, 2025
8cd4dbf
annot to legend
MarieSacksick Jul 31, 2025
c9f5dfa
Merge branch 'main' into plot_comp_report_metrics
MarieSacksick Jul 31, 2025
509d4af
change scale according to data range
MarieSacksick Jul 31, 2025
7abf9dc
extend to support custom metrics
MarieSacksick Aug 1, 2025
1d4b013
fix: remove useless subplot
MarieSacksick Aug 1, 2025
589ddb8
add test for axis
MarieSacksick Aug 4, 2025
e5ff840
fix test
MarieSacksick Aug 4, 2025
98f4829
Merge branch 'main' into plot_comp_report_metrics
MarieSacksick Aug 4, 2025
c155b43
Merge branch 'main' into plot_comp_report_metrics
MarieSacksick Oct 8, 2025
65f0c9a
linting and update import to be consistant with refactor
MarieSacksick Oct 8, 2025
2bfc789
add default plot function in metrics summary display
MarieSacksick Oct 8, 2025
95eeca9
fix docs with available metrics
MarieSacksick Oct 9, 2025
409e44a
Merge branch 'main' into plot_comp_report_metrics
MarieSacksick Oct 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions examples/getting_started/plot_skore_getting_started.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,11 @@
# %%
comparator.metrics.summarize(indicator_favorability=True).frame()

# %%
# To be more specific in our comparison, we can decide to compare the ROC AUC and the fitting time.

# %%
comparator.metrics.summarize().plot(x="roc_auc", y="fit_time")
# %%
# Thus, we easily have the result of our benchmark for several recommended metrics.

Expand Down
20 changes: 19 additions & 1 deletion skore/src/skore/_sklearn/_comparison/metrics_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
_DEFAULT,
Aggregate,
PositiveLabel,
ReportType,
Scoring,
ScoringName,
YPlotData,
Expand Down Expand Up @@ -177,7 +178,24 @@ class is set to the one provided when creating the report. If `None`,
results.index = results.index.str.replace(
r"\((.*)\)$", r"\1", regex=True
)
return MetricsSummaryDisplay(results)

report_type: ReportType
if self._parent._reports_type == "EstimatorReport":
report_type = "comparison-estimator"
elif self._parent._reports_type == "CrossValidationReport":
report_type = "comparison-cross-validation"
else:
raise ValueError(
"Comparison should only apply to EstimatorReport or "
"CrossValidationReport"
)
return MetricsSummaryDisplay(
summarize_data=results,
report_type=report_type,
data_source=data_source,
scoring_names=scoring_names,
default_verbose_metric_names=self._score_or_loss_info,
)

@progress_decorator(description="Compute metric for each estimator")
def _compute_metric_scores(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,12 @@ class is set to the one provided when creating the report. If `None`,
results.index = results.index.str.replace(
r"\((.*)\)$", r"\1", regex=True
)
return MetricsSummaryDisplay(summarize_data=results)
return MetricsSummaryDisplay(
summarize_data=results,
report_type="cross-validation",
data_source=data_source,
default_verbose_metric_names=self._score_or_loss_info,
)

@progress_decorator(description="Compute metric for each split")
def _compute_metric_scores(
Expand Down
8 changes: 7 additions & 1 deletion skore/src/skore/_sklearn/_estimator/metrics_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,13 @@ class is set to the one provided when creating the report. If `None`,
results.index = results.index.str.replace(
r"\((.*)\)$", r"\1", regex=True
)
return MetricsSummaryDisplay(summarize_data=results)

return MetricsSummaryDisplay(
summarize_data=results,
report_type="estimator",
data_source=data_source,
default_verbose_metric_names=self._score_or_loss_info,
)

def _compute_metric_scores(
self,
Expand Down
146 changes: 142 additions & 4 deletions skore/src/skore/_sklearn/_plot/metrics/metrics_summary_display.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
import itertools

import matplotlib.pyplot as plt
import pandas as pd

from skore._sklearn._plot.base import DisplayMixin
from skore._sklearn._plot.utils import _interval_max_min_ratio
from skore._sklearn.types import ReportType, ScoringName


class MetricsSummaryDisplay(DisplayMixin):
Expand All @@ -8,8 +15,20 @@ class MetricsSummaryDisplay(DisplayMixin):
This class should not be instantiated directly.
"""

def __init__(self, summarize_data):
def __init__(
self,
*,
summarize_data,
report_type: ReportType,
data_source: str = "test",
default_verbose_metric_names: dict[str, dict[str, str]],
scoring_names: ScoringName | list[ScoringName] | None = None,
):
self.summarize_data = summarize_data
self.report_type = report_type
self.data_source = data_source
self.scoring_names = scoring_names
self.default_verbose_metric_names = default_verbose_metric_names

def frame(self):
"""Return the summarize as a dataframe.
Expand All @@ -21,7 +40,126 @@ def frame(self):
"""
return self.summarize_data

def _plot_matplotlib(self, x: str, y: str) -> None:
"""Plot visualization.

Parameters
----------
x : str, default=None
The metric to display on x-axis. By default, the first column.

y : str, default=None
The metric to display on y-axis. By default, the second column.
"""
if self.report_type in (["cross-validation", "comparison-cross-validation"]):
raise NotImplementedError("To come soon!")
elif self.report_type == "estimator":
raise NotImplementedError()
# it does not make sense to plot the metrics for a single estimator
elif self.report_type == "comparison-estimator":
self._plot_matplotlib_comparison_estimator(x, y)

def _plot_matplotlib_comparison_estimator(self, x, y):
_, ax = plt.subplots()

# Get verbose name from x and y
# if they are not verbose already
x_verbose = self.default_verbose_metric_names.get(x, {}).get("name", x)
y_verbose = self.default_verbose_metric_names.get(y, {}).get("name", y)

# Check that the metrics are in the report
# If the metric is not in the report, help the user by suggesting
# supported metrics
reverse_score_info = {
value["name"]: key
for key, value in self.default_verbose_metric_names.items()
}
available_metrics = self.summarize_data.index
if isinstance(available_metrics, pd.MultiIndex):
available_metrics = available_metrics.get_level_values(0).to_list()

# if scoring_names is provided, they are the supported metrics
# otherwise, the default verbose names apply.
if self.scoring_names is not None:
supported_metrics = self.scoring_names
else:
supported_metrics = [
reverse_score_info.get(col, col) for col in available_metrics
]

if x not in supported_metrics:
raise ValueError(
f"Performance metric '{x}' not found in the report. "
f"Supported metrics are: {supported_metrics}."
)
if y not in supported_metrics:
raise ValueError(
f"Performance metric '{y}' not found in the report. "
f"Supported metrics are: {supported_metrics}."
)

x_data = self.summarize_data.loc[x_verbose]
y_data = self.summarize_data.loc[y_verbose]
if len(x_data.shape) > 1:
if x_data.shape[0] == 1:
x_data = x_data.reset_index(drop=True).values[0]
else:
raise ValueError(
"The perf metric x requires to add a positive label parameter."
)
if len(y_data.shape) > 1:
if y_data.shape[0] == 1:
y_data = y_data.reset_index(drop=True).values[0]
else:
raise ValueError(
"The perf metric y requires to add a positive label parameter."
)

# Make it clear in the axis labels that we are using the train set
if x == "fit_time" and self.data_source != "train":
x_label_text = x_verbose + " on train set"
else:
x_label_text = x_verbose
if y == "fit_time" and self.data_source != "train":
y_label_text = y_verbose + " on train set"
else:
y_label_text = y_verbose

title = f"{x_verbose} vs {y_verbose}"
if self.data_source is not None:
title += f" on {self.data_source} set"

# Add legend
text = self.summarize_data.columns
markers = itertools.cycle(("o", "s", "^", "D", "v", "P", "*", "X", "h", "8"))
colors = itertools.cycle(plt.rcParams["axes.prop_cycle"].by_key()["color"])

for label, x_coord, y_coord in zip(text, x_data, y_data, strict=False):
marker = next(markers)
color = next(colors)
ax.scatter(x_coord, y_coord, marker=marker, color=color, label=label)

if _interval_max_min_ratio(x_data) > 5:
xscale = "symlog" if x_data.min() <= 0 else "log"
else:
xscale = "linear"

if _interval_max_min_ratio(y_data) > 5:
yscale = "symlog" if y_data.min() <= 0 else "log"
else:
yscale = "linear"

ax.set(
title=title,
xlabel=x_label_text,
ylabel=y_label_text,
xscale=xscale,
yscale=yscale,
)
ax.legend(title="Models", loc="best")

self.ax_ = ax

@DisplayMixin.style_plot
def plot(self):
"""Not yet implemented."""
raise NotImplementedError
def plot(self, x: str, y: str):
self._plot(**{"x": x, "y": y})
11 changes: 11 additions & 0 deletions skore/src/skore/_sklearn/_plot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,3 +348,14 @@ def sample_mpl_colormap(
"""
indices = np.linspace(0, 1, n)
return [cmap(i) for i in indices]


def _interval_max_min_ratio(data):
"""Compute the ratio between the largest and smallest inter-point distances.

A value larger than 5 typically indicates that the parameter range would
better be displayed with a log scale while a linear scale would be more
suitable otherwise.
"""
diff = np.diff(np.sort(data), axis=0)
return diff.max() / diff.min()
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import pytest
from sklearn.datasets import make_classification
from sklearn.ensemble import HistGradientBoostingClassifier
from sklearn.model_selection import train_test_split
from skore import CrossValidationReport, EstimatorReport


@pytest.fixture
def estimator_report_classification():
X, y = make_classification(random_state=0)
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=0
)

estimator_report = EstimatorReport(
estimator=HistGradientBoostingClassifier(),
X_train=X_train,
y_train=y_train,
X_test=X_test,
y_test=y_test,
)
return estimator_report


def test_not_implemented_estimator(estimator_report_classification):
"""
Test that the plot_comparison_estimator method raises NotImplementedError
when called with a binary classification comparator.
"""
with pytest.raises(NotImplementedError):
estimator_report_classification.metrics.summarize().plot(
x="accuracy", y="f1_score"
)


def test_not_implemented_other_categories():
"""
Test that the plot_comparison_estimator method raises NotImplementedError
when called with a binary classification comparator.
"""
X, y = make_classification(random_state=0)
cv_report = CrossValidationReport(
estimator=HistGradientBoostingClassifier(),
X=X,
y=y,
)
with pytest.raises(NotImplementedError, match="To come soon!"):
cv_report.metrics.summarize().plot(x="accuracy", y="f1_score")
Loading
Loading