|
6 | 6 |
|
7 | 7 |
|
8 | 8 | import itertools |
| 9 | +import subprocess |
| 10 | +import sys |
| 11 | +import textwrap |
9 | 12 | from unittest import mock |
10 | 13 | from unittest.mock import patch |
11 | 14 |
|
@@ -1297,3 +1300,80 @@ def test_missing_jax_raises_on_instantiation(self) -> None: |
1297 | 1300 | train_X=torch.rand(10, 2), |
1298 | 1301 | train_Y=torch.rand(10, 1), |
1299 | 1302 | ) |
| 1303 | + |
| 1304 | + def test_core_workflow_without_jax(self) -> None: |
| 1305 | + """Core BoTorch works when JAX/jaxlib/NumPyro are not installed. |
| 1306 | +
|
| 1307 | + JAX, jaxlib, and NumPyro are optional dependencies that are only needed |
| 1308 | + to fit fully Bayesian models. We run this in a subprocess that blocks |
| 1309 | + those imports (the parent process has them installed and has already |
| 1310 | + imported the gated modules with ``_HAS_JAX=True``), and assert that the |
| 1311 | + standard fit + acquisition-optimization loop works end to end while |
| 1312 | + constructing a fully Bayesian model raises a clear ``ImportError``. |
| 1313 | + """ |
| 1314 | + script = textwrap.dedent( |
| 1315 | + """ |
| 1316 | + import builtins |
| 1317 | +
|
| 1318 | + _real_import = builtins.__import__ |
| 1319 | +
|
| 1320 | + def _blocked_import(name, *args, **kwargs): |
| 1321 | + if name.split(".")[0] in {"jax", "jaxlib", "numpyro"}: |
| 1322 | + raise ImportError(f"blocked {name}") |
| 1323 | + return _real_import(name, *args, **kwargs) |
| 1324 | +
|
| 1325 | + builtins.__import__ = _blocked_import |
| 1326 | +
|
| 1327 | + import torch |
| 1328 | + import botorch # noqa: F401 |
| 1329 | + import botorch.fit # noqa: F401 |
| 1330 | + import botorch.models.fully_bayesian_multitask # noqa: F401 |
| 1331 | + from botorch.acquisition.analytic import LogExpectedImprovement |
| 1332 | + from botorch.fit import fit_gpytorch_mll |
| 1333 | + from botorch.models import SingleTaskGP |
| 1334 | + from botorch.models.fully_bayesian import ( |
| 1335 | + _HAS_JAX, |
| 1336 | + SaasFullyBayesianSingleTaskGP, |
| 1337 | + ) |
| 1338 | + from botorch.optim import optimize_acqf |
| 1339 | + from botorch.test_utils.mock import mock_optimize_context_manager |
| 1340 | + from gpytorch.mlls import ExactMarginalLogLikelihood |
| 1341 | +
|
| 1342 | + assert _HAS_JAX is False, "_HAS_JAX should be False with JAX blocked" |
| 1343 | +
|
| 1344 | + # Standard (non-fully-Bayesian) BO loop must work without JAX. |
| 1345 | + train_X = torch.rand(8, 2, dtype=torch.double) |
| 1346 | + train_Y = train_X.sum(dim=-1, keepdim=True) |
| 1347 | + model = SingleTaskGP(train_X=train_X, train_Y=train_Y) |
| 1348 | + mll = ExactMarginalLogLikelihood(model.likelihood, model) |
| 1349 | + with mock_optimize_context_manager(): |
| 1350 | + fit_gpytorch_mll(mll) |
| 1351 | + acqf = LogExpectedImprovement(model=model, best_f=train_Y.max()) |
| 1352 | + candidate, _ = optimize_acqf( |
| 1353 | + acq_function=acqf, |
| 1354 | + bounds=torch.tensor( |
| 1355 | + [[0.0, 0.0], [1.0, 1.0]], dtype=torch.double |
| 1356 | + ), |
| 1357 | + q=1, |
| 1358 | + num_restarts=2, |
| 1359 | + raw_samples=4, |
| 1360 | + ) |
| 1361 | + assert candidate.shape == (1, 2), candidate.shape |
| 1362 | +
|
| 1363 | + # Fully Bayesian models still require JAX -> clear ImportError. |
| 1364 | + try: |
| 1365 | + SaasFullyBayesianSingleTaskGP(train_X=train_X, train_Y=train_Y) |
| 1366 | + except ImportError: |
| 1367 | + pass |
| 1368 | + else: |
| 1369 | + raise AssertionError( |
| 1370 | + "expected ImportError when constructing SAAS without JAX" |
| 1371 | + ) |
| 1372 | + """ |
| 1373 | + ) |
| 1374 | + result = subprocess.run( |
| 1375 | + [sys.executable, "-c", script], |
| 1376 | + capture_output=True, |
| 1377 | + text=True, |
| 1378 | + ) |
| 1379 | + self.assertEqual(result.returncode, 0, msg=result.stderr) |
0 commit comments