Skip to content

Commit 094095b

Browse files
committed
Add regression test that core BoTorch works without JAX
Subprocess smoke test that blocks jax/jaxlib/numpyro at import time and verifies the standard fit + acquisition-optimization loop works end to end (SingleTaskGP + LogExpectedImprovement under mock_optimize), while constructing a fully Bayesian model raises a clear ImportError.
1 parent fb270da commit 094095b

1 file changed

Lines changed: 80 additions & 0 deletions

File tree

test/models/test_fully_bayesian.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66

77

88
import itertools
9+
import subprocess
10+
import sys
11+
import textwrap
912
from unittest import mock
1013
from unittest.mock import patch
1114

@@ -1297,3 +1300,80 @@ def test_missing_jax_raises_on_instantiation(self) -> None:
12971300
train_X=torch.rand(10, 2),
12981301
train_Y=torch.rand(10, 1),
12991302
)
1303+
1304+
def test_core_workflow_without_jax(self) -> None:
1305+
"""Core BoTorch works when JAX/jaxlib/NumPyro are not installed.
1306+
1307+
JAX, jaxlib, and NumPyro are optional dependencies that are only needed
1308+
to fit fully Bayesian models. We run this in a subprocess that blocks
1309+
those imports (the parent process has them installed and has already
1310+
imported the gated modules with ``_HAS_JAX=True``), and assert that the
1311+
standard fit + acquisition-optimization loop works end to end while
1312+
constructing a fully Bayesian model raises a clear ``ImportError``.
1313+
"""
1314+
script = textwrap.dedent(
1315+
"""
1316+
import builtins
1317+
1318+
_real_import = builtins.__import__
1319+
1320+
def _blocked_import(name, *args, **kwargs):
1321+
if name.split(".")[0] in {"jax", "jaxlib", "numpyro"}:
1322+
raise ImportError(f"blocked {name}")
1323+
return _real_import(name, *args, **kwargs)
1324+
1325+
builtins.__import__ = _blocked_import
1326+
1327+
import torch
1328+
import botorch # noqa: F401
1329+
import botorch.fit # noqa: F401
1330+
import botorch.models.fully_bayesian_multitask # noqa: F401
1331+
from botorch.acquisition.analytic import LogExpectedImprovement
1332+
from botorch.fit import fit_gpytorch_mll
1333+
from botorch.models import SingleTaskGP
1334+
from botorch.models.fully_bayesian import (
1335+
_HAS_JAX,
1336+
SaasFullyBayesianSingleTaskGP,
1337+
)
1338+
from botorch.optim import optimize_acqf
1339+
from botorch.test_utils.mock import mock_optimize_context_manager
1340+
from gpytorch.mlls import ExactMarginalLogLikelihood
1341+
1342+
assert _HAS_JAX is False, "_HAS_JAX should be False with JAX blocked"
1343+
1344+
# Standard (non-fully-Bayesian) BO loop must work without JAX.
1345+
train_X = torch.rand(8, 2, dtype=torch.double)
1346+
train_Y = train_X.sum(dim=-1, keepdim=True)
1347+
model = SingleTaskGP(train_X=train_X, train_Y=train_Y)
1348+
mll = ExactMarginalLogLikelihood(model.likelihood, model)
1349+
with mock_optimize_context_manager():
1350+
fit_gpytorch_mll(mll)
1351+
acqf = LogExpectedImprovement(model=model, best_f=train_Y.max())
1352+
candidate, _ = optimize_acqf(
1353+
acq_function=acqf,
1354+
bounds=torch.tensor(
1355+
[[0.0, 0.0], [1.0, 1.0]], dtype=torch.double
1356+
),
1357+
q=1,
1358+
num_restarts=2,
1359+
raw_samples=4,
1360+
)
1361+
assert candidate.shape == (1, 2), candidate.shape
1362+
1363+
# Fully Bayesian models still require JAX -> clear ImportError.
1364+
try:
1365+
SaasFullyBayesianSingleTaskGP(train_X=train_X, train_Y=train_Y)
1366+
except ImportError:
1367+
pass
1368+
else:
1369+
raise AssertionError(
1370+
"expected ImportError when constructing SAAS without JAX"
1371+
)
1372+
"""
1373+
)
1374+
result = subprocess.run(
1375+
[sys.executable, "-c", script],
1376+
capture_output=True,
1377+
text=True,
1378+
)
1379+
self.assertEqual(result.returncode, 0, msg=result.stderr)

0 commit comments

Comments
 (0)