-
Notifications
You must be signed in to change notification settings - Fork 82
Update JAX export example to use new functionality #1602
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
1f14a27
e4444a0
2eb8181
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6,12 +6,12 @@ JAX. We assume that users are familiar with | |
| [Reactant compilation of Lux models](@ref reactant-compilation). | ||
|
|
||
| ```@example exporting_to_stablehlo | ||
| using Lux, Reactant, Random | ||
| using Lux, Reactant, Random, NPZ | ||
|
|
||
| const dev = reactant_device() | ||
| ``` | ||
|
|
||
| We simply define a Lux model and generate the stablehlo code using `Reactant.@code_hlo`. | ||
| We simply define a Lux model and parameters. | ||
|
|
||
| ```@example exporting_to_stablehlo | ||
| model = Chain( | ||
|
|
@@ -37,95 +37,48 @@ x = randn(Random.default_rng(), Float32, 28, 28, 1, 4) |> dev | |
| nothing # hide | ||
| ``` | ||
|
|
||
| Now instead of compiling the model, we will use `Reactant.@code_hlo` to generate the | ||
| StableHLO code. | ||
| Now, we can use `Reactant.Serialization.export_to_enzymejax` to generate the necessary files to run the model in JAX. This function will generate a Python script, an MLIR file, and a `.npz` file with the inputs. | ||
|
|
||
| ```@example exporting_to_stablehlo | ||
| hlo_code = @code_hlo model(x, ps, st) | ||
| ``` | ||
|
|
||
| Now we just save this into an `mlir` file. | ||
|
|
||
| ```@example exporting_to_stablehlo | ||
| write("exported_lux_model.mlir", string(hlo_code)) | ||
| lux_model_func(x, ps, st) = model(x, ps, st) | ||
| # It's recommended to create a temporary directory for the exported files | ||
| output_dir = mktempdir() | ||
| println("Exported files will be in: ", output_dir) | ||
| py_script_path = Reactant.Serialization.export_to_enzymejax( | ||
| lux_model_func, x, ps, st; function_name="lux_model", output_dir=output_dir) | ||
| println("Python script generated at: ", py_script_path) | ||
| nothing # hide | ||
| ``` | ||
|
|
||
| Now we define a python script to run the model using EnzymeJAX. | ||
| The generated files will be in the `output_dir`. You can now run the model in Python. | ||
|
|
||
| ```python | ||
| from enzyme_ad.jax import hlo_call | ||
|
|
||
| # This is a sample python script. The actual generated script can be run directly. | ||
|
||
| # Assuming you are in the output directory you can run `python lux_model.py`. | ||
| # | ||
| # To integrate into your own python code, you can do the following. | ||
| # Make sure the output directory is in your python path. | ||
| import jax | ||
| import jax.numpy as jnp | ||
|
|
||
| with open("exported_lux_model.mlir", "r") as file: | ||
| code = file.read() | ||
|
|
||
| import numpy as np | ||
| from lux_model import run_lux_model | ||
| import os | ||
| import glob | ||
|
|
||
| def run_lux_model( | ||
| x, | ||
| weight1, | ||
| bias1, | ||
| weight3, | ||
| bias3, | ||
| weight6_1, | ||
| bias6_1, | ||
| weight6_2, | ||
| bias6_2, | ||
| weight6_3, | ||
| bias6_3, | ||
| ): | ||
| return hlo_call( | ||
| x, | ||
| weight1, | ||
| bias1, | ||
| weight3, | ||
| bias3, | ||
| weight6_1, | ||
| bias6_1, | ||
| weight6_2, | ||
| bias6_2, | ||
| weight6_3, | ||
| bias6_3, | ||
| source=code, | ||
| ) | ||
|
|
||
|
|
||
| # Note that all the inputs must be transposed, i.e. if the julia function has an input of | ||
| # shape (28, 28, 1, 4), then the input to the exported function called from python must be | ||
| # of shape (4, 1, 28, 28). This is because multi-dimensional arrays in Julia are stored in | ||
| # column-major order, while in JAX/Python they are stored in row-major order. | ||
|
|
||
| # Input as defined in our exported Lux model | ||
| x = jax.random.normal(jax.random.PRNGKey(0), (4, 1, 28, 28)) | ||
| # Find the inputs file | ||
| # The format is {function_name}_{id}_inputs.npz | ||
| input_files = glob.glob(os.path.join(".", "lux_model_*_inputs.npz")) | ||
| assert len(input_files) == 1, "Expected to find exactly one inputs file" | ||
|
|
||
| # Weights and biases corresponding to `ps` and `st` in our exported Lux model | ||
| weight1 = jax.random.normal(jax.random.PRNGKey(0), (6, 1, 5, 5)) | ||
| bias1 = jax.random.normal(jax.random.PRNGKey(0), (6,)) | ||
| weight3 = jax.random.normal(jax.random.PRNGKey(0), (16, 6, 5, 5)) | ||
| bias3 = jax.random.normal(jax.random.PRNGKey(0), (16,)) | ||
| weight6_1 = jax.random.normal(jax.random.PRNGKey(0), (256, 128)) | ||
| bias6_1 = jax.random.normal(jax.random.PRNGKey(0), (128,)) | ||
| weight6_2 = jax.random.normal(jax.random.PRNGKey(0), (128, 84)) | ||
| bias6_2 = jax.random.normal(jax.random.PRNGKey(0), (84,)) | ||
| weight6_3 = jax.random.normal(jax.random.PRNGKey(0), (84, 10)) | ||
| bias6_3 = jax.random.normal(jax.random.PRNGKey(0), (10,)) | ||
| # Load the inputs from the .npz file | ||
| inputs = np.load(input_files[0]) | ||
| inputs = [inputs[f] for f in inputs.files] | ||
|
|
||
| # Run the exported Lux model | ||
| print( | ||
| jax.jit(run_lux_model)( | ||
| x, | ||
| weight1, | ||
| bias1, | ||
| weight3, | ||
| bias3, | ||
| weight6_1, | ||
| bias6_1, | ||
| weight6_2, | ||
| bias6_2, | ||
| weight6_3, | ||
| bias6_3, | ||
| ) | ||
| ) | ||
| result = run_lux_model(*inputs) | ||
| print(result) | ||
|
|
||
| # The function can be JIT-compiled | ||
| jitted_model = jax.jit(run_lux_model) | ||
| result_jitted = jitted_model(*inputs) | ||
| print(result_jitted) | ||
| ``` | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to add NPZ to the project dependencies in
docs/. activate the project and add it usingPkg.add("NPZ")Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@julesWe need to add NPZ to the project dependencies in docs/. activate the project and add it using Pkg.add("NPZ")