Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
aa20336
docs: objective of the branch
MarieSacksick Apr 4, 2025
881e2b6
feat: start the function
MarieSacksick Apr 4, 2025
aa16872
add comment
MarieSacksick Apr 4, 2025
987bb13
save before we
MarieSacksick Apr 4, 2025
63b878e
notes feedback from Guillaume
MarieSacksick Apr 11, 2025
c8c5ddb
add utils docstring
MarieSacksick May 13, 2025
f8fee6d
Merge branch 'main' into timing_plot
MarieSacksick May 13, 2025
1524d46
feat pairwise: handle missing pos label
MarieSacksick May 13, 2025
653326f
improve feat
MarieSacksick May 14, 2025
b8866ac
Merge branch 'main' into timing_plot
MarieSacksick May 22, 2025
e288561
turn into display
MarieSacksick May 22, 2025
43e83a8
complete docstrings
MarieSacksick May 22, 2025
86da68e
add tests
MarieSacksick May 22, 2025
fdc7dfb
correct docstring and comments
MarieSacksick May 23, 2025
de68fd3
bugfix
MarieSacksick May 23, 2025
37ddfd9
fix tests
MarieSacksick May 23, 2025
e71d9db
Merge branch 'main' into timing_plot
MarieSacksick May 23, 2025
ae1fcb9
Update skore/src/skore/sklearn/_comparison/report.py
MarieSacksick May 23, 2025
cd083ac
remove traces from inspiration
MarieSacksick May 23, 2025
cc2f43c
Merge branch 'main' into timing_plot
MarieSacksick May 26, 2025
72eb0dc
Update skore/src/skore/sklearn/_plot/metrics/pair_plot.py
MarieSacksick May 28, 2025
254648f
Update skore/src/skore/sklearn/_plot/metrics/pair_plot.py
MarieSacksick May 28, 2025
3294dc5
Update skore/src/skore/sklearn/_plot/metrics/pair_plot.py
MarieSacksick May 28, 2025
f482c5f
linting
MarieSacksick May 28, 2025
a63dd79
Merge branch 'main' into timing_plot
MarieSacksick Jun 2, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions examples/getting_started/plot_skore_getting_started.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,10 @@
# %%
comparator.metrics.report_metrics(indicator_favorability=True)

# %%
# We can highlight the performance metric gain against timings
comparator.pairwise_plot(perf_metric_x="brier_score", perf_metric_y="fit_time")

# %%
# Thus, we easily have the result of our benchmark for several recommended metrics.

Expand Down
50 changes: 50 additions & 0 deletions skore/src/skore/sklearn/_comparison/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from skore.sklearn._base import _BaseReport
from skore.sklearn._cross_validation.report import CrossValidationReport
from skore.sklearn._estimator.report import EstimatorReport
from skore.sklearn._plot import PairPlotDisplay
from skore.utils._progress_bar import progress_decorator

if TYPE_CHECKING:
Expand Down Expand Up @@ -398,6 +399,55 @@ def get_predictions(
for report in self.reports_
]

def pairwise_plot(
self,
perf_metric_x: str,
perf_metric_y: str,
data_source: Literal["test", "train", "X_y"] = "test",
pos_label: Optional[Any] = None,
X: Optional[ArrayLike] = None,
y: Optional[ArrayLike] = None,
):
"""Plot a given performance metric against another.

Parameters
----------
perf_metric_x : str
The performance metric to plot on the abscissa axis.

perf_metric_y : str
The performance metrics to plot on the ordinates axis.

data_source : {"test", "train", "X_y"}, default="test"
The data source to use.

- "test" : use the test set provided when creating the report.
- "train" : use the train set provided when creating the report.
- "X_y" : use the provided `X` and `y` to compute the metric.

pos_label : int, float, bool or str, default=None
The positive class when it comes to binary classification. When
`response_method="predict_proba"`, it will select the column corresponding
to the positive class. When `response_method="decision_function"`, it will
negate the decision function if `pos_label` is different from
`estimator.classes_[1]`.

Returns
-------
A matplotlib plot.
"""
# TODO
# - add kwargs (later)

return PairPlotDisplay.from_metrics(
metrics=self.metrics.report_metrics(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing that I realized with the implementation now is that we are going to want most of the parameters to pass them to report_metrics.

Now, I'm thinking that it would means that the PairPlotDisplay is just a kind of plot associated with report_metrics. In short, I think that it would make sense to be able to write:

report.metrics.report_metrics().plot(kind="pair", x="fit_time", y="accuracy")

but also

report.metrics.report_metrics().plot(kind="bar")

And it allows to pass the arguments as:

report.metrics.report_metrics(data_source="train", ...).plot(kind="pair")

pos_label=pos_label, data_source=data_source, X=X, y=y
).T.reset_index(),
perf_metric_x=perf_metric_x,
perf_metric_y=perf_metric_y,
data_source=data_source,
)

