Skip to content

Commit 0fd1f4a

Browse files
YoussefMAllamglemaitrethomass-dev
authored
fix: Raise an error when CrossValidationReport fails before at least one estimator is fitted (#1574)
Co-authored-by: Guillaume Lemaitre <[email protected]> Co-authored-by: Thomas S. <[email protected]>
1 parent d35228d commit 0fd1f4a

File tree

3 files changed

+99
-56
lines changed

3 files changed

+99
-56
lines changed

skore/src/skore/sklearn/_cross_validation/report.py

Lines changed: 56 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,18 @@ def _generate_estimator_report(
2929
y: Optional[ArrayLike],
3030
train_indices: ArrayLike,
3131
test_indices: ArrayLike,
32-
) -> EstimatorReport:
33-
return EstimatorReport(
34-
estimator,
35-
fit=True,
36-
X_train=_safe_indexing(X, train_indices),
37-
y_train=_safe_indexing(y, train_indices),
38-
X_test=_safe_indexing(X, test_indices),
39-
y_test=_safe_indexing(y, test_indices),
40-
)
32+
) -> Union[EstimatorReport, KeyboardInterrupt, Exception]:
33+
try:
34+
return EstimatorReport(
35+
estimator,
36+
fit=True,
37+
X_train=_safe_indexing(X, train_indices),
38+
y_train=_safe_indexing(y, train_indices),
39+
X_test=_safe_indexing(X, test_indices),
40+
y_test=_safe_indexing(y, test_indices),
41+
)
42+
except (KeyboardInterrupt, Exception) as e:
43+
return e
4144

4245

4346
class CrossValidationReport(_BaseReport, DirNamesMixin):
@@ -198,31 +201,55 @@ def _fit_estimator_reports(self) -> list[EstimatorReport]:
198201
)
199202

200203
estimator_reports = []
201-
try:
202-
for report in generator:
203-
estimator_reports.append(report)
204-
progress.update(task, advance=1, refresh=True)
205-
except (Exception, KeyboardInterrupt) as e:
206-
from skore import console # avoid circular import
204+
for report in generator:
205+
estimator_reports.append(report)
206+
progress.update(task, advance=1, refresh=True)
207+
208+
warn_msg = None
209+
if not any (
210+
isinstance(report, EstimatorReport)
211+
for report in estimator_reports
212+
):
213+
traceback_msg = "\n".join(str(exc) for exc in estimator_reports)
214+
raise RuntimeError(
215+
"Cross-validation failed: no estimators were successfully fitted. "
216+
"Please check your data, estimator, or cross-validation setup.\n"
217+
f"Traceback: \n{traceback_msg}"
218+
)
219+
elif any(isinstance(report, Exception) for report in estimator_reports):
220+
msg_traceback = "\n".join(
221+
str(exc) for exc in estimator_reports if isinstance(exc, Exception)
222+
)
223+
warn_msg = (
224+
"Cross-validation process was interrupted by an error before "
225+
"all estimators could be fitted; CrossValidationReport object "
226+
"might not contain all the expected results.\n"
227+
f"Traceback: \n{msg_traceback}"
228+
)
229+
estimator_reports = [
230+
report
231+
for report in estimator_reports
232+
if not isinstance(report, Exception)
233+
]
234+
if any(isinstance(report, KeyboardInterrupt) for report in estimator_reports):
235+
warn_msg = (
236+
"Cross-validation process was interrupted manually before all "
237+
"estimators could be fitted; CrossValidationReport object "
238+
"might not contain all the expected results."
239+
)
240+
estimator_reports = [
241+
report
242+
for report in estimator_reports
243+
if not isinstance(report, KeyboardInterrupt)
244+
]
207245

208-
if isinstance(e, KeyboardInterrupt):
209-
message = (
210-
"Cross-validation process was interrupted manually before all "
211-
"estimators could be fitted; CrossValidationReport object "
212-
"might not contain all the expected results."
213-
)
214-
else:
215-
message = (
216-
"Cross-validation process was interrupted by an error before "
217-
"all estimators could be fitted; CrossValidationReport object "
218-
"might not contain all the expected results. "
219-
f"Traceback: \n{e}"
220-
)
246+
if warn_msg is not None:
247+
from skore import console # avoid circular import
221248

222249
console.print(
223250
Panel(
224251
title="Cross-validation interrupted",
225-
renderable=message,
252+
renderable=warn_msg,
226253
style="orange1",
227254
border_style="cyan",
228255
)

skore/src/skore/utils/_testing.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import contextlib
22
import copy
33

4+
import numpy as np
5+
from sklearn.base import BaseEstimator, ClassifierMixin
6+
47

58
@contextlib.contextmanager
69
def check_cache_changed(value):
@@ -16,3 +19,23 @@ def check_cache_unchanged(value):
1619
initial_value = copy.copy(value)
1720
yield
1821
assert value == initial_value
22+
23+
24+
class MockEstimator(ClassifierMixin, BaseEstimator):
25+
def __init__(self, *, error, n_call=0, fail_after_n_clone=3):
26+
self.error = error
27+
self.n_call = n_call
28+
self.fail_after_n_clone = fail_after_n_clone
29+
30+
def fit(self, X, y):
31+
if self.n_call > self.fail_after_n_clone:
32+
raise self.error
33+
self.classes_ = np.unique(y)
34+
return self
35+
36+
def __sklearn_clone__(self):
37+
self.n_call += 1
38+
return self
39+
40+
def predict(self, X):
41+
return np.ones(X.shape[0])

skore/tests/unit/sklearn/cross_validation/test_cross_validation.py

Lines changed: 20 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import numpy as np
55
import pandas as pd
66
import pytest
7-
from sklearn.base import BaseEstimator, ClassifierMixin, clone
7+
from sklearn.base import clone
88
from sklearn.datasets import make_classification, make_regression
99
from sklearn.ensemble import RandomForestClassifier
1010
from sklearn.exceptions import NotFittedError
@@ -26,6 +26,7 @@
2626
)
2727
from skore.sklearn._estimator import EstimatorReport
2828
from skore.sklearn._plot import RocCurveDisplay
29+
from skore.utils._testing import MockEstimator
2930

3031

3132
@pytest.fixture
@@ -895,38 +896,16 @@ def test_cross_validation_report_custom_metric(binary_classification_data):
895896
(KeyboardInterrupt(), "Cross-validation interrupted manually"),
896897
],
897898
)
899+
@pytest.mark.parametrize("n_jobs", [None, 1, 2])
898900
def test_cross_validation_report_interrupted(
899-
binary_classification_data, capsys, error, error_message
901+
binary_classification_data, capsys, error, error_message, n_jobs
900902
):
901903
"""Check that we can interrupt cross-validation without losing all
902904
data."""
903-
904-
class MockEstimator(ClassifierMixin, BaseEstimator):
905-
def __init__(self, n_call=0, fail_after_n_clone=3):
906-
self.n_call = n_call
907-
self.fail_after_n_clone = fail_after_n_clone
908-
909-
def fit(self, X, y):
910-
if self.n_call > self.fail_after_n_clone:
911-
raise error
912-
self.classes_ = np.unique(y)
913-
return self
914-
915-
def __sklearn_clone__(self):
916-
"""Do not clone the estimator
917-
918-
Instead, we increment a counter each time that
919-
`sklearn.clone` is called.
920-
"""
921-
self.n_call += 1
922-
return self
923-
924-
def predict(self, X):
925-
return np.ones(X.shape[0])
926-
927905
_, X, y = binary_classification_data
928906

929-
report = CrossValidationReport(MockEstimator(), X, y, cv_splitter=10)
907+
estimator = MockEstimator(error=error, n_call=0, fail_after_n_clone=8)
908+
report = CrossValidationReport(estimator, X, y, cv_splitter=10, n_jobs=n_jobs)
930909

931910
captured = capsys.readouterr()
932911
assert all(word in captured.out for word in error_message.split(" "))
@@ -990,6 +969,20 @@ def test_cross_validation_timings(
990969
assert timings.columns.tolist() == expected_columns
991970

992971

972+
@pytest.mark.parametrize("n_jobs", [None, 1, 2])
973+
def test_cross_validation_report_failure_all_splits(n_jobs):
974+
"""Check that we raise an error when no estimators were successfully fitted.
975+
during the cross-validation process."""
976+
X, y = make_classification(n_samples=100, n_features=10, random_state=42)
977+
estimator = MockEstimator(
978+
error=ValueError("Intentional failure for testing"), fail_after_n_clone=0
979+
)
980+
981+
err_msg = "Cross-validation failed: no estimators were successfully fitted"
982+
with pytest.raises(RuntimeError, match=err_msg):
983+
CrossValidationReport(estimator, X, y, n_jobs=n_jobs)
984+
985+
993986
def test_cross_validation_timings_flat_index(binary_classification_data):
994987
"""Check the behaviour of the `timings` method display formatting."""
995988
estimator, X, y = binary_classification_data

0 commit comments

Comments
 (0)