I noticed that the nightly test failed last night: https://github.com/infer-actively/pymdp/actions/runs/24548242995
It appears it is due to the latest v0.10.0 release of jax, which has removed xla_pmap_p (among other things) from the C++ pmap infrastructure
ERROR test/test_pybefit_model_fitting.py - ImportError while importing test module '/home/runner/work/pymdp/pymdp/test/test_pybefit_model_fitting.py'.
Hint: make sure your test modules/packages have valid Python names.
Traceback:
/opt/hostedtoolcache/Python/3.12.13/x64/lib/python3.12/importlib/__init__.py:90: in import_module
return _bootstrap._gcd_import(name[level:], package, level)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
test/test_pybefit_model_fitting.py:14: in <module>
from pybefit.inference import Normal, NumpyroModel # noqa: E402
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/pybefit/inference/__init__.py:1: in <module>
from .methods import *
.venv/lib/python3.12/site-packages/pybefit/inference/methods.py:7: in <module>
import numpyro.infer as ninfer
.venv/lib/python3.12/site-packages/numpyro/__init__.py:13: in <module>
from numpyro import compat, diagnostics, distributions, handlers, infer, ops, optim
.venv/lib/python3.12/site-packages/numpyro/infer/__init__.py:6: in <module>
from numpyro.infer.elbo import (
.venv/lib/python3.12/site-packages/numpyro/infer/elbo.py:35: in <module>
from numpyro.ops.provenance import eval_provenance
.venv/lib/python3.12/site-packages/numpyro/ops/provenance.py:7: in <module>
from jax.extend.core.primitives import call_p, closed_call_p, jit_p, xla_pmap_p
E ImportError: cannot import name 'xla_pmap_p' from 'jax.extend.core.primitives' (/home/runner/work/pymdp/pymdp/.venv/lib/python3.12/site-packages/jax/extend/core/primitives.py)
ERROR test/test_tmaze_recoverability.py - ImportError while importing test module '/home/runner/work/pymdp/pymdp/test/test_tmaze_recoverability.py'.
Hint: make sure your test modules/packages have valid Python names.
Traceback:
/opt/hostedtoolcache/Python/3.12.13/x64/lib/python3.12/importlib/__init__.py:90: in import_module
return _bootstrap._gcd_import(name[level:], package, level)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
test/test_tmaze_recoverability.py:5: in <module>
from examples.model_fitting.tmaze_recoverability import RecoverabilityConfig, run_recoverability
examples/model_fitting/tmaze_recoverability.py:23: in <module>
from numpyro.infer import Predictive
.venv/lib/python3.12/site-packages/numpyro/__init__.py:13: in <module>
from numpyro import compat, diagnostics, distributions, handlers, infer, ops, optim
.venv/lib/python3.12/site-packages/numpyro/infer/__init__.py:6: in <module>
from numpyro.infer.elbo import (
.venv/lib/python3.12/site-packages/numpyro/infer/elbo.py:35: in <module>
from numpyro.ops.provenance import eval_provenance
.venv/lib/python3.12/site-packages/numpyro/ops/provenance.py:7: in <module>
from jax.extend.core.primitives import call_p, closed_call_p, jit_p, xla_pmap_p
E ImportError: cannot import name 'xla_pmap_p' from 'jax.extend.core.primitives' (/home/runner/work/pymdp/pymdp/.venv/lib/python3.12/site-packages/jax/extend/core/primitives.py)
=================== 1 passed, 2 warnings, 2 errors in 41.75s ===================
Update (2026-04-21): The immediate cause is in numpyro, not pybefit — numpyro/ops/provenance.py still does an unconditional from jax.extend.core.primitives import ... xla_pmap_p, which fails on jax ≥ 0.10.0. pybefit is just an unlucky downstream user (it depends on numpyro without pinning).
The upstream fix is pyro-ppl/numpyro#2173, which guards the xla_pmap_p import with a try/except. It has been approved but is not yet merged or released (latest numpyro on PyPI is 0.20.1, 2026-03-25). Tracking bug: pyro-ppl/numpyro#2174.
Short-term workaround in pymdp: pin jax/jaxlib to <0.10 — see #391. Revert once a patched numpyro release ships.
I noticed that the nightly test failed last night: https://github.com/infer-actively/pymdp/actions/runs/24548242995
It appears it is due to the latest v0.10.0 release of jax, which has removed
xla_pmap_p(among other things) from the C++ pmap infrastructureUpdate (2026-04-21): The immediate cause is in numpyro, not pybefit —
numpyro/ops/provenance.pystill does an unconditionalfrom jax.extend.core.primitives import ... xla_pmap_p, which fails on jax ≥ 0.10.0. pybefit is just an unlucky downstream user (it depends on numpyro without pinning).The upstream fix is pyro-ppl/numpyro#2173, which guards the
xla_pmap_pimport with a try/except. It has been approved but is not yet merged or released (latest numpyro on PyPI is 0.20.1, 2026-03-25). Tracking bug: pyro-ppl/numpyro#2174.Short-term workaround in pymdp: pin
jax/jaxlibto<0.10— see #391. Revert once a patched numpyro release ships.