|
| 1 | +import itertools |
| 2 | + |
1 | 3 | import matplotlib.pyplot as plt |
2 | 4 | import pandas as pd |
3 | 5 |
|
@@ -153,23 +155,21 @@ def _plot_comparison_estimator(self, x, y): |
153 | 155 | if self.data_source is not None: |
154 | 156 | title += f" on {self.data_source} set" |
155 | 157 |
|
156 | | - self.ax_.scatter(x=x_data, y=y_data) |
157 | | - self.ax_.set(title=title, xlabel=x_label_text, ylabel=y_label_text) |
158 | | - |
159 | | - # Add labels to the points with a small offset |
| 158 | + # Use a set of markers and colors for each data point |
160 | 159 | text = self.summarize_data.columns.tolist() |
| 160 | + markers = itertools.cycle(("o", "s", "^", "D", "v", "P", "*", "X", "h", "8")) |
| 161 | + colors = itertools.cycle(plt.rcParams["axes.prop_cycle"].by_key()["color"]) |
| 162 | + |
| 163 | + handles = [] |
161 | 164 | for label, x_coord, y_coord in zip(text, x_data, y_data, strict=False): |
162 | | - self.ax_.annotate( |
163 | | - label, |
164 | | - (x_coord, y_coord), |
165 | | - textcoords="offset points", |
166 | | - xytext=(10, 0), |
167 | | - bbox=dict( |
168 | | - boxstyle="round,pad=0.3", |
169 | | - edgecolor="gray", |
170 | | - facecolor="white", |
171 | | - alpha=0.7, |
172 | | - ), |
| 165 | + marker = next(markers) |
| 166 | + color = next(colors) |
| 167 | + sc = self.ax_.scatter( |
| 168 | + x_coord, y_coord, marker=marker, color=color, label=label |
173 | 169 | ) |
| 170 | + handles.append(sc) |
| 171 | + |
| 172 | + self.ax_.set(title=title, xlabel=x_label_text, ylabel=y_label_text) |
| 173 | + self.ax_.legend(title="Models", loc="best") |
174 | 174 |
|
175 | 175 | return self.figure_, self.ax_ |
0 commit comments