Skip to content

Commit b5948c9

Browse files
sdaultonmeta-codesync[bot]
authored andcommitted
Back out "Back out "[botorch] Replace Pyro with NumPyro for fully Bayesian NUTS inference (96% reduction in fit time)"" (#5162)
Summary: X-link: meta-pytorch/botorch#3269 Pull Request resolved: #5162 Reviewed By: saitcakmak Differential Revision: D99688719 fbshipit-source-id: d78f3f75dd0503b61e1cb7f48e64484979c3ee48
1 parent f3e8831 commit b5948c9

2 files changed

Lines changed: 37 additions & 55 deletions

File tree

ax/utils/sensitivity/tests/test_sensitivity.py

Lines changed: 22 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -201,37 +201,23 @@ def test_DgsmGpSampling(self) -> None:
201201
def _test_sobol_gp_mean(
202202
self,
203203
sensitivity: SobolSensitivityGPMean,
204-
expected_first_order: Tensor,
205-
expected_total_order: Tensor,
206-
expected_second_order: Tensor | None = None,
204+
expected_first_order_shape: torch.Size,
205+
expected_total_order_shape: torch.Size,
206+
expected_second_order_shape: torch.Size | None = None,
207207
) -> None:
208-
"""
209-
Check that outputs are as expected. The `assertAllClose` checks are
210-
characterization tests rather than correctness tests; they check that
211-
behavior is the same as it was in the past, not that it is
212-
quantitatively correct. Innocuous changes such as changing a random seed
213-
could potentially break these tests, and it may become necessary to
214-
delete them.
215-
"""
216-
atol = 4e-3
217-
rtol = 2e-3
208+
"""Check that outputs have the correct type and shape."""
218209
first_order = sensitivity.first_order_indices()
219210
self.assertIsInstance(first_order, Tensor)
220-
self.assertAllClose(first_order, expected_first_order, atol=atol, rtol=rtol)
221-
self.assertEqual(first_order.shape, expected_first_order.shape)
211+
self.assertEqual(first_order.shape, expected_first_order_shape)
222212

223213
total_order = sensitivity.total_order_indices()
224214
self.assertIsInstance(total_order, Tensor)
225-
self.assertAllClose(total_order, expected_total_order, atol=atol, rtol=rtol)
226-
self.assertEqual(total_order.shape, expected_total_order.shape)
215+
self.assertEqual(total_order.shape, expected_total_order_shape)
227216

228-
if expected_second_order is not None:
217+
if expected_second_order_shape is not None:
229218
second_order = sensitivity.second_order_indices()
230219
self.assertIsInstance(second_order, Tensor)
231-
self.assertAllClose(
232-
second_order, expected_second_order, atol=atol, rtol=rtol
233-
)
234-
self.assertEqual(second_order.shape, expected_second_order.shape)
220+
self.assertEqual(second_order.shape, expected_second_order_shape)
235221

236222
def test_SobolGPMean(self) -> None:
237223
bounds = torch.tensor([(0.0, 1.0) for _ in range(2)]).t()
@@ -240,9 +226,9 @@ def test_SobolGPMean(self) -> None:
240226
)
241227
self._test_sobol_gp_mean(
242228
sensitivity=sensitivity_mean,
243-
expected_first_order=torch.tensor([1.1547, -0.4024], dtype=torch.float64),
244-
expected_total_order=torch.tensor([0.4299, 0.4894], dtype=torch.float64),
245-
expected_second_order=torch.tensor([-1.4845], dtype=torch.float64),
229+
expected_first_order_shape=torch.Size([2]),
230+
expected_total_order_shape=torch.Size([2]),
231+
expected_second_order_shape=torch.Size([1]),
246232
)
247233

248234
def test_SobolGPMean_SAASBO(self) -> None:
@@ -252,9 +238,9 @@ def test_SobolGPMean_SAASBO(self) -> None:
252238
)
253239
self._test_sobol_gp_mean(
254240
sensitivity=sensitivity_mean_saas,
255-
expected_first_order=torch.tensor([0.5752, 0.5143], dtype=torch.double),
256-
expected_total_order=torch.tensor([0.9897, 0.0979], dtype=torch.float64),
257-
expected_second_order=torch.tensor([0.8332], dtype=torch.double),
241+
expected_first_order_shape=torch.Size([2]),
242+
expected_total_order_shape=torch.Size([2]),
243+
expected_second_order_shape=torch.Size([1]),
258244
)
259245

