Skip to content

fix: remove upper pin from JAX#811

Merged
Qazalbash merged 2 commits into
devfrom
remove-jax-pin
May 2, 2026
Merged

fix: remove upper pin from JAX#811
Qazalbash merged 2 commits into
devfrom
remove-jax-pin

Conversation

@Qazalbash
Copy link
Copy Markdown
Member

Updating the version of JAX and NumPyro, because pyro-ppl/numpyro#2173 has been merged along with a new numpyro release.

pyro-ppl/numpyro#2173 has been merged along with
a new numpyro release
@Qazalbash Qazalbash self-assigned this May 2, 2026
@Qazalbash Qazalbash added the internal An internal refactor or improvement label May 2, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request updates dependency versions in pyproject.toml for numpyro and jax. Feedback indicates that the numpyro version 0.21.0 is likely a typo as it does not exist on PyPI, and the jax minimum version increase to 0.10.0 contradicts the PR's goal of supporting JAX 0.5.0 while creating an unnecessary breaking change.

I am having trouble creating individual review comments. Click here to see my feedback.

pyproject.toml (56)

high

The minimum version of numpyro is bumped from 0.19.0 to 0.21.0. Note that the latest official release of NumPyro is 0.15.3. If this is intended to support the latest release that includes JAX 0.5.0 support (as mentioned in the PR description), this version requirement might be a typo. Using a version number that does not exist on PyPI will prevent the package from being installed in standard environments.

pyproject.toml (82-85)

high

The change to JAX dependencies removes the upper pin <0.10.0 but also increases the minimum version from 0.7.0 to 0.10.0. This is a breaking change for users on JAX versions between 0.7.0 and 0.10.0. Furthermore, if the goal is to support JAX 0.5.0 (as mentioned in the PR description), setting the minimum to 0.10.0 will exclude it, as 0.5.0 is numerically lower than 0.10.0. If the intent was only to remove the upper bound to allow newer versions, the lower bound should remain 0.7.0 or be set to the actual minimum version that provides the required support.

@Qazalbash Qazalbash merged commit fc00c91 into dev May 2, 2026
5 checks passed
@Qazalbash Qazalbash deleted the remove-jax-pin branch May 2, 2026 18:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

internal An internal refactor or improvement

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant