Skip to content
85 changes: 56 additions & 29 deletions skore/src/skore/sklearn/_cross_validation/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,18 @@ def _generate_estimator_report(
y: Optional[ArrayLike],
train_indices: ArrayLike,
test_indices: ArrayLike,
) -> EstimatorReport:
return EstimatorReport(
estimator,
fit=True,
X_train=_safe_indexing(X, train_indices),
y_train=_safe_indexing(y, train_indices),
X_test=_safe_indexing(X, test_indices),
y_test=_safe_indexing(y, test_indices),
)
) -> Union[EstimatorReport, KeyboardInterrupt, Exception]:
try:
return EstimatorReport(
estimator,
fit=True,
X_train=_safe_indexing(X, train_indices),
y_train=_safe_indexing(y, train_indices),
X_test=_safe_indexing(X, test_indices),
y_test=_safe_indexing(y, test_indices),
)
except (KeyboardInterrupt, Exception) as e:
return e


class CrossValidationReport(_BaseReport, DirNamesMixin):
Expand Down Expand Up @@ -198,31 +201,55 @@ def _fit_estimator_reports(self) -> list[EstimatorReport]:
)

estimator_reports = []
try:
for report in generator:
estimator_reports.append(report)
progress.update(task, advance=1, refresh=True)
except (Exception, KeyboardInterrupt) as e:
from skore import console # avoid circular import
for report in generator:
estimator_reports.append(report)
progress.update(task, advance=1, refresh=True)

warn_msg = None
if not any (
isinstance(report, EstimatorReport)
for report in estimator_reports
):
traceback_msg = "\n".join(str(exc) for exc in estimator_reports)
raise RuntimeError(
"Cross-validation failed: no estimators were successfully fitted. "
"Please check your data, estimator, or cross-validation setup.\n"
f"Traceback: \n{traceback_msg}"
)
elif any(isinstance(report, Exception) for report in estimator_reports):
msg_traceback = "\n".join(
str(exc) for exc in estimator_reports if isinstance(exc, Exception)
)
warn_msg = (
"Cross-validation process was interrupted by an error before "
"all estimators could be fitted; CrossValidationReport object "
"might not contain all the expected results.\n"
f"Traceback: \n{msg_traceback}"
)
estimator_reports = [
report
for report in estimator_reports
if not isinstance(report, Exception)
]
elif any(isinstance(report, KeyboardInterrupt) for report in estimator_reports):
warn_msg = (
"Cross-validation process was interrupted manually before all "
"estimators could be fitted; CrossValidationReport object "
"might not contain all the expected results."
)
estimator_reports = [
report
for report in estimator_reports
if not isinstance(report, KeyboardInterrupt)
]

if isinstance(e, KeyboardInterrupt):
message = (
"Cross-validation process was interrupted manually before all "
"estimators could be fitted; CrossValidationReport object "
"might not contain all the expected results."
)
else:
message = (
"Cross-validation process was interrupted by an error before "
"all estimators could be fitted; CrossValidationReport object "
"might not contain all the expected results. "
f"Traceback: \n{e}"
)
if warn_msg is not None:
from skore import console # avoid circular import

console.print(
Panel(
title="Cross-validation interrupted",
renderable=message,
renderable=warn_msg,
style="orange1",
border_style="cyan",
)
Expand Down
23 changes: 23 additions & 0 deletions skore/src/skore/utils/_testing.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import contextlib
import copy

import numpy as np
from sklearn.base import BaseEstimator, ClassifierMixin


@contextlib.contextmanager
def check_cache_changed(value):
Expand All @@ -16,3 +19,23 @@ def check_cache_unchanged(value):
initial_value = copy.copy(value)
yield
assert value == initial_value


class MockEstimator(ClassifierMixin, BaseEstimator):
def __init__(self, *, error, n_call=0, fail_after_n_clone=3):
self.error = error
self.n_call = n_call
self.fail_after_n_clone = fail_after_n_clone

def fit(self, X, y):
if self.n_call > self.fail_after_n_clone:
raise self.error
self.classes_ = np.unique(y)
return self

def __sklearn_clone__(self):
self.n_call += 1
return self

def predict(self, X):
return np.ones(X.shape[0])
47 changes: 20 additions & 27 deletions skore/tests/unit/sklearn/cross_validation/test_cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np
import pandas as pd
import pytest
from sklearn.base import BaseEstimator, ClassifierMixin, clone
from sklearn.base import clone
from sklearn.datasets import make_classification, make_regression
from sklearn.ensemble import RandomForestClassifier
from sklearn.exceptions import NotFittedError
Expand All @@ -26,6 +26,7 @@
)
from skore.sklearn._estimator import EstimatorReport
from skore.sklearn._plot import RocCurveDisplay
from skore.utils._testing import MockEstimator


@pytest.fixture
Expand Down Expand Up @@ -895,38 +896,16 @@ def test_cross_validation_report_custom_metric(binary_classification_data):
(KeyboardInterrupt(), "Cross-validation interrupted manually"),
],
)
@pytest.mark.parametrize("n_jobs", [None, 1, 2])
def test_cross_validation_report_interrupted(
binary_classification_data, capsys, error, error_message
binary_classification_data, capsys, error, error_message, n_jobs
):
"""Check that we can interrupt cross-validation without losing all
data."""

class MockEstimator(ClassifierMixin, BaseEstimator):
def __init__(self, n_call=0, fail_after_n_clone=3):
self.n_call = n_call
self.fail_after_n_clone = fail_after_n_clone

def fit(self, X, y):
if self.n_call > self.fail_after_n_clone:
raise error
self.classes_ = np.unique(y)
return self

def __sklearn_clone__(self):
"""Do not clone the estimator

Instead, we increment a counter each time that
`sklearn.clone` is called.
"""
self.n_call += 1
return self

def predict(self, X):
return np.ones(X.shape[0])

_, X, y = binary_classification_data

report = CrossValidationReport(MockEstimator(), X, y, cv_splitter=10)
estimator = MockEstimator(error=error, n_call=0, fail_after_n_clone=8)
report = CrossValidationReport(estimator, X, y, cv_splitter=10, n_jobs=n_jobs)

captured = capsys.readouterr()
assert all(word in captured.out for word in error_message.split(" "))
Expand Down Expand Up @@ -990,6 +969,20 @@ def test_cross_validation_timings(
assert timings.columns.tolist() == expected_columns


@pytest.mark.parametrize("n_jobs", [None, 1, 2])
def test_cross_validation_report_failure_all_splits(n_jobs):
"""Check that we raise an error when no estimators were successfully fitted.
during the cross-validation process."""
X, y = make_classification(n_samples=100, n_features=10, random_state=42)
estimator = MockEstimator(
error=ValueError("Intentional failure for testing"), fail_after_n_clone=0
)

err_msg = "Cross-validation failed: no estimators were successfully fitted"
with pytest.raises(RuntimeError, match=err_msg):
CrossValidationReport(estimator, X, y, n_jobs=n_jobs)


def test_cross_validation_timings_flat_index(binary_classification_data):
"""Check the behaviour of the `timings` method display formatting."""
estimator, X, y = binary_classification_data
Expand Down
Loading