deps: unpin jax/jaxlib now that numpyro 0.21.0 ships fix#404
Conversation
Reverts the temporary <0.10 cap on jax/jaxlib (pyproject.toml + setup.cfg core deps and gpu extras) introduced in #391. The upstream cause — numpyro/ops/provenance.py importing the removed xla_pmap_p — was fixed by pyro-ppl/numpyro#2173 (merged 2026-05-02) and shipped in numpyro 0.21.0 the same day. uv now resolves jax 0.10.0 + numpyro 0.21.0 cleanly. Closes the workaround tracked in #389. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
jax 0.10.0 finished the deprecation of the a_min/a_max kwargs on jnp.clip in favour of the numpy 2.0-style min/max kwargs. Five call sites in examples/advanced/pymdp_with_neural_encoder.ipynb still used the old spelling and crashed nightly notebook tests with "TypeError: clip() got an unexpected keyword argument 'a_min'" as soon as the <0.10 cap was lifted in the previous commit. Library code is unaffected; only this notebook used the deprecated API. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
|
First nightly run (25428706834) green on the steps that gate this PR — It surfaced one secondary regression in Fixed in 225c9a1 by switching to numpy-2.0-style |
🤖 I have created a release *beep* *boop* --- ## [1.0.2](v1.0.1...v1.0.2) (2026-05-07) ### Dependencies * unpin jax/jaxlib now that numpyro 0.21.0 ships fix ([#404](#404)) ([5491979](5491979)) ### Documentation * **contributing:** fix release-trigger table and document cadence ([#403](#403)) ([66226ea](66226ea)) * **notebooks:** render math in mkdocs-jupyter notebooks ([#400](#400)) ([23c29ed](23c29ed)) --- This PR was generated with [Release Please](https://github.com/googleapis/release-please). See [documentation](https://github.com/googleapis/release-please#release-please).
Summary
<0.10cap onjax/jaxlib(andjax[cuda12]) introduced in deps: pin jax/jaxlib <0.10 pending numpyro fix (#389) #391 acrosspyproject.tomlandsetup.cfg.xla_pmap_pimport in a try/except so the absence on jax ≥ 0.10.0 no longer breaksimport numpyro(and thereforepybefit).Verification
numpyro/ops/provenance.py.uv lock --dry-runresolves cleanly: jax 0.10.0, jaxlib 0.10.0, numpyro 0.21.0, pybefit 0.1.23.TestCI passes on this PR.Notes
This closes out the workaround from #389. Should be a clean inverse of #391.
🤖 Generated with Claude Code