####################################################################################
# Methods related to the help and repr
####################################################################################
Expand Down
2 changes: 2 additions & 0 deletions skore/src/skore/sklearn/_plot/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from skore.sklearn._plot.metrics import (
ConfusionMatrixDisplay,
PairPlotDisplay,
PrecisionRecallCurveDisplay,
PredictionErrorDisplay,
RocCurveDisplay,
Expand All @@ -10,4 +11,5 @@
"RocCurveDisplay",
"PrecisionRecallCurveDisplay",
"PredictionErrorDisplay",
"PairPlotDisplay",
]
2 changes: 2 additions & 0 deletions skore/src/skore/sklearn/_plot/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from skore.sklearn._plot.metrics.confusion_matrix import ConfusionMatrixDisplay
from skore.sklearn._plot.metrics.pair_plot import PairPlotDisplay
from skore.sklearn._plot.metrics.precision_recall_curve import (
PrecisionRecallCurveDisplay,
)
Expand All @@ -10,4 +11,5 @@
"PrecisionRecallCurveDisplay",
"PredictionErrorDisplay",
"RocCurveDisplay",
"PairPlotDisplay",
]
219 changes: 219 additions & 0 deletions skore/src/skore/sklearn/_plot/metrics/pair_plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
import matplotlib.pyplot as plt

from skore.sklearn._plot.style import StyleDisplayMixin
from skore.sklearn._plot.utils import HelpDisplayMixin
from skore.sklearn.utils import _SCORE_OR_LOSS_INFO


class PairPlotDisplay(HelpDisplayMixin, StyleDisplayMixin):
"""Display for pair plot.

Parameters
----------
scatter_data : pandas.DataFrame
Dataframe containing the data to plot.

x_column : str
The name of the column to plot on the x-axis.
If None, the first column of the dataframe is used.

y_column : str
The name of the column to plot on the y-axis.
If None, the second column of the dataframe is used.

display_label_x : str, default=None
The label to use for the x-axis. If None, the name of the column will be used.

display_label_y : str, default=None
The label to use for the y-axis. If None, the name of the column will be used.

data_source : str, default=None
To specify the data source for the plot.

Attributes
----------
figure_ : matplotlib Figure
Figure containing the pair plot.

ax_ : matplotlib Axes
Axes with pair plot.
"""

def __init__(
self,
scatter_data,
*,
x_column=None,
y_column=None,
display_label_x=None,
display_label_y=None,
data_source=None,
):
self.scatter_data = scatter_data
if x_column is None:
x_column = scatter_data.columns[0]
self.x_column = x_column
if y_column is None:
y_column = scatter_data.columns[1]
self.y_column = y_column
self.display_label_x = (
display_label_x if display_label_x is not None else self.x_column
)
self.display_label_y = (
display_label_y if display_label_y is not None else self.y_column
)
self.data_source = data_source
self.figure_ = None
self.ax_ = None
self.text_ = None
Comment on lines +66 to +68
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For consistency, those are created only when plot is called. We can see in a subsequent PR if we want to make consistent this behaviour with an initialization.


def plot(self, ax=None, **kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is this method that would benefit from the style.

Suggested change
def plot(self, ax=None, **kwargs):
@StyleDisplayMixin.style_plot
def plot(self, ax=None, **kwargs):

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't need ax anymore. We decided with @auguste-probabl to reduce the API here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can also remove kwargs because it is unused.

"""Plot a given performance metric against another.

Parameters
----------
ax : matplotlib axes, default=None
Axes object to plot on. If None, a new figure and axes is created.

**kwargs : dict
Additional keyword arguments to be passed to matplotlib's
`ax.imshow`.

Returns
-------
self : PairPlotDisplay
"""
if ax is None:
fig, ax = plt.subplots()
else:
fig = ax.figure

scatter_data = self.scatter_data

title = f"{self.display_label_x} vs {self.display_label_x}"
if self.data_source is not None:
title += f" on {self.data_source} data"
ax.scatter(x=scatter_data[self.x_column], y=scatter_data[self.y_column])
ax.set_title(title)
ax.set_xlabel(self.display_label_x)
ax.set_ylabel(self.display_label_y)

self.figure_, self.ax_ = fig, ax
return self

@classmethod
def from_metrics(
cls,
metrics,
perf_metric_x,
perf_metric_y,
data_source=None,
):
Comment on lines +104 to +111
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The display should not expose this public function. The idea is that the reporters will be the only object that can create an instance of Display.

You can have a look at RocCurveDisplay (or PrecisionRecallDisplay). Basically I think we should keep the name _compute_data_for_display. However, we can adapt the input parameters.

"""Create a pair plot display from metrics.

Parameters
----------
metrics : pandas.DataFrame
Dataframe containing the data to plot. The dataframe should
contain the performance metrics for each estimator.

perf_metric_x : str
The name of the column to plot on the x-axis.

perf_metric_y : str
The name of the column to plot on the y-axis.

data_source : str
To specify the data source for the plot.

Returns
-------
display : :class:`PairPlotDisplay`
The scatter plot display.
"""
x_label = _SCORE_OR_LOSS_INFO.get(perf_metric_x, {}).get("name", perf_metric_x)
y_label = _SCORE_OR_LOSS_INFO.get(perf_metric_y, {}).get("name", perf_metric_y)
Comment on lines +134 to +135
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that those should be passed directly by the methods from the report. It would be handy because we would have access to the dictionary _SCORE_OR_LOSS_INFO in the report side.

scatter_data = metrics

# Check that the metrics are in the report
# If the metric is not in the report, help the user by suggesting
# supported metrics
reverse_score_info = {
value["name"]: key for key, value in _SCORE_OR_LOSS_INFO.items()
}
available_columns = scatter_data.columns.get_level_values(0).to_list()
available_columns.remove("Estimator")
supported_metrics = [
reverse_score_info.get(col, col) for col in available_columns
]
if perf_metric_x not in supported_metrics:
raise ValueError(
f"Performance metric {perf_metric_x} not found in the report. "
f"Supported metrics are: {supported_metrics}."
)
if perf_metric_y not in supported_metrics:
raise ValueError(
f"Performance metric {perf_metric_y} not found in the report. "
f"Supported metrics are: {supported_metrics}."
)

# Check that x and y are 1D arrays (i.e. the metrics don't need pos_label)
x = scatter_data[x_label]
y = scatter_data[y_label]
if len(x.shape) > 1:
raise ValueError(
"The perf metric x requires to add a positive label parameter."
)
if len(y.shape) > 1:
raise ValueError(
"The perf metric y requires to add a positive label parameter."
)

# Make it clear in the axis labels that we are using the train set
if perf_metric_x == "fit_time" and data_source != "train":
x_label_text = x_label + " on train set"
else:
x_label_text = x_label
if perf_metric_y == "fit_time" and data_source != "train":
y_label_text = y_label + " on train set"
else:
y_label_text = y_label

disp = cls(
scatter_data=scatter_data,
x_column=x_label,
y_column=y_label,
display_label_x=x_label_text,
display_label_y=y_label_text,
data_source=data_source,
).plot()

# Add labels to the points with a small offset
ax = disp.ax_
text = scatter_data["Estimator"]
for label, x_coord, y_coord in zip(text, x, y):
ax.annotate(
label,
(x_coord, y_coord),
textcoords="offset points",
xytext=(10, 0),
bbox=dict(
boxstyle="round,pad=0.3",
edgecolor="gray",
facecolor="white",
alpha=0.7,
),
)

disp.ax_ = ax
return disp

def frame(self):
"""Return the dataframe used for the pair plot.

Returns
-------
scatter_data : pandas.DataFrame
The dataframe used to create the scatter plot.
"""
return self.scatter_data
16 changes: 16 additions & 0 deletions skore/src/skore/sklearn/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
"""Utility functions for Skore and Scikit-learn integration."""

_SCORE_OR_LOSS_INFO: dict[str, dict[str, str]] = {
"fit_time": {"name": "Fit time (s)", "icon": "(↘︎)"},
"predict_time": {"name": "Predict time (s)", "icon": "(↘︎)"},
"accuracy": {"name": "Accuracy", "icon": "(↗︎)"},
"precision": {"name": "Precision", "icon": "(↗︎)"},
"recall": {"name": "Recall", "icon": "(↗︎)"},
"brier_score": {"name": "Brier score", "icon": "(↘︎)"},
"roc_auc": {"name": "ROC AUC", "icon": "(↗︎)"},
"log_loss": {"name": "Log loss", "icon": "(↘︎)"},
"r2": {"name": "R²", "icon": "(↗︎)"},
"rmse": {"name": "RMSE", "icon": "(↘︎)"},
"custom_metric": {"name": "Custom metric", "icon": ""},
"report_metrics": {"name": "Report metrics", "icon": ""},
}
Loading