diff --git a/examples/getting_started/plot_skore_getting_started.py b/examples/getting_started/plot_skore_getting_started.py index de46939ab0..033da087e1 100644 --- a/examples/getting_started/plot_skore_getting_started.py +++ b/examples/getting_started/plot_skore_getting_started.py @@ -207,6 +207,10 @@ # %% comparator.metrics.report_metrics(indicator_favorability=True) +# %% +# We can highlight the performance metric gain against timings +comparator.pairwise_plot(perf_metric_x="brier_score", perf_metric_y="fit_time") + # %% # Thus, we easily have the result of our benchmark for several recommended metrics. diff --git a/skore/src/skore/sklearn/_comparison/report.py b/skore/src/skore/sklearn/_comparison/report.py index ac93d83163..1402345a9a 100644 --- a/skore/src/skore/sklearn/_comparison/report.py +++ b/skore/src/skore/sklearn/_comparison/report.py @@ -13,6 +13,7 @@ from skore.sklearn._base import _BaseReport from skore.sklearn._cross_validation.report import CrossValidationReport from skore.sklearn._estimator.report import EstimatorReport +from skore.sklearn._plot import PairPlotDisplay from skore.utils._progress_bar import progress_decorator if TYPE_CHECKING: @@ -398,6 +399,55 @@ def get_predictions( for report in self.reports_ ] + def pairwise_plot( + self, + perf_metric_x: str, + perf_metric_y: str, + data_source: Literal["test", "train", "X_y"] = "test", + pos_label: Optional[Any] = None, + X: Optional[ArrayLike] = None, + y: Optional[ArrayLike] = None, + ): + """Plot a given performance metric against another. + + Parameters + ---------- + perf_metric_x : str + The performance metric to plot on the abscissa axis. + + perf_metric_y : str + The performance metrics to plot on the ordinates axis. + + data_source : {"test", "train", "X_y"}, default="test" + The data source to use. + + - "test" : use the test set provided when creating the report. + - "train" : use the train set provided when creating the report. + - "X_y" : use the provided `X` and `y` to compute the metric. + + pos_label : int, float, bool or str, default=None + The positive class when it comes to binary classification. When + `response_method="predict_proba"`, it will select the column corresponding + to the positive class. When `response_method="decision_function"`, it will + negate the decision function if `pos_label` is different from + `estimator.classes_[1]`. + + Returns + ------- + A matplotlib plot. + """ + # TODO + # - add kwargs (later) + + return PairPlotDisplay.from_metrics( + metrics=self.metrics.report_metrics( + pos_label=pos_label, data_source=data_source, X=X, y=y + ).T.reset_index(), + perf_metric_x=perf_metric_x, + perf_metric_y=perf_metric_y, + data_source=data_source, + ) + #################################################################################### # Methods related to the help and repr #################################################################################### diff --git a/skore/src/skore/sklearn/_plot/__init__.py b/skore/src/skore/sklearn/_plot/__init__.py index c53a218fdd..26c7458eb1 100644 --- a/skore/src/skore/sklearn/_plot/__init__.py +++ b/skore/src/skore/sklearn/_plot/__init__.py @@ -1,5 +1,6 @@ from skore.sklearn._plot.metrics import ( ConfusionMatrixDisplay, + PairPlotDisplay, PrecisionRecallCurveDisplay, PredictionErrorDisplay, RocCurveDisplay, @@ -10,4 +11,5 @@ "RocCurveDisplay", "PrecisionRecallCurveDisplay", "PredictionErrorDisplay", + "PairPlotDisplay", ] diff --git a/skore/src/skore/sklearn/_plot/metrics/__init__.py b/skore/src/skore/sklearn/_plot/metrics/__init__.py index cdd5b05e71..0748464706 100644 --- a/skore/src/skore/sklearn/_plot/metrics/__init__.py +++ b/skore/src/skore/sklearn/_plot/metrics/__init__.py @@ -1,4 +1,5 @@ from skore.sklearn._plot.metrics.confusion_matrix import ConfusionMatrixDisplay +from skore.sklearn._plot.metrics.pair_plot import PairPlotDisplay from skore.sklearn._plot.metrics.precision_recall_curve import ( PrecisionRecallCurveDisplay, ) @@ -10,4 +11,5 @@ "PrecisionRecallCurveDisplay", "PredictionErrorDisplay", "RocCurveDisplay", + "PairPlotDisplay", ] diff --git a/skore/src/skore/sklearn/_plot/metrics/pair_plot.py b/skore/src/skore/sklearn/_plot/metrics/pair_plot.py new file mode 100644 index 0000000000..b6fc4c9a42 --- /dev/null +++ b/skore/src/skore/sklearn/_plot/metrics/pair_plot.py @@ -0,0 +1,219 @@ +import matplotlib.pyplot as plt + +from skore.sklearn._plot.style import StyleDisplayMixin +from skore.sklearn._plot.utils import HelpDisplayMixin +from skore.sklearn.utils import _SCORE_OR_LOSS_INFO + + +class PairPlotDisplay(HelpDisplayMixin, StyleDisplayMixin): + """Display for pair plot. + + Parameters + ---------- + scatter_data : pandas.DataFrame + Dataframe containing the data to plot. + + x_column : str + The name of the column to plot on the x-axis. + If None, the first column of the dataframe is used. + + y_column : str + The name of the column to plot on the y-axis. + If None, the second column of the dataframe is used. + + display_label_x : str, default=None + The label to use for the x-axis. If None, the name of the column will be used. + + display_label_y : str, default=None + The label to use for the y-axis. If None, the name of the column will be used. + + data_source : str, default=None + To specify the data source for the plot. + + Attributes + ---------- + figure_ : matplotlib Figure + Figure containing the pair plot. + + ax_ : matplotlib Axes + Axes with pair plot. + """ + + def __init__( + self, + scatter_data, + *, + x_column=None, + y_column=None, + display_label_x=None, + display_label_y=None, + data_source=None, + ): + self.scatter_data = scatter_data + if x_column is None: + x_column = scatter_data.columns[0] + self.x_column = x_column + if y_column is None: + y_column = scatter_data.columns[1] + self.y_column = y_column + self.display_label_x = ( + display_label_x if display_label_x is not None else self.x_column + ) + self.display_label_y = ( + display_label_y if display_label_y is not None else self.y_column + ) + self.data_source = data_source + self.figure_ = None + self.ax_ = None + self.text_ = None + + def plot(self, ax=None, **kwargs): + """Plot a given performance metric against another. + + Parameters + ---------- + ax : matplotlib axes, default=None + Axes object to plot on. If None, a new figure and axes is created. + + **kwargs : dict + Additional keyword arguments to be passed to matplotlib's + `ax.imshow`. + + Returns + ------- + self : PairPlotDisplay + """ + if ax is None: + fig, ax = plt.subplots() + else: + fig = ax.figure + + scatter_data = self.scatter_data + + title = f"{self.display_label_x} vs {self.display_label_x}" + if self.data_source is not None: + title += f" on {self.data_source} data" + ax.scatter(x=scatter_data[self.x_column], y=scatter_data[self.y_column]) + ax.set_title(title) + ax.set_xlabel(self.display_label_x) + ax.set_ylabel(self.display_label_y) + + self.figure_, self.ax_ = fig, ax + return self + + @classmethod + def from_metrics( + cls, + metrics, + perf_metric_x, + perf_metric_y, + data_source=None, + ): + """Create a pair plot display from metrics. + + Parameters + ---------- + metrics : pandas.DataFrame + Dataframe containing the data to plot. The dataframe should + contain the performance metrics for each estimator. + + perf_metric_x : str + The name of the column to plot on the x-axis. + + perf_metric_y : str + The name of the column to plot on the y-axis. + + data_source : str + To specify the data source for the plot. + + Returns + ------- + display : :class:`PairPlotDisplay` + The scatter plot display. + """ + x_label = _SCORE_OR_LOSS_INFO.get(perf_metric_x, {}).get("name", perf_metric_x) + y_label = _SCORE_OR_LOSS_INFO.get(perf_metric_y, {}).get("name", perf_metric_y) + scatter_data = metrics + + # Check that the metrics are in the report + # If the metric is not in the report, help the user by suggesting + # supported metrics + reverse_score_info = { + value["name"]: key for key, value in _SCORE_OR_LOSS_INFO.items() + } + available_columns = scatter_data.columns.get_level_values(0).to_list() + available_columns.remove("Estimator") + supported_metrics = [ + reverse_score_info.get(col, col) for col in available_columns + ] + if perf_metric_x not in supported_metrics: + raise ValueError( + f"Performance metric {perf_metric_x} not found in the report. " + f"Supported metrics are: {supported_metrics}." + ) + if perf_metric_y not in supported_metrics: + raise ValueError( + f"Performance metric {perf_metric_y} not found in the report. " + f"Supported metrics are: {supported_metrics}." + ) + + # Check that x and y are 1D arrays (i.e. the metrics don't need pos_label) + x = scatter_data[x_label] + y = scatter_data[y_label] + if len(x.shape) > 1: + raise ValueError( + "The perf metric x requires to add a positive label parameter." + ) + if len(y.shape) > 1: + raise ValueError( + "The perf metric y requires to add a positive label parameter." + ) + + # Make it clear in the axis labels that we are using the train set + if perf_metric_x == "fit_time" and data_source != "train": + x_label_text = x_label + " on train set" + else: + x_label_text = x_label + if perf_metric_y == "fit_time" and data_source != "train": + y_label_text = y_label + " on train set" + else: + y_label_text = y_label + + disp = cls( + scatter_data=scatter_data, + x_column=x_label, + y_column=y_label, + display_label_x=x_label_text, + display_label_y=y_label_text, + data_source=data_source, + ).plot() + + # Add labels to the points with a small offset + ax = disp.ax_ + text = scatter_data["Estimator"] + for label, x_coord, y_coord in zip(text, x, y): + ax.annotate( + label, + (x_coord, y_coord), + textcoords="offset points", + xytext=(10, 0), + bbox=dict( + boxstyle="round,pad=0.3", + edgecolor="gray", + facecolor="white", + alpha=0.7, + ), + ) + + disp.ax_ = ax + return disp + + def frame(self): + """Return the dataframe used for the pair plot. + + Returns + ------- + scatter_data : pandas.DataFrame + The dataframe used to create the scatter plot. + """ + return self.scatter_data diff --git a/skore/src/skore/sklearn/utils.py b/skore/src/skore/sklearn/utils.py new file mode 100644 index 0000000000..9a482baf44 --- /dev/null +++ b/skore/src/skore/sklearn/utils.py @@ -0,0 +1,16 @@ +"""Utility functions for Skore and Scikit-learn integration.""" + +_SCORE_OR_LOSS_INFO: dict[str, dict[str, str]] = { + "fit_time": {"name": "Fit time (s)", "icon": "(↘︎)"}, + "predict_time": {"name": "Predict time (s)", "icon": "(↘︎)"}, + "accuracy": {"name": "Accuracy", "icon": "(↗︎)"}, + "precision": {"name": "Precision", "icon": "(↗︎)"}, + "recall": {"name": "Recall", "icon": "(↗︎)"}, + "brier_score": {"name": "Brier score", "icon": "(↘︎)"}, + "roc_auc": {"name": "ROC AUC", "icon": "(↗︎)"}, + "log_loss": {"name": "Log loss", "icon": "(↘︎)"}, + "r2": {"name": "R²", "icon": "(↗︎)"}, + "rmse": {"name": "RMSE", "icon": "(↘︎)"}, + "custom_metric": {"name": "Custom metric", "icon": ""}, + "report_metrics": {"name": "Report metrics", "icon": ""}, +} diff --git a/skore/tests/unit/sklearn/plot/test_pair_plot.py b/skore/tests/unit/sklearn/plot/test_pair_plot.py new file mode 100644 index 0000000000..138b2e94ae --- /dev/null +++ b/skore/tests/unit/sklearn/plot/test_pair_plot.py @@ -0,0 +1,162 @@ +import pytest +from sklearn.datasets import make_classification +from sklearn.linear_model import LogisticRegression +from sklearn.model_selection import train_test_split +from sklearn.tree import DecisionTreeClassifier +from skore import ComparisonReport, EstimatorReport +from skore.sklearn._plot import PairPlotDisplay + + +@pytest.fixture +def binary_classification_data(): + X, y = make_classification(class_sep=0.1, random_state=42) + X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42) + return ( + X_train, + X_test, + y_train, + y_test, + ) + + +def test_pairplot_test_set( + pyplot, + binary_classification_data, +): + X_train, X_test, y_train, y_test = binary_classification_data + est_1 = LogisticRegression() + est_2 = DecisionTreeClassifier() + report1 = EstimatorReport( + est_1, + X_train=X_train, + y_train=y_train, + X_test=X_test, + y_test=y_test, + ) + report2 = EstimatorReport( + est_2, + X_train=X_train, + y_train=y_train, + X_test=X_test, + y_test=y_test, + ) + comparison = ComparisonReport( + {"Logistic Regression": report1, "Decision Tree": report2} + ) + display = comparison.pairwise_plot( + perf_metric_x="fit_time", + perf_metric_y="predict_time", + data_source="train", + ) + + display.plot() + + assert isinstance(display, PairPlotDisplay) + assert hasattr(display, "figure_") + assert hasattr(display, "ax_") + + +def test_pairplot_train_set( + pyplot, + binary_classification_data, +): + X_train, X_test, y_train, y_test = binary_classification_data + est_1 = LogisticRegression() + est_2 = DecisionTreeClassifier() + report1 = EstimatorReport( + est_1, + X_train=X_train, + y_train=y_train, + X_test=X_test, + y_test=y_test, + ) + report2 = EstimatorReport( + est_2, + X_train=X_train, + y_train=y_train, + X_test=X_test, + y_test=y_test, + ) + comparison = ComparisonReport( + {"Logistic Regression": report1, "Decision Tree": report2} + ) + + display = comparison.pairwise_plot( + perf_metric_x="fit_time", + perf_metric_y="roc_auc", + data_source="train", + pos_label=1, + ) + + display.plot() + + assert isinstance(display, PairPlotDisplay) + assert hasattr(display, "figure_") + assert hasattr(display, "ax_") + + +def test_pairplot_missing_col( + pyplot, + binary_classification_data, +): + X_train, X_test, y_train, y_test = binary_classification_data + est_1 = LogisticRegression() + est_2 = DecisionTreeClassifier() + report1 = EstimatorReport( + est_1, + X_train=X_train, + y_train=y_train, + X_test=X_test, + y_test=y_test, + ) + report2 = EstimatorReport( + est_2, + X_train=X_train, + y_train=y_train, + X_test=X_test, + y_test=y_test, + ) + comparison = ComparisonReport( + {"Logistic Regression": report1, "Decision Tree": report2} + ) + + with pytest.raises(ValueError): + comparison.pairwise_plot( + perf_metric_x="fit_time", + perf_metric_y="an_invented_column", + data_source="train", + pos_label=1, + ) + + +def test_pairplot_missing_pos_label( + pyplot, + binary_classification_data, +): + X_train, X_test, y_train, y_test = binary_classification_data + est_1 = LogisticRegression() + est_2 = DecisionTreeClassifier() + report1 = EstimatorReport( + est_1, + X_train=X_train, + y_train=y_train, + X_test=X_test, + y_test=y_test, + ) + report2 = EstimatorReport( + est_2, + X_train=X_train, + y_train=y_train, + X_test=X_test, + y_test=y_test, + ) + comparison = ComparisonReport( + {"Logistic Regression": report1, "Decision Tree": report2} + ) + + with pytest.raises(ValueError): + comparison.pairwise_plot( + perf_metric_x="fit_time", + perf_metric_y="recall", + data_source="train", + )