Skip to content

Make JAX/jaxlib/NumPyro optional dependencies#3319

Open
saitcakmak wants to merge 4 commits into
mainfrom
jax-optional-dependency
Open

Make JAX/jaxlib/NumPyro optional dependencies#3319
saitcakmak wants to merge 4 commits into
mainfrom
jax-optional-dependency

Conversation

@saitcakmak
Copy link
Copy Markdown
Contributor

Summary

Fixes #3318. JAX, jaxlib, and NumPyro were added as required dependencies in 0.18.0, but they are only needed to fit fully Bayesian models (SaasFullyBayesianSingleTaskGP / SaasFullyBayesianMultiTaskGP), which run NUTS via NumPyro. As the issue notes, a Torch project shouldn't pull in JAX for everyone.

The code is already written for these to be optional:

  • botorch/models/fully_bayesian.py guards import jax/numpyro behind try/except (_HAS_JAX) and raises via _check_jax_available() only at model construction.
  • botorch/models/fully_bayesian_multitask.py guards its imports behind if _HAS_JAX:.
  • fit_fully_bayesian_model_nuts uses a local import (not module-level).

I verified empirically (with jax/jaxlib/numpyro blocked at import) that import botorch, import botorch.fit, importing the fully-Bayesian modules, and building/using SingleTaskGP all work; only constructing a fully Bayesian model raises a clear ImportError.

Changes

  • Move jax, jaxlib, numpyro from core dependencies into a new fully_bayesian optional extra in pyproject.toml.
  • Add botorch[fully_bayesian] to the test extra so CI still exercises the fully Bayesian path.
  • Update the _check_jax_available error message to point at pip install "botorch[fully_bayesian]" instead of the Meta-internal PACKAGE / python-scientific-stack instructions (which are meaningless to OSS users).

Note: botorch/utils/jax_utils.py still imports JAX at module level, but it is not imported by any production code (only its own test), so it does not affect import botorch.

Test plan

  • pytest test/models/test_fully_bayesian.py test/models/test_fully_bayesian_multitask.py → 98 passed.
  • Confirmed import botorch and core model usage work with JAX/jaxlib/numpyro blocked.
  • ufmt + flake8 clean on changed files.

JAX, jaxlib, and NumPyro are only needed to fit fully Bayesian models
(e.g. SaasFullyBayesianSingleTaskGP / SaasFullyBayesianMultiTaskGP),
which run NUTS via NumPyro. The rest of BoTorch imports and runs without
them: imports are already gated behind try/except (_HAS_JAX) and a local
import in fit_fully_bayesian_model_nuts.

Move them out of the core dependencies into a new 'fully_bayesian' extra,
included in the 'test' extra so CI still exercises them. Also update the
_check_jax_available error message to point at
'pip install botorch[fully_bayesian]' instead of the Meta-internal
PACKAGE / python-scientific-stack instructions.

Fixes #3318
@meta-cla meta-cla Bot added the CLA Signed Do not delete this pull request or issue due to inactivity. label Jun 5, 2026
Subprocess smoke test that blocks jax/jaxlib/numpyro at import time and
verifies the standard fit + acquisition-optimization loop works end to
end (SingleTaskGP + LogExpectedImprovement under mock_optimize), while
constructing a fully Bayesian model raises a clear ImportError.
@meta-codesync
Copy link
Copy Markdown

meta-codesync Bot commented Jun 5, 2026

@saitcakmak has imported this pull request. If you are a Meta employee, you can view this in D107686465.

@codecov
Copy link
Copy Markdown

codecov Bot commented Jun 5, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 99.98%. Comparing base (7a5692d) to head (9b41ee9).

Additional details and impacted files
@@           Coverage Diff           @@
##             main    #3319   +/-   ##
=======================================
  Coverage   99.98%   99.98%           
=======================================
  Files         226      226           
  Lines       22470    22470           
=======================================
  Hits        22466    22466           
  Misses          4        4           

☔ View full report in Codecov by Harness.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

Tutorials exercise fully Bayesian models, so the 'tutorials' extra (used
by reusable_tutorials.yml) must pull in JAX. The 'test' and 'dev' extras
already chain to 'fully_bayesian'; add it to 'tutorials' too.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed Do not delete this pull request or issue due to inactivity.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug]: Jax, Jaxlib, Numpyro included as dependencies unconditionally

1 participant