Skip to content

Commit 74b42c3

Browse files
sdaultonmeta-codesync[bot]
authored andcommitted
Back out "Replace Pyro with NumPyro for fully Bayesian NUTS inference (96% reduction in fit time)" (facebook#5136)
Summary: X-link: meta-pytorch/botorch#3263 Pull Request resolved: facebook#5136 Reviewed By: Balandat Differential Revision: D99446023 fbshipit-source-id: 76efd2e595d48589ee3252efe407e31521044468
1 parent c8a547d commit 74b42c3

2 files changed

Lines changed: 55 additions & 37 deletions

File tree

ax/utils/sensitivity/tests/test_sensitivity.py

Lines changed: 52 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -201,23 +201,37 @@ def test_DgsmGpSampling(self) -> None:
201201
def _test_sobol_gp_mean(
202202
self,
203203
sensitivity: SobolSensitivityGPMean,
204-
expected_first_order_shape: torch.Size,
205-
expected_total_order_shape: torch.Size,
206-
expected_second_order_shape: torch.Size | None = None,
204+
expected_first_order: Tensor,
205+
expected_total_order: Tensor,
206+
expected_second_order: Tensor | None = None,
207207
) -> None:
208-
"""Check that outputs have the correct type and shape."""
208+
"""
209+
Check that outputs are as expected. The `assertAllClose` checks are
210+
characterization tests rather than correctness tests; they check that
211+
behavior is the same as it was in the past, not that it is
212+
quantitatively correct. Innocuous changes such as changing a random seed
213+
could potentially break these tests, and it may become necessary to
214+
delete them.
215+
"""
216+
atol = 4e-3
217+
rtol = 2e-3
209218
first_order = sensitivity.first_order_indices()
210219
self.assertIsInstance(first_order, Tensor)
211-
self.assertEqual(first_order.shape, expected_first_order_shape)
220+
self.assertAllClose(first_order, expected_first_order, atol=atol, rtol=rtol)
221+
self.assertEqual(first_order.shape, expected_first_order.shape)
212222

213223
total_order = sensitivity.total_order_indices()
214224
self.assertIsInstance(total_order, Tensor)
215-
self.assertEqual(total_order.shape, expected_total_order_shape)
225+
self.assertAllClose(total_order, expected_total_order, atol=atol, rtol=rtol)
226+
self.assertEqual(total_order.shape, expected_total_order.shape)
216227

217-
if expected_second_order_shape is not None:
228+
if expected_second_order is not None:
218229
second_order = sensitivity.second_order_indices()
219230
self.assertIsInstance(second_order, Tensor)
220-
self.assertEqual(second_order.shape, expected_second_order_shape)
231+
self.assertAllClose(
232+
second_order, expected_second_order, atol=atol, rtol=rtol
233+
)
234+
self.assertEqual(second_order.shape, expected_second_order.shape)
221235

222236
def test_SobolGPMean(self) -> None:
223237
bounds = torch.tensor([(0.0, 1.0) for _ in range(2)]).t()
@@ -226,9 +240,9 @@ def test_SobolGPMean(self) -> None:
226240
)
227241
self._test_sobol_gp_mean(
228242
sensitivity=sensitivity_mean,
229-
expected_first_order_shape=torch.Size([2]),
230-
expected_total_order_shape=torch.Size([2]),
231-
expected_second_order_shape=torch.Size([1]),
243+
expected_first_order=torch.tensor([1.1547, -0.4024], dtype=torch.float64),
244+
expected_total_order=torch.tensor([0.4299, 0.4894], dtype=torch.float64),
245+
expected_second_order=torch.tensor([-1.4845], dtype=torch.float64),
232246
)
233247

234248
def test_SobolGPMean_SAASBO(self) -> None:
@@ -238,9 +252,9 @@ def test_SobolGPMean_SAASBO(self) -> None:
238252
)
239253
self._test_sobol_gp_mean(
240254
sensitivity=sensitivity_mean_saas,
241-
expected_first_order_shape=torch.Size([2]),
242-
expected_total_order_shape=torch.Size([2]),
243-
expected_second_order_shape=torch.Size([1]),
255+
expected_first_order=torch.tensor([0.5752, 0.5143], dtype=torch.double),
256+
expected_total_order=torch.tensor([0.9897, 0.0979], dtype=torch.float64),
257+
expected_second_order=torch.tensor([0.8332], dtype=torch.double),
244258
)
245259

