diff --git a/docs/src/manual/exporting_to_jax.md b/docs/src/manual/exporting_to_jax.md index 257285495f..808cbffeb3 100644 --- a/docs/src/manual/exporting_to_jax.md +++ b/docs/src/manual/exporting_to_jax.md @@ -6,12 +6,12 @@ JAX. We assume that users are familiar with [Reactant compilation of Lux models](@ref reactant-compilation). ```@example exporting_to_stablehlo -using Lux, Reactant, Random +using Lux, Reactant, Random, NPZ const dev = reactant_device() ``` -We simply define a Lux model and generate the stablehlo code using `Reactant.@code_hlo`. +We simply define a Lux model and parameters. ```@example exporting_to_stablehlo model = Chain( @@ -37,95 +37,21 @@ x = randn(Random.default_rng(), Float32, 28, 28, 1, 4) |> dev nothing # hide ``` -Now instead of compiling the model, we will use `Reactant.@code_hlo` to generate the -StableHLO code. +Now, we can use `Reactant.Serialization.export_to_enzymejax` to generate the necessary files to run the model in JAX. This function will generate a Python script, an MLIR file, and a `.npz` file with the inputs. ```@example exporting_to_stablehlo -hlo_code = @code_hlo model(x, ps, st) +lux_model_func(x, ps, st) = model(x, ps, st) +# It's recommended to create a temporary directory for the exported files +output_dir = mktempdir() +py_script_path = Reactant.Serialization.export_to_enzymejax( + lux_model_func, x, ps, st; function_name="lux_model", output_dir=output_dir) + +println("Exported files are in: ", output_dir) +println("Generated python script `", py_script_path, "` contains:") ``` -Now we just save this into an `mlir` file. +The generated Python script can be run directly. Here are its contents: -```@example exporting_to_stablehlo -write("exported_lux_model.mlir", string(hlo_code)) -nothing # hide -``` - -Now we define a python script to run the model using EnzymeJAX. - -```python -from enzyme_ad.jax import hlo_call - -import jax -import jax.numpy as jnp - -with open("exported_lux_model.mlir", "r") as file: - code = file.read() - - -def run_lux_model( - x, - weight1, - bias1, - weight3, - bias3, - weight6_1, - bias6_1, - weight6_2, - bias6_2, - weight6_3, - bias6_3, -): - return hlo_call( - x, - weight1, - bias1, - weight3, - bias3, - weight6_1, - bias6_1, - weight6_2, - bias6_2, - weight6_3, - bias6_3, - source=code, - ) - - -# Note that all the inputs must be transposed, i.e. if the julia function has an input of -# shape (28, 28, 1, 4), then the input to the exported function called from python must be -# of shape (4, 1, 28, 28). This is because multi-dimensional arrays in Julia are stored in -# column-major order, while in JAX/Python they are stored in row-major order. - -# Input as defined in our exported Lux model -x = jax.random.normal(jax.random.PRNGKey(0), (4, 1, 28, 28)) - -# Weights and biases corresponding to `ps` and `st` in our exported Lux model -weight1 = jax.random.normal(jax.random.PRNGKey(0), (6, 1, 5, 5)) -bias1 = jax.random.normal(jax.random.PRNGKey(0), (6,)) -weight3 = jax.random.normal(jax.random.PRNGKey(0), (16, 6, 5, 5)) -bias3 = jax.random.normal(jax.random.PRNGKey(0), (16,)) -weight6_1 = jax.random.normal(jax.random.PRNGKey(0), (256, 128)) -bias6_1 = jax.random.normal(jax.random.PRNGKey(0), (128,)) -weight6_2 = jax.random.normal(jax.random.PRNGKey(0), (128, 84)) -bias6_2 = jax.random.normal(jax.random.PRNGKey(0), (84,)) -weight6_3 = jax.random.normal(jax.random.PRNGKey(0), (84, 10)) -bias6_3 = jax.random.normal(jax.random.PRNGKey(0), (10,)) - -# Run the exported Lux model -print( - jax.jit(run_lux_model)( - x, - weight1, - bias1, - weight3, - bias3, - weight6_1, - bias6_1, - weight6_2, - bias6_2, - weight6_3, - bias6_3, - ) -) +```@repl exporting_to_stablehlo +print(read(py_script_path, String)) ```