Skip to content
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 34 additions & 81 deletions docs/src/manual/exporting_to_jax.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@ JAX. We assume that users are familiar with
[Reactant compilation of Lux models](@ref reactant-compilation).

```@example exporting_to_stablehlo
using Lux, Reactant, Random
using Lux, Reactant, Random, NPZ
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to add NPZ to the project dependencies in docs/. activate the project and add it using Pkg.add("NPZ")

Copy link
Member

@avik-pal avik-pal Dec 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jules We need to add NPZ to the project dependencies in docs/. activate the project and add it using Pkg.add("NPZ")


const dev = reactant_device()
```

We simply define a Lux model and generate the stablehlo code using `Reactant.@code_hlo`.
We simply define a Lux model and parameters.

```@example exporting_to_stablehlo
model = Chain(
Expand All @@ -37,95 +37,48 @@ x = randn(Random.default_rng(), Float32, 28, 28, 1, 4) |> dev
nothing # hide
```

Now instead of compiling the model, we will use `Reactant.@code_hlo` to generate the
StableHLO code.
Now, we can use `Reactant.Serialization.export_to_enzymejax` to generate the necessary files to run the model in JAX. This function will generate a Python script, an MLIR file, and a `.npz` file with the inputs.

```@example exporting_to_stablehlo
hlo_code = @code_hlo model(x, ps, st)
```

Now we just save this into an `mlir` file.

```@example exporting_to_stablehlo
write("exported_lux_model.mlir", string(hlo_code))
lux_model_func(x, ps, st) = model(x, ps, st)
# It's recommended to create a temporary directory for the exported files
output_dir = mktempdir()
println("Exported files will be in: ", output_dir)
py_script_path = Reactant.Serialization.export_to_enzymejax(
lux_model_func, x, ps, st; function_name="lux_model", output_dir=output_dir)
println("Python script generated at: ", py_script_path)
nothing # hide
```

Now we define a python script to run the model using EnzymeJAX.
The generated files will be in the `output_dir`. You can now run the model in Python.

```python
from enzyme_ad.jax import hlo_call

# This is a sample python script. The actual generated script can be run directly.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of manually writing the python script, print the contents of the script generated

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jules instead of manually writing the python script, print the contents of the script generated

# Assuming you are in the output directory you can run `python lux_model.py`.
#
# To integrate into your own python code, you can do the following.
# Make sure the output directory is in your python path.
import jax
import jax.numpy as jnp

with open("exported_lux_model.mlir", "r") as file:
code = file.read()

import numpy as np
from lux_model import run_lux_model
import os
import glob

def run_lux_model(
x,
weight1,
bias1,
weight3,
bias3,
weight6_1,
bias6_1,
weight6_2,
bias6_2,
weight6_3,
bias6_3,
):
return hlo_call(
x,
weight1,
bias1,
weight3,
bias3,
weight6_1,
bias6_1,
weight6_2,
bias6_2,
weight6_3,
bias6_3,
source=code,
)


# Note that all the inputs must be transposed, i.e. if the julia function has an input of
# shape (28, 28, 1, 4), then the input to the exported function called from python must be
# of shape (4, 1, 28, 28). This is because multi-dimensional arrays in Julia are stored in
# column-major order, while in JAX/Python they are stored in row-major order.

# Input as defined in our exported Lux model
x = jax.random.normal(jax.random.PRNGKey(0), (4, 1, 28, 28))
# Find the inputs file
# The format is {function_name}_{id}_inputs.npz
input_files = glob.glob(os.path.join(".", "lux_model_*_inputs.npz"))
assert len(input_files) == 1, "Expected to find exactly one inputs file"

# Weights and biases corresponding to `ps` and `st` in our exported Lux model
weight1 = jax.random.normal(jax.random.PRNGKey(0), (6, 1, 5, 5))
bias1 = jax.random.normal(jax.random.PRNGKey(0), (6,))
weight3 = jax.random.normal(jax.random.PRNGKey(0), (16, 6, 5, 5))
bias3 = jax.random.normal(jax.random.PRNGKey(0), (16,))
weight6_1 = jax.random.normal(jax.random.PRNGKey(0), (256, 128))
bias6_1 = jax.random.normal(jax.random.PRNGKey(0), (128,))
weight6_2 = jax.random.normal(jax.random.PRNGKey(0), (128, 84))
bias6_2 = jax.random.normal(jax.random.PRNGKey(0), (84,))
weight6_3 = jax.random.normal(jax.random.PRNGKey(0), (84, 10))
bias6_3 = jax.random.normal(jax.random.PRNGKey(0), (10,))
# Load the inputs from the .npz file
inputs = np.load(input_files[0])
inputs = [inputs[f] for f in inputs.files]

# Run the exported Lux model
print(
jax.jit(run_lux_model)(
x,
weight1,
bias1,
weight3,
bias3,
weight6_1,
bias6_1,
weight6_2,
bias6_2,
weight6_3,
bias6_3,
)
)
result = run_lux_model(*inputs)
print(result)

# The function can be JIT-compiled
jitted_model = jax.jit(run_lux_model)
result_jitted = jitted_model(*inputs)
print(result_jitted)
```
Loading