Write Python. Run JAX.
| You're welcome to try it out and report any issues! |
jaxify lets you apply JAX transformations (like @jax.jit and/or @jax.vmap) to functions with common Python constructs that JAX cannot itself handle, such as if conditions that depend on input values.
pip install jaxifyimport jax
import jax.numpy as jnp
from jaxify import jaxify
@jax.jit
@jax.vmap
@jaxify # <-- Just decorate your function with @jaxify
def absolute_value(x):
if x >= 0: # <-- If block in a JIT-compiled function
return x
else:
return -x
xs = jnp.arange(-1000, 1000)
ys = absolute_value(xs) # <-- Runs at JAX speed!
print(ys)The @jaxify decorator transforms Python functions using a mixture of static analysis and dynamic tracing to replace unsupported Python constructs with JAX-compatible alternatives. After the transformations, the functions become traceable by JAX, enabling you to apply functional JAX transformations like @jax.jit and @jax.vmap in a seamless manner.
The following Python constructs are currently supported within @jaxify-decorated functions:
| Construct | Works? | Notes |
|---|---|---|
if statements |
β | Fully supported including elif and else clauses. All branches are traced and translated to calls to jax.lax.cond |
if expressions (e.g. a if b else c) |
β | Traced and translated to jax.lax.cond |
| Construct | Works? | Notes |
|---|---|---|
==, !=, <, >, <=, >= |
β | Chained comparisons (e.g. x < y <= z) are supported by translation to the equivalent chain of individual comparisons |
| Construct | Works? | Notes |
|---|---|---|
and / or |
β | Short-circuiting of traced values supported via translation to jax.lax.cond calls |
not |
β | Translates to jnp.logical_not for traced single values |
| Construct | Works? | Notes |
|---|---|---|
for loops |
β | Currently unsupported. Use jax.lax.fori_loop, jax.lax.scan, or jax.lax.while_loop instead |
while loops |
β | Currently unsupported. Use jax.lax.while_loop instead |
| Construct | Works? | Notes |
|---|---|---|
match-case |
β
|
Static values only. For traced values, use an if-elif-else chain or jax.lax.switch instead |