Skip to content

Commit 2eb8181

Browse files
docs: Update JAX export example to print generated script
Addresses feedback on the pull request to print the content of the generated Python script directly in the documentation. This ensures the documentation is always in sync with the generated output.
1 parent e4444a0 commit 2eb8181

File tree

1 file changed

+6
-33
lines changed

1 file changed

+6
-33
lines changed

docs/src/manual/exporting_to_jax.md

Lines changed: 6 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -43,42 +43,15 @@ Now, we can use `Reactant.Serialization.export_to_enzymejax` to generate the nec
4343
lux_model_func(x, ps, st) = model(x, ps, st)
4444
# It's recommended to create a temporary directory for the exported files
4545
output_dir = mktempdir()
46-
println("Exported files will be in: ", output_dir)
4746
py_script_path = Reactant.Serialization.export_to_enzymejax(
4847
lux_model_func, x, ps, st; function_name="lux_model", output_dir=output_dir)
49-
println("Python script generated at: ", py_script_path)
50-
nothing # hide
51-
```
52-
53-
The generated files will be in the `output_dir`. You can now run the model in Python.
5448
55-
```python
56-
# This is a sample python script. The actual generated script can be run directly.
57-
# Assuming you are in the output directory you can run `python lux_model.py`.
58-
#
59-
# To integrate into your own python code, you can do the following.
60-
# Make sure the output directory is in your python path.
61-
import jax
62-
import numpy as np
63-
from lux_model import run_lux_model
64-
import os
65-
import glob
66-
67-
# Find the inputs file
68-
# The format is {function_name}_{id}_inputs.npz
69-
input_files = glob.glob(os.path.join(".", "lux_model_*_inputs.npz"))
70-
assert len(input_files) == 1, "Expected to find exactly one inputs file"
71-
72-
# Load the inputs from the .npz file
73-
inputs = np.load(input_files[0])
74-
inputs = [inputs[f] for f in inputs.files]
49+
println("Exported files are in: ", output_dir)
50+
println("Generated python script `", py_script_path, "` contains:")
51+
```
7552

76-
# Run the exported Lux model
77-
result = run_lux_model(*inputs)
78-
print(result)
53+
The generated Python script can be run directly. Here are its contents:
7954

80-
# The function can be JIT-compiled
81-
jitted_model = jax.jit(run_lux_model)
82-
result_jitted = jitted_model(*inputs)
83-
print(result_jitted)
55+
```@repl exporting_to_stablehlo
56+
print(read(py_script_path, String))
8457
```

0 commit comments

Comments
 (0)