Skip to content

Commit 7ddd115

Browse files
authored
Merge pull request #163 from muhlbach/main
Update sklearn.py by addind catch to OptunaSearchCV
2 parents e401294 + 5982b92 commit 7ddd115

File tree

2 files changed

+51
-0
lines changed

2 files changed

+51
-0
lines changed

optuna_integration/sklearn/sklearn.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -501,6 +501,11 @@ class OptunaSearchCV(BaseEstimator):
501501
See the tutorial of `Callback for Study.optimize <https://optuna.readthedocs.io/en/stable/tutorial/20_recipes/007_optuna_callback.html#optuna-callback>`_
502502
for how to use and implement callback functions.
503503
504+
catch:
505+
A study continues to run even when a trial raises one of the exceptions specified
506+
in this argument. Default is an empty tuple, i.e. the study will stop for any
507+
exception except for :class:`~optuna.exceptions.TrialPruned`.
508+
504509
Attributes:
505510
best_estimator_:
506511
Estimator that was chosen by the search. This is present only if
@@ -733,6 +738,7 @@ def __init__(
733738
timeout: float | None = None,
734739
verbose: int = 0,
735740
callbacks: list[Callable[[study_module.Study, FrozenTrial], None]] | None = None,
741+
catch: Iterable[type[Exception]] | type[Exception] = (),
736742
) -> None:
737743
_imports.check()
738744

@@ -767,6 +773,7 @@ def __init__(
767773
self.timeout = timeout
768774
self.verbose = verbose
769775
self.callbacks = callbacks
776+
self.catch = catch
770777

771778
def _check_is_fitted(self) -> None:
772779
attributes = ["n_splits_", "sample_indices_", "scorer_", "study_"]
@@ -925,6 +932,7 @@ def fit(
925932
n_trials=self.n_trials,
926933
timeout=self.timeout,
927934
callbacks=self.callbacks,
935+
catch=self.catch,
928936
)
929937

930938
_logger.info("Finished hyperparameter search!")

tests/sklearn/test_sklearn.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
from sklearn.exceptions import NotFittedError
2525
from sklearn.linear_model import LogisticRegression
2626
from sklearn.linear_model import SGDClassifier
27+
from sklearn.metrics import make_scorer
28+
from sklearn.metrics import r2_score
2729
from sklearn.model_selection import PredefinedSplit
2830
from sklearn.neighbors import KernelDensity
2931
from sklearn.tree import DecisionTreeRegressor
@@ -458,6 +460,47 @@ def test_callbacks() -> None:
458460
assert callback.call_count == n_trials
459461

460462

463+
@pytest.mark.parametrize("catch", [(ValueError,), ()])
464+
def test_catch(catch: tuple) -> None:
465+
466+
class MockDististribution(distributions.BaseDistribution):
467+
468+
def _contains(self) -> None: # type: ignore
469+
raise ValueError
470+
471+
def single(self) -> None: # type: ignore
472+
raise ValueError
473+
474+
def to_internal_repr(self) -> None: # type: ignore
475+
raise ValueError
476+
477+
est = SGDClassifier(max_iter=5, tol=1e-03)
478+
X, y = make_blobs(n_samples=10)
479+
param_dist = {"param": MockDististribution()}
480+
n_trials = 3
481+
482+
with warnings.catch_warnings():
483+
warnings.simplefilter("ignore", ExperimentalWarning)
484+
optuna_search = OptunaSearchCV(
485+
est,
486+
param_dist,
487+
cv=3,
488+
max_iter=5,
489+
n_trials=n_trials,
490+
error_score=0,
491+
refit=False,
492+
scoring=make_scorer(r2_score),
493+
catch=catch,
494+
)
495+
496+
if catch:
497+
optuna_search.fit(X, y)
498+
assert optuna_search.n_trials_ == 3
499+
else:
500+
with pytest.raises(ValueError):
501+
optuna_search.fit(X, y)
502+
503+
461504
@pytest.mark.filterwarnings("ignore::UserWarning")
462505
@patch("optuna_integration.sklearn.sklearn.cross_validate")
463506
def test_terminator_cv_score_reporting(mock: MagicMock) -> None:

0 commit comments

Comments
 (0)