Skip to content

Commit 2764147

Browse files
authored
Add examples of calling FlashInfer from JAX via jax-tvm-ffi (#3092)
## 📌 Description This PR adds a new example under examples/jax_tvm_ffi/ showing how to call FlashInfer from JAX via jax-tvm-ffi. It also adds examples/README.md to document the examples directory and make the new example easier to discover. The goal is to provide a minimal reference for users interested in integrating FlashInfer outside of PyTorch, especially in JAX-based workflows. ## 🔍 Related Issues N/A ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). This PR only adds example code and documentation; no changes to core functionality, so no additional tests were added. Examples run successfully end-to-end. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Documentation** * Added an Examples overview and detailed per-example guides covering setup, installation, GPU/CUDA prerequisites, compilation/caching behavior, Hugging Face gated-model steps, authentication flows, and troubleshooting for JAX↔TVM FFI workflows. * **New Features** * Added runnable JAX↔TVM FFI examples (notebooks and standalone scripts) demonstrating fused activations/FFN, RoPE, and attention kernels, end-to-end Gemma 3 inference, correctness validations, and latency micro-benchmarks. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent 7a5e604 commit 2764147

7 files changed

Lines changed: 2358 additions & 1 deletion

File tree

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,3 +199,5 @@ cython_debug/
199199

200200
# Cursor
201201
.cursor/
202+
docs/tutorials/generated/
203+
docs/sg_execution_times.rst

docs/conf.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
"sphinx.ext.napoleon",
3939
"sphinx.ext.autosummary",
4040
"sphinx.ext.mathjax",
41+
"sphinx_gallery.gen_gallery",
4142
]
4243

4344
autodoc_default_flags = ["members"]
@@ -47,11 +48,24 @@
4748

4849
language = "en"
4950

50-
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
51+
exclude_patterns = [
52+
"_build",
53+
"Thumbs.db",
54+
".DS_Store",
55+
"tutorials/jax_tvm_ffi/README.rst",
56+
]
5157

5258
# The name of the Pygments (syntax highlighting) style to use.
5359
pygments_style = "sphinx"
5460

61+
sphinx_gallery_conf = {
62+
"examples_dirs": "tutorials/jax_tvm_ffi",
63+
"gallery_dirs": "tutorials/generated/jax_tvm_ffi",
64+
"filename_pattern": r".*\.py",
65+
"plot_gallery": "False",
66+
"download_all_examples": False,
67+
}
68+
5569
# A list of ignored prefixes for module index sorting.
5670
# If true, `todo` and `todoList` produce output, else they produce nothing.
5771
todo_include_todos = False

docs/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ FlashInfer is a library and kernel generator for Large Language Models that prov
2525

2626
tutorials/recursive_attention
2727
tutorials/kv_layout
28+
tutorials/generated/jax_tvm_ffi/index
2829

2930
.. toctree::
3031
:maxdepth: 2

docs/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
furo == 2024.8.6
22
sphinx == 8.1.3
3+
sphinx-gallery == 0.19.0
34
sphinx-reredirects == 0.1.5
45
sphinx-tabs == 3.4.5
56
sphinx-toolbox == 3.8.1
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
FlashInfer on JAX with TVM FFI
2+
==============================
3+
4+
These tutorials show how to call FlashInfer GPU kernels from JAX through the
5+
`jax-tvm-ffi <https://github.com/NVIDIA/jax-tvm-ffi>`_ bridge.
6+
7+
The Sphinx-Gallery ``.py`` files in this directory are the canonical source:
8+
9+
* ``flashinfer_jax_tvm_ffi.py`` explains the core build, register, and call
10+
pattern for FlashInfer kernels from JAX.
11+
* ``gemma3_flashinfer_jax.py`` applies the same pattern to Gemma 3 1B Instruct
12+
inference.
13+
14+
During the documentation build, Sphinx-Gallery renders these files into HTML
15+
pages and creates downloadable Python and Jupyter notebook versions from the
16+
same source files. Do not edit or commit the generated
17+
``docs/tutorials/generated/jax_tvm_ffi/`` directory; it is produced by
18+
Sphinx-Gallery.
19+
20+
The examples are not executed during the default documentation build because
21+
they require an NVIDIA GPU, CUDA, FlashInfer JIT compilation, and in the Gemma 3
22+
case Hugging Face credentials for a gated model.
23+
24+
Execution requirements
25+
----------------------
26+
27+
To run the tutorials directly, use a CUDA-capable environment with:
28+
29+
* NVIDIA GPU with SM 7.5 or newer.
30+
* CUDA 12.6 or newer.
31+
* Python 3.10 or newer.
32+
* JAX with CUDA support.
33+
* ``flashinfer-python`` and ``jax-tvm-ffi``.
34+
35+
The Gemma 3 tutorial additionally requires:
36+
37+
* ``torch`` CPU wheels for dtype literals used by FlashInfer's JIT API.
38+
* ``safetensors``, ``huggingface_hub``, and ``transformers``.
39+
* Hugging Face access to ``google/gemma-3-1b-it`` and an ``HF_TOKEN``.
40+
41+
For example:
42+
43+
.. code-block:: bash
44+
45+
pip install 'jax[cuda13]'
46+
pip install flashinfer-python -U jax-tvm-ffi \
47+
--no-build-isolation \
48+
--extra-index-url https://flashinfer.ai/whl/cu130/
49+
50+
# Additional dependencies for the Gemma 3 tutorial only:
51+
pip install torch --index-url https://download.pytorch.org/whl/cpu
52+
pip install safetensors huggingface_hub transformers
53+
54+
To build the documentation locally from the repository root:
55+
56+
.. code-block:: bash
57+
58+
pip install -r docs/requirements.txt
59+
sphinx-build -b html docs docs/_build/html -j auto
60+
61+
To run a tutorial directly, execute its canonical source file:
62+
63+
.. code-block:: bash
64+
65+
python docs/tutorials/jax_tvm_ffi/flashinfer_jax_tvm_ffi.py
66+
python docs/tutorials/jax_tvm_ffi/gemma3_flashinfer_jax.py

0 commit comments

Comments
 (0)