deps: pin jax/jaxlib <0.10 pending numpyro fix (#389)#391
Conversation
jax 0.10.0 removed xla_pmap_p from jax.extend.core.primitives, which breaks numpyro's provenance module and thus any pybefit-using test. Upstream fix is pyro-ppl/numpyro#2173 (approved, not yet released). Revert this pin once a patched numpyro ships. Fixes #389. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
There was a problem hiding this comment.
Pull request overview
Pins jax/jaxlib to avoid import numpyro failures caused by jax>=0.10.0 removing xla_pmap_p, which breaks downstream pybefit-based test modules and nightly CI.
Changes:
- Add
<0.10upper bounds forjaxandjaxlibin core dependencies. - Add
<0.10upper bound forjax[cuda12]in thegpuextra. - Document the upstream numpyro fix/links in
pyproject.tomlto aid future unpinning.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
setup.cfg is retained for install paths (e.g. Jetson) that don't read pyproject.toml, so the upper bound has to live in both places to stay consistent. Same rationale as the pyproject.toml pin — tracks pyro-ppl/numpyro#2173. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
|
Addressed the Copilot review: mirrored the Not consolidating away |
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 2 out of 2 changed files in this pull request and generated no new comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
conorheins
left a comment
There was a problem hiding this comment.
Reviewed Issue #389 first, then checked PR #391 against the failure mode described there. The dependency bounds land in the wheel metadata, uv lock --dry-run --python 3.12 resolves below JAX 0.10, and the two affected tests pass locally in an isolated worktree:
uv run --group test pytest test/test_pybefit_model_fitting.py test/test_tmaze_recoverability.py -q
No blocking findings from me. This looks like the right short-term workaround until NumPyro releases the upstream fix.
- codex
|
Looks like pyro-ppl/numpyro#2173 got merged so we should revisit this and remove the upper version restriction now |
## Summary - Reverts the temporary `<0.10` cap on `jax`/`jaxlib` (and `jax[cuda12]`) introduced in #391 across `pyproject.toml` and `setup.cfg`. - Removes the inline TODO comments referencing [pyro-ppl/numpyro#2173](pyro-ppl/numpyro#2173). - numpyro#2173 merged 2026-05-02 and shipped in numpyro **0.21.0** the same day. The fix wraps the `xla_pmap_p` import in a try/except so the absence on jax ≥ 0.10.0 no longer breaks `import numpyro` (and therefore `pybefit`). ## Verification - [x] Confirmed numpyro 0.21.0 release tag contains the try/except guard in `numpyro/ops/provenance.py`. - [x] `uv lock --dry-run` resolves cleanly: jax 0.10.0, jaxlib 0.10.0, numpyro 0.21.0, pybefit 0.1.23. - [x] Standard `Test` CI passes on this PR. - [x] Manual Branch Nightly run on this branch passes. ## Notes This closes out the workaround from #389. Should be a clean inverse of #391. 🤖 Generated with [Claude Code](https://claude.com/claude-code) --------- Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Summary
xla_pmap_pfromjax.extend.core.primitives, which is imported unconditionally bynumpyro/ops/provenance.py. That breaksimport numpyro, which in turn breakspybefitand the two model-fitting test modules (test_pybefit_model_fitting.py,test_tmaze_recoverability.py).jax/jaxlibto<0.10as a short-term workaround and adds a comment inpyproject.tomlpointing at the upstream PR so we remember to revert.xla_pmap_p#389 once merged; revert this pin as soon as a patched numpyro release is on PyPI.Test plan
uv lock --dry-runresolves jax 0.9.2, numpyro 0.20.1, pybefit 0.1.23 cleanlymanual-branch-nightly.yaml) passesTestCI on this PR passes🤖 Generated with Claude Code