260246
sensitivity_mean_bootstrap = SobolSensitivityGPMean(
@@ -267,17 +253,9 @@ def test_SobolGPMean_SAASBO(self) -> None:
267253
)
268254
self._test_sobol_gp_mean(
269255
sensitivity=sensitivity_mean_bootstrap,
270-
expected_first_order=torch.tensor(
271-
[[0.6327, 10.0889, 1.0044], [0.2089, 0.7322, 0.2706]],
272-
dtype=torch.float64,
273-
),
274-
expected_total_order=torch.tensor(
275-
[[0.8013, 0.1824, 0.1351], [0.2203, 0.0304, 0.0551]],
276-
dtype=torch.float64,
277-
),
278-
expected_second_order=torch.tensor(
279-
[[0.7978, 22.6598, 1.5053]], dtype=torch.float64
280-
),
256+
expected_first_order_shape=torch.Size([2, 3]),
257+
expected_total_order_shape=torch.Size([2, 3]),
258+
expected_second_order_shape=torch.Size([1, 3]),
281259
)
282260

283261
sensitivity_mean_bootstrap = SobolSensitivityGPMean(
@@ -290,26 +268,18 @@ def test_SobolGPMean_SAASBO(self) -> None:
290268
)
291269
self._test_sobol_gp_mean(
292270
sensitivity=sensitivity_mean_bootstrap,
293-
expected_first_order=torch.tensor(
294-
[[3.4512, 32.4428, 1.8012], [0.2069, 121.8610, 3.4909]],
295-
dtype=torch.float64,
296-
),
297-
expected_total_order=torch.tensor(
298-
[[0.4288, 0.0903, 0.0950], [0.7923, 0.2218, 0.1489]],
299-
dtype=torch.float64,
300-
),
301-
expected_second_order=torch.tensor(
302-
[[-6.3790, 397.4363, 6.3043]], dtype=torch.float64
303-
),
271+
expected_first_order_shape=torch.Size([2, 3]),
272+
expected_total_order_shape=torch.Size([2, 3]),
273+
expected_second_order_shape=torch.Size([1, 3]),
304274
)
305275

306276
sensitivity_mean = SobolSensitivityGPMean(
307277
self.model, num_mc_samples=10, bounds=bounds, second_order=False
308278
)
309279
self._test_sobol_gp_mean(
310280
sensitivity=sensitivity_mean,
311-
expected_first_order=torch.tensor([0.9566, -0.4183], dtype=torch.float64),
312-
expected_total_order=torch.tensor([0.3440, 0.3685], dtype=torch.float64),
281+
expected_first_order_shape=torch.Size([2]),
282+
expected_total_order_shape=torch.Size([2]),
313283
)
314284

315285
with self.assertRaisesRegex(ValueError, "Second order indices"):

ax/utils/testing/tests/test_mock.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from unittest.mock import patch
1010

11+
import jax.numpy as jnp
1112
import torch
1213
from ax.adapter.registry import Generators
1314
from ax.adapter.transforms.choice_encode import OrderedChoiceToIntegerRange
@@ -22,7 +23,6 @@
2223
from botorch.generation.gen import gen_candidates_scipy
2324
from botorch.optim.optimize_mixed import generate_starting_points
2425
from botorch.utils.testing import MockAcquisitionFunction, skip_if_import_error
25-
from pyro.infer import MCMC
2626

2727

2828
class TestMock(TestCase):
@@ -45,13 +45,25 @@ def test_botorch_mocks(self) -> None:
4545

4646
def test_fully_bayesian_mocks(self) -> None:
4747
experiment = get_branin_experiment(with_completed_batch=True)
48-
with patch("botorch.fit.MCMC", wraps=MCMC) as mock_mcmc:
48+
num_mcmc_samples = 16
49+
dim = len(experiment.search_space.parameters)
50+
# Mock MCMC to return proper JAX arrays that postprocess_mcmc_samples
51+
# can handle (it performs JAX operations on them).
52+
mock_samples = {
53+
"mean": jnp.ones(num_mcmc_samples),
54+
"noise": jnp.ones(num_mcmc_samples),
55+
"outputscale": jnp.ones(num_mcmc_samples),
56+
"kernel_tausq": jnp.ones(num_mcmc_samples),
57+
"_kernel_inv_length_sq": jnp.ones((num_mcmc_samples, dim)),
58+
}
59+
with patch("botorch.fit.MCMC") as mock_mcmc:
60+
mock_mcmc.return_value.get_samples.return_value = mock_samples
4961
with mock_botorch_optimize_context_manager():
5062
Generators.SAASBO(experiment=experiment, data=experiment.lookup_data())
5163
mock_mcmc.assert_called_once()
5264
kwargs = mock_mcmc.call_args.kwargs
5365
self.assertEqual(kwargs["num_samples"], 16)
54-
self.assertEqual(kwargs["warmup_steps"], 0)
66+
self.assertEqual(kwargs["num_warmup"], 0)
5567

5668
def test_mixed_optimizer_mocks(self) -> None:
5769
experiment = get_branin_experiment(

0 commit comments

Comments
 (0)