Make JAX/jaxlib/NumPyro optional dependencies#3319
Open
saitcakmak wants to merge 4 commits into
Open
Conversation
JAX, jaxlib, and NumPyro are only needed to fit fully Bayesian models (e.g. SaasFullyBayesianSingleTaskGP / SaasFullyBayesianMultiTaskGP), which run NUTS via NumPyro. The rest of BoTorch imports and runs without them: imports are already gated behind try/except (_HAS_JAX) and a local import in fit_fully_bayesian_model_nuts. Move them out of the core dependencies into a new 'fully_bayesian' extra, included in the 'test' extra so CI still exercises them. Also update the _check_jax_available error message to point at 'pip install botorch[fully_bayesian]' instead of the Meta-internal PACKAGE / python-scientific-stack instructions. Fixes #3318
Subprocess smoke test that blocks jax/jaxlib/numpyro at import time and verifies the standard fit + acquisition-optimization loop works end to end (SingleTaskGP + LogExpectedImprovement under mock_optimize), while constructing a fully Bayesian model raises a clear ImportError.
|
@saitcakmak has imported this pull request. If you are a Meta employee, you can view this in D107686465. |
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #3319 +/- ##
=======================================
Coverage 99.98% 99.98%
=======================================
Files 226 226
Lines 22470 22470
=======================================
Hits 22466 22466
Misses 4 4 ☔ View full report in Codecov by Harness. 🚀 New features to boost your workflow:
|
Tutorials exercise fully Bayesian models, so the 'tutorials' extra (used by reusable_tutorials.yml) must pull in JAX. The 'test' and 'dev' extras already chain to 'fully_bayesian'; add it to 'tutorials' too.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Fixes #3318. JAX, jaxlib, and NumPyro were added as required dependencies in 0.18.0, but they are only needed to fit fully Bayesian models (
SaasFullyBayesianSingleTaskGP/SaasFullyBayesianMultiTaskGP), which run NUTS via NumPyro. As the issue notes, a Torch project shouldn't pull in JAX for everyone.The code is already written for these to be optional:
botorch/models/fully_bayesian.pyguardsimport jax/numpyrobehindtry/except(_HAS_JAX) and raises via_check_jax_available()only at model construction.botorch/models/fully_bayesian_multitask.pyguards its imports behindif _HAS_JAX:.fit_fully_bayesian_model_nutsuses a local import (not module-level).I verified empirically (with
jax/jaxlib/numpyroblocked at import) thatimport botorch,import botorch.fit, importing the fully-Bayesian modules, and building/usingSingleTaskGPall work; only constructing a fully Bayesian model raises a clearImportError.Changes
jax,jaxlib,numpyrofrom coredependenciesinto a newfully_bayesianoptional extra inpyproject.toml.botorch[fully_bayesian]to thetestextra so CI still exercises the fully Bayesian path._check_jax_availableerror message to point atpip install "botorch[fully_bayesian]"instead of the Meta-internalPACKAGE/python-scientific-stackinstructions (which are meaningless to OSS users).Note:
botorch/utils/jax_utils.pystill imports JAX at module level, but it is not imported by any production code (only its own test), so it does not affectimport botorch.Test plan
pytest test/models/test_fully_bayesian.py test/models/test_fully_bayesian_multitask.py→ 98 passed.import botorchand core model usage work with JAX/jaxlib/numpyro blocked.