246260
sensitivity_mean_bootstrap = SobolSensitivityGPMean(
@@ -253,9 +267,17 @@ def test_SobolGPMean_SAASBO(self) -> None:
253267
)
254268
self._test_sobol_gp_mean(
255269
sensitivity=sensitivity_mean_bootstrap,
256-
expected_first_order_shape=torch.Size([2, 3]),
257-
expected_total_order_shape=torch.Size([2, 3]),
258-
expected_second_order_shape=torch.Size([1, 3]),
270+
expected_first_order=torch.tensor(
271+
[[0.6327, 10.0889, 1.0044], [0.2089, 0.7322, 0.2706]],
272+
dtype=torch.float64,
273+
),
274+
expected_total_order=torch.tensor(
275+
[[0.8013, 0.1824, 0.1351], [0.2203, 0.0304, 0.0551]],
276+
dtype=torch.float64,
277+
),
278+
expected_second_order=torch.tensor(
279+
[[0.7978, 22.6598, 1.5053]], dtype=torch.float64
280+
),
259281
)
260282

261283
sensitivity_mean_bootstrap = SobolSensitivityGPMean(
@@ -268,18 +290,26 @@ def test_SobolGPMean_SAASBO(self) -> None:
268290
)
269291
self._test_sobol_gp_mean(
270292
sensitivity=sensitivity_mean_bootstrap,
271-
expected_first_order_shape=torch.Size([2, 3]),
272-
expected_total_order_shape=torch.Size([2, 3]),
273-
expected_second_order_shape=torch.Size([1, 3]),
293+
expected_first_order=torch.tensor(
294+
[[3.4512, 32.4428, 1.8012], [0.2069, 121.8610, 3.4909]],
295+
dtype=torch.float64,
296+
),
297+
expected_total_order=torch.tensor(
298+
[[0.4288, 0.0903, 0.0950], [0.7923, 0.2218, 0.1489]],
299+
dtype=torch.float64,
300+
),
301+
expected_second_order=torch.tensor(
302+
[[-6.3790, 397.4363, 6.3043]], dtype=torch.float64
303+
),
274304
)
275305

276306
sensitivity_mean = SobolSensitivityGPMean(
277307
self.model, num_mc_samples=10, bounds=bounds, second_order=False
278308
)
279309
self._test_sobol_gp_mean(
280310
sensitivity=sensitivity_mean,
281-
expected_first_order_shape=torch.Size([2]),
282-
expected_total_order_shape=torch.Size([2]),
311+
expected_first_order=torch.tensor([0.9566, -0.4183], dtype=torch.float64),
312+
expected_total_order=torch.tensor([0.3440, 0.3685], dtype=torch.float64),
283313
)
284314

285315
with self.assertRaisesRegex(ValueError, "Second order indices"):

ax/utils/testing/tests/test_mock.py

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
from unittest.mock import patch
1010

11-
import jax.numpy as jnp
1211
import torch
1312
from ax.adapter.registry import Generators
1413
from ax.adapter.transforms.choice_encode import OrderedChoiceToIntegerRange
@@ -23,6 +22,7 @@
2322
from botorch.generation.gen import gen_candidates_scipy
2423
from botorch.optim.optimize_mixed import generate_starting_points
2524
from botorch.utils.testing import MockAcquisitionFunction, skip_if_import_error
25+
from pyro.infer import MCMC
2626

2727

2828
class TestMock(TestCase):
@@ -45,25 +45,13 @@ def test_botorch_mocks(self) -> None:
4545

4646
def test_fully_bayesian_mocks(self) -> None:
4747
experiment = get_branin_experiment(with_completed_batch=True)
48-
num_mcmc_samples = 16
49-
dim = len(experiment.search_space.parameters)
50-
# Mock MCMC to return proper JAX arrays that postprocess_mcmc_samples
51-
# can handle (it performs JAX operations on them).
52-
mock_samples = {
53-
"mean": jnp.ones(num_mcmc_samples),
54-
"noise": jnp.ones(num_mcmc_samples),
55-
"outputscale": jnp.ones(num_mcmc_samples),
56-
"kernel_tausq": jnp.ones(num_mcmc_samples),
57-
"_kernel_inv_length_sq": jnp.ones((num_mcmc_samples, dim)),
58-
}
59-
with patch("botorch.fit.MCMC") as mock_mcmc:
60-
mock_mcmc.return_value.get_samples.return_value = mock_samples
48+
with patch("botorch.fit.MCMC", wraps=MCMC) as mock_mcmc:
6149
with mock_botorch_optimize_context_manager():
6250
Generators.SAASBO(experiment=experiment, data=experiment.lookup_data())
6351
mock_mcmc.assert_called_once()
6452
kwargs = mock_mcmc.call_args.kwargs
6553
self.assertEqual(kwargs["num_samples"], 16)
66-
self.assertEqual(kwargs["num_warmup"], 0)
54+
self.assertEqual(kwargs["warmup_steps"], 0)
6755

6856
def test_mixed_optimizer_mocks(self) -> None:
6957
experiment = get_branin_experiment(

0 commit comments

Comments
 (0)