Skip to content

Commit

Permalink
Make fitting compatible with scipy 1.15 optimization changes (#2667)
Browse files Browse the repository at this point in the history
Summary:
Resolves #2666 by updating the regexp that is used to filter out the optimization warning emitted when the maximium number of iterations is reached.

NOTE: This kind of regexp filtering is kind of brittle, but it's not necessarily obvious how to do this differently if scipy doesn't return these in a more structured form.

Pull Request resolved: #2667

Reviewed By: saitcakmak

Differential Revision: D67816653

Pulled By: Balandat

fbshipit-source-id: aa6f83298e50e4e1ebedde7893888ecd82c4bae1
  • Loading branch information
Balandat authored and facebook-github-bot committed Jan 4, 2025
1 parent dc69c05 commit d48d4b7
Show file tree
Hide file tree
Showing 8 changed files with 86 additions and 41 deletions.
4 changes: 3 additions & 1 deletion botorch/optim/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@


_LBFGSB_MAXITER_MAXFUN_REGEX = re.compile( # regex for maxiter and maxfun messages
"TOTAL NO. of (ITERATIONS REACHED LIMIT|f AND g EVALUATIONS EXCEEDS LIMIT)"
# Note that the messages changed with scipy 1.15, hence the different matching here.
"TOTAL NO. (of|OF) "
+ "(ITERATIONS REACHED LIMIT|(f AND g|F,G) EVALUATIONS EXCEEDS LIMIT)"
)


Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@ gpytorch==1.13
linear_operator==0.5.3
torch>=2.0.1
pyro-ppl>=1.8.4
scipy<1.15
scipy
multipledispatch
8 changes: 5 additions & 3 deletions test/generation/test_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.

import math
import re
import warnings
from unittest import mock

Expand Down Expand Up @@ -225,13 +226,14 @@ def test_gen_candidates_scipy_with_fixed_features_inequality_constraints(self):
def test_gen_candidates_scipy_warns_opt_failure(self):
with warnings.catch_warnings(record=True) as ws:
self.test_gen_candidates(options={"maxls": 1})
expected_msg = (
expected_msg = re.compile(
# The message changed with scipy 1.15, hence the different matching here.
"Optimization failed within `scipy.optimize.minimize` with status 2"
" and message ABNORMAL_TERMINATION_IN_LNSRCH."
" and message ABNORMAL(|_TERMINATION_IN_LNSRCH)."
)
expected_warning_raised = any(
issubclass(w.category, OptimizationWarning)
and expected_msg in str(w.message)
and expected_msg.search(str(w.message))
for w in ws
)
self.assertTrue(expected_warning_raised)
Expand Down
17 changes: 13 additions & 4 deletions test/optim/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,11 +135,20 @@ def _callback(parameters, result, out) -> None:
def test_post_processing(self):
closure = next(iter(self.closures.values()))
wrapper = NdarrayOptimizationClosure(closure, closure.parameters)

# Scipy changed return values and messages in v1.15, so we check both
# old and new versions here.
status_msgs = [
# scipy >=1.15
(OptimizationStatus.FAILURE, "ABNORMAL_TERMINATION_IN_LNSRCH"),
(OptimizationStatus.STOPPED, "TOTAL NO. of ITERATIONS REACHED LIMIT"),
# scipy <1.15
(OptimizationStatus.FAILURE, "ABNORMAL "),
(OptimizationStatus.STOPPED, "TOTAL NO. OF ITERATIONS REACHED LIMIT"),
]

with patch.object(core, "minimize_with_timeout") as mock_minimize_with_timeout:
for status, msg in (
(OptimizationStatus.FAILURE, b"ABNORMAL_TERMINATION_IN_LNSRCH"),
(OptimizationStatus.STOPPED, "TOTAL NO. of ITERATIONS REACHED LIMIT"),
):
for status, msg in status_msgs:
mock_minimize_with_timeout.return_value = OptimizeResult(
x=wrapper.state,
fun=1.0,
Expand Down
13 changes: 11 additions & 2 deletions test/optim/test_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.

import math
import re
from unittest.mock import MagicMock, patch
from warnings import catch_warnings

Expand All @@ -20,6 +21,11 @@
from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood
from scipy.optimize import OptimizeResult

MAX_ITER_MSG_REGEX = re.compile(
# Note that the message changed with scipy 1.15, hence the different matching here.
"TOTAL NO. (of|OF) ITERATIONS REACHED LIMIT"
)


class TestFitGPyTorchMLLScipy(BotorchTestCase):
def setUp(self, suppress_input_warnings: bool = True) -> None:
Expand Down Expand Up @@ -63,15 +69,18 @@ def _test_fit_gpytorch_mll_scipy(self, mll):
)

# Test maxiter warning message
self.assertTrue(any("TOTAL NO. of" in str(w.message) for w in ws))

self.assertTrue(any(MAX_ITER_MSG_REGEX.search(str(w.message)) for w in ws))
self.assertTrue(
any(issubclass(w.category, OptimizationWarning) for w in ws)
)

# Test iteration tracking
self.assertIsInstance(result, OptimizationResult)
self.assertLessEqual(result.step, options["maxiter"])
self.assertEqual(sum(1 for w in ws if "TOTAL NO. of" in str(w.message)), 1)
self.assertEqual(
sum(1 for w in ws if MAX_ITER_MSG_REGEX.search(str(w.message))), 1
)

# Test that user provided bounds are respected
with self.subTest("bounds"), module_rollback_ctx(mll, checkpoint=ckpt):
Expand Down
55 changes: 31 additions & 24 deletions test/optim/test_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.

import itertools
import re
import warnings
from functools import partial
from itertools import product
Expand Down Expand Up @@ -724,19 +725,20 @@ def test_optimize_acqf_warns_on_opt_failure(self):
raw_samples=raw_samples,
batch_initial_conditions=initial_conditions,
)
message = (
"Optimization failed in `gen_candidates_scipy` with the following "
"warning(s):\n[OptimizationWarning('Optimization failed within "
"`scipy.optimize.minimize` with status 2 and message "
"ABNORMAL_TERMINATION_IN_LNSRCH.')]\nBecause you specified "
"`batch_initial_conditions` larger than required `num_restarts`, "
"optimization will not be retried with new initial conditions and "
"will proceed with the current solution. Suggested remediation: "
"Try again with different `batch_initial_conditions`, don't provide "
"`batch_initial_conditions`, or increase `num_restarts`."
message_regex = re.compile(
r"Optimization failed in `gen_candidates_scipy` with the following "
r"warning\(s\):\n\[OptimizationWarning\('Optimization failed within "
r"`scipy.optimize.minimize` with status 2 and message "
r"ABNORMAL(: |_TERMINATION_IN_LNSRCH).'\)]\nBecause you specified "
r"`batch_initial_conditions` larger than required `num_restarts`, "
r"optimization will not be retried with new initial conditions and "
r"will proceed with the current solution. Suggested remediation: "
r"Try again with different `batch_initial_conditions`, don't provide "
r"`batch_initial_conditions`, or increase `num_restarts`."
)
expected_warning_raised = any(
issubclass(w.category, RuntimeWarning) and message in str(w.message)
issubclass(w.category, RuntimeWarning)
and message_regex.search(str(w.message))
for w in ws
)
self.assertTrue(expected_warning_raised)
Expand Down Expand Up @@ -774,14 +776,16 @@ def test_optimize_acqf_successfully_restarts_on_opt_failure(self):
# more likely
options={"maxls": 2},
)
message = (
"Optimization failed in `gen_candidates_scipy` with the following "
"warning(s):\n[OptimizationWarning('Optimization failed within "
"`scipy.optimize.minimize` with status 2 and message ABNORMAL_TERMINATION"
"_IN_LNSRCH.')]\nTrying again with a new set of initial conditions."
message_regex = re.compile(
r"Optimization failed in `gen_candidates_scipy` with the following "
r"warning\(s\):\n\[OptimizationWarning\('Optimization failed within "
r"`scipy.optimize.minimize` with status 2 and message ABNORMAL(: |"
r"_TERMINATION_IN_LNSRCH).'\)\]\nTrying again with a new set of "
r"initial conditions."
)
expected_warning_raised = any(
issubclass(w.category, RuntimeWarning) and message in str(w.message)
issubclass(w.category, RuntimeWarning)
and message_regex.search(str(w.message))
for w in ws
)
self.assertTrue(expected_warning_raised)
Expand All @@ -803,7 +807,8 @@ def test_optimize_acqf_successfully_restarts_on_opt_failure(self):
retry_on_optimization_warning=False,
)
expected_warning_raised = any(
issubclass(w.category, RuntimeWarning) and message in str(w.message)
issubclass(w.category, RuntimeWarning)
and message_regex.search(str(w.message))
for w in ws
)
self.assertFalse(expected_warning_raised)
Expand Down Expand Up @@ -840,19 +845,21 @@ def test_optimize_acqf_warns_on_second_opt_failure(self):
options={"maxls": 2},
)

message_1 = (
"Optimization failed in `gen_candidates_scipy` with the following "
"warning(s):\n[OptimizationWarning('Optimization failed within "
"`scipy.optimize.minimize` with status 2 and message ABNORMAL_TERMINATION"
"_IN_LNSRCH.')]\nTrying again with a new set of initial conditions."
message_1_regex = re.compile(
r"Optimization failed in `gen_candidates_scipy` with the following "
r"warning\(s\):\n\[OptimizationWarning\('Optimization failed within "
r"`scipy.optimize.minimize` with status 2 and message ABNORMAL(: |"
r"_TERMINATION_IN_LNSRCH).'\)\]\nTrying again with a new set of "
r"initial conditions."
)

message_2 = (
"Optimization failed on the second try, after generating a new set "
"of initial conditions."
)
first_expected_warning_raised = any(
issubclass(w.category, RuntimeWarning) and message_1 in str(w.message)
issubclass(w.category, RuntimeWarning)
and message_1_regex.search(str(w.message))
for w in ws
)
second_expected_warning_raised = any(
Expand Down
17 changes: 14 additions & 3 deletions test/test_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.

import math
import re
from collections.abc import Callable, Iterable
from contextlib import ExitStack, nullcontext
from copy import deepcopy
Expand All @@ -30,7 +31,10 @@
from gpytorch.mlls import ExactMarginalLogLikelihood, VariationalELBO
from linear_operator.utils.errors import NotPSDError

MAX_ITER_MSG = "TOTAL NO. of ITERATIONS REACHED LIMIT"
MAX_ITER_MSG_REGEX = re.compile(
# Note that the message changed with scipy 1.15, hence the different matching here.
"TOTAL NO. (of|OF) ITERATIONS REACHED LIMIT"
)


class MockOptimizer:
Expand Down Expand Up @@ -215,7 +219,12 @@ def _test_warnings(self, mll, ckpt):
optimizer = MockOptimizer(randomize_requires_grad=False)
optimizer.warnings = [
WarningMessage("test_runtime_warning", RuntimeWarning, __file__, 0),
WarningMessage(MAX_ITER_MSG, OptimizationWarning, __file__, 0),
WarningMessage(
"STOP: TOTAL NO. OF ITERATIONS REACHED LIMIT",
OptimizationWarning,
__file__,
0,
),
WarningMessage(
"Optimization timed out after X", OptimizationWarning, __file__, 0
),
Expand Down Expand Up @@ -260,7 +269,9 @@ def _test_warnings(self, mll, ckpt):
{str(w.message) for w in rethrown + unresolved},
)
if logs: # test that default filter logs certain warnings
self.assertTrue(any(MAX_ITER_MSG in log for log in logs.output))
self.assertTrue(
any(MAX_ITER_MSG_REGEX.search(log) for log in logs.output)
)

# Test default of retrying upon encountering an uncaught OptimizationWarning
optimizer.warnings.append(
Expand Down
11 changes: 8 additions & 3 deletions test/test_utils/test_mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.


import re
import warnings
from unittest.mock import patch

Expand All @@ -26,6 +27,12 @@
from botorch.utils.testing import BotorchTestCase, MockAcquisitionFunction


MAX_ITER_MSG = re.compile(
# Note that the message changed with scipy 1.15, hence the different matching here.
"TOTAL NO. (of|OF) ITERATIONS REACHED LIMIT"
)


class SinAcqusitionFunction(MockAcquisitionFunction):
"""Simple acquisition function with known numerical properties."""

Expand Down Expand Up @@ -56,9 +63,7 @@ def closure():

with mock_optimize_context_manager():
result = scipy_minimize(closure=closure, parameters={"x": x})
self.assertEqual(
result.message, "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT"
)
self.assertTrue(MAX_ITER_MSG.search(result.message))

with self.subTest("optimize_acqf"):
with mock_optimize_context_manager():
Expand Down

0 comments on commit d48d4b7

Please sign in to comment.