Skip to content

gerlero/jaxify

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 

History

44 Commits
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

jaxify

Write Python. Run JAX.

CI Codecov Ruff ty uv Publish PyPI PyPI - Python Version

⚠️ jaxify is an experimental project under development
You're welcome to try it out and report any issues!

jaxify lets you apply JAX transformations (like @jax.jit and/or @jax.vmap) to functions with common Python constructs that JAX cannot itself handle, such as if conditions that depend on input values.

Installation

pip install jaxify

Getting started

import jax
import jax.numpy as jnp
from jaxify import jaxify

@jax.jit
@jax.vmap
@jaxify  # <-- Just decorate your function with @jaxify
def absolute_value(x):
    if x >= 0:  # <-- If block in a JIT-compiled function
        return x
    else:
        return -x

xs = jnp.arange(-1000, 1000)
ys = absolute_value(xs)  # <-- Runs at JAX speed!
print(ys)

How it works

The @jaxify decorator transforms Python functions using a mixture of static analysis and dynamic tracing to replace unsupported Python constructs with JAX-compatible alternatives. After the transformations, the functions become traceable by JAX, enabling you to apply functional JAX transformations like @jax.jit and @jax.vmap in a seamless manner.

Compatibility status

The following Python constructs are currently supported within @jaxify-decorated functions:

πŸ”€ Conditionals

Construct Works? Notes
if statements βœ… Fully supported including elif and else clauses. All branches are traced and translated to calls to jax.lax.cond
if expressions (e.g. a if b else c) βœ… Traced and translated to jax.lax.cond

βš–οΈ Comparisons

Construct Works? Notes
==, !=, <, >, <=, >= βœ… Chained comparisons (e.g. x < y <= z) are supported by translation to the equivalent chain of individual comparisons

1️⃣ Logical operators

Construct Works? Notes
and / or βœ… Short-circuiting of traced values supported via translation to jax.lax.cond calls
not βœ… Translates to jnp.logical_not for traced single values

πŸ”„ Loops

Construct Works? Notes
for loops ❌ Currently unsupported. Use jax.lax.fori_loop, jax.lax.scan, or jax.lax.while_loop instead
while loops ❌ Currently unsupported. Use jax.lax.while_loop instead

🎯 Pattern matching

Construct Works? Notes
match-case βœ…βš οΈ Static values only. For traced values, use an if-elif-else chain or jax.lax.switch instead

About

πŸͺ„ Write Python. Run JAX.

Topics

Resources

Stars

Watchers

Forks

Languages