Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 14 additions & 88 deletions docs/src/manual/exporting_to_jax.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@ JAX. We assume that users are familiar with
[Reactant compilation of Lux models](@ref reactant-compilation).

```@example exporting_to_stablehlo
using Lux, Reactant, Random
using Lux, Reactant, Random, NPZ
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to add NPZ to the project dependencies in docs/. activate the project and add it using Pkg.add("NPZ")

Copy link
Member

@avik-pal avik-pal Dec 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jules We need to add NPZ to the project dependencies in docs/. activate the project and add it using Pkg.add("NPZ")


const dev = reactant_device()
```

We simply define a Lux model and generate the stablehlo code using `Reactant.@code_hlo`.
We simply define a Lux model and parameters.

```@example exporting_to_stablehlo
model = Chain(
Expand All @@ -37,95 +37,21 @@ x = randn(Random.default_rng(), Float32, 28, 28, 1, 4) |> dev
nothing # hide
```

Now instead of compiling the model, we will use `Reactant.@code_hlo` to generate the
StableHLO code.
Now, we can use `Reactant.Serialization.export_to_enzymejax` to generate the necessary files to run the model in JAX. This function will generate a Python script, an MLIR file, and a `.npz` file with the inputs.

```@example exporting_to_stablehlo
hlo_code = @code_hlo model(x, ps, st)
lux_model_func(x, ps, st) = model(x, ps, st)
# It's recommended to create a temporary directory for the exported files
output_dir = mktempdir()
py_script_path = Reactant.Serialization.export_to_enzymejax(
lux_model_func, x, ps, st; function_name="lux_model", output_dir=output_dir)

println("Exported files are in: ", output_dir)
println("Generated python script `", py_script_path, "` contains:")
```

Now we just save this into an `mlir` file.
The generated Python script can be run directly. Here are its contents:

```@example exporting_to_stablehlo
write("exported_lux_model.mlir", string(hlo_code))
nothing # hide
```

Now we define a python script to run the model using EnzymeJAX.

```python
from enzyme_ad.jax import hlo_call

import jax
import jax.numpy as jnp

with open("exported_lux_model.mlir", "r") as file:
code = file.read()


def run_lux_model(
x,
weight1,
bias1,
weight3,
bias3,
weight6_1,
bias6_1,
weight6_2,
bias6_2,
weight6_3,
bias6_3,
):
return hlo_call(
x,
weight1,
bias1,
weight3,
bias3,
weight6_1,
bias6_1,
weight6_2,
bias6_2,
weight6_3,
bias6_3,
source=code,
)


# Note that all the inputs must be transposed, i.e. if the julia function has an input of
# shape (28, 28, 1, 4), then the input to the exported function called from python must be
# of shape (4, 1, 28, 28). This is because multi-dimensional arrays in Julia are stored in
# column-major order, while in JAX/Python they are stored in row-major order.

# Input as defined in our exported Lux model
x = jax.random.normal(jax.random.PRNGKey(0), (4, 1, 28, 28))

# Weights and biases corresponding to `ps` and `st` in our exported Lux model
weight1 = jax.random.normal(jax.random.PRNGKey(0), (6, 1, 5, 5))
bias1 = jax.random.normal(jax.random.PRNGKey(0), (6,))
weight3 = jax.random.normal(jax.random.PRNGKey(0), (16, 6, 5, 5))
bias3 = jax.random.normal(jax.random.PRNGKey(0), (16,))
weight6_1 = jax.random.normal(jax.random.PRNGKey(0), (256, 128))
bias6_1 = jax.random.normal(jax.random.PRNGKey(0), (128,))
weight6_2 = jax.random.normal(jax.random.PRNGKey(0), (128, 84))
bias6_2 = jax.random.normal(jax.random.PRNGKey(0), (84,))
weight6_3 = jax.random.normal(jax.random.PRNGKey(0), (84, 10))
bias6_3 = jax.random.normal(jax.random.PRNGKey(0), (10,))

# Run the exported Lux model
print(
jax.jit(run_lux_model)(
x,
weight1,
bias1,
weight3,
bias3,
weight6_1,
bias6_1,
weight6_2,
bias6_2,
weight6_3,
bias6_3,
)
)
```@repl exporting_to_stablehlo
print(read(py_script_path, String))
```
Loading