Releases: lebrice/torch_jax_interop
Releases · lebrice/torch_jax_interop
v0.0.7 - Don't require jax[cuda12], just jax
What's Changed
Full Changelog: v0.0.6...v0.0.7
v0.0.6 - Reduce python version requirement to 3.10
v0.0.5 - Rename JaxModule->WrappedJaxFunction, add wrapper for jax scalar-valued functions
- Rename
JaxModule
->WrappedJaxFunction
- Add
WrappedJaxScalarFunction
:- Offered as an alternative to
WrappedJaxFunction
for scalar-valued functions (although both will work). - This is potentially more efficient, since it uses a
jax.jit
-edjax.value_and_grad
to do a fused forward and backward pass compared withWrappedJaxFunction
that usesjax.vjp
.
- Offered as an alternative to
- Add more tests
- Add more doctests, which serve as both documentation and unit tests.
- Add some examples in the README.
Full Changelog: v0.0.4...v0.0.5