fix: remove upper pin from JAX#811
Conversation
pyro-ppl/numpyro#2173 has been merged along with a new numpyro release
There was a problem hiding this comment.
Code Review
This pull request updates dependency versions in pyproject.toml for numpyro and jax. Feedback indicates that the numpyro version 0.21.0 is likely a typo as it does not exist on PyPI, and the jax minimum version increase to 0.10.0 contradicts the PR's goal of supporting JAX 0.5.0 while creating an unnecessary breaking change.
I am having trouble creating individual review comments. Click here to see my feedback.
pyproject.toml (56)
The minimum version of numpyro is bumped from 0.19.0 to 0.21.0. Note that the latest official release of NumPyro is 0.15.3. If this is intended to support the latest release that includes JAX 0.5.0 support (as mentioned in the PR description), this version requirement might be a typo. Using a version number that does not exist on PyPI will prevent the package from being installed in standard environments.
pyproject.toml (82-85)
The change to JAX dependencies removes the upper pin <0.10.0 but also increases the minimum version from 0.7.0 to 0.10.0. This is a breaking change for users on JAX versions between 0.7.0 and 0.10.0. Furthermore, if the goal is to support JAX 0.5.0 (as mentioned in the PR description), setting the minimum to 0.10.0 will exclude it, as 0.5.0 is numerically lower than 0.10.0. If the intent was only to remove the upper bound to allow newer versions, the lower bound should remain 0.7.0 or be set to the actual minimum version that provides the required support.
Updating the version of JAX and NumPyro, because pyro-ppl/numpyro#2173 has been merged along with a new numpyro release.