Skip to content

Commit 12f0b83

Browse files
committed
change annotation to legend
1 parent 2bfbf33 commit 12f0b83

File tree

2 files changed

+33
-15
lines changed

2 files changed

+33
-15
lines changed

skore/src/skore/sklearn/_plot/metrics/metrics_summary_display.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import itertools
2+
13
import matplotlib.pyplot as plt
24
import pandas as pd
35

@@ -153,23 +155,21 @@ def _plot_comparison_estimator(self, x, y):
153155
if self.data_source is not None:
154156
title += f" on {self.data_source} set"
155157

156-
self.ax_.scatter(x=x_data, y=y_data)
157-
self.ax_.set(title=title, xlabel=x_label_text, ylabel=y_label_text)
158-
159-
# Add labels to the points with a small offset
158+
# Use a set of markers and colors for each data point
160159
text = self.summarize_data.columns.tolist()
160+
markers = itertools.cycle(("o", "s", "^", "D", "v", "P", "*", "X", "h", "8"))
161+
colors = itertools.cycle(plt.rcParams["axes.prop_cycle"].by_key()["color"])
162+
163+
handles = []
161164
for label, x_coord, y_coord in zip(text, x_data, y_data, strict=False):
162-
self.ax_.annotate(
163-
label,
164-
(x_coord, y_coord),
165-
textcoords="offset points",
166-
xytext=(10, 0),
167-
bbox=dict(
168-
boxstyle="round,pad=0.3",
169-
edgecolor="gray",
170-
facecolor="white",
171-
alpha=0.7,
172-
),
165+
marker = next(markers)
166+
color = next(colors)
167+
sc = self.ax_.scatter(
168+
x_coord, y_coord, marker=marker, color=color, label=label
173169
)
170+
handles.append(sc)
171+
172+
self.ax_.set(title=title, xlabel=x_label_text, ylabel=y_label_text)
173+
self.ax_.legend(title="Models", loc="best")
174174

175175
return self.figure_, self.ax_

test.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# %%
2+
from skore import EstimatorReport, ComparisonReport
3+
from sklearn.datasets import load_breast_cancer
4+
from sklearn.linear_model import LogisticRegression
5+
from sklearn.ensemble import HistGradientBoostingClassifier
6+
from skore import train_test_split
7+
X, y = load_breast_cancer(return_X_y=True)
8+
split_data = train_test_split(X=X, y=y, random_state=0, as_dict=True)
9+
classifier = LogisticRegression()
10+
report_a = EstimatorReport(classifier, pos_label=1, **split_data)
11+
classifier = HistGradientBoostingClassifier()
12+
report_b = EstimatorReport(classifier, pos_label=1, **split_data)
13+
comparison_report = ComparisonReport(
14+
{"report_a": report_a, "report_b": report_b}
15+
)
16+
display = comparison_report.metrics.summarize()
17+
display.plot(x="roc_auc", y="fit_time")
18+
# %%

0 commit comments

Comments
 (0)