Skip to content

Commit 5491979

Browse files
conorheinsclaude
andauthored
deps: unpin jax/jaxlib now that numpyro 0.21.0 ships fix (#404)
## Summary - Reverts the temporary `<0.10` cap on `jax`/`jaxlib` (and `jax[cuda12]`) introduced in #391 across `pyproject.toml` and `setup.cfg`. - Removes the inline TODO comments referencing [pyro-ppl/numpyro#2173](pyro-ppl/numpyro#2173). - numpyro#2173 merged 2026-05-02 and shipped in numpyro **0.21.0** the same day. The fix wraps the `xla_pmap_p` import in a try/except so the absence on jax ≥ 0.10.0 no longer breaks `import numpyro` (and therefore `pybefit`). ## Verification - [x] Confirmed numpyro 0.21.0 release tag contains the try/except guard in `numpyro/ops/provenance.py`. - [x] `uv lock --dry-run` resolves cleanly: jax 0.10.0, jaxlib 0.10.0, numpyro 0.21.0, pybefit 0.1.23. - [x] Standard `Test` CI passes on this PR. - [x] Manual Branch Nightly run on this branch passes. ## Notes This closes out the workaround from #389. Should be a clean inverse of #391. 🤖 Generated with [Claude Code](https://claude.com/claude-code) --------- Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 23c29ed commit 5491979

3 files changed

Lines changed: 11 additions & 18 deletions

File tree

examples/advanced/pymdp_with_neural_encoder.ipynb

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@
274274
"source": [
275275
"def sample_categorical_batch(key: jnp.ndarray, probs: jnp.ndarray) -> jnp.ndarray:\n",
276276
" keys = jr.split(key, probs.shape[0])\n",
277-
" return jax.vmap(lambda k, p: jr.categorical(k, jnp.log(jnp.clip(p, a_min=EPS))))(keys, probs).astype(jnp.int32)\n",
277+
" return jax.vmap(lambda k, p: jr.categorical(k, jnp.log(jnp.clip(p, min=EPS))))(keys, probs).astype(jnp.int32)\n",
278278
"\n",
279279
"\n",
280280
"def sample_balanced_initial_states(key: jnp.ndarray, num_sequences: int) -> jnp.ndarray:\n",
@@ -402,7 +402,7 @@
402402
"def model_obs_categorical(model: FrontendModel, x_t: jnp.ndarray):\n",
403403
" logits = jax.vmap(model.encoder)(x_t) / LOGIT_TEMPERATURE\n",
404404
" probs = jnn.softmax(logits, axis=-1)\n",
405-
" probs = jnp.clip(probs, a_min=EPS)\n",
405+
" probs = jnp.clip(probs, min=EPS)\n",
406406
" probs = probs / probs.sum(axis=-1, keepdims=True)\n",
407407
" return [probs]\n",
408408
"\n",
@@ -444,10 +444,10 @@
444444
" return jnp.array(0.0)\n",
445445
"\n",
446446
" pred_obs_next = jnp.einsum('btk,ok->bto', pred_next_state_seq, A_FIXED_SINGLE)\n",
447-
" pred_obs_next = jnp.clip(pred_obs_next, a_min=EPS)\n",
447+
" pred_obs_next = jnp.clip(pred_obs_next, min=EPS)\n",
448448
" pred_obs_next = pred_obs_next / pred_obs_next.sum(axis=-1, keepdims=True)\n",
449449
"\n",
450-
" obs_next = jnp.clip(obs_seq[:, 1:, :], a_min=EPS)\n",
450+
" obs_next = jnp.clip(obs_seq[:, 1:, :], min=EPS)\n",
451451
" kl = jnp.sum(obs_next * (jnp.log(obs_next) - jnp.log(pred_obs_next)), axis=-1)\n",
452452
" return jnp.mean(kl)\n",
453453
"\n",
@@ -710,7 +710,7 @@
710710
" t = true_states.reshape(-1)\n",
711711
" p = pred_states.reshape(-1)\n",
712712
" cm = cm.at[t, p].add(1)\n",
713-
" cm = cm / jnp.clip(cm.sum(axis=1, keepdims=True), a_min=1.0)\n",
713+
" cm = cm / jnp.clip(cm.sum(axis=1, keepdims=True), min=1.0)\n",
714714
" return cm\n",
715715
"\n",
716716
"\n",

pyproject.toml

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,8 @@ classifiers=[
2424
requires-python = ">=3.10"
2525
dependencies = [
2626
'numpy>=1.19.5',
27-
# jax/jaxlib upper-bounded until numpyro ships the fix from
28-
# https://github.com/pyro-ppl/numpyro/pull/2173 (jax 0.10.0 removed
29-
# xla_pmap_p, which breaks numpyro -> pybefit imports). See issue #389.
30-
'jax>=0.3.4,<0.10',
31-
'jaxlib>=0.3.4,<0.10',
27+
'jax>=0.3.4',
28+
'jaxlib>=0.3.4',
3229
'equinox>=0.9',
3330
'multimethod>=1.11',
3431
'matplotlib>=3.1.3',
@@ -60,7 +57,7 @@ test = [
6057

6158
[project.optional-dependencies]
6259
gpu = [
63-
'jax[cuda12]>=0.3.4,<0.10',
60+
'jax[cuda12]>=0.3.4',
6461
]
6562
docs = [
6663
"mkdocs>=1.6.1,<2",

setup.cfg

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,8 @@ project_urls =
2424
packages = find:
2525
install_requires =
2626
numpy>=1.19.5
27-
# jax/jaxlib upper-bounded until numpyro ships the fix from
28-
# https://github.com/pyro-ppl/numpyro/pull/2173 (jax 0.10.0 removed
29-
# xla_pmap_p, which breaks numpyro -> pybefit imports). See issue #389.
30-
# Keep in sync with pyproject.toml.
31-
jax>=0.3.4,<0.10
32-
jaxlib>=0.3.4,<0.10
27+
jax>=0.3.4
28+
jaxlib>=0.3.4
3329
equinox>=0.9
3430
multimethod>=1.11
3531
matplotlib>=3.1.3
@@ -40,7 +36,7 @@ include_package_data = True
4036

4137
[options.extras_require]
4238
gpu =
43-
jax[cuda12]>=0.3.4,<0.10
39+
jax[cuda12]>=0.3.4
4440

4541
[options.package_data]
4642
pymdp = envs/assets/*

0 commit comments

Comments
 (0)