Skip to content

deps: pin jax/jaxlib <0.10 pending numpyro fix (#389)#391

Merged
conorheins merged 2 commits into
mainfrom
claude/upbeat-allen-8d7618
Apr 22, 2026
Merged

deps: pin jax/jaxlib <0.10 pending numpyro fix (#389)#391
conorheins merged 2 commits into
mainfrom
claude/upbeat-allen-8d7618

Conversation

@conorheins
Copy link
Copy Markdown
Collaborator

@conorheins conorheins commented Apr 21, 2026

Summary

  • Nightly CI has been failing since jax 0.10.0 landed — it removed xla_pmap_p from jax.extend.core.primitives, which is imported unconditionally by numpyro/ops/provenance.py. That breaks import numpyro, which in turn breaks pybefit and the two model-fitting test modules (test_pybefit_model_fitting.py, test_tmaze_recoverability.py).
  • Upstream fix is pyro-ppl/numpyro#2173 (approved, not yet merged/released). Tracking bug: pyro-ppl/numpyro#2174.
  • This PR pins jax/jaxlib to <0.10 as a short-term workaround and adds a comment in pyproject.toml pointing at the upstream PR so we remember to revert.
  • Closes Nightly test failing due to removal of xla_pmap_p #389 once merged; revert this pin as soon as a patched numpyro release is on PyPI.

Test plan

  • uv lock --dry-run resolves jax 0.9.2, numpyro 0.20.1, pybefit 0.1.23 cleanly
  • Nightly Tests CI run on this branch (via manual-branch-nightly.yaml) passes
  • Standard Test CI on this PR passes

🤖 Generated with Claude Code

jax 0.10.0 removed xla_pmap_p from jax.extend.core.primitives, which
breaks numpyro's provenance module and thus any pybefit-using test.
Upstream fix is pyro-ppl/numpyro#2173 (approved, not yet released).
Revert this pin once a patched numpyro ships.

Fixes #389.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@conorheins conorheins requested a review from Copilot April 21, 2026 20:14
@conorheins conorheins added dependencies Pull requests that update a dependency file tests related to test coverage labels Apr 21, 2026
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Pins jax/jaxlib to avoid import numpyro failures caused by jax>=0.10.0 removing xla_pmap_p, which breaks downstream pybefit-based test modules and nightly CI.

Changes:

  • Add <0.10 upper bounds for jax and jaxlib in core dependencies.
  • Add <0.10 upper bound for jax[cuda12] in the gpu extra.
  • Document the upstream numpyro fix/links in pyproject.toml to aid future unpinning.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread pyproject.toml
Comment thread pyproject.toml
setup.cfg is retained for install paths (e.g. Jetson) that don't read
pyproject.toml, so the upper bound has to live in both places to stay
consistent. Same rationale as the pyproject.toml pin — tracks
pyro-ppl/numpyro#2173.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@conorheins
Copy link
Copy Markdown
Collaborator Author

conorheins commented Apr 21, 2026

Addressed the Copilot review: mirrored the <0.10 pin on jax, jaxlib, and the gpu extra in setup.cfg in 47604ec.

Not consolidating away setup.cfg (see #162) — it's retained intentionally to support install paths (e.g. Nvidia Jetson) that can't cleanly consume pyproject.toml. Both files will need to stay in sync on dependency changes; noted that in the setup.cfg comment.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 2 out of 2 changed files in this pull request and generated no new comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Copy link
Copy Markdown
Collaborator Author

@conorheins conorheins left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewed Issue #389 first, then checked PR #391 against the failure mode described there. The dependency bounds land in the wheel metadata, uv lock --dry-run --python 3.12 resolves below JAX 0.10, and the two affected tests pass locally in an isolated worktree:

uv run --group test pytest test/test_pybefit_model_fitting.py test/test_tmaze_recoverability.py -q

No blocking findings from me. This looks like the right short-term workaround until NumPyro releases the upstream fix.

  • codex

@conorheins conorheins merged commit a63de85 into main Apr 22, 2026
9 checks passed
@conorheins conorheins deleted the claude/upbeat-allen-8d7618 branch April 22, 2026 06:31
@conorheins
Copy link
Copy Markdown
Collaborator Author

Looks like pyro-ppl/numpyro#2173 got merged so we should revisit this and remove the upper version restriction now

conorheins added a commit that referenced this pull request May 7, 2026
## Summary
- Reverts the temporary `<0.10` cap on `jax`/`jaxlib` (and
`jax[cuda12]`) introduced in #391 across `pyproject.toml` and
`setup.cfg`.
- Removes the inline TODO comments referencing
[pyro-ppl/numpyro#2173](pyro-ppl/numpyro#2173).
- numpyro#2173 merged 2026-05-02 and shipped in numpyro **0.21.0** the
same day. The fix wraps the `xla_pmap_p` import in a try/except so the
absence on jax ≥ 0.10.0 no longer breaks `import numpyro` (and therefore
`pybefit`).

## Verification
- [x] Confirmed numpyro 0.21.0 release tag contains the try/except guard
in `numpyro/ops/provenance.py`.
- [x] `uv lock --dry-run` resolves cleanly: jax 0.10.0, jaxlib 0.10.0,
numpyro 0.21.0, pybefit 0.1.23.
- [x] Standard `Test` CI passes on this PR.
- [x] Manual Branch Nightly run on this branch passes.

## Notes
This closes out the workaround from #389. Should be a clean inverse of
#391.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

dependencies Pull requests that update a dependency file tests related to test coverage

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Nightly test failing due to removal of xla_pmap_p

2 participants