|
| 1 | +FlashInfer on JAX with TVM FFI |
| 2 | +============================== |
| 3 | + |
| 4 | +These tutorials show how to call FlashInfer GPU kernels from JAX through the |
| 5 | +`jax-tvm-ffi <https://github.com/NVIDIA/jax-tvm-ffi>`_ bridge. |
| 6 | + |
| 7 | +The Sphinx-Gallery ``.py`` files in this directory are the canonical source: |
| 8 | + |
| 9 | +* ``flashinfer_jax_tvm_ffi.py`` explains the core build, register, and call |
| 10 | + pattern for FlashInfer kernels from JAX. |
| 11 | +* ``gemma3_flashinfer_jax.py`` applies the same pattern to Gemma 3 1B Instruct |
| 12 | + inference. |
| 13 | + |
| 14 | +During the documentation build, Sphinx-Gallery renders these files into HTML |
| 15 | +pages and creates downloadable Python and Jupyter notebook versions from the |
| 16 | +same source files. Do not edit or commit the generated |
| 17 | +``docs/tutorials/generated/jax_tvm_ffi/`` directory; it is produced by |
| 18 | +Sphinx-Gallery. |
| 19 | + |
| 20 | +The examples are not executed during the default documentation build because |
| 21 | +they require an NVIDIA GPU, CUDA, FlashInfer JIT compilation, and in the Gemma 3 |
| 22 | +case Hugging Face credentials for a gated model. |
| 23 | + |
| 24 | +Execution requirements |
| 25 | +---------------------- |
| 26 | + |
| 27 | +To run the tutorials directly, use a CUDA-capable environment with: |
| 28 | + |
| 29 | +* NVIDIA GPU with SM 7.5 or newer. |
| 30 | +* CUDA 12.6 or newer. |
| 31 | +* Python 3.10 or newer. |
| 32 | +* JAX with CUDA support. |
| 33 | +* ``flashinfer-python`` and ``jax-tvm-ffi``. |
| 34 | + |
| 35 | +The Gemma 3 tutorial additionally requires: |
| 36 | + |
| 37 | +* ``torch`` CPU wheels for dtype literals used by FlashInfer's JIT API. |
| 38 | +* ``safetensors``, ``huggingface_hub``, and ``transformers``. |
| 39 | +* Hugging Face access to ``google/gemma-3-1b-it`` and an ``HF_TOKEN``. |
| 40 | + |
| 41 | +For example: |
| 42 | + |
| 43 | +.. code-block:: bash |
| 44 | +
|
| 45 | + pip install 'jax[cuda13]' |
| 46 | + pip install flashinfer-python -U jax-tvm-ffi \ |
| 47 | + --no-build-isolation \ |
| 48 | + --extra-index-url https://flashinfer.ai/whl/cu130/ |
| 49 | +
|
| 50 | + # Additional dependencies for the Gemma 3 tutorial only: |
| 51 | + pip install torch --index-url https://download.pytorch.org/whl/cpu |
| 52 | + pip install safetensors huggingface_hub transformers |
| 53 | +
|
| 54 | +To build the documentation locally from the repository root: |
| 55 | + |
| 56 | +.. code-block:: bash |
| 57 | +
|
| 58 | + pip install -r docs/requirements.txt |
| 59 | + sphinx-build -b html docs docs/_build/html -j auto |
| 60 | +
|
| 61 | +To run a tutorial directly, execute its canonical source file: |
| 62 | + |
| 63 | +.. code-block:: bash |
| 64 | +
|
| 65 | + python docs/tutorials/jax_tvm_ffi/flashinfer_jax_tvm_ffi.py |
| 66 | + python docs/tutorials/jax_tvm_ffi/gemma3_flashinfer_jax.py |
0 commit comments