|
24 | 24 | from sklearn.exceptions import NotFittedError |
25 | 25 | from sklearn.linear_model import LogisticRegression |
26 | 26 | from sklearn.linear_model import SGDClassifier |
| 27 | +from sklearn.metrics import make_scorer |
| 28 | +from sklearn.metrics import r2_score |
27 | 29 | from sklearn.model_selection import PredefinedSplit |
28 | 30 | from sklearn.neighbors import KernelDensity |
29 | 31 | from sklearn.tree import DecisionTreeRegressor |
@@ -458,6 +460,47 @@ def test_callbacks() -> None: |
458 | 460 | assert callback.call_count == n_trials |
459 | 461 |
|
460 | 462 |
|
| 463 | +@pytest.mark.parametrize("catch", [(ValueError,), ()]) |
| 464 | +def test_catch(catch: tuple) -> None: |
| 465 | + |
| 466 | + class MockDististribution(distributions.BaseDistribution): |
| 467 | + |
| 468 | + def _contains(self) -> None: # type: ignore |
| 469 | + raise ValueError |
| 470 | + |
| 471 | + def single(self) -> None: # type: ignore |
| 472 | + raise ValueError |
| 473 | + |
| 474 | + def to_internal_repr(self) -> None: # type: ignore |
| 475 | + raise ValueError |
| 476 | + |
| 477 | + est = SGDClassifier(max_iter=5, tol=1e-03) |
| 478 | + X, y = make_blobs(n_samples=10) |
| 479 | + param_dist = {"param": MockDististribution()} |
| 480 | + n_trials = 3 |
| 481 | + |
| 482 | + with warnings.catch_warnings(): |
| 483 | + warnings.simplefilter("ignore", ExperimentalWarning) |
| 484 | + optuna_search = OptunaSearchCV( |
| 485 | + est, |
| 486 | + param_dist, |
| 487 | + cv=3, |
| 488 | + max_iter=5, |
| 489 | + n_trials=n_trials, |
| 490 | + error_score=0, |
| 491 | + refit=False, |
| 492 | + scoring=make_scorer(r2_score), |
| 493 | + catch=catch, |
| 494 | + ) |
| 495 | + |
| 496 | + if catch: |
| 497 | + optuna_search.fit(X, y) |
| 498 | + assert optuna_search.n_trials_ == 3 |
| 499 | + else: |
| 500 | + with pytest.raises(ValueError): |
| 501 | + optuna_search.fit(X, y) |
| 502 | + |
| 503 | + |
461 | 504 | @pytest.mark.filterwarnings("ignore::UserWarning") |
462 | 505 | @patch("optuna_integration.sklearn.sklearn.cross_validate") |
463 | 506 | def test_terminator_cv_score_reporting(mock: MagicMock) -> None: |
|
0 commit comments