-
Notifications
You must be signed in to change notification settings - Fork 100
feat: Turn report_metrics of ComparisonReport into Displays #1520
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
aa20336
881e2b6
aa16872
987bb13
63b878e
c8c5ddb
f8fee6d
1524d46
653326f
b8866ac
e288561
43e83a8
86da68e
fdc7dfb
de68fd3
37ddfd9
e71d9db
ae1fcb9
cd083ac
cc2f43c
72eb0dc
254648f
3294dc5
f482c5f
a63dd79
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,219 @@ | ||||||||
| import matplotlib.pyplot as plt | ||||||||
|
|
||||||||
| from skore.sklearn._plot.style import StyleDisplayMixin | ||||||||
MarieSacksick marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||
| from skore.sklearn._plot.utils import HelpDisplayMixin | ||||||||
| from skore.sklearn.utils import _SCORE_OR_LOSS_INFO | ||||||||
|
|
||||||||
|
|
||||||||
| class PairPlotDisplay(HelpDisplayMixin, StyleDisplayMixin): | ||||||||
| """Display for pair plot. | ||||||||
|
|
||||||||
| Parameters | ||||||||
| ---------- | ||||||||
| scatter_data : pandas.DataFrame | ||||||||
| Dataframe containing the data to plot. | ||||||||
|
|
||||||||
| x_column : str | ||||||||
| The name of the column to plot on the x-axis. | ||||||||
| If None, the first column of the dataframe is used. | ||||||||
|
|
||||||||
| y_column : str | ||||||||
| The name of the column to plot on the y-axis. | ||||||||
| If None, the second column of the dataframe is used. | ||||||||
|
|
||||||||
| display_label_x : str, default=None | ||||||||
| The label to use for the x-axis. If None, the name of the column will be used. | ||||||||
|
|
||||||||
| display_label_y : str, default=None | ||||||||
| The label to use for the y-axis. If None, the name of the column will be used. | ||||||||
|
|
||||||||
| data_source : str, default=None | ||||||||
| To specify the data source for the plot. | ||||||||
|
|
||||||||
| Attributes | ||||||||
| ---------- | ||||||||
| figure_ : matplotlib Figure | ||||||||
| Figure containing the pair plot. | ||||||||
|
|
||||||||
| ax_ : matplotlib Axes | ||||||||
| Axes with pair plot. | ||||||||
| """ | ||||||||
|
|
||||||||
| def __init__( | ||||||||
| self, | ||||||||
| scatter_data, | ||||||||
| *, | ||||||||
| x_column=None, | ||||||||
| y_column=None, | ||||||||
| display_label_x=None, | ||||||||
| display_label_y=None, | ||||||||
| data_source=None, | ||||||||
| ): | ||||||||
| self.scatter_data = scatter_data | ||||||||
| if x_column is None: | ||||||||
| x_column = scatter_data.columns[0] | ||||||||
| self.x_column = x_column | ||||||||
| if y_column is None: | ||||||||
| y_column = scatter_data.columns[1] | ||||||||
| self.y_column = y_column | ||||||||
| self.display_label_x = ( | ||||||||
| display_label_x if display_label_x is not None else self.x_column | ||||||||
| ) | ||||||||
| self.display_label_y = ( | ||||||||
| display_label_y if display_label_y is not None else self.y_column | ||||||||
| ) | ||||||||
| self.data_source = data_source | ||||||||
| self.figure_ = None | ||||||||
| self.ax_ = None | ||||||||
| self.text_ = None | ||||||||
|
Comment on lines
+66
to
+68
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For consistency, those are created only when plot is called. We can see in a subsequent PR if we want to make consistent this behaviour with an initialization. |
||||||||
|
|
||||||||
| def plot(self, ax=None, **kwargs): | ||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is this method that would benefit from the style.
Suggested change
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we don't need
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can also remove kwargs because it is unused. |
||||||||
| """Plot a given performance metric against another. | ||||||||
|
|
||||||||
| Parameters | ||||||||
| ---------- | ||||||||
| ax : matplotlib axes, default=None | ||||||||
| Axes object to plot on. If None, a new figure and axes is created. | ||||||||
|
|
||||||||
| **kwargs : dict | ||||||||
| Additional keyword arguments to be passed to matplotlib's | ||||||||
| `ax.imshow`. | ||||||||
|
|
||||||||
| Returns | ||||||||
| ------- | ||||||||
| self : PairPlotDisplay | ||||||||
| """ | ||||||||
| if ax is None: | ||||||||
| fig, ax = plt.subplots() | ||||||||
| else: | ||||||||
| fig = ax.figure | ||||||||
|
|
||||||||
| scatter_data = self.scatter_data | ||||||||
|
|
||||||||
| title = f"{self.display_label_x} vs {self.display_label_x}" | ||||||||
| if self.data_source is not None: | ||||||||
| title += f" on {self.data_source} data" | ||||||||
| ax.scatter(x=scatter_data[self.x_column], y=scatter_data[self.y_column]) | ||||||||
| ax.set_title(title) | ||||||||
| ax.set_xlabel(self.display_label_x) | ||||||||
| ax.set_ylabel(self.display_label_y) | ||||||||
|
|
||||||||
| self.figure_, self.ax_ = fig, ax | ||||||||
| return self | ||||||||
|
|
||||||||
| @classmethod | ||||||||
| def from_metrics( | ||||||||
| cls, | ||||||||
| metrics, | ||||||||
| perf_metric_x, | ||||||||
| perf_metric_y, | ||||||||
| data_source=None, | ||||||||
| ): | ||||||||
|
Comment on lines
+104
to
+111
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The display should not expose this public function. The idea is that the reporters will be the only object that can create an instance of You can have a look at |
||||||||
| """Create a pair plot display from metrics. | ||||||||
|
|
||||||||
| Parameters | ||||||||
| ---------- | ||||||||
| metrics : pandas.DataFrame | ||||||||
| Dataframe containing the data to plot. The dataframe should | ||||||||
| contain the performance metrics for each estimator. | ||||||||
|
|
||||||||
| perf_metric_x : str | ||||||||
| The name of the column to plot on the x-axis. | ||||||||
|
|
||||||||
| perf_metric_y : str | ||||||||
| The name of the column to plot on the y-axis. | ||||||||
|
|
||||||||
| data_source : str | ||||||||
| To specify the data source for the plot. | ||||||||
|
|
||||||||
| Returns | ||||||||
| ------- | ||||||||
| display : :class:`PairPlotDisplay` | ||||||||
| The scatter plot display. | ||||||||
| """ | ||||||||
| x_label = _SCORE_OR_LOSS_INFO.get(perf_metric_x, {}).get("name", perf_metric_x) | ||||||||
| y_label = _SCORE_OR_LOSS_INFO.get(perf_metric_y, {}).get("name", perf_metric_y) | ||||||||
|
Comment on lines
+134
to
+135
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think that those should be passed directly by the methods from the report. It would be handy because we would have access to the dictionary |
||||||||
| scatter_data = metrics | ||||||||
|
|
||||||||
| # Check that the metrics are in the report | ||||||||
| # If the metric is not in the report, help the user by suggesting | ||||||||
| # supported metrics | ||||||||
| reverse_score_info = { | ||||||||
| value["name"]: key for key, value in _SCORE_OR_LOSS_INFO.items() | ||||||||
| } | ||||||||
| available_columns = scatter_data.columns.get_level_values(0).to_list() | ||||||||
| available_columns.remove("Estimator") | ||||||||
| supported_metrics = [ | ||||||||
| reverse_score_info.get(col, col) for col in available_columns | ||||||||
| ] | ||||||||
| if perf_metric_x not in supported_metrics: | ||||||||
| raise ValueError( | ||||||||
| f"Performance metric {perf_metric_x} not found in the report. " | ||||||||
| f"Supported metrics are: {supported_metrics}." | ||||||||
| ) | ||||||||
| if perf_metric_y not in supported_metrics: | ||||||||
| raise ValueError( | ||||||||
| f"Performance metric {perf_metric_y} not found in the report. " | ||||||||
| f"Supported metrics are: {supported_metrics}." | ||||||||
| ) | ||||||||
|
|
||||||||
| # Check that x and y are 1D arrays (i.e. the metrics don't need pos_label) | ||||||||
| x = scatter_data[x_label] | ||||||||
| y = scatter_data[y_label] | ||||||||
| if len(x.shape) > 1: | ||||||||
| raise ValueError( | ||||||||
| "The perf metric x requires to add a positive label parameter." | ||||||||
| ) | ||||||||
| if len(y.shape) > 1: | ||||||||
| raise ValueError( | ||||||||
| "The perf metric y requires to add a positive label parameter." | ||||||||
| ) | ||||||||
|
|
||||||||
| # Make it clear in the axis labels that we are using the train set | ||||||||
| if perf_metric_x == "fit_time" and data_source != "train": | ||||||||
| x_label_text = x_label + " on train set" | ||||||||
| else: | ||||||||
| x_label_text = x_label | ||||||||
| if perf_metric_y == "fit_time" and data_source != "train": | ||||||||
| y_label_text = y_label + " on train set" | ||||||||
| else: | ||||||||
| y_label_text = y_label | ||||||||
|
|
||||||||
| disp = cls( | ||||||||
| scatter_data=scatter_data, | ||||||||
| x_column=x_label, | ||||||||
| y_column=y_label, | ||||||||
| display_label_x=x_label_text, | ||||||||
| display_label_y=y_label_text, | ||||||||
| data_source=data_source, | ||||||||
| ).plot() | ||||||||
|
|
||||||||
| # Add labels to the points with a small offset | ||||||||
| ax = disp.ax_ | ||||||||
| text = scatter_data["Estimator"] | ||||||||
| for label, x_coord, y_coord in zip(text, x, y): | ||||||||
| ax.annotate( | ||||||||
| label, | ||||||||
| (x_coord, y_coord), | ||||||||
| textcoords="offset points", | ||||||||
| xytext=(10, 0), | ||||||||
| bbox=dict( | ||||||||
| boxstyle="round,pad=0.3", | ||||||||
| edgecolor="gray", | ||||||||
| facecolor="white", | ||||||||
| alpha=0.7, | ||||||||
| ), | ||||||||
| ) | ||||||||
|
|
||||||||
| disp.ax_ = ax | ||||||||
| return disp | ||||||||
|
|
||||||||
| def frame(self): | ||||||||
| """Return the dataframe used for the pair plot. | ||||||||
|
|
||||||||
| Returns | ||||||||
| ------- | ||||||||
| scatter_data : pandas.DataFrame | ||||||||
| The dataframe used to create the scatter plot. | ||||||||
| """ | ||||||||
| return self.scatter_data | ||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,16 @@ | ||
| """Utility functions for Skore and Scikit-learn integration.""" | ||
|
|
||
| _SCORE_OR_LOSS_INFO: dict[str, dict[str, str]] = { | ||
| "fit_time": {"name": "Fit time (s)", "icon": "(↘︎)"}, | ||
| "predict_time": {"name": "Predict time (s)", "icon": "(↘︎)"}, | ||
| "accuracy": {"name": "Accuracy", "icon": "(↗︎)"}, | ||
| "precision": {"name": "Precision", "icon": "(↗︎)"}, | ||
| "recall": {"name": "Recall", "icon": "(↗︎)"}, | ||
| "brier_score": {"name": "Brier score", "icon": "(↘︎)"}, | ||
| "roc_auc": {"name": "ROC AUC", "icon": "(↗︎)"}, | ||
| "log_loss": {"name": "Log loss", "icon": "(↘︎)"}, | ||
| "r2": {"name": "R²", "icon": "(↗︎)"}, | ||
| "rmse": {"name": "RMSE", "icon": "(↘︎)"}, | ||
| "custom_metric": {"name": "Custom metric", "icon": ""}, | ||
| "report_metrics": {"name": "Report metrics", "icon": ""}, | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One thing that I realized with the implementation now is that we are going to want most of the parameters to pass them to
report_metrics.Now, I'm thinking that it would means that the
PairPlotDisplayis just a kind of plot associated withreport_metrics. In short, I think that it would make sense to be able to write:but also
And it allows to pass the arguments as: