Skip to content

Releases: mila-iqia/torch_jax_interop

v0.0.8 - Upgrade to more recent versions of Jax

08 Jan 19:40
83f00b1

Choose a tag to compare

What's Changed

Full Changelog: v0.0.7...v0.0.8

v0.0.7 - Don't require jax[cuda12], just jax

08 Jul 15:11
e93e5c5

Choose a tag to compare

What's Changed

  • Don't require jax[cuda12] as the dep, just jax by @lebrice in #2

Full Changelog: v0.0.6...v0.0.7

v0.0.6 - Reduce python version requirement to 3.10

02 Jul 17:20
18357dc

Choose a tag to compare

What's Changed

  • Reduce Python version requirement from 3.11-->3.10 by @lebrice in #1

New Contributors

Full Changelog: v0.0.5...v0.0.6

v0.0.5 - Rename JaxModule->WrappedJaxFunction, add wrapper for jax scalar-valued functions

13 Jun 17:37
3ec8451

Choose a tag to compare

  • Rename JaxModule -> WrappedJaxFunction
  • Add WrappedJaxScalarFunction:
    • Offered as an alternative to WrappedJaxFunction for scalar-valued functions (although both will work).
    • This is potentially more efficient, since it uses a jax.jit-ed jax.value_and_grad to do a fused forward and backward pass compared with WrappedJaxFunction that uses jax.vjp.
  • Add more tests
  • Add more doctests, which serve as both documentation and unit tests.
  • Add some examples in the README.

Full Changelog: v0.0.4...v0.0.5

Add JaxModule: enables using jax functions in PyTorch autograd graphs

06 Jun 20:55
754445a

Choose a tag to compare