|
1 | 1 | import numbers |
2 | 2 | from collections import namedtuple |
3 | | -from typing import Any, Literal |
| 3 | +from typing import Any, Literal, cast |
4 | 4 |
|
5 | 5 | import matplotlib.pyplot as plt |
6 | 6 | import numpy as np |
@@ -265,7 +265,7 @@ def _plot_single_estimator( |
265 | 265 | self.ax_.legend(handles, labels, loc="lower right") |
266 | 266 | self.ax_.set_title(f"Prediction Error for {estimator_name}") |
267 | 267 |
|
268 | | - return scatter |
| 268 | + return cast(list[Artist], scatter) |
269 | 269 |
|
270 | 270 | def _plot_cross_validated_estimator( |
271 | 271 | self, |
@@ -352,7 +352,7 @@ def _plot_cross_validated_estimator( |
352 | 352 | self.ax_.legend(handles, labels, loc="lower right", title=legend_title) |
353 | 353 | self.ax_.set_title(f"Prediction Error for {estimator_name}") |
354 | 354 |
|
355 | | - return scatter |
| 355 | + return cast(list[Artist], scatter) |
356 | 356 |
|
357 | 357 | def _plot_comparison_estimator( |
358 | 358 | self, |
@@ -435,7 +435,7 @@ def _plot_comparison_estimator( |
435 | 435 | self.ax_.legend(handles, labels, loc="lower right", title=legend_title) |
436 | 436 | self.ax_.set_title("Prediction Error") |
437 | 437 |
|
438 | | - return scatter |
| 438 | + return cast(list[Artist], scatter) |
439 | 439 |
|
440 | 440 | def _plot_comparison_cross_validation( |
441 | 441 | self, |
@@ -518,7 +518,7 @@ def _plot_comparison_cross_validation( |
518 | 518 | self.ax_.legend(handles, labels, loc="lower right", title=legend_title) |
519 | 519 | self.ax_.set_title("Prediction Error") |
520 | 520 |
|
521 | | - return scatter |
| 521 | + return cast(list[Artist], scatter) |
522 | 522 |
|
523 | 523 | @DisplayMixin.style_plot |
524 | 524 | def plot( |
@@ -824,9 +824,9 @@ def _compute_data_for_display( |
824 | 824 | } |
825 | 825 | ) |
826 | 826 | else: |
827 | | - y_true_sample = y_true_i.y |
828 | | - y_pred_sample = y_pred_i.y |
829 | | - residuals_sample = y_true_i.y - y_pred_i.y |
| 827 | + y_true_sample = cast(np.typing.NDArray, y_true_i.y) |
| 828 | + y_pred_sample = cast(np.typing.NDArray, y_pred_i.y) |
| 829 | + residuals_sample = y_true_sample - y_pred_sample |
830 | 830 |
|
831 | 831 | for y_true_sample_i, y_pred_sample_i, residuals_sample_i in zip( |
832 | 832 | y_true_sample, y_pred_sample, residuals_sample, strict=False |
|
0 commit comments