Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -199,3 +199,5 @@ cython_debug/

# Cursor
.cursor/
docs/tutorials/generated/
docs/sg_execution_times.rst
16 changes: 15 additions & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
"sphinx.ext.napoleon",
"sphinx.ext.autosummary",
"sphinx.ext.mathjax",
"sphinx_gallery.gen_gallery",
]

autodoc_default_flags = ["members"]
Expand All @@ -47,11 +48,24 @@

language = "en"

exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
exclude_patterns = [
"_build",
"Thumbs.db",
".DS_Store",
"tutorials/jax_tvm_ffi/README.rst",
]

# The name of the Pygments (syntax highlighting) style to use.
pygments_style = "sphinx"

sphinx_gallery_conf = {
"examples_dirs": "tutorials/jax_tvm_ffi",
"gallery_dirs": "tutorials/generated/jax_tvm_ffi",
"filename_pattern": r".*\.py",
"plot_gallery": "False",
"download_all_examples": False,
}

# A list of ignored prefixes for module index sorting.
# If true, `todo` and `todoList` produce output, else they produce nothing.
todo_include_todos = False
Expand Down
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ FlashInfer is a library and kernel generator for Large Language Models that prov

tutorials/recursive_attention
tutorials/kv_layout
tutorials/generated/jax_tvm_ffi/index

.. toctree::
:maxdepth: 2
Expand Down
1 change: 1 addition & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
furo == 2024.8.6
sphinx == 8.1.3
sphinx-gallery == 0.19.0
sphinx-reredirects == 0.1.5
sphinx-tabs == 3.4.5
sphinx-toolbox == 3.8.1
66 changes: 66 additions & 0 deletions docs/tutorials/jax_tvm_ffi/README.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
FlashInfer on JAX with TVM FFI
==============================

These tutorials show how to call FlashInfer GPU kernels from JAX through the
`jax-tvm-ffi <https://github.com/NVIDIA/jax-tvm-ffi>`_ bridge.

The Sphinx-Gallery ``.py`` files in this directory are the canonical source:

* ``flashinfer_jax_tvm_ffi.py`` explains the core build, register, and call
pattern for FlashInfer kernels from JAX.
* ``gemma3_flashinfer_jax.py`` applies the same pattern to Gemma 3 1B Instruct
inference.

During the documentation build, Sphinx-Gallery renders these files into HTML
pages and creates downloadable Python and Jupyter notebook versions from the
same source files. Do not edit or commit the generated
``docs/tutorials/generated/jax_tvm_ffi/`` directory; it is produced by
Sphinx-Gallery.

The examples are not executed during the default documentation build because
they require an NVIDIA GPU, CUDA, FlashInfer JIT compilation, and in the Gemma 3
case Hugging Face credentials for a gated model.

Execution requirements
----------------------

To run the tutorials directly, use a CUDA-capable environment with:

* NVIDIA GPU with SM 7.5 or newer.
* CUDA 12.6 or newer.
* Python 3.10 or newer.
* JAX with CUDA support.
* ``flashinfer-python`` and ``jax-tvm-ffi``.

The Gemma 3 tutorial additionally requires:

* ``torch`` CPU wheels for dtype literals used by FlashInfer's JIT API.
* ``safetensors``, ``huggingface_hub``, and ``transformers``.
* Hugging Face access to ``google/gemma-3-1b-it`` and an ``HF_TOKEN``.

For example:

.. code-block:: bash

pip install 'jax[cuda13]'
pip install flashinfer-python -U jax-tvm-ffi \
--no-build-isolation \
--extra-index-url https://flashinfer.ai/whl/cu130/

# Additional dependencies for the Gemma 3 tutorial only:
pip install torch --index-url https://download.pytorch.org/whl/cpu
pip install safetensors huggingface_hub transformers

To build the documentation locally from the repository root:

.. code-block:: bash

pip install -r docs/requirements.txt
sphinx-build -b html docs docs/_build/html -j auto

To run a tutorial directly, execute its canonical source file:

.. code-block:: bash

python docs/tutorials/jax_tvm_ffi/flashinfer_jax_tvm_ffi.py
python docs/tutorials/jax_tvm_ffi/gemma3_flashinfer_jax.py
Loading
Loading