From aa20336deac10f8881e1371516ba8970bae27d32 Mon Sep 17 00:00:00 2001 From: Marie Date: Fri, 4 Apr 2025 16:03:46 +0200 Subject: [PATCH 01/20] docs: objective of the branch --- .../plot_skore_getting_started.py | 29 +++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/examples/getting_started/plot_skore_getting_started.py b/examples/getting_started/plot_skore_getting_started.py index c546b15a26..d1508ef7e8 100644 --- a/examples/getting_started/plot_skore_getting_started.py +++ b/examples/getting_started/plot_skore_getting_started.py @@ -199,6 +199,35 @@ # %% comparator.metrics.report_metrics(pos_label=1) +# %% +# We can highlight the performance metric gain against timings +comparator.metrics.report_metrics().loc[ + ["Fit time", "Brier score"] +].T.reset_index().plot(kind="scatter", x="Fit time", y="Brier score") +plt.tight_layout() + +# %% +scatter_data = ( + comparator.metrics.report_metrics().loc[["Fit time", "Brier score"]].T.reset_index() +) +scatter_data.plot(kind="scatter", x="Fit time", y="Brier score") + +# Add labels to the points with a small offset +text = scatter_data["Estimator"] +x = scatter_data["Fit time"] +y = scatter_data["Brier score"] +for label, x_coord, y_coord in zip(text, x, y): + plt.annotate( + label, + (x_coord, y_coord), + textcoords="offset points", + xytext=(10, 10), + bbox=dict( + boxstyle="round,pad=0.3", edgecolor="gray", facecolor="white", alpha=0.7 + ), + ) +plt.tight_layout() + # %% # Thus, we easily have the result of our benchmark for several recommended metrics. From 881e2b6457294a55d36166051f31ac3e3ab38b29 Mon Sep 17 00:00:00 2001 From: Marie Date: Fri, 4 Apr 2025 16:31:40 +0200 Subject: [PATCH 02/20] feat: start the function --- .../plot_skore_getting_started.py | 9 +-- skore/src/skore/sklearn/_comparison/report.py | 69 +++++++++++++++++++ 2 files changed, 71 insertions(+), 7 deletions(-) diff --git a/examples/getting_started/plot_skore_getting_started.py b/examples/getting_started/plot_skore_getting_started.py index d1508ef7e8..e2ca198c9b 100644 --- a/examples/getting_started/plot_skore_getting_started.py +++ b/examples/getting_started/plot_skore_getting_started.py @@ -201,15 +201,10 @@ # %% # We can highlight the performance metric gain against timings -comparator.metrics.report_metrics().loc[ - ["Fit time", "Brier score"] -].T.reset_index().plot(kind="scatter", x="Fit time", y="Brier score") -plt.tight_layout() +comparator.plot_perf_against_time(perf_metric="Brier score") # %% -scatter_data = ( - comparator.metrics.report_metrics().loc[["Fit time", "Brier score"]].T.reset_index() -) +scatter_data = comparator.metrics.report_metrics().T.reset_index() scatter_data.plot(kind="scatter", x="Fit time", y="Brier score") # Add labels to the points with a small offset diff --git a/skore/src/skore/sklearn/_comparison/report.py b/skore/src/skore/sklearn/_comparison/report.py index 1cf403608e..b42bbd14aa 100644 --- a/skore/src/skore/sklearn/_comparison/report.py +++ b/skore/src/skore/sklearn/_comparison/report.py @@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Any, Literal, Optional, Union import joblib +import matplotlib.pyplot as plt import numpy as np from numpy.typing import ArrayLike @@ -350,6 +351,74 @@ def get_predictions( for report in self.estimator_reports_ ] + def plot_perf_against_time( + self, + perf_metric: str, + data_source: Literal["test", "train", "X_y"] = "test", + time_metric: Literal["fit", "predict"] = "predict", + ): + """ + Plot a given performance metric against a time metric. + + Parameters + ---------- + perf_metric : str + + data_source : {"test", "train", "X_y"}, default="test" + The data source to use. + + - "test" : use the test set provided when creating the report. + - "train" : use the train set provided when creating the report. + - "X_y" : use the provided `X` and `y` to compute the metric. + + perf_metric : str + + time_metric: {"fit", "predict"}, default = "predict" + The time metric to use in the plot. + + + Returns + ------- + A matplotlib plot. + + """ + # Border cases to handle: + # - what if a metrics in not computed on all the estimators? + # - what if a metrics need pos_label? + # - what if time_metric = "fit", and data_source != "train"? + + # TODO + # - add example + # - add test + # - add kwargs + + scatter_data = self.metrics.report_metrics().T.reset_index() + scatter_data.plot( + kind="scatter", + x="Fit time", + y="Brier score", + title="Performance vs Time (s)", + ) + + # Add labels to the points with a small offset + text = scatter_data["Estimator"] + x = scatter_data["Fit time"] + y = scatter_data["Brier score"] + for label, x_coord, y_coord in zip(text, x, y): + plt.annotate( + label, + (x_coord, y_coord), + textcoords="offset points", + xytext=(10, 0), + bbox=dict( + boxstyle="round,pad=0.3", + edgecolor="gray", + facecolor="white", + alpha=0.7, + ), + ) + plt.tight_layout() + #################################################################################### # Methods related to the help and repr #################################################################################### From aa168725e85811cade6848e5c874f55cce34c3d9 Mon Sep 17 00:00:00 2001 From: Marie Date: Fri, 4 Apr 2025 16:32:47 +0200 Subject: [PATCH 03/20] add comment --- skore/src/skore/sklearn/_comparison/report.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/skore/src/skore/sklearn/_comparison/report.py b/skore/src/skore/sklearn/_comparison/report.py index b42bbd14aa..a67461a346 100644 --- a/skore/src/skore/sklearn/_comparison/report.py +++ b/skore/src/skore/sklearn/_comparison/report.py @@ -387,6 +387,10 @@ def plot_perf_against_time( # - what if a metrics need pos_label? # - what if time_metric = "fit", and data_source != "train"? + # Question + # should this become an accessor method, e.g. `plots`, + # the equivalent to `metrics`? + # TODO # - add example # - add test From 987bb13f5d0bab878646a5cc2e523f51b8de624c Mon Sep 17 00:00:00 2001 From: Marie Date: Fri, 4 Apr 2025 16:49:11 +0200 Subject: [PATCH 04/20] save before we --- skore/src/skore/sklearn/_comparison/report.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/skore/src/skore/sklearn/_comparison/report.py b/skore/src/skore/sklearn/_comparison/report.py index a67461a346..866ad189b0 100644 --- a/skore/src/skore/sklearn/_comparison/report.py +++ b/skore/src/skore/sklearn/_comparison/report.py @@ -388,25 +388,32 @@ def plot_perf_against_time( # - what if time_metric = "fit", and data_source != "train"? # Question - # should this become an accessor method, e.g. `plots`, - # the equivalent to `metrics`? + # - should this become an accessor method, e.g. `plots`, + # the equivalent to `metrics`? + # - how to deal with perf metric? should it be consistent with + # the metric name or the column name in metrics report? # TODO # - add example # - add test # - add kwargs + if time_metric == "fit": + x_label = "Fit time" + elif time_metric == "predict": + x_label = "Predict time" + scatter_data = self.metrics.report_metrics().T.reset_index() scatter_data.plot( kind="scatter", - x="Fit time", + x=x_label, y="Brier score", title="Performance vs Time (s)", ) # Add labels to the points with a small offset text = scatter_data["Estimator"] - x = scatter_data["Fit time"] + x = scatter_data[x_label] y = scatter_data["Brier score"] for label, x_coord, y_coord in zip(text, x, y): plt.annotate( From 63b878eb98c04b680050ea213ebb37d06b30cea3 Mon Sep 17 00:00:00 2001 From: Marie Date: Fri, 11 Apr 2025 14:49:36 +0200 Subject: [PATCH 05/20] notes feedback from Guillaume --- skore/src/skore/sklearn/_comparison/report.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/skore/src/skore/sklearn/_comparison/report.py b/skore/src/skore/sklearn/_comparison/report.py index 866ad189b0..0553fe39e9 100644 --- a/skore/src/skore/sklearn/_comparison/report.py +++ b/skore/src/skore/sklearn/_comparison/report.py @@ -386,17 +386,22 @@ def plot_perf_against_time( # - what if a metrics in not computed on all the estimators? # - what if a metrics need pos_label? # - what if time_metric = "fit", and data_source != "train"? + # > data source does not apply to fit. Let's be clear in the axis labels. # Question # - should this become an accessor method, e.g. `plots`, # the equivalent to `metrics`? + # > should be in the accessor model_selection # - how to deal with perf metric? should it be consistent with # the metric name or the column name in metrics report? + # > available in _SCORE_OR_LOSS_INFO # TODO # - add example # - add test # - add kwargs + # - turn into display + # - change name to sth like `pairwise_plot` if time_metric == "fit": x_label = "Fit time" From c8c5ddbd0d7e8e6195b09909fbc99845aa91cb75 Mon Sep 17 00:00:00 2001 From: Marie Date: Tue, 13 May 2025 17:34:21 +0200 Subject: [PATCH 06/20] add utils docstring --- .../plot_skore_getting_started.py | 2 +- skore/src/skore/sklearn/_comparison/report.py | 40 ++++++++++--------- skore/src/skore/sklearn/utils.py | 16 ++++++++ 3 files changed, 39 insertions(+), 19 deletions(-) create mode 100644 skore/src/skore/sklearn/utils.py diff --git a/examples/getting_started/plot_skore_getting_started.py b/examples/getting_started/plot_skore_getting_started.py index e2ca198c9b..9a0cb3c4d6 100644 --- a/examples/getting_started/plot_skore_getting_started.py +++ b/examples/getting_started/plot_skore_getting_started.py @@ -201,7 +201,7 @@ # %% # We can highlight the performance metric gain against timings -comparator.plot_perf_against_time(perf_metric="Brier score") +comparator.pairwise_plot(perf_metric_x="brier_score", perf_metric_y="fit_time") # %% scatter_data = comparator.metrics.report_metrics().T.reset_index() diff --git a/skore/src/skore/sklearn/_comparison/report.py b/skore/src/skore/sklearn/_comparison/report.py index 0553fe39e9..ad068dffcb 100644 --- a/skore/src/skore/sklearn/_comparison/report.py +++ b/skore/src/skore/sklearn/_comparison/report.py @@ -12,6 +12,7 @@ from skore.externals._pandas_accessors import DirNamesMixin from skore.sklearn._base import _BaseReport from skore.sklearn._estimator.report import EstimatorReport +from skore.sklearn.utils import _SCORE_OR_LOSS_INFO from skore.utils._progress_bar import progress_decorator if TYPE_CHECKING: @@ -351,18 +352,22 @@ def get_predictions( for report in self.estimator_reports_ ] - def plot_perf_against_time( + def pairwise_plot( self, - perf_metric: str, + perf_metric_x: str, + perf_metric_y: str, data_source: Literal["test", "train", "X_y"] = "test", - time_metric: Literal["fit", "predict"] = "predict", + pos_label: Optional[Any] = None, ): - """ - Plot a given performance metric against a time metric. + """Plot a given performance metric against another. Parameters ---------- - perf_metric : str + perf_metric_x : str + The performance metric to plot on the abscissa axis. + + perf_metric_y : str + The performance metrics to plot on the ordinates axis. data_source : {"test", "train", "X_y"}, default="test" The data source to use. @@ -371,11 +376,12 @@ def plot_perf_against_time( - "train" : use the train set provided when creating the report. - "X_y" : use the provided `X` and `y` to compute the metric. - perf_metric : str - - time_metric: {"fit", "predict"}, default = "predict" - The time metric to use in the plot. - + pos_label : int, float, bool or str, default=None + The positive class when it comes to binary classification. When + `response_method="predict_proba"`, it will select the column corresponding + to the positive class. When `response_method="decision_function"`, it will + negate the decision function if `pos_label` is different from + `estimator.classes_[1]`. Returns ------- @@ -403,23 +409,21 @@ def plot_perf_against_time( # - turn into display # - change name to sth like `pairwise_plot` - if time_metric == "fit": - x_label = "Fit time" - elif time_metric == "predict": - x_label = "Predict time" + x_label = _SCORE_OR_LOSS_INFO[perf_metric_x].get("name", perf_metric_x) + y_label = _SCORE_OR_LOSS_INFO[perf_metric_y].get("name", perf_metric_y) - scatter_data = self.metrics.report_metrics().T.reset_index() + scatter_data = self.metrics.report_metrics(pos_label=pos_label).T.reset_index() scatter_data.plot( kind="scatter", x=x_label, - y="Brier score", + y=y_label, title="Performance vs Time (s)", ) # Add labels to the points with a small offset text = scatter_data["Estimator"] x = scatter_data[x_label] - y = scatter_data["Brier score"] + y = scatter_data[y_label] for label, x_coord, y_coord in zip(text, x, y): plt.annotate( label, diff --git a/skore/src/skore/sklearn/utils.py b/skore/src/skore/sklearn/utils.py new file mode 100644 index 0000000000..9a482baf44 --- /dev/null +++ b/skore/src/skore/sklearn/utils.py @@ -0,0 +1,16 @@ +"""Utility functions for Skore and Scikit-learn integration.""" + +_SCORE_OR_LOSS_INFO: dict[str, dict[str, str]] = { + "fit_time": {"name": "Fit time (s)", "icon": "(↘︎)"}, + "predict_time": {"name": "Predict time (s)", "icon": "(↘︎)"}, + "accuracy": {"name": "Accuracy", "icon": "(↗︎)"}, + "precision": {"name": "Precision", "icon": "(↗︎)"}, + "recall": {"name": "Recall", "icon": "(↗︎)"}, + "brier_score": {"name": "Brier score", "icon": "(↘︎)"}, + "roc_auc": {"name": "ROC AUC", "icon": "(↗︎)"}, + "log_loss": {"name": "Log loss", "icon": "(↘︎)"}, + "r2": {"name": "R²", "icon": "(↗︎)"}, + "rmse": {"name": "RMSE", "icon": "(↘︎)"}, + "custom_metric": {"name": "Custom metric", "icon": ""}, + "report_metrics": {"name": "Report metrics", "icon": ""}, +} From 1524d4633da2e46a98f5eca3e2ae088d24a6621e Mon Sep 17 00:00:00 2001 From: Marie Date: Tue, 13 May 2025 17:51:15 +0200 Subject: [PATCH 07/20] feat pairwise: handle missing pos label --- .../plot_skore_getting_started.py | 20 ------------- skore/src/skore/sklearn/_comparison/report.py | 28 ++++++++++--------- 2 files changed, 15 insertions(+), 33 deletions(-) diff --git a/examples/getting_started/plot_skore_getting_started.py b/examples/getting_started/plot_skore_getting_started.py index b8690ad672..e2dd2ce262 100644 --- a/examples/getting_started/plot_skore_getting_started.py +++ b/examples/getting_started/plot_skore_getting_started.py @@ -211,26 +211,6 @@ # We can highlight the performance metric gain against timings comparator.pairwise_plot(perf_metric_x="brier_score", perf_metric_y="fit_time") -# %% -scatter_data = comparator.metrics.report_metrics().T.reset_index() -scatter_data.plot(kind="scatter", x="Fit time", y="Brier score") - -# Add labels to the points with a small offset -text = scatter_data["Estimator"] -x = scatter_data["Fit time"] -y = scatter_data["Brier score"] -for label, x_coord, y_coord in zip(text, x, y): - plt.annotate( - label, - (x_coord, y_coord), - textcoords="offset points", - xytext=(10, 10), - bbox=dict( - boxstyle="round,pad=0.3", edgecolor="gray", facecolor="white", alpha=0.7 - ), - ) -plt.tight_layout() - # %% # Thus, we easily have the result of our benchmark for several recommended metrics. diff --git a/skore/src/skore/sklearn/_comparison/report.py b/skore/src/skore/sklearn/_comparison/report.py index 5ad9905b8f..6ecaef8c45 100644 --- a/skore/src/skore/sklearn/_comparison/report.py +++ b/skore/src/skore/sklearn/_comparison/report.py @@ -438,31 +438,33 @@ def pairwise_plot( # - what if a metrics need pos_label? # - what if time_metric = "fit", and data_source != "train"? # > data source does not apply to fit. Let's be clear in the axis labels. - - # Question - # - should this become an accessor method, e.g. `plots`, - # the equivalent to `metrics`? - # > should be in the accessor model_selection - # - how to deal with perf metric? should it be consistent with - # the metric name or the column name in metrics report? - # > available in _SCORE_OR_LOSS_INFO + # - what happens if the metric is a user-created metric? # TODO - # - add example # - add test - # - add kwargs + # - add kwargs (later) # - turn into display - # - change name to sth like `pairwise_plot` x_label = _SCORE_OR_LOSS_INFO[perf_metric_x].get("name", perf_metric_x) y_label = _SCORE_OR_LOSS_INFO[perf_metric_y].get("name", perf_metric_y) - scatter_data = self.metrics.report_metrics(pos_label=pos_label).T.reset_index() + x = scatter_data[x_label] + y = scatter_data[y_label] + + if len(x.shape) > 1: + raise ValueError( + "The perf metric x requires to add a positive label parameter." + ) + if len(y.shape) > 1: + raise ValueError( + "The perf metric y requires to add a positive label parameter." + ) + scatter_data.plot( kind="scatter", x=x_label, y=y_label, - title="Performance vs Time (s)", + title=f"{x_label} vs {y_label}", ) # Add labels to the points with a small offset From 653326fa56eef689012ccf4f429b2c39d24bb0c8 Mon Sep 17 00:00:00 2001 From: Marie Date: Wed, 14 May 2025 10:57:50 +0200 Subject: [PATCH 08/20] improve feat --- skore/src/skore/sklearn/_comparison/report.py | 58 +++++++++++++++---- 1 file changed, 46 insertions(+), 12 deletions(-) diff --git a/skore/src/skore/sklearn/_comparison/report.py b/skore/src/skore/sklearn/_comparison/report.py index 6ecaef8c45..6955f58ab7 100644 --- a/skore/src/skore/sklearn/_comparison/report.py +++ b/skore/src/skore/sklearn/_comparison/report.py @@ -403,6 +403,8 @@ def pairwise_plot( perf_metric_y: str, data_source: Literal["test", "train", "X_y"] = "test", pos_label: Optional[Any] = None, + X: Optional[ArrayLike] = None, + y: Optional[ArrayLike] = None, ): """Plot a given performance metric against another. @@ -433,24 +435,43 @@ def pairwise_plot( A matplotlib plot. """ - # Border cases to handle: - # - what if a metrics in not computed on all the estimators? - # - what if a metrics need pos_label? - # - what if time_metric = "fit", and data_source != "train"? - # > data source does not apply to fit. Let's be clear in the axis labels. - # - what happens if the metric is a user-created metric? - # TODO # - add test # - add kwargs (later) # - turn into display - x_label = _SCORE_OR_LOSS_INFO[perf_metric_x].get("name", perf_metric_x) - y_label = _SCORE_OR_LOSS_INFO[perf_metric_y].get("name", perf_metric_y) - scatter_data = self.metrics.report_metrics(pos_label=pos_label).T.reset_index() + # translate the parameters into column names + x_label = _SCORE_OR_LOSS_INFO.get(perf_metric_x, {}).get("name", perf_metric_x) + y_label = _SCORE_OR_LOSS_INFO.get(perf_metric_y, {}).get("name", perf_metric_y) + scatter_data = self.metrics.report_metrics( + pos_label=pos_label, data_source=data_source, X=X, y=y + ).T.reset_index() + + # Check that the metrics are in the report + # If the metric is not in the report, help the user by suggesting + # supported metrics + reverse_score_info = { + value["name"]: key for key, value in _SCORE_OR_LOSS_INFO.items() + } + available_columns = scatter_data.columns.get_level_values(0).to_list() + available_columns.remove("Estimator") + supported_metrics = [ + reverse_score_info.get(col, col) for col in available_columns + ] + if perf_metric_x not in supported_metrics: + raise ValueError( + f"Performance metric {perf_metric_x} not found in the report. " + f"Supported metrics are: {supported_metrics}." + ) + if perf_metric_y not in supported_metrics: + raise ValueError( + f"Performance metric {perf_metric_y} not found in the report. " + f"Supported metrics are: {supported_metrics}." + ) + + # Check that x and y are 1D arrays (i.e. the metrics don't need pos_label) x = scatter_data[x_label] y = scatter_data[y_label] - if len(x.shape) > 1: raise ValueError( "The perf metric x requires to add a positive label parameter." @@ -460,12 +481,25 @@ def pairwise_plot( "The perf metric y requires to add a positive label parameter." ) + # Make it clear in the axis labels that we are using the train set + if perf_metric_x == "fit_time" and data_source != "train": + x_label_text = x_label + " on train set" + else: + x_label_text = x_label + if perf_metric_y == "fit_time" and data_source != "train": + y_label_text = y_label + " on train set" + else: + y_label_text = y_label + + # Create the scatter plot scatter_data.plot( kind="scatter", x=x_label, y=y_label, - title=f"{x_label} vs {y_label}", + title=f"{x_label} vs {y_label} on {data_source} data", ) + plt.xlabel(x_label_text) + plt.ylabel(y_label_text) # Add labels to the points with a small offset text = scatter_data["Estimator"] From e28856122c8004b5a61bf7d5fd4d2cb6b419ee4c Mon Sep 17 00:00:00 2001 From: Marie Date: Thu, 22 May 2025 17:01:54 +0200 Subject: [PATCH 09/20] turn into display --- skore/src/skore/sklearn/_comparison/report.py | 88 +------ skore/src/skore/sklearn/_plot/__init__.py | 2 + .../skore/sklearn/_plot/metrics/__init__.py | 2 + .../skore/sklearn/_plot/metrics/pair_plot.py | 225 ++++++++++++++++++ 4 files changed, 237 insertions(+), 80 deletions(-) create mode 100644 skore/src/skore/sklearn/_plot/metrics/pair_plot.py diff --git a/skore/src/skore/sklearn/_comparison/report.py b/skore/src/skore/sklearn/_comparison/report.py index 03d569e698..c5ed286bde 100644 --- a/skore/src/skore/sklearn/_comparison/report.py +++ b/skore/src/skore/sklearn/_comparison/report.py @@ -6,7 +6,6 @@ from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast import joblib -import matplotlib.pyplot as plt import numpy as np from numpy.typing import ArrayLike @@ -14,7 +13,7 @@ from skore.sklearn._base import _BaseReport from skore.sklearn._cross_validation.report import CrossValidationReport from skore.sklearn._estimator.report import EstimatorReport -from skore.sklearn.utils import _SCORE_OR_LOSS_INFO +from skore.sklearn._plot import PairPlotDisplay from skore.utils._progress_bar import progress_decorator if TYPE_CHECKING: @@ -441,85 +440,14 @@ def pairwise_plot( # - add kwargs (later) # - turn into display - # translate the parameters into column names - x_label = _SCORE_OR_LOSS_INFO.get(perf_metric_x, {}).get("name", perf_metric_x) - y_label = _SCORE_OR_LOSS_INFO.get(perf_metric_y, {}).get("name", perf_metric_y) - scatter_data = self.metrics.report_metrics( - pos_label=pos_label, data_source=data_source, X=X, y=y - ).T.reset_index() - - # Check that the metrics are in the report - # If the metric is not in the report, help the user by suggesting - # supported metrics - reverse_score_info = { - value["name"]: key for key, value in _SCORE_OR_LOSS_INFO.items() - } - available_columns = scatter_data.columns.get_level_values(0).to_list() - available_columns.remove("Estimator") - supported_metrics = [ - reverse_score_info.get(col, col) for col in available_columns - ] - if perf_metric_x not in supported_metrics: - raise ValueError( - f"Performance metric {perf_metric_x} not found in the report. " - f"Supported metrics are: {supported_metrics}." - ) - if perf_metric_y not in supported_metrics: - raise ValueError( - f"Performance metric {perf_metric_y} not found in the report. " - f"Supported metrics are: {supported_metrics}." - ) - - # Check that x and y are 1D arrays (i.e. the metrics don't need pos_label) - x = scatter_data[x_label] - y = scatter_data[y_label] - if len(x.shape) > 1: - raise ValueError( - "The perf metric x requires to add a positive label parameter." - ) - if len(y.shape) > 1: - raise ValueError( - "The perf metric y requires to add a positive label parameter." - ) - - # Make it clear in the axis labels that we are using the train set - if perf_metric_x == "fit_time" and data_source != "train": - x_label_text = x_label + " on train set" - else: - x_label_text = x_label - if perf_metric_y == "fit_time" and data_source != "train": - y_label_text = y_label + " on train set" - else: - y_label_text = y_label - - # Create the scatter plot - scatter_data.plot( - kind="scatter", - x=x_label, - y=y_label, - title=f"{x_label} vs {y_label} on {data_source} data", + return PairPlotDisplay.from_metrics( + metrics=self.metrics.report_metrics( + pos_label=pos_label, data_source=data_source, X=X, y=y + ).T.reset_index(), + perf_metric_x=perf_metric_x, + perf_metric_y=perf_metric_y, + data_source=data_source, ) - plt.xlabel(x_label_text) - plt.ylabel(y_label_text) - - # Add labels to the points with a small offset - text = scatter_data["Estimator"] - x = scatter_data[x_label] - y = scatter_data[y_label] - for label, x_coord, y_coord in zip(text, x, y): - plt.annotate( - label, - (x_coord, y_coord), - textcoords="offset points", - xytext=(10, 0), - bbox=dict( - boxstyle="round,pad=0.3", - edgecolor="gray", - facecolor="white", - alpha=0.7, - ), - ) - plt.tight_layout() #################################################################################### # Methods related to the help and repr diff --git a/skore/src/skore/sklearn/_plot/__init__.py b/skore/src/skore/sklearn/_plot/__init__.py index c53a218fdd..26c7458eb1 100644 --- a/skore/src/skore/sklearn/_plot/__init__.py +++ b/skore/src/skore/sklearn/_plot/__init__.py @@ -1,5 +1,6 @@ from skore.sklearn._plot.metrics import ( ConfusionMatrixDisplay, + PairPlotDisplay, PrecisionRecallCurveDisplay, PredictionErrorDisplay, RocCurveDisplay, @@ -10,4 +11,5 @@ "RocCurveDisplay", "PrecisionRecallCurveDisplay", "PredictionErrorDisplay", + "PairPlotDisplay", ] diff --git a/skore/src/skore/sklearn/_plot/metrics/__init__.py b/skore/src/skore/sklearn/_plot/metrics/__init__.py index cdd5b05e71..0748464706 100644 --- a/skore/src/skore/sklearn/_plot/metrics/__init__.py +++ b/skore/src/skore/sklearn/_plot/metrics/__init__.py @@ -1,4 +1,5 @@ from skore.sklearn._plot.metrics.confusion_matrix import ConfusionMatrixDisplay +from skore.sklearn._plot.metrics.pair_plot import PairPlotDisplay from skore.sklearn._plot.metrics.precision_recall_curve import ( PrecisionRecallCurveDisplay, ) @@ -10,4 +11,5 @@ "PrecisionRecallCurveDisplay", "PredictionErrorDisplay", "RocCurveDisplay", + "PairPlotDisplay", ] diff --git a/skore/src/skore/sklearn/_plot/metrics/pair_plot.py b/skore/src/skore/sklearn/_plot/metrics/pair_plot.py new file mode 100644 index 0000000000..1407651b93 --- /dev/null +++ b/skore/src/skore/sklearn/_plot/metrics/pair_plot.py @@ -0,0 +1,225 @@ +import matplotlib.pyplot as plt + +from skore.sklearn._plot.base import Display +from skore.sklearn._plot.style import StyleDisplayMixin +from skore.sklearn.utils import _SCORE_OR_LOSS_INFO + + +class PairPlotDisplay(Display): + """Display for pair plot. + + Parameters + ---------- + scatter_data : + + x_column : str + + y_column : str + + display_label_x : str, default=None + + display_label_y : str, default=None + + data_source : str, default=None + + + Attributes + ---------- + figure_ : matplotlib Figure + Figure containing the confusion matrix. + + ax_ : matplotlib Axes + Axes with confusion matrix. + + text_ : ndarray of shape (n_classes, n_classes), dtype=matplotlib Text or \ + None + Array of matplotlib text elements containing the values in the + confusion matrix. + """ + + @StyleDisplayMixin.style_plot + def __init__( + self, + scatter_data, + *, + x_column=None, + y_column=None, + display_label_x=None, + display_label_y=None, + data_source=None, + display_labels=None, + ): + self.scatter_data = scatter_data + self.x_column = scatter_data.columns[0] + self.y_column = scatter_data.columns[1] + self.display_label_x = display_label_x + self.display_label_y = display_label_y + self.data_source = data_source + self.figure_ = None + self.ax_ = None + self.text_ = None + + def plot(self, ax=None, **kwargs): + """Plot a given performance metric against another. + + Parameters + ---------- + ax : matplotlib axes, default=None + Axes object to plot on. If None, a new figure and axes is created. + + **kwargs : dict + Additional keyword arguments to be passed to matplotlib's + `ax.imshow`. + + Returns + ------- + self : PairPlotDisplay + Configured with the confusion matrix. + """ + if ax is None: + fig, ax = plt.subplots() + else: + fig = ax.figure + + scatter_data = self.scatter_data + + ax.scatter( + x=scatter_data[self.x_column], + y=scatter_data[self.y_column], + title=f"{self.display_label_x} vs {self.display_label_x} on \ + {self.data_source} data", + ) + ax.set_xlabel(self.display_label_x) + ax.set_ylabel(self.display_label_y) + + # Add labels to the points with a small offset + text = scatter_data["Estimator"] + x = scatter_data[self.x_column] + y = scatter_data[self.y_column] + for label, x_coord, y_coord in zip(text, x, y): + ax.annotate( + label, + (x_coord, y_coord), + textcoords="offset points", + xytext=(10, 0), + bbox=dict( + boxstyle="round,pad=0.3", + edgecolor="gray", + facecolor="white", + alpha=0.7, + ), + ) + + self.figure_, self.ax_ = fig, ax + return self + + @classmethod + def from_metrics( + cls, + metrics, + perf_metric_x, + perf_metric_y, + data_source, + ): + """Create a confusion matrix display from predictions. + + Parameters + ---------- + y_true : array-like of shape (n_samples,) + True labels. + + y_pred : array-like of shape (n_samples,) + Predicted labels, as returned by a classifier. + + sample_weight : array-like of shape (n_samples,), default=None + Sample weights. + + display_labels : list of str, default=None + Target names used for plotting. By default, labels will be inferred + from y_true. + + include_values : bool, default=True + Includes values in confusion matrix. + + normalize : {'true', 'pred', 'all'}, default=None + Normalizes confusion matrix over the true (rows), predicted (columns) + conditions or all the population. If None, confusion matrix will not be + normalized. + + values_format : str, default=None + Format specification for values in confusion matrix. If None, the format + specification is 'd' or '.2g' whichever is shorter. + + Returns + ------- + display : :class:`PairPlotDisplay` + The confusion matrix display. + """ + x_label = _SCORE_OR_LOSS_INFO.get(perf_metric_x, {}).get("name", perf_metric_x) + y_label = _SCORE_OR_LOSS_INFO.get(perf_metric_y, {}).get("name", perf_metric_y) + scatter_data = metrics + + # Check that the metrics are in the report + # If the metric is not in the report, help the user by suggesting + # supported metrics + reverse_score_info = { + value["name"]: key for key, value in _SCORE_OR_LOSS_INFO.items() + } + available_columns = scatter_data.columns.get_level_values(0).to_list() + available_columns.remove("Estimator") + supported_metrics = [ + reverse_score_info.get(col, col) for col in available_columns + ] + if perf_metric_x not in supported_metrics: + raise ValueError( + f"Performance metric {perf_metric_x} not found in the report. " + f"Supported metrics are: {supported_metrics}." + ) + if perf_metric_y not in supported_metrics: + raise ValueError( + f"Performance metric {perf_metric_y} not found in the report. " + f"Supported metrics are: {supported_metrics}." + ) + + # Check that x and y are 1D arrays (i.e. the metrics don't need pos_label) + x = scatter_data[x_label] + y = scatter_data[y_label] + if len(x.shape) > 1: + raise ValueError( + "The perf metric x requires to add a positive label parameter." + ) + if len(y.shape) > 1: + raise ValueError( + "The perf metric y requires to add a positive label parameter." + ) + + # Make it clear in the axis labels that we are using the train set + if perf_metric_x == "fit_time" and data_source != "train": + x_label_text = x_label + " on train set" + else: + x_label_text = x_label + if perf_metric_y == "fit_time" and data_source != "train": + y_label_text = y_label + " on train set" + else: + y_label_text = y_label + + disp = cls( + scatter_data=scatter_data, + x_column=None, + y_column=None, + display_label_x=x_label_text, + display_label_y=y_label_text, + data_source=None, + ) + + return disp + + def frame(self): + """Return the confusion matrix as a dataframe. + + Returns + ------- + scatter_data : pandas.DataFrame + The dataframe used to create the scatter plot. + """ + return self.scatter_data From 43e83a869ec3d15201281eb7746181be6570e802 Mon Sep 17 00:00:00 2001 From: Marie Date: Thu, 22 May 2025 17:30:27 +0200 Subject: [PATCH 10/20] complete docstrings --- skore/src/skore/sklearn/_comparison/report.py | 1 - .../skore/sklearn/_plot/metrics/pair_plot.py | 120 +++++++++--------- 2 files changed, 58 insertions(+), 63 deletions(-) diff --git a/skore/src/skore/sklearn/_comparison/report.py b/skore/src/skore/sklearn/_comparison/report.py index c5ed286bde..edca8436a5 100644 --- a/skore/src/skore/sklearn/_comparison/report.py +++ b/skore/src/skore/sklearn/_comparison/report.py @@ -438,7 +438,6 @@ def pairwise_plot( # TODO # - add test # - add kwargs (later) - # - turn into display return PairPlotDisplay.from_metrics( metrics=self.metrics.report_metrics( diff --git a/skore/src/skore/sklearn/_plot/metrics/pair_plot.py b/skore/src/skore/sklearn/_plot/metrics/pair_plot.py index 1407651b93..2648204055 100644 --- a/skore/src/skore/sklearn/_plot/metrics/pair_plot.py +++ b/skore/src/skore/sklearn/_plot/metrics/pair_plot.py @@ -10,18 +10,25 @@ class PairPlotDisplay(Display): Parameters ---------- - scatter_data : + scatter_data : pandas.DataFrame + Dataframe containing the data to plot. x_column : str + The name of the column to plot on the x-axis. + If None, the first column of the dataframe is used. y_column : str + The name of the column to plot on the y-axis. + If None, the second column of the dataframe is used. display_label_x : str, default=None + The label to use for the x-axis. If None, the name of the column will be used. display_label_y : str, default=None + The label to use for the y-axis. If None, the name of the column will be used. data_source : str, default=None - + To specify the data source for the plot. Attributes ---------- @@ -30,11 +37,6 @@ class PairPlotDisplay(Display): ax_ : matplotlib Axes Axes with confusion matrix. - - text_ : ndarray of shape (n_classes, n_classes), dtype=matplotlib Text or \ - None - Array of matplotlib text elements containing the values in the - confusion matrix. """ @StyleDisplayMixin.style_plot @@ -47,13 +49,20 @@ def __init__( display_label_x=None, display_label_y=None, data_source=None, - display_labels=None, ): self.scatter_data = scatter_data - self.x_column = scatter_data.columns[0] - self.y_column = scatter_data.columns[1] - self.display_label_x = display_label_x - self.display_label_y = display_label_y + if x_column is None: + x_column = scatter_data.columns[0] + self.x_column = x_column + if y_column is None: + y_column = scatter_data.columns[1] + self.y_column = y_column + self.display_label_x = ( + display_label_x if display_label_x is not None else self.x_column + ) + self.display_label_y = ( + display_label_y if display_label_y is not None else self.y_column + ) self.data_source = data_source self.figure_ = None self.ax_ = None @@ -74,7 +83,6 @@ def plot(self, ax=None, **kwargs): Returns ------- self : PairPlotDisplay - Configured with the confusion matrix. """ if ax is None: fig, ax = plt.subplots() @@ -83,33 +91,15 @@ def plot(self, ax=None, **kwargs): scatter_data = self.scatter_data + title = f"{self.display_label_x} vs {self.display_label_x}" + if self.data_source is not None: + title += f" on {self.data_source} data" ax.scatter( - x=scatter_data[self.x_column], - y=scatter_data[self.y_column], - title=f"{self.display_label_x} vs {self.display_label_x} on \ - {self.data_source} data", + x=scatter_data[self.x_column], y=scatter_data[self.y_column], title=title ) ax.set_xlabel(self.display_label_x) ax.set_ylabel(self.display_label_y) - # Add labels to the points with a small offset - text = scatter_data["Estimator"] - x = scatter_data[self.x_column] - y = scatter_data[self.y_column] - for label, x_coord, y_coord in zip(text, x, y): - ax.annotate( - label, - (x_coord, y_coord), - textcoords="offset points", - xytext=(10, 0), - bbox=dict( - boxstyle="round,pad=0.3", - edgecolor="gray", - facecolor="white", - alpha=0.7, - ), - ) - self.figure_, self.ax_ = fig, ax return self @@ -119,41 +109,29 @@ def from_metrics( metrics, perf_metric_x, perf_metric_y, - data_source, + data_source=None, ): """Create a confusion matrix display from predictions. Parameters ---------- - y_true : array-like of shape (n_samples,) - True labels. - - y_pred : array-like of shape (n_samples,) - Predicted labels, as returned by a classifier. - - sample_weight : array-like of shape (n_samples,), default=None - Sample weights. + metrics : pandas.DataFrame + Dataframe containing the data to plot. The dataframe should + contain the performance metrics for each estimator. - display_labels : list of str, default=None - Target names used for plotting. By default, labels will be inferred - from y_true. + perf_metric_x : str + The name of the column to plot on the x-axis. - include_values : bool, default=True - Includes values in confusion matrix. + perf_metric_y : str + The name of the column to plot on the y-axis. - normalize : {'true', 'pred', 'all'}, default=None - Normalizes confusion matrix over the true (rows), predicted (columns) - conditions or all the population. If None, confusion matrix will not be - normalized. - - values_format : str, default=None - Format specification for values in confusion matrix. If None, the format - specification is 'd' or '.2g' whichever is shorter. + data_source : str + To specify the data source for the plot. Returns ------- display : :class:`PairPlotDisplay` - The confusion matrix display. + The scatter plot display. """ x_label = _SCORE_OR_LOSS_INFO.get(perf_metric_x, {}).get("name", perf_metric_x) y_label = _SCORE_OR_LOSS_INFO.get(perf_metric_y, {}).get("name", perf_metric_y) @@ -205,17 +183,35 @@ def from_metrics( disp = cls( scatter_data=scatter_data, - x_column=None, - y_column=None, + x_column=x_label, + y_column=y_label, display_label_x=x_label_text, display_label_y=y_label_text, - data_source=None, + data_source=data_source, ) + # Add labels to the points with a small offset + ax = disp.ax_ + text = scatter_data["Estimator"] + for label, x_coord, y_coord in zip(text, x, y): + ax.annotate( + label, + (x_coord, y_coord), + textcoords="offset points", + xytext=(10, 0), + bbox=dict( + boxstyle="round,pad=0.3", + edgecolor="gray", + facecolor="white", + alpha=0.7, + ), + ) + + disp.ax_ = ax return disp def frame(self): - """Return the confusion matrix as a dataframe. + """Return the dataframe used for the pair plot. Returns ------- From 86da68e22fd76c405f2cf6366ede0717595ce643 Mon Sep 17 00:00:00 2001 From: Marie Date: Thu, 22 May 2025 17:53:07 +0200 Subject: [PATCH 11/20] add tests --- .../tests/unit/sklearn/plot/test_pair_plot.py | 151 ++++++++++++++++++ 1 file changed, 151 insertions(+) create mode 100644 skore/tests/unit/sklearn/plot/test_pair_plot.py diff --git a/skore/tests/unit/sklearn/plot/test_pair_plot.py b/skore/tests/unit/sklearn/plot/test_pair_plot.py new file mode 100644 index 0000000000..39aaf74f5e --- /dev/null +++ b/skore/tests/unit/sklearn/plot/test_pair_plot.py @@ -0,0 +1,151 @@ +import pytest +from sklearn.datasets import make_classification +from sklearn.linear_model import LogisticRegression +from sklearn.model_selection import train_test_split +from skore import EstimatorReport +from skore.sklearn._plot import PairPlotDisplay + + +@pytest.fixture +def binary_classification_data(): + X, y = make_classification(class_sep=0.1, random_state=42) + X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42) + return ( + LogisticRegression().fit(X_train, y_train), + X_train, + X_test, + y_train, + y_test, + ) + + +@pytest.fixture +def multiclass_classification_data(): + X, y = make_classification( + class_sep=0.1, + n_classes=3, + n_clusters_per_class=1, + random_state=42, + ) + X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42) + return ( + LogisticRegression().fit(X_train, y_train), + X_train, + X_test, + y_train, + y_test, + ) + + +@pytest.fixture +def regression_data(): + X, y = make_classification( + n_samples=100, + n_features=10, + n_informative=5, + n_redundant=2, + random_state=42, + ) + X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42) + return ( + LogisticRegression().fit(X_train, y_train), + X_train, + X_test, + y_train, + y_test, + ) + + +def test_pairplot_test_set( + pyplot, + binary_classification_data, +): + estimator, X_train, X_test, y_train, y_test = binary_classification_data + report = EstimatorReport( + estimator, + X_train=X_train, + y_train=y_train, + X_test=X_test, + y_test=y_test, + ) + display = report.metrics.pairwise_plot( + perf_metric_x="fit_time", + perf_metric_y="predict_time", + data_source="train", + ) + + display.plot() + + assert isinstance(display, PairPlotDisplay) + assert hasattr(display, "figure_") + assert hasattr(display, "ax_") + + +def test_pairplot_train_set( + pyplot, + binary_classification_data, +): + estimator, X_train, X_test, y_train, y_test = binary_classification_data + report = EstimatorReport( + estimator, + X_train=X_train, + y_train=y_train, + X_test=X_test, + y_test=y_test, + ) + + display = report.metrics.pairwise_plot( + perf_metric_x="fit_time", + perf_metric_y="roc_auc", + data_source="train", + pos_label=1, + ) + + display.plot() + + assert isinstance(display, PairPlotDisplay) + assert hasattr(display, "figure_") + assert hasattr(display, "ax_") + + +def test_pairplot_missing_col( + pyplot, + binary_classification_data, +): + estimator, X_train, X_test, y_train, y_test = binary_classification_data + report = EstimatorReport( + estimator, + X_train=X_train, + y_train=y_train, + X_test=X_test, + y_test=y_test, + ) + + with pytest.raises(ValueError): + report.metrics.pairwise_plot( + perf_metric_x="fit_time", + perf_metric_y="an_invented_column", + data_source="train", + pos_label=1, + ) + + +def test_pairplot_missing_pos_label( + pyplot, + binary_classification_data, +): + estimator, X_train, X_test, y_train, y_test = binary_classification_data + report = EstimatorReport( + estimator, + X_train=X_train, + y_train=y_train, + X_test=X_test, + y_test=y_test, + ) + + with pytest.raises(ValueError): + report.metrics.pairwise_plot( + perf_metric_x="fit_time", + perf_metric_y="recall", + data_source="train", + ) From fdc7dfb8f90bcd270ae21b8be076f9484d1c75eb Mon Sep 17 00:00:00 2001 From: Marie Date: Fri, 23 May 2025 10:23:32 +0200 Subject: [PATCH 12/20] correct docstring and comments --- skore/src/skore/sklearn/_comparison/report.py | 1 - skore/src/skore/sklearn/_plot/metrics/pair_plot.py | 4 ++-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/skore/src/skore/sklearn/_comparison/report.py b/skore/src/skore/sklearn/_comparison/report.py index edca8436a5..57706b3e82 100644 --- a/skore/src/skore/sklearn/_comparison/report.py +++ b/skore/src/skore/sklearn/_comparison/report.py @@ -436,7 +436,6 @@ def pairwise_plot( """ # TODO - # - add test # - add kwargs (later) return PairPlotDisplay.from_metrics( diff --git a/skore/src/skore/sklearn/_plot/metrics/pair_plot.py b/skore/src/skore/sklearn/_plot/metrics/pair_plot.py index 2648204055..7e32b70bd9 100644 --- a/skore/src/skore/sklearn/_plot/metrics/pair_plot.py +++ b/skore/src/skore/sklearn/_plot/metrics/pair_plot.py @@ -33,10 +33,10 @@ class PairPlotDisplay(Display): Attributes ---------- figure_ : matplotlib Figure - Figure containing the confusion matrix. + Figure containing the pair plot. ax_ : matplotlib Axes - Axes with confusion matrix. + Axes with pair plot. """ @StyleDisplayMixin.style_plot From de68fd3d90a8972caa7cbb87c8e74b405443b5c2 Mon Sep 17 00:00:00 2001 From: Marie Date: Fri, 23 May 2025 10:46:15 +0200 Subject: [PATCH 13/20] bugfix --- skore/src/skore/sklearn/_plot/metrics/pair_plot.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/skore/src/skore/sklearn/_plot/metrics/pair_plot.py b/skore/src/skore/sklearn/_plot/metrics/pair_plot.py index 7e32b70bd9..2180b0da88 100644 --- a/skore/src/skore/sklearn/_plot/metrics/pair_plot.py +++ b/skore/src/skore/sklearn/_plot/metrics/pair_plot.py @@ -94,9 +94,8 @@ def plot(self, ax=None, **kwargs): title = f"{self.display_label_x} vs {self.display_label_x}" if self.data_source is not None: title += f" on {self.data_source} data" - ax.scatter( - x=scatter_data[self.x_column], y=scatter_data[self.y_column], title=title - ) + ax.scatter(x=scatter_data[self.x_column], y=scatter_data[self.y_column]) + ax.set_title(title) ax.set_xlabel(self.display_label_x) ax.set_ylabel(self.display_label_y) @@ -188,7 +187,7 @@ def from_metrics( display_label_x=x_label_text, display_label_y=y_label_text, data_source=data_source, - ) + ).plot() # Add labels to the points with a small offset ax = disp.ax_ From 37ddfd97b4dcc19770e0795660c6beeed7b01ee7 Mon Sep 17 00:00:00 2001 From: Marie Date: Fri, 23 May 2025 10:57:23 +0200 Subject: [PATCH 14/20] fix tests --- .../tests/unit/sklearn/plot/test_pair_plot.py | 121 ++++++++++-------- 1 file changed, 66 insertions(+), 55 deletions(-) diff --git a/skore/tests/unit/sklearn/plot/test_pair_plot.py b/skore/tests/unit/sklearn/plot/test_pair_plot.py index 39aaf74f5e..138b2e94ae 100644 --- a/skore/tests/unit/sklearn/plot/test_pair_plot.py +++ b/skore/tests/unit/sklearn/plot/test_pair_plot.py @@ -2,7 +2,8 @@ from sklearn.datasets import make_classification from sklearn.linear_model import LogisticRegression from sklearn.model_selection import train_test_split -from skore import EstimatorReport +from sklearn.tree import DecisionTreeClassifier +from skore import ComparisonReport, EstimatorReport from skore.sklearn._plot import PairPlotDisplay @@ -11,44 +12,6 @@ def binary_classification_data(): X, y = make_classification(class_sep=0.1, random_state=42) X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42) return ( - LogisticRegression().fit(X_train, y_train), - X_train, - X_test, - y_train, - y_test, - ) - - -@pytest.fixture -def multiclass_classification_data(): - X, y = make_classification( - class_sep=0.1, - n_classes=3, - n_clusters_per_class=1, - random_state=42, - ) - X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42) - return ( - LogisticRegression().fit(X_train, y_train), - X_train, - X_test, - y_train, - y_test, - ) - - -@pytest.fixture -def regression_data(): - X, y = make_classification( - n_samples=100, - n_features=10, - n_informative=5, - n_redundant=2, - random_state=42, - ) - X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42) - return ( - LogisticRegression().fit(X_train, y_train), X_train, X_test, y_train, @@ -60,15 +23,27 @@ def test_pairplot_test_set( pyplot, binary_classification_data, ): - estimator, X_train, X_test, y_train, y_test = binary_classification_data - report = EstimatorReport( - estimator, + X_train, X_test, y_train, y_test = binary_classification_data + est_1 = LogisticRegression() + est_2 = DecisionTreeClassifier() + report1 = EstimatorReport( + est_1, + X_train=X_train, + y_train=y_train, + X_test=X_test, + y_test=y_test, + ) + report2 = EstimatorReport( + est_2, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test, ) - display = report.metrics.pairwise_plot( + comparison = ComparisonReport( + {"Logistic Regression": report1, "Decision Tree": report2} + ) + display = comparison.pairwise_plot( perf_metric_x="fit_time", perf_metric_y="predict_time", data_source="train", @@ -85,16 +60,28 @@ def test_pairplot_train_set( pyplot, binary_classification_data, ): - estimator, X_train, X_test, y_train, y_test = binary_classification_data - report = EstimatorReport( - estimator, + X_train, X_test, y_train, y_test = binary_classification_data + est_1 = LogisticRegression() + est_2 = DecisionTreeClassifier() + report1 = EstimatorReport( + est_1, + X_train=X_train, + y_train=y_train, + X_test=X_test, + y_test=y_test, + ) + report2 = EstimatorReport( + est_2, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test, ) + comparison = ComparisonReport( + {"Logistic Regression": report1, "Decision Tree": report2} + ) - display = report.metrics.pairwise_plot( + display = comparison.pairwise_plot( perf_metric_x="fit_time", perf_metric_y="roc_auc", data_source="train", @@ -112,17 +99,29 @@ def test_pairplot_missing_col( pyplot, binary_classification_data, ): - estimator, X_train, X_test, y_train, y_test = binary_classification_data - report = EstimatorReport( - estimator, + X_train, X_test, y_train, y_test = binary_classification_data + est_1 = LogisticRegression() + est_2 = DecisionTreeClassifier() + report1 = EstimatorReport( + est_1, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test, ) + report2 = EstimatorReport( + est_2, + X_train=X_train, + y_train=y_train, + X_test=X_test, + y_test=y_test, + ) + comparison = ComparisonReport( + {"Logistic Regression": report1, "Decision Tree": report2} + ) with pytest.raises(ValueError): - report.metrics.pairwise_plot( + comparison.pairwise_plot( perf_metric_x="fit_time", perf_metric_y="an_invented_column", data_source="train", @@ -134,17 +133,29 @@ def test_pairplot_missing_pos_label( pyplot, binary_classification_data, ): - estimator, X_train, X_test, y_train, y_test = binary_classification_data - report = EstimatorReport( - estimator, + X_train, X_test, y_train, y_test = binary_classification_data + est_1 = LogisticRegression() + est_2 = DecisionTreeClassifier() + report1 = EstimatorReport( + est_1, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test, ) + report2 = EstimatorReport( + est_2, + X_train=X_train, + y_train=y_train, + X_test=X_test, + y_test=y_test, + ) + comparison = ComparisonReport( + {"Logistic Regression": report1, "Decision Tree": report2} + ) with pytest.raises(ValueError): - report.metrics.pairwise_plot( + comparison.pairwise_plot( perf_metric_x="fit_time", perf_metric_y="recall", data_source="train", From ae1fcb90c0444580a6c152cb566b723fef679c6a Mon Sep 17 00:00:00 2001 From: Marie Sacksick <79304610+MarieSacksick@users.noreply.github.com> Date: Fri, 23 May 2025 12:44:50 +0200 Subject: [PATCH 15/20] Update skore/src/skore/sklearn/_comparison/report.py Co-authored-by: Auguste Baum --- skore/src/skore/sklearn/_comparison/report.py | 1 - 1 file changed, 1 deletion(-) diff --git a/skore/src/skore/sklearn/_comparison/report.py b/skore/src/skore/sklearn/_comparison/report.py index 7d58754a5d..6cc5860d5b 100644 --- a/skore/src/skore/sklearn/_comparison/report.py +++ b/skore/src/skore/sklearn/_comparison/report.py @@ -433,7 +433,6 @@ def pairwise_plot( Returns ------- A matplotlib plot. - """ # TODO # - add kwargs (later) From cd083ac58324aa73fd5a4ccbdd7e1102b39d10d8 Mon Sep 17 00:00:00 2001 From: Marie Date: Fri, 23 May 2025 12:47:01 +0200 Subject: [PATCH 16/20] remove traces from inspiration --- skore/src/skore/sklearn/_plot/metrics/pair_plot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skore/src/skore/sklearn/_plot/metrics/pair_plot.py b/skore/src/skore/sklearn/_plot/metrics/pair_plot.py index 2180b0da88..5ed0e747a8 100644 --- a/skore/src/skore/sklearn/_plot/metrics/pair_plot.py +++ b/skore/src/skore/sklearn/_plot/metrics/pair_plot.py @@ -110,7 +110,7 @@ def from_metrics( perf_metric_y, data_source=None, ): - """Create a confusion matrix display from predictions. + """Create a pair plot display from metrics. Parameters ---------- From 72eb0dc7e1d2d509e968b724288f99e9e85fcd91 Mon Sep 17 00:00:00 2001 From: Marie Sacksick <79304610+MarieSacksick@users.noreply.github.com> Date: Wed, 28 May 2025 15:05:28 +0200 Subject: [PATCH 17/20] Update skore/src/skore/sklearn/_plot/metrics/pair_plot.py Co-authored-by: Guillaume Lemaitre --- skore/src/skore/sklearn/_plot/metrics/pair_plot.py | 1 + 1 file changed, 1 insertion(+) diff --git a/skore/src/skore/sklearn/_plot/metrics/pair_plot.py b/skore/src/skore/sklearn/_plot/metrics/pair_plot.py index 5ed0e747a8..6872b248a0 100644 --- a/skore/src/skore/sklearn/_plot/metrics/pair_plot.py +++ b/skore/src/skore/sklearn/_plot/metrics/pair_plot.py @@ -2,6 +2,7 @@ from skore.sklearn._plot.base import Display from skore.sklearn._plot.style import StyleDisplayMixin +from skore.sklearn._plot.utils import HelpDisplayMixin from skore.sklearn.utils import _SCORE_OR_LOSS_INFO From 254648f9fc0e50b624ea02dd8f0f4a3bf10c3e10 Mon Sep 17 00:00:00 2001 From: Marie Sacksick <79304610+MarieSacksick@users.noreply.github.com> Date: Wed, 28 May 2025 15:05:37 +0200 Subject: [PATCH 18/20] Update skore/src/skore/sklearn/_plot/metrics/pair_plot.py Co-authored-by: Guillaume Lemaitre --- skore/src/skore/sklearn/_plot/metrics/pair_plot.py | 1 - 1 file changed, 1 deletion(-) diff --git a/skore/src/skore/sklearn/_plot/metrics/pair_plot.py b/skore/src/skore/sklearn/_plot/metrics/pair_plot.py index 6872b248a0..0703a12045 100644 --- a/skore/src/skore/sklearn/_plot/metrics/pair_plot.py +++ b/skore/src/skore/sklearn/_plot/metrics/pair_plot.py @@ -40,7 +40,6 @@ class PairPlotDisplay(Display): Axes with pair plot. """ - @StyleDisplayMixin.style_plot def __init__( self, scatter_data, From 3294dc557c281a002406c850f6961e1de52d308f Mon Sep 17 00:00:00 2001 From: Marie Sacksick <79304610+MarieSacksick@users.noreply.github.com> Date: Wed, 28 May 2025 15:05:52 +0200 Subject: [PATCH 19/20] Update skore/src/skore/sklearn/_plot/metrics/pair_plot.py Co-authored-by: Guillaume Lemaitre --- skore/src/skore/sklearn/_plot/metrics/pair_plot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skore/src/skore/sklearn/_plot/metrics/pair_plot.py b/skore/src/skore/sklearn/_plot/metrics/pair_plot.py index 0703a12045..685ff72978 100644 --- a/skore/src/skore/sklearn/_plot/metrics/pair_plot.py +++ b/skore/src/skore/sklearn/_plot/metrics/pair_plot.py @@ -6,7 +6,7 @@ from skore.sklearn.utils import _SCORE_OR_LOSS_INFO -class PairPlotDisplay(Display): +class PairPlotDisplay(HelpDisplayMixin, StyleDisplayMixin): """Display for pair plot. Parameters From f482c5feb4a6b668231f9f81d37bb8c96c32c003 Mon Sep 17 00:00:00 2001 From: Marie Date: Wed, 28 May 2025 15:08:21 +0200 Subject: [PATCH 20/20] linting --- skore/src/skore/sklearn/_plot/metrics/pair_plot.py | 1 - 1 file changed, 1 deletion(-) diff --git a/skore/src/skore/sklearn/_plot/metrics/pair_plot.py b/skore/src/skore/sklearn/_plot/metrics/pair_plot.py index 685ff72978..b6fc4c9a42 100644 --- a/skore/src/skore/sklearn/_plot/metrics/pair_plot.py +++ b/skore/src/skore/sklearn/_plot/metrics/pair_plot.py @@ -1,6 +1,5 @@ import matplotlib.pyplot as plt -from skore.sklearn._plot.base import Display from skore.sklearn._plot.style import StyleDisplayMixin from skore.sklearn._plot.utils import HelpDisplayMixin from skore.sklearn.utils import _SCORE_OR_LOSS_INFO