Skip to content

Commit d48d4b7

Browse files
Balandatfacebook-github-bot
authored andcommitted
Make fitting compatible with scipy 1.15 optimization changes (#2667)
Summary: Resolves #2666 by updating the regexp that is used to filter out the optimization warning emitted when the maximium number of iterations is reached. NOTE: This kind of regexp filtering is kind of brittle, but it's not necessarily obvious how to do this differently if scipy doesn't return these in a more structured form. Pull Request resolved: #2667 Reviewed By: saitcakmak Differential Revision: D67816653 Pulled By: Balandat fbshipit-source-id: aa6f83298e50e4e1ebedde7893888ecd82c4bae1
1 parent dc69c05 commit d48d4b7

File tree

8 files changed

+86
-41
lines changed

8 files changed

+86
-41
lines changed

botorch/optim/core.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,9 @@
3434

3535

3636
_LBFGSB_MAXITER_MAXFUN_REGEX = re.compile( # regex for maxiter and maxfun messages
37-
"TOTAL NO. of (ITERATIONS REACHED LIMIT|f AND g EVALUATIONS EXCEEDS LIMIT)"
37+
# Note that the messages changed with scipy 1.15, hence the different matching here.
38+
"TOTAL NO. (of|OF) "
39+
+ "(ITERATIONS REACHED LIMIT|(f AND g|F,G) EVALUATIONS EXCEEDS LIMIT)"
3840
)
3941

4042

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,5 @@ gpytorch==1.13
44
linear_operator==0.5.3
55
torch>=2.0.1
66
pyro-ppl>=1.8.4
7-
scipy<1.15
7+
scipy
88
multipledispatch

test/generation/test_gen.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import math
8+
import re
89
import warnings
910
from unittest import mock
1011

@@ -225,13 +226,14 @@ def test_gen_candidates_scipy_with_fixed_features_inequality_constraints(self):
225226
def test_gen_candidates_scipy_warns_opt_failure(self):
226227
with warnings.catch_warnings(record=True) as ws:
227228
self.test_gen_candidates(options={"maxls": 1})
228-
expected_msg = (
229+
expected_msg = re.compile(
230+
# The message changed with scipy 1.15, hence the different matching here.
229231
"Optimization failed within `scipy.optimize.minimize` with status 2"
230-
" and message ABNORMAL_TERMINATION_IN_LNSRCH."
232+
" and message ABNORMAL(|_TERMINATION_IN_LNSRCH)."
231233
)
232234
expected_warning_raised = any(
233235
issubclass(w.category, OptimizationWarning)
234-
and expected_msg in str(w.message)
236+
and expected_msg.search(str(w.message))
235237
for w in ws
236238
)
237239
self.assertTrue(expected_warning_raised)

test/optim/test_core.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -135,11 +135,20 @@ def _callback(parameters, result, out) -> None:
135135
def test_post_processing(self):
136136
closure = next(iter(self.closures.values()))
137137
wrapper = NdarrayOptimizationClosure(closure, closure.parameters)
138+
139+
# Scipy changed return values and messages in v1.15, so we check both
140+
# old and new versions here.
141+
status_msgs = [
142+
# scipy >=1.15
143+
(OptimizationStatus.FAILURE, "ABNORMAL_TERMINATION_IN_LNSRCH"),
144+
(OptimizationStatus.STOPPED, "TOTAL NO. of ITERATIONS REACHED LIMIT"),
145+
# scipy <1.15
146+
(OptimizationStatus.FAILURE, "ABNORMAL "),
147+
(OptimizationStatus.STOPPED, "TOTAL NO. OF ITERATIONS REACHED LIMIT"),
148+
]
149+
138150
with patch.object(core, "minimize_with_timeout") as mock_minimize_with_timeout:
139-
for status, msg in (
140-
(OptimizationStatus.FAILURE, b"ABNORMAL_TERMINATION_IN_LNSRCH"),
141-
(OptimizationStatus.STOPPED, "TOTAL NO. of ITERATIONS REACHED LIMIT"),
142-
):
151+
for status, msg in status_msgs:
143152
mock_minimize_with_timeout.return_value = OptimizeResult(
144153
x=wrapper.state,
145154
fun=1.0,

test/optim/test_fit.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import math
8+
import re
89
from unittest.mock import MagicMock, patch
910
from warnings import catch_warnings
1011

@@ -20,6 +21,11 @@
2021
from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood
2122
from scipy.optimize import OptimizeResult
2223

24+
MAX_ITER_MSG_REGEX = re.compile(
25+
# Note that the message changed with scipy 1.15, hence the different matching here.
26+
"TOTAL NO. (of|OF) ITERATIONS REACHED LIMIT"
27+
)
28+
2329

2430
class TestFitGPyTorchMLLScipy(BotorchTestCase):
2531
def setUp(self, suppress_input_warnings: bool = True) -> None:
@@ -63,15 +69,18 @@ def _test_fit_gpytorch_mll_scipy(self, mll):
6369
)
6470

6571
# Test maxiter warning message
66-
self.assertTrue(any("TOTAL NO. of" in str(w.message) for w in ws))
72+
73+
self.assertTrue(any(MAX_ITER_MSG_REGEX.search(str(w.message)) for w in ws))
6774
self.assertTrue(
6875
any(issubclass(w.category, OptimizationWarning) for w in ws)
6976
)
7077

7178
# Test iteration tracking
7279
self.assertIsInstance(result, OptimizationResult)
7380
self.assertLessEqual(result.step, options["maxiter"])
74-
self.assertEqual(sum(1 for w in ws if "TOTAL NO. of" in str(w.message)), 1)
81+
self.assertEqual(
82+
sum(1 for w in ws if MAX_ITER_MSG_REGEX.search(str(w.message))), 1
83+
)
7584

7685
# Test that user provided bounds are respected
7786
with self.subTest("bounds"), module_rollback_ctx(mll, checkpoint=ckpt):

test/optim/test_optimize.py

Lines changed: 31 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import itertools
8+
import re
89
import warnings
910
from functools import partial
1011
from itertools import product
@@ -724,19 +725,20 @@ def test_optimize_acqf_warns_on_opt_failure(self):
724725
raw_samples=raw_samples,
725726
batch_initial_conditions=initial_conditions,
726727
)
727-
message = (
728-
"Optimization failed in `gen_candidates_scipy` with the following "
729-
"warning(s):\n[OptimizationWarning('Optimization failed within "
730-
"`scipy.optimize.minimize` with status 2 and message "
731-
"ABNORMAL_TERMINATION_IN_LNSRCH.')]\nBecause you specified "
732-
"`batch_initial_conditions` larger than required `num_restarts`, "
733-
"optimization will not be retried with new initial conditions and "
734-
"will proceed with the current solution. Suggested remediation: "
735-
"Try again with different `batch_initial_conditions`, don't provide "
736-
"`batch_initial_conditions`, or increase `num_restarts`."
728+
message_regex = re.compile(
729+
r"Optimization failed in `gen_candidates_scipy` with the following "
730+
r"warning\(s\):\n\[OptimizationWarning\('Optimization failed within "
731+
r"`scipy.optimize.minimize` with status 2 and message "
732+
r"ABNORMAL(: |_TERMINATION_IN_LNSRCH).'\)]\nBecause you specified "
733+
r"`batch_initial_conditions` larger than required `num_restarts`, "
734+
r"optimization will not be retried with new initial conditions and "
735+
r"will proceed with the current solution. Suggested remediation: "
736+
r"Try again with different `batch_initial_conditions`, don't provide "
737+
r"`batch_initial_conditions`, or increase `num_restarts`."
737738
)
738739
expected_warning_raised = any(
739-
issubclass(w.category, RuntimeWarning) and message in str(w.message)
740+
issubclass(w.category, RuntimeWarning)
741+
and message_regex.search(str(w.message))
740742
for w in ws
741743
)
742744
self.assertTrue(expected_warning_raised)
@@ -774,14 +776,16 @@ def test_optimize_acqf_successfully_restarts_on_opt_failure(self):
774776
# more likely
775777
options={"maxls": 2},
776778
)
777-
message = (
778-
"Optimization failed in `gen_candidates_scipy` with the following "
779-
"warning(s):\n[OptimizationWarning('Optimization failed within "
780-
"`scipy.optimize.minimize` with status 2 and message ABNORMAL_TERMINATION"
781-
"_IN_LNSRCH.')]\nTrying again with a new set of initial conditions."
779+
message_regex = re.compile(
780+
r"Optimization failed in `gen_candidates_scipy` with the following "
781+
r"warning\(s\):\n\[OptimizationWarning\('Optimization failed within "
782+
r"`scipy.optimize.minimize` with status 2 and message ABNORMAL(: |"
783+
r"_TERMINATION_IN_LNSRCH).'\)\]\nTrying again with a new set of "
784+
r"initial conditions."
782785
)
783786
expected_warning_raised = any(
784-
issubclass(w.category, RuntimeWarning) and message in str(w.message)
787+
issubclass(w.category, RuntimeWarning)
788+
and message_regex.search(str(w.message))
785789
for w in ws
786790
)
787791
self.assertTrue(expected_warning_raised)
@@ -803,7 +807,8 @@ def test_optimize_acqf_successfully_restarts_on_opt_failure(self):
803807
retry_on_optimization_warning=False,
804808
)
805809
expected_warning_raised = any(
806-
issubclass(w.category, RuntimeWarning) and message in str(w.message)
810+
issubclass(w.category, RuntimeWarning)
811+
and message_regex.search(str(w.message))
807812
for w in ws
808813
)
809814
self.assertFalse(expected_warning_raised)
@@ -840,19 +845,21 @@ def test_optimize_acqf_warns_on_second_opt_failure(self):
840845
options={"maxls": 2},
841846
)
842847

843-
message_1 = (
844-
"Optimization failed in `gen_candidates_scipy` with the following "
845-
"warning(s):\n[OptimizationWarning('Optimization failed within "
846-
"`scipy.optimize.minimize` with status 2 and message ABNORMAL_TERMINATION"
847-
"_IN_LNSRCH.')]\nTrying again with a new set of initial conditions."
848+
message_1_regex = re.compile(
849+
r"Optimization failed in `gen_candidates_scipy` with the following "
850+
r"warning\(s\):\n\[OptimizationWarning\('Optimization failed within "
851+
r"`scipy.optimize.minimize` with status 2 and message ABNORMAL(: |"
852+
r"_TERMINATION_IN_LNSRCH).'\)\]\nTrying again with a new set of "
853+
r"initial conditions."
848854
)
849855

850856
message_2 = (
851857
"Optimization failed on the second try, after generating a new set "
852858
"of initial conditions."
853859
)
854860
first_expected_warning_raised = any(
855-
issubclass(w.category, RuntimeWarning) and message_1 in str(w.message)
861+
issubclass(w.category, RuntimeWarning)
862+
and message_1_regex.search(str(w.message))
856863
for w in ws
857864
)
858865
second_expected_warning_raised = any(

test/test_fit.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import math
8+
import re
89
from collections.abc import Callable, Iterable
910
from contextlib import ExitStack, nullcontext
1011
from copy import deepcopy
@@ -30,7 +31,10 @@
3031
from gpytorch.mlls import ExactMarginalLogLikelihood, VariationalELBO
3132
from linear_operator.utils.errors import NotPSDError
3233

33-
MAX_ITER_MSG = "TOTAL NO. of ITERATIONS REACHED LIMIT"
34+
MAX_ITER_MSG_REGEX = re.compile(
35+
# Note that the message changed with scipy 1.15, hence the different matching here.
36+
"TOTAL NO. (of|OF) ITERATIONS REACHED LIMIT"
37+
)
3438

3539

3640
class MockOptimizer:
@@ -215,7 +219,12 @@ def _test_warnings(self, mll, ckpt):
215219
optimizer = MockOptimizer(randomize_requires_grad=False)
216220
optimizer.warnings = [
217221
WarningMessage("test_runtime_warning", RuntimeWarning, __file__, 0),
218-
WarningMessage(MAX_ITER_MSG, OptimizationWarning, __file__, 0),
222+
WarningMessage(
223+
"STOP: TOTAL NO. OF ITERATIONS REACHED LIMIT",
224+
OptimizationWarning,
225+
__file__,
226+
0,
227+
),
219228
WarningMessage(
220229
"Optimization timed out after X", OptimizationWarning, __file__, 0
221230
),
@@ -260,7 +269,9 @@ def _test_warnings(self, mll, ckpt):
260269
{str(w.message) for w in rethrown + unresolved},
261270
)
262271
if logs: # test that default filter logs certain warnings
263-
self.assertTrue(any(MAX_ITER_MSG in log for log in logs.output))
272+
self.assertTrue(
273+
any(MAX_ITER_MSG_REGEX.search(log) for log in logs.output)
274+
)
264275

265276
# Test default of retrying upon encountering an uncaught OptimizationWarning
266277
optimizer.warnings.append(

test/test_utils/test_mock.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77

8+
import re
89
import warnings
910
from unittest.mock import patch
1011

@@ -26,6 +27,12 @@
2627
from botorch.utils.testing import BotorchTestCase, MockAcquisitionFunction
2728

2829

30+
MAX_ITER_MSG = re.compile(
31+
# Note that the message changed with scipy 1.15, hence the different matching here.
32+
"TOTAL NO. (of|OF) ITERATIONS REACHED LIMIT"
33+
)
34+
35+
2936
class SinAcqusitionFunction(MockAcquisitionFunction):
3037
"""Simple acquisition function with known numerical properties."""
3138

@@ -56,9 +63,7 @@ def closure():
5663

5764
with mock_optimize_context_manager():
5865
result = scipy_minimize(closure=closure, parameters={"x": x})
59-
self.assertEqual(
60-
result.message, "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT"
61-
)
66+
self.assertTrue(MAX_ITER_MSG.search(result.message))
6267

6368
with self.subTest("optimize_acqf"):
6469
with mock_optimize_context_manager():

0 commit comments

Comments
 (0)