Skip to content

Commit 4c2fe8b

Browse files
committed
chore(skore): Fix typings
1 parent a58f8d8 commit 4c2fe8b

File tree

12 files changed

+43
-26
lines changed

12 files changed

+43
-26
lines changed

skore/pyproject.toml

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -159,9 +159,16 @@ convention = "numpy"
159159
"tests/*" = ["D"]
160160

161161
[tool.mypy]
162-
ignore_missing_imports = true
163-
exclude = ["src/skore/_externals/.*", "hatch/*", "tests/*"]
162+
exclude = ["src/skore/_externals/", "hatch/", "tests/"]
164163

165164
[[tool.mypy.overrides]]
166-
module = ["sklearn.*"]
167165
ignore_missing_imports = true
166+
module = [
167+
"ipywidgets.*",
168+
"joblib.*",
169+
"pandas.*",
170+
"plotly.*",
171+
"seaborn.*",
172+
"sklearn.*",
173+
"skrub.*",
174+
]

skore/src/skore/_sklearn/_comparison/report.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,7 @@ def get_predictions(
342342
] = "predict",
343343
X: ArrayLike | None = None,
344344
pos_label: PositiveLabel | None = _DEFAULT,
345-
) -> list[ArrayLike]:
345+
) -> list[ArrayLike | list[ArrayLike]]:
346346
"""Get predictions from the underlying reports.
347347
348348
This method has the advantage to reload from the cache if the predictions

skore/src/skore/_sklearn/_cross_validation/data_accessor.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def _retrieve_data_as_frame(
4242
y = self._parent.y
4343

4444
if not sbd.is_dataframe(X):
45-
X = pd.DataFrame(X, columns=[f"Feature {i}" for i in range(X.shape[1])])
45+
X = pd.DataFrame(X, columns=[f"Feature {i}" for i in range(X.shape[1])]) # type: ignore
4646

4747
if with_y:
4848
if y is None:
@@ -52,10 +52,11 @@ def _retrieve_data_as_frame(
5252
name = y.name if y.name is not None else "Target"
5353
y = y.to_frame(name=name)
5454
elif not sbd.is_dataframe(y):
55-
if y.ndim == 1:
55+
if y.ndim == 1: # type: ignore
5656
columns = ["Target"]
5757
else:
58-
columns = [f"Target {i}" for i in range(y.shape[1])]
58+
columns = [f"Target {i}" for i in range(y.shape[1])] # type: ignore
59+
5960
y = pd.DataFrame(y, columns=columns)
6061

6162
return X, y

skore/src/skore/_sklearn/_cross_validation/metrics_accessor.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1143,16 +1143,17 @@ def _get_display(
11431143
X, y, _ = report.metrics._get_X_y_and_data_source_hash(
11441144
data_source=data_source
11451145
)
1146+
11461147
y_true.append(
11471148
YPlotData(
11481149
estimator_name=self._parent.estimator_name_,
11491150
split=report_idx,
1150-
y=y,
1151+
y=cast(ArrayLike, y),
11511152
)
11521153
)
11531154
results = _get_cached_response_values(
11541155
cache=report._cache,
1155-
estimator_hash=report._hash,
1156+
estimator_hash=int(report._hash),
11561157
estimator=report._estimator,
11571158
X=X,
11581159
response_method=response_method,

skore/src/skore/_sklearn/_estimator/feature_importance_accessor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -547,7 +547,7 @@ def _feature_permutation(
547547
feature_names = (
548548
self._parent.estimator_.feature_names_in_
549549
if hasattr(self._parent.estimator_, "feature_names_in_")
550-
else [f"Feature #{i}" for i in range(X_.shape[1])]
550+
else [f"Feature #{i}" for i in range(X_.shape[1])] # type: ignore
551551
)
552552

553553
# If there is more than one metric

skore/src/skore/_sklearn/_estimator/metrics_accessor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -482,7 +482,7 @@ def _compute_metric_scores(
482482

483483
results = _get_cached_response_values(
484484
cache=self._parent._cache,
485-
estimator_hash=self._parent._hash,
485+
estimator_hash=int(self._parent._hash),
486486
estimator=self._parent.estimator_,
487487
X=X,
488488
response_method=response_method,
@@ -1674,7 +1674,7 @@ def _get_display(
16741674
else:
16751675
results = _get_cached_response_values(
16761676
cache=self._parent._cache,
1677-
estimator_hash=self._parent._hash,
1677+
estimator_hash=int(self._parent._hash),
16781678
estimator=self._parent.estimator_,
16791679
X=X,
16801680
response_method=response_method,

skore/src/skore/_sklearn/_estimator/report.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,7 @@ def get_predictions(
389389

390390
results = _get_cached_response_values(
391391
cache=self._cache,
392-
estimator_hash=self._hash,
392+
estimator_hash=int(self._hash),
393393
estimator=self._estimator,
394394
X=X_,
395395
response_method=response_method,

skore/src/skore/_sklearn/_plot/metrics/precision_recall_curve.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -865,17 +865,19 @@ def _compute_data_for_display(
865865
):
866866
label_binarizer = LabelBinarizer().fit(est.classes_)
867867
y_true_onehot_i: NDArray = label_binarizer.transform(y_true_i.y)
868+
y_pred_i_y = cast(NDArray, y_pred_i.y)
869+
868870
for class_idx, class_ in enumerate(est.classes_):
869871
precision_class_i, recall_class_i, thresholds_class_i = (
870872
precision_recall_curve(
871873
y_true_onehot_i[:, class_idx],
872-
y_pred_i.y[:, class_idx],
874+
y_pred_i_y[:, class_idx],
873875
pos_label=None,
874876
drop_intermediate=drop_intermediate,
875877
)
876878
)
877879
average_precision_class_i = average_precision_score(
878-
y_true_onehot_i[:, class_idx], y_pred_i.y[:, class_idx]
880+
y_true_onehot_i[:, class_idx], y_pred_i_y[:, class_idx]
879881
)
880882

881883
for precision, recall, threshold in zip(

skore/src/skore/_sklearn/_plot/metrics/prediction_error.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import numbers
22
from collections import namedtuple
3-
from typing import Any, Literal
3+
from typing import Any, Literal, cast
44

55
import matplotlib.pyplot as plt
66
import numpy as np
@@ -265,7 +265,7 @@ def _plot_single_estimator(
265265
self.ax_.legend(handles, labels, loc="lower right")
266266
self.ax_.set_title(f"Prediction Error for {estimator_name}")
267267

268-
return scatter
268+
return cast(list[Artist], scatter)
269269

270270
def _plot_cross_validated_estimator(
271271
self,
@@ -352,7 +352,7 @@ def _plot_cross_validated_estimator(
352352
self.ax_.legend(handles, labels, loc="lower right", title=legend_title)
353353
self.ax_.set_title(f"Prediction Error for {estimator_name}")
354354

355-
return scatter
355+
return cast(list[Artist], scatter)
356356

357357
def _plot_comparison_estimator(
358358
self,
@@ -435,7 +435,7 @@ def _plot_comparison_estimator(
435435
self.ax_.legend(handles, labels, loc="lower right", title=legend_title)
436436
self.ax_.set_title("Prediction Error")
437437

438-
return scatter
438+
return cast(list[Artist], scatter)
439439

440440
def _plot_comparison_cross_validation(
441441
self,
@@ -518,7 +518,7 @@ def _plot_comparison_cross_validation(
518518
self.ax_.legend(handles, labels, loc="lower right", title=legend_title)
519519
self.ax_.set_title("Prediction Error")
520520

521-
return scatter
521+
return cast(list[Artist], scatter)
522522

523523
@DisplayMixin.style_plot
524524
def plot(
@@ -824,9 +824,9 @@ def _compute_data_for_display(
824824
}
825825
)
826826
else:
827-
y_true_sample = y_true_i.y
828-
y_pred_sample = y_pred_i.y
829-
residuals_sample = y_true_i.y - y_pred_i.y
827+
y_true_sample = cast(np.typing.NDArray, y_true_i.y)
828+
y_pred_sample = cast(np.typing.NDArray, y_pred_i.y)
829+
residuals_sample = y_true_sample - y_pred_sample
830830

831831
for y_true_sample_i, y_pred_sample_i, residuals_sample_i in zip(
832832
y_true_sample, y_pred_sample, residuals_sample, strict=False

skore/src/skore/_sklearn/_plot/metrics/roc_curve.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,8 @@ def __init__(
148148
self.ml_task = ml_task
149149
self.report_type = report_type
150150

151+
self.chance_level_: Line2D | list[Line2D] | None
152+
151153
def _plot_single_estimator(
152154
self,
153155
*,
@@ -947,10 +949,12 @@ def _compute_data_for_display(
947949
):
948950
label_binarizer = LabelBinarizer().fit(est.classes_)
949951
y_true_onehot_i: NDArray = label_binarizer.transform(y_true_i.y)
952+
y_pred_i_y = cast(NDArray, y_pred_i.y)
953+
950954
for class_idx, class_ in enumerate(est.classes_):
951955
fpr_class_i, tpr_class_i, thresholds_class_i = roc_curve(
952956
y_true_onehot_i[:, class_idx],
953-
y_pred_i.y[:, class_idx],
957+
y_pred_i_y[:, class_idx],
954958
pos_label=None,
955959
drop_intermediate=drop_intermediate,
956960
)

0 commit comments

Comments
 (0)