Skip to content

fix: update JAX dependency stack to modern compatible versions#544

Open
Okyumi wants to merge 1 commit into
vwxyzjn:masterfrom
Okyumi:fix/jax-dependency-stack
Open

fix: update JAX dependency stack to modern compatible versions#544
Okyumi wants to merge 1 commit into
vwxyzjn:masterfrom
Okyumi:fix/jax-dependency-stack

Conversation

@Okyumi
Copy link
Copy Markdown

@Okyumi Okyumi commented Mar 19, 2026

Summary

Fixes #540 — Updates the optional JAX dependency group in pyproject.toml from pinned 2023 versions to bounded ranges that resolve correctly with modern package managers (pip, uv).

Problem

The current JAX deps pin ancient versions (jax==0.4.8, jaxlib==0.4.7, flax==0.6.8). When pip/uv resolves these, it often installs an incompatible jaxlib, producing:

RuntimeError: jaxlib version 0.4.30 is newer than and incompatible with jax version 0.4.8

This blocks all JAX contributors from running or merging JAX-related PRs.

Changes

File modified: pyproject.toml (dependency versions only, zero code changes)

Package Before After Rationale
jax ==0.4.8 >=0.4.26,<0.5 Modern stable release, <0.5 prevents future major breaks
jaxlib ==0.4.7 >=0.4.26,<0.5 Matched to jax range
flax ==0.6.8 >=0.7.0,<0.8 Codebase uses flax.serialization.to_bytes, FrozenDict, flax.struct.dataclass — all preserved in 0.7.x, deprecated in 0.8+
optax ==0.1.4 >=0.1.4,<0.2 optax.adam() and optax.incremental_update() APIs changed in 0.2.0; 14 files use these
chex ==0.1.5 >=0.1.5 No breaking changes
scipy <1.13.0 >=1.10.0 jax 0.4.26+ works with modern scipy

Why These Bounds

The upper bounds on flax (<0.8) and optax (<0.2) are specifically chosen to avoid API-breaking changes that would require modifying the 14 JAX algorithm files. This means zero code changes — the fix is entirely in dependency version constraints.

Testing

  • pyproject.toml parses correctly
  • Version bounds verified against API usage across all JAX files: ppo_continuous_action_jax.py, sac_continuous_action_jax.py, ddpg_continuous_action_jax.py, td3_continuous_action_jax.py, and their Atari/evaluation variants

Update optional JAX dependencies from pinned 2023 versions to bounded
ranges that resolve correctly with modern package managers:

- jax/jaxlib: ==0.4.8/==0.4.7 → >=0.4.26,<0.5
- flax: ==0.6.8 → >=0.7.0,<0.8  (preserves FrozenDict, serialization APIs)
- optax: ==0.1.4 → >=0.1.4,<0.2  (preserves adam(), incremental_update() APIs)
- chex: ==0.1.5 → >=0.1.5  (no breaking changes)
- scipy: <1.13.0 → >=1.10.0  (jax 0.4.26+ works with modern scipy)

Upper bounds on flax (<0.8) and optax (<0.2) prevent API-breaking upgrades
that would require code changes to the 14 JAX algorithm files.

Fixes vwxyzjn#540
@vercel
Copy link
Copy Markdown

vercel Bot commented Mar 19, 2026

@Okyumi is attempting to deploy a commit to the Costa Huang's projects Team on Vercel.

A member of the Team first needs to authorize it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

CI: All JAX tests failing due to outdated JAX dependency stack

1 participant