Releases: mila-iqia/torch_jax_interop
Releases · mila-iqia/torch_jax_interop
v0.0.8 - Upgrade to more recent versions of Jax
v0.0.7 - Don't require jax[cuda12], just jax
What's Changed
Full Changelog: v0.0.6...v0.0.7
v0.0.6 - Reduce python version requirement to 3.10
v0.0.5 - Rename JaxModule->WrappedJaxFunction, add wrapper for jax scalar-valued functions
- Rename
JaxModule->WrappedJaxFunction - Add
WrappedJaxScalarFunction:- Offered as an alternative to
WrappedJaxFunctionfor scalar-valued functions (although both will work). - This is potentially more efficient, since it uses a
jax.jit-edjax.value_and_gradto do a fused forward and backward pass compared withWrappedJaxFunctionthat usesjax.vjp.
- Offered as an alternative to
- Add more tests
- Add more doctests, which serve as both documentation and unit tests.
- Add some examples in the README.
Full Changelog: v0.0.4...v0.0.5