Fix xla_pmap_p import for JAX versions that removed pmap#2173
Conversation
JAX removed the C++ pmap infrastructure (including xla_pmap_p) in a recent release. Guard the import so numpyro works with both old and new JAX versions.
|
Thanks @saitcakmak |
|
LGTM. It seems the failing tests in Python 3.14 CI (test-inference) are not caused by this PR. They fail inside funsor, a transitive dependency: .venv/lib/python3.14/site-packages/funsor/jax/ops.py:203
E TypeError: clip() got an unexpected keyword argument 'a_max' funsor calls I am looking into the upstream issue: pyro-ppl/funsor#611 @fehiepsi would you mind taking a look? |
is resovled
* chore: Temporarily upperbound jax Until numpyro v0.20 is compatible with jax v0.10, pyro-ppl/numpyro#2173 * fix: Make support a property Change made in numpyro v0.20.0 * fix: Explicitly import jax.test_util
setup.cfg is retained for install paths (e.g. Jetson) that don't read pyproject.toml, so the upper bound has to live in both places to stay consistent. Same rationale as the pyproject.toml pin — tracks pyro-ppl/numpyro#2173. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
jax 0.10.0 removed xla_pmap_p from jax.extend.core.primitives, which breaks numpyro's provenance module and thus any pybefit-using test. Upstream fix is pyro-ppl/numpyro#2173 (approved, not yet released). Revert this pin once a patched numpyro ships. Fixes infer-actively#389. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
|
I'm sorry to be "that person" but would it be possible to please kindly get this or a similar fix merged and released? |
We are working on an upstream fix pyro-ppl/funsor#611 |
|
@saitcakmak now that the upstream issue has been fixed, could you please rebase or sync with master so that we can merge this one :) Thanks! |
|
@fehiepsi @Qazalbash, as this one was approved before, I will merge it as many folks wanna have it merged :D We should cut a release after this one and other queued PRs that have been approved 🙏 |
Absolutely! We need it. |
* fix: remove upper pin from JAX pyro-ppl/numpyro#2173 has been merged along with a new numpyro release * fix: update numpyro and jax minimum versions
## Summary - Reverts the temporary `<0.10` cap on `jax`/`jaxlib` (and `jax[cuda12]`) introduced in #391 across `pyproject.toml` and `setup.cfg`. - Removes the inline TODO comments referencing [pyro-ppl/numpyro#2173](pyro-ppl/numpyro#2173). - numpyro#2173 merged 2026-05-02 and shipped in numpyro **0.21.0** the same day. The fix wraps the `xla_pmap_p` import in a try/except so the absence on jax ≥ 0.10.0 no longer breaks `import numpyro` (and therefore `pybefit`). ## Verification - [x] Confirmed numpyro 0.21.0 release tag contains the try/except guard in `numpyro/ops/provenance.py`. - [x] `uv lock --dry-run` resolves cleanly: jax 0.10.0, jaxlib 0.10.0, numpyro 0.21.0, pybefit 0.1.23. - [x] Standard `Test` CI passes on this PR. - [x] Manual Branch Nightly run on this branch passes. ## Notes This closes out the workaround from #389. Should be a clean inverse of #391. 🤖 Generated with [Claude Code](https://claude.com/claude-code) --------- Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Summary
xla_pmap_p). This causes anImportErrorwhen importing numpyro with newer JAX versions.xla_pmap_pimport with a try/except and skip the provenance tracking rule registration when it's unavailable.Test plan
python -c "from numpyro.ops.provenance import eval_provenance"succeedseval_provenanceworks end-to-end (tested withlambda x, y, z: x + y)test/ops/test_provenance.pypass