Skip to content

Releases: lebrice/torch_jax_interop

v0.0.7 - Don't require jax[cuda12], just jax

08 Jul 15:11
e93e5c5
Compare
Choose a tag to compare

What's Changed

  • Don't require jax[cuda12] as the dep, just jax by @lebrice in #2

Full Changelog: v0.0.6...v0.0.7

v0.0.6 - Reduce python version requirement to 3.10

02 Jul 17:20
18357dc
Compare
Choose a tag to compare

What's Changed

  • Reduce Python version requirement from 3.11-->3.10 by @lebrice in #1

New Contributors

Full Changelog: v0.0.5...v0.0.6

v0.0.5 - Rename JaxModule->WrappedJaxFunction, add wrapper for jax scalar-valued functions

13 Jun 17:37
3ec8451
Compare
Choose a tag to compare
  • Rename JaxModule -> WrappedJaxFunction
  • Add WrappedJaxScalarFunction:
    • Offered as an alternative to WrappedJaxFunction for scalar-valued functions (although both will work).
    • This is potentially more efficient, since it uses a jax.jit-ed jax.value_and_grad to do a fused forward and backward pass compared with WrappedJaxFunction that uses jax.vjp.
  • Add more tests
  • Add more doctests, which serve as both documentation and unit tests.
  • Add some examples in the README.

Full Changelog: v0.0.4...v0.0.5

Add JaxModule: enables using jax functions in PyTorch autograd graphs

06 Jun 20:55
754445a
Compare
Choose a tag to compare