Starting with version 0.4.36 JAX deprecated jax.lib.xla_extension.XlaRuntimeError in favor of jax.errors.JaxRuntimeError (https://docs.jax.dev/en/latest/changelog.html#jax-0-4-36-dec-5-2024), so the current code is not compatible with newer versions of JAX.
I may be able to start a PR in the next few weeks, but I first wanted to check whether you are aware of additional changes that would break compatibility?
Starting with version 0.4.36 JAX deprecated
jax.lib.xla_extension.XlaRuntimeErrorin favor ofjax.errors.JaxRuntimeError(https://docs.jax.dev/en/latest/changelog.html#jax-0-4-36-dec-5-2024), so the current code is not compatible with newer versions of JAX.I may be able to start a PR in the next few weeks, but I first wanted to check whether you are aware of additional changes that would break compatibility?