Skip to content

Fix xla_pmap_p import for JAX versions that removed pmap#2173

Merged
juanitorduz merged 2 commits into
pyro-ppl:masterfrom
saitcakmak:fix-xla-pmap-p-import
May 2, 2026
Merged

Fix xla_pmap_p import for JAX versions that removed pmap#2173
juanitorduz merged 2 commits into
pyro-ppl:masterfrom
saitcakmak:fix-xla-pmap-p-import

Conversation

@saitcakmak
Copy link
Copy Markdown
Contributor

Summary

  • JAX 0.10.0 removed the C++ pmap infrastructure (including xla_pmap_p). This causes an ImportError when importing numpyro with newer JAX versions.
  • Guard the xla_pmap_p import with a try/except and skip the provenance tracking rule registration when it's unavailable.

Test plan

  • python -c "from numpyro.ops.provenance import eval_provenance" succeeds
  • eval_provenance works end-to-end (tested with lambda x, y, z: x + y)
  • All 10 tests in test/ops/test_provenance.py pass

JAX removed the C++ pmap infrastructure (including xla_pmap_p) in a
recent release. Guard the import so numpyro works with both old and
new JAX versions.
@Qazalbash
Copy link
Copy Markdown
Collaborator

Thanks @saitcakmak

@Qazalbash Qazalbash requested a review from juanitorduz April 16, 2026 14:23
@juanitorduz
Copy link
Copy Markdown
Collaborator

juanitorduz commented Apr 16, 2026

LGTM.

It seems the failing tests in Python 3.14 CI (test-inference) are not caused by this PR. They fail inside funsor, a transitive dependency:

  .venv/lib/python3.14/site-packages/funsor/jax/ops.py:203                                                                                                                                                           
  E       TypeError: clip() got an unexpected keyword argument 'a_max' 

funsor calls jnp.clip(..., a_max=...), but recent JAX releases dropped the deprecated a_max/a_min kwargs (replaced by max/min, aligning with NumPy 2.x).

I am looking into the upstream issue: pyro-ppl/funsor#611 @fehiepsi would you mind taking a look?

@Qazalbash Qazalbash linked an issue Apr 17, 2026 that may be closed by this pull request
Qazalbash added a commit to kokabsc/gwkokab that referenced this pull request Apr 18, 2026
sethaxen added a commit to sethaxen/CAGPJax that referenced this pull request Apr 20, 2026
* chore: Temporarily upperbound jax

Until numpyro v0.20 is compatible with jax v0.10, pyro-ppl/numpyro#2173

* fix: Make support a property

Change made in numpyro v0.20.0

* fix: Explicitly import jax.test_util
conorheins added a commit to infer-actively/pymdp that referenced this pull request Apr 21, 2026
setup.cfg is retained for install paths (e.g. Jetson) that don't read
pyproject.toml, so the upper bound has to live in both places to stay
consistent. Same rationale as the pyproject.toml pin — tracks
pyro-ppl/numpyro#2173.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
tedvr pushed a commit to tedvr/pymdp that referenced this pull request Apr 24, 2026
jax 0.10.0 removed xla_pmap_p from jax.extend.core.primitives, which
breaks numpyro's provenance module and thus any pybefit-using test.
Upstream fix is pyro-ppl/numpyro#2173 (approved, not yet released).
Revert this pin once a patched numpyro ships.

Fixes infer-actively#389.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@Zethson
Copy link
Copy Markdown

Zethson commented Apr 30, 2026

I'm sorry to be "that person" but would it be possible to please kindly get this or a similar fix merged and released?

@juanitorduz
Copy link
Copy Markdown
Collaborator

Fix funsor for JAX 0.10 / NumPy 2.x pyro-ppl/funsor#611

We are working on an upstream fix pyro-ppl/funsor#611

@juanitorduz
Copy link
Copy Markdown
Collaborator

@saitcakmak now that the upstream issue has been fixed, could you please rebase or sync with master so that we can merge this one :) Thanks!

@juanitorduz
Copy link
Copy Markdown
Collaborator

@fehiepsi @Qazalbash, as this one was approved before, I will merge it as many folks wanna have it merged :D

We should cut a release after this one and other queued PRs that have been approved 🙏

@juanitorduz juanitorduz merged commit 18c2cc1 into pyro-ppl:master May 2, 2026
9 checks passed
@Qazalbash
Copy link
Copy Markdown
Collaborator

We should cut a release after this one and other queued PRs that have been approved 🙏

Absolutely! We need it.

Qazalbash added a commit to kokabsc/gwkokab that referenced this pull request May 2, 2026
* fix: remove upper pin from JAX

pyro-ppl/numpyro#2173 has been merged along with
a new numpyro release

* fix: update numpyro and jax minimum versions
@saitcakmak saitcakmak deleted the fix-xla-pmap-p-import branch May 4, 2026 20:15
conorheins added a commit to infer-actively/pymdp that referenced this pull request May 7, 2026
## Summary
- Reverts the temporary `<0.10` cap on `jax`/`jaxlib` (and
`jax[cuda12]`) introduced in #391 across `pyproject.toml` and
`setup.cfg`.
- Removes the inline TODO comments referencing
[pyro-ppl/numpyro#2173](pyro-ppl/numpyro#2173).
- numpyro#2173 merged 2026-05-02 and shipped in numpyro **0.21.0** the
same day. The fix wraps the `xla_pmap_p` import in a try/except so the
absence on jax ≥ 0.10.0 no longer breaks `import numpyro` (and therefore
`pybefit`).

## Verification
- [x] Confirmed numpyro 0.21.0 release tag contains the try/except guard
in `numpyro/ops/provenance.py`.
- [x] `uv lock --dry-run` resolves cleanly: jax 0.10.0, jaxlib 0.10.0,
numpyro 0.21.0, pybefit 0.1.23.
- [x] Standard `Test` CI passes on this PR.
- [x] Manual Branch Nightly run on this branch passes.

## Notes
This closes out the workaround from #389. Should be a clean inverse of
#391.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

JAX 0.10.0 breaks NumPyro

5 participants