diff --git a/.gitignore b/.gitignore index e28797e5ed..771a46b29d 100644 --- a/.gitignore +++ b/.gitignore @@ -199,3 +199,5 @@ cython_debug/ # Cursor .cursor/ +docs/tutorials/generated/ +docs/sg_execution_times.rst diff --git a/docs/conf.py b/docs/conf.py index bbf5c9d75f..2302b8f3bd 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -38,6 +38,7 @@ "sphinx.ext.napoleon", "sphinx.ext.autosummary", "sphinx.ext.mathjax", + "sphinx_gallery.gen_gallery", ] autodoc_default_flags = ["members"] @@ -47,11 +48,24 @@ language = "en" -exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] +exclude_patterns = [ + "_build", + "Thumbs.db", + ".DS_Store", + "tutorials/jax_tvm_ffi/README.rst", +] # The name of the Pygments (syntax highlighting) style to use. pygments_style = "sphinx" +sphinx_gallery_conf = { + "examples_dirs": "tutorials/jax_tvm_ffi", + "gallery_dirs": "tutorials/generated/jax_tvm_ffi", + "filename_pattern": r".*\.py", + "plot_gallery": "False", + "download_all_examples": False, +} + # A list of ignored prefixes for module index sorting. # If true, `todo` and `todoList` produce output, else they produce nothing. todo_include_todos = False diff --git a/docs/index.rst b/docs/index.rst index 028ed54a59..24a9388183 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -25,6 +25,7 @@ FlashInfer is a library and kernel generator for Large Language Models that prov tutorials/recursive_attention tutorials/kv_layout + tutorials/generated/jax_tvm_ffi/index .. toctree:: :maxdepth: 2 diff --git a/docs/requirements.txt b/docs/requirements.txt index 9811ed0268..93a3bce641 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,5 +1,6 @@ furo == 2024.8.6 sphinx == 8.1.3 +sphinx-gallery == 0.19.0 sphinx-reredirects == 0.1.5 sphinx-tabs == 3.4.5 sphinx-toolbox == 3.8.1 diff --git a/docs/tutorials/jax_tvm_ffi/README.rst b/docs/tutorials/jax_tvm_ffi/README.rst new file mode 100644 index 0000000000..2a821d7ead --- /dev/null +++ b/docs/tutorials/jax_tvm_ffi/README.rst @@ -0,0 +1,66 @@ +FlashInfer on JAX with TVM FFI +============================== + +These tutorials show how to call FlashInfer GPU kernels from JAX through the +`jax-tvm-ffi `_ bridge. + +The Sphinx-Gallery ``.py`` files in this directory are the canonical source: + +* ``flashinfer_jax_tvm_ffi.py`` explains the core build, register, and call + pattern for FlashInfer kernels from JAX. +* ``gemma3_flashinfer_jax.py`` applies the same pattern to Gemma 3 1B Instruct + inference. + +During the documentation build, Sphinx-Gallery renders these files into HTML +pages and creates downloadable Python and Jupyter notebook versions from the +same source files. Do not edit or commit the generated +``docs/tutorials/generated/jax_tvm_ffi/`` directory; it is produced by +Sphinx-Gallery. + +The examples are not executed during the default documentation build because +they require an NVIDIA GPU, CUDA, FlashInfer JIT compilation, and in the Gemma 3 +case Hugging Face credentials for a gated model. + +Execution requirements +---------------------- + +To run the tutorials directly, use a CUDA-capable environment with: + +* NVIDIA GPU with SM 7.5 or newer. +* CUDA 12.6 or newer. +* Python 3.10 or newer. +* JAX with CUDA support. +* ``flashinfer-python`` and ``jax-tvm-ffi``. + +The Gemma 3 tutorial additionally requires: + +* ``torch`` CPU wheels for dtype literals used by FlashInfer's JIT API. +* ``safetensors``, ``huggingface_hub``, and ``transformers``. +* Hugging Face access to ``google/gemma-3-1b-it`` and an ``HF_TOKEN``. + +For example: + +.. code-block:: bash + + pip install 'jax[cuda13]' + pip install flashinfer-python -U jax-tvm-ffi \ + --no-build-isolation \ + --extra-index-url https://flashinfer.ai/whl/cu130/ + + # Additional dependencies for the Gemma 3 tutorial only: + pip install torch --index-url https://download.pytorch.org/whl/cpu + pip install safetensors huggingface_hub transformers + +To build the documentation locally from the repository root: + +.. code-block:: bash + + pip install -r docs/requirements.txt + sphinx-build -b html docs docs/_build/html -j auto + +To run a tutorial directly, execute its canonical source file: + +.. code-block:: bash + + python docs/tutorials/jax_tvm_ffi/flashinfer_jax_tvm_ffi.py + python docs/tutorials/jax_tvm_ffi/gemma3_flashinfer_jax.py diff --git a/docs/tutorials/jax_tvm_ffi/flashinfer_jax_tvm_ffi.py b/docs/tutorials/jax_tvm_ffi/flashinfer_jax_tvm_ffi.py new file mode 100644 index 0000000000..f2e6538a94 --- /dev/null +++ b/docs/tutorials/jax_tvm_ffi/flashinfer_jax_tvm_ffi.py @@ -0,0 +1,1111 @@ +# This file is the canonical Sphinx-Gallery source for the tutorial. +# The Sphinx docs build generates HTML, a downloadable .py file, and a +# downloadable .ipynb notebook from this single source. + +""" +Enabling FlashInfer GPU Kernels on JAX with the JAX TVM FFI Bridge +================================================================== + + +Overview +-------- + + +JAX's XLA compiler is excellent for training and general tensor computation, but LLM inference has a distinct performance profile: every decode step attends over an ever-growing KV-cache, placing the bottleneck squarely on memory bandwidth rather than raw compute. FlashInfer is a library of hand-tuned CUDA kernels built precisely for this regime. + +FlashInfer ships every CUDA kernel as a shared library - an ``.so``-file compiled with a cross-language binary interface defined by `Apache TVM's Foreign Function Interface `_. Any language with a TVM FFI binding can load these ``.so``-files and call the functions inside. This tutorial shows how to do it from JAX via `jax-tvm-ffi `_, a bridge library that adapts TVM FFI functions to XLA custom calls. + + +What you'll build +----------------- + + +Three FlashInfer kernels: + + +.. list-table:: + :header-rows: 1 + + * - Kernel + - What it computes + - New concept + * - ``silu_and_mul`` + - Gated FFN activation: ``silu(gate) x up`` + - The minimal bridge: load -> register -> call + * - ``apply_rope`` + - Rotary positional embeddings on packed batches + - Multiple outputs; argument reordering + * - ``single_decode`` + - Attention over a KV-cache (single request, GQA) + - Type-specialized JIT; scratch buffers; optional-argument sentinels + + +At the end, all three run together inside a single ``@jax.jit`` region - exactly as they would in a real LLM decode loop. + + +Preliminaries +------------- + + +Hardware and software requirements +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + +.. list-table:: + :header-rows: 0 + + * - GPU + - NVIDIA, SM 7.5+ (Turing or later) + * - Python packages + - ``jax``, ``jax-tvm-ffi``, ``flashinfer`` + * - CUDA + - 12.6+ +""" + +# %% +# Setting the environment +# ~~~~~~~~~~~~~~~~~~~~~~~ +# +# +# The easiest way to get a working JAX environment is the `NVIDIA NGC JAX container `_. To install it manually: +# +# Recommended (CUDA 13): +# +# +# .. code-block:: bash +# +# pip install 'jax[cuda13]' +# +# +# For CUDA 12.x: +# +# +# .. code-block:: bash +# +# pip install 'jax[cuda12]' +# +# +# Three packages are required beyond a standard JAX environment: +# +# +# .. list-table:: +# :header-rows: 1 +# +# * - Package +# - Role +# * - ``flashinfer-python`` +# - FlashInfer CUDA kernels + JIT compilation system +# * - ``jax-tvm-ffi`` +# - Bridge: adapts TVM FFI functions to XLA custom calls +# +# +# ``flashinfer-python`` ships pre-built wheels for each CUDA/Python combination. The ``--extra-index-url`` below selects the CUDA 13.0 wheel; replace ``cu130`` with a corresponding mapping for your `CUDA Toolkit release `_: +# +# - CUDA 13 -> cu130 +# - CUDA 12.x -> cu12x (e.g., cu126) +# +# Run the cell below only once in your environment. + +# %% +# Install the tutorial dependencies before running the notebook or script:: +# +# pip install flashinfer-python -U jax-tvm-ffi \ +# --no-build-isolation \ +# --extra-index-url https://flashinfer.ai/whl/cu130/ \ +# + +# %% +# Loading dependencies +# ~~~~~~~~~~~~~~~~~~~~ + +# %% +# Run the cell below to load the dependencies. + +# %% +import os +import time +import math +import jinja2 +import numpy as np +import subprocess + +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # suppress TF/XLA info & warnings + +if "CUDA_HOME" not in os.environ: + try: + nvcc = subprocess.check_output(["which", "nvcc"], text=True).strip() + os.environ["CUDA_HOME"] = str(os.path.dirname(os.path.dirname(nvcc))) + except subprocess.CalledProcessError: + os.environ["CUDA_HOME"] = "/usr/local/cuda" + +if "--xla_gpu_cuda_data_dir=" not in os.environ.get("XLA_FLAGS", ""): + os.environ["XLA_FLAGS"] = ( + f"{os.environ.get('XLA_FLAGS', '')} " + f"--xla_gpu_cuda_data_dir={os.environ['CUDA_HOME']}" + ).strip() + +import jax +import jax.numpy as jnp + +import jax_tvm_ffi # Bridge adapter: TVM FFI -> XLA custom call +from flashinfer.jit import gen_act_and_mul_module, gen_jit_spec, env as jit_env +from flashinfer.jit.rope import gen_rope_module +from flashinfer.jit.attention.utils import generate_additional_params +from flashinfer.jit.utils import write_if_different + +print(f"JAX: {jax.__version__}") +print(f"Devices: {jax.devices()}") +print(f"CUDA home: {os.environ.get('CUDA_HOME')}") +print(f"JIT cache: {jit_env.FLASHINFER_GEN_SRC_DIR.parent}") + +# %% +# The JAX TVM FFI bridge +# ---------------------- +# +# +# Every FlashInfer kernel lives in a compiled ``.so``-file. Getting that kernel into a ``@jax.jit`` computation graph takes three steps - the same three steps for every kernel in this tutorial: +# +# +# .. code-block:: text +# +# Step 1 BUILD & LOAD jit_spec.build_and_load() -> tvm_ffi.Module +# Step 2 REGISTER jax_tvm_ffi.register_ffi_target(name, wrapper, arg_spec) +# Step 3 CALL jax.ffi.ffi_call(name, output_shapes)(*inputs, **scalar_attrs) +# +# +# Step 1 - Compile and load the TVM FFI module +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# +# FlashInfer compiles kernels on demand using its JIT system. ``flashinfer.jit`` contains framework-agnostic helpers - one per kernel family - that generate CUDA source code, compile it with nvcc + ninja, and produce a shared library (``.so``-file). The result is cached in ``~/.cache/flashinfer/`` so compilation only happens once per configuration. You'll see the specific helper for each kernel when we get to the examples. +# +# Each helper returns a **JitSpec** - a recipe describing what to compile. Calling ``.build_and_load()`` on a JitSpec runs the compilation pipeline and returns a ``tvm_ffi.Module``: +# +# +# .. code-block:: python +# +# module = some_jit_spec.build_and_load() +# module.my_kernel_function # tvm_ffi.Function, callable from Python +# module.my_kernel_function(a, b, c) # call it directly +# +# +# Under the hood, ``.build_and_load()`` writes the generated ``.so`` to the cache directory and calls ``tvm_ffi.load_module()`` to open it. If the ``.so`` already exists (cache hit), the compilation step is skipped and ``tvm_ffi.load_module()`` loads the cached binary directly. +# +# In other words: ``tvm_ffi.load_module(path)`` is the low-level primitive that opens any ``.so``-file, while ``.build_and_load()`` is the high-level entry point that handles source generation, compilation, caching, *and* loading in one call. In this tutorial we always use ``.build_and_load()`` because FlashInfer's JIT system manages the full pipeline for us. +# +# +# Step 2 - Register as a JAX FFI target +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# +# ``jax_tvm_ffi.register_ffi_target(name, wrapper, arg_spec)`` teaches XLA about the kernel. It wraps a Python callable into an XLA custom call node. Once registered, XLA can schedule the kernel inside ``@jax.jit`` computations. +# +# The ``wrapper`` is a Python function whose job is to call the TVM function. It may need to reorder arguments: JAX delivers output tensors first and input tensors second, but TVM functions are compiled with their own parameter order that you must match. +# +# +# .. code-block:: python +# +# def _wrapper(out, x, scalar): # positional order set by arg_spec +# module.my_kernel_function(x, out, scalar) # reorder here to match TVM signature +# +# jax_tvm_ffi.register_ffi_target( +# 'registered_name', _wrapper, +# arg_spec=['rets', 'args', 'attrs.scalar'], +# platform='gpu', allow_cuda_graph=True, pass_owned_tensor=True, +# ) +# +# +# .. note:: Concept: ``arg_spec`` routing +# +# +# ``arg_spec`` is a list that tells the bridge how to route JAX's three categories of call-time data into the wrapper's positional arguments: +# +# .. list-table:: +# :header-rows: 1 +# +# * - ``arg_spec`` value +# - What the wrapper receives +# * - ``'rets'`` +# - All output tensors, pre-allocated by XLA +# * - ``'args'`` +# - All input tensors, in call order +# * - ``'attrs.KEY'`` +# - A scalar keyword argument named ``KEY`` +# +# For example, ``arg_spec=['rets', 'args', 'attrs.top_k']`` means the wrapper is called as ``wrapper(*outputs, *inputs, top_k)``. +# +# With ``pass_owned_tensor=True`` (required for GPU kernels), the tensors inside the wrapper are ``tvm_ffi.Tensor`` objects - they have a ``.shape`` attribute and are passed directly to TVM functions. +# +# +# Step 3 - Call as a regular JAX expression +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# +# .. code-block:: python +# +# result = jax.ffi.ffi_call( +# 'registered_name', +# jax.ShapeDtypeStruct(output_shape, output_dtype), # tells XLA what to pre-allocate +# vmap_method='broadcast_all', # how jax.vmap should batch this call +# )(*input_tensors, scalar=value) +# +# +# The name ``'registered_name'`` must match what was passed to ``register_ffi_target`` in Step 2, and ``scalar=value`` must match the ``'attrs.scalar'`` entry in ``arg_spec``. +# +# Output shapes must be statically known - XLA needs them at trace time to pre-allocate the output buffers that get passed as ``'rets'`` in the wrapper. The call site is a regular JAX expression and composes naturally with ``@jax.jit``. +# +# +# .. note:: Concept: ``vmap_method`` +# +# +# ``vmap_method`` tells JAX how to handle ``jax.vmap`` over this call: +# +# .. list-table:: +# :header-rows: 1 +# +# * - Value +# - Behaviour +# * - ``'broadcast_all'`` +# - Treat all inputs as sharing the same batch axis; run the kernel once per batch element. Use this when the operation is independent across a batch dimension, such as an element-wise activation or a per-token positional embedding. +# * - omitted +# - The call raises ``NotImplementedError`` inside ``jax.vmap``. Use this when the kernel's internals, such as scratch buffers or cross-head reduction, do not decompose cleanly over an added batch axis. +# +# Please refer to `this documentation `_ to learn more about ``vmap``. +# +# +# The full execution path +# ~~~~~~~~~~~~~~~~~~~~~~~ +# +# +# .. code-block:: text +# +# User JAX code +# | jax.ffi.ffi_call(name, output_shapes)(*inputs, **attrs) +# v +# XLA Compiler +# - emits a custom call node; records output shapes and scalar attrs at trace time +# | +# XLA Runtime +# - looks up the registered target; pre-allocates output buffers +# | +# JAX-TVM-FFI Bridge +# - unpacks call frame -> positional args according to arg_spec +# | +# Python wrapper (user-defined) +# - reorders arguments to match the TVM function signature +# | +# tvm_ffi.Function -> CUDA kernel + +# %% +# Example 1: Gated SiLU (a.k.a. SwiGLU-style) +# ------------------------------------------- +# +# +# In this example, you should learn how minimal the JAX-TVM bridge can be when the function signature and ``arg_spec`` are aligned: we have one kernel and one thin wrapper, and JAX handles the rest. +# +# Modern LLM feed-forward layers use a gated activation instead of a plain nonlinearity. A common form is: +# +# +# .. code-block:: text +# +# FFN(x) = SiLU(W1 * x) * (W2 * x) +# +# where * is elementwise multiply. +# +# In practice, implementations often compute both linear projections in one matmul by using a weight with twice the hidden width, producing a tensor of shape ``[..., 2H]``. Here, hidden size ``H`` is just the width of the model's internal feature vectors. And in gated FFNs, we temporarily double it (``2H``) to compute two parallel projections before combining them back to size ``H``: +# +# +# .. code-block:: text +# +# a = input[..., :H] (the "gate" half) +# b = input[..., H:] (the "value / up-projection" half) +# +# +# The TVM function signature is: +# +# +# .. code-block:: python +# +# silu_and_mul(out, input, enable_pdl) +# +# +# - ``out``: a pre-allocated output buffer that the kernel writes into. +# - ``input``: the fused ``[..., 2H]`` activation tensor (contains both halves concatenated). +# - ``enable_pdl``: a boolean that toggles Programmatic Dependent Launch (PDL) - an SM90+ feature that can help chain GPU work with less host involvement. For this tutorial, we keep it ``False``. +# +# The same compiled binary supports ``float16`` and ``bfloat16``, selecting the right path via runtime dispatch based on the input dtype. +# +# This is the simplest bridge case: **one input tensor -> one output tensor -> one scalar attribute**. And because the TVM function's parameter order (``out, input, enable_pdl``) matches the order JAX delivers them (``rets``, then ``args``, then ``attrs``), the wrapper needs no reordering. +# +# +# Compile and load +# ~~~~~~~~~~~~~~~~ +# +# +# ``gen_act_and_mul_module('silu')`` returns a JitSpec - a recipe describing what to compile. +# +# ``.build_and_load()`` runs nvcc + ninja, writes the .so to the FlashInfer cache directory, and returns a ``tvm_ffi.Module``. + +# %% +print("Compiling silu_and_mul (first run may take ~30 s)...") +silu_module = gen_act_and_mul_module("silu").build_and_load() +print(f" Module type: {type(silu_module).__name__}") +print(f" Function: {silu_module.silu_and_mul}") + +# %% +# Register as a JAX FFI target and validate +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# +# In this section, we connect the compiled TVM kernel to JAX by registering it as an FFI target and exposing a clean Python wrapper. The key idea is that JAX groups parameters as outputs **(rets) -> inputs (args) -> scalar attributes (attrs)**, and because the TVM function already expects ``(out, input, enable_pdl)`` in that exact order, the wrapper simply forwards the call with no argument reshuffling. The ``silu_and_mul`` function then uses ``jax.ffi.ffi_call`` to allocate the output buffer, pass the fused ``[..., 2H]`` tensor to the kernel, and return the computed ``[..., H]`` result, while remaining fully batch-compatible via ``vmap_method="broadcast_all"``. +# +# The validation step should confirm correctness against a pure JAX reference. + +# %% +# -- Register as a JAX FFI target -------------------------------------- + +# TVM function: silu_and_mul(out, input, enable_pdl) +# arg_spec: ['rets', 'args', 'attrs.enable_pdl'] + +# JAX delivers (out, input, enable_pdl) in the same order as the TVM function +# expects, so no reordering is needed in this wrapper. + + +def _silu_and_mul_wrapper(out, x, enable_pdl): + silu_module.silu_and_mul(out, x, enable_pdl) + + +jax_tvm_ffi.register_ffi_target( + "flashinfer.silu_and_mul", + _silu_and_mul_wrapper, + arg_spec=["rets", "args", "attrs.enable_pdl"], + platform="gpu", + allow_cuda_graph=True, + pass_owned_tensor=True, +) + +# -- JAX-facing function ----------------------------------------------- + + +def silu_and_mul(x: jax.Array) -> jax.Array: + """Fused silu(gate) * up. Input: [..., 2H] Output: [..., H]""" + out_shape = x.shape[:-1] + (x.shape[-1] // 2,) + return jax.ffi.ffi_call( + "flashinfer.silu_and_mul", + jax.ShapeDtypeStruct(out_shape, x.dtype), + vmap_method="broadcast_all", # element-wise op: independent across any batch dim + )(x, enable_pdl=False) + + +# -- Validate ------------------------------------------------------------------ + +TOKENS, HIDDEN = 32, 256 +gate_up = jax.random.normal(jax.random.key(0), (TOKENS, 2 * HIDDEN), dtype=jnp.float16) + +out = silu_and_mul(gate_up) + +# Compute reference to test our function +gate_ref = gate_up[..., :HIDDEN].astype(jnp.float32) +up_ref = gate_up[..., HIDDEN:].astype(jnp.float32) +ref = (jax.nn.silu(gate_ref) * up_ref).astype(jnp.float16) + +np.testing.assert_allclose( + np.array(out.astype(jnp.float32)), + np.array(ref.astype(jnp.float32)), + rtol=1e-2, + atol=1e-2, +) +print("silu_and_mul: PASSED") +print(f" {gate_up.shape} -> {out.shape}") +print(f" max error: {float(jnp.max(jnp.abs(out.astype(jnp.float32) - ref))):.5f}") + +# %% +# Example 2: Rotary Positional Embeddings +# --------------------------------------- +# +# +# In this example, you should learn how to handle two new complications that were absent in Example 1: a TVM function that **produces two outputs** and whose **argument order does not match JAX's convention**. Both are resolved entirely inside the Python wrapper - registration and the call site otherwise follow the same pattern. +# +# RoPE encodes each token's position by rotating its query and key vectors in pairs. For position ``p`` and dimension index ``i``: +# +# +# .. code-block:: text +# +# theta_i = rope_theta^(-2i / head_dim) +# q_rot[2i] = q[2i] * cos(p*theta_i) - q[2i+1] * sin(p*theta_i) +# q_rot[2i+1] = q[2i] * sin(p*theta_i) + q[2i+1] * cos(p*theta_i) +# +# +# The same rotation is applied to both ``q`` and ``k``, so the kernel produces two rotated outputs. +# +# +# Ragged (packed) batches +# ~~~~~~~~~~~~~~~~~~~~~~~ +# +# +# Rather than a padded ``[batch, seq_len, heads, dim]`` tensor, ``apply_rope`` takes a flat ``[total_tokens, heads, dim]`` tensor with two auxiliary arrays: +# +# - ``indptr`` - a CSR-style pointer array where ``indptr[i]:indptr[i+1]`` gives the token range of sequence ``i`` +# - ``offsets`` - the KV-cache position of the first token of each sequence +# +# This lets the kernel handle variable-length sequences without padding overhead. +# +# +# The TVM function signature +# ~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# +# .. code-block:: python +# +# apply_rope(q, k, q_rope, k_rope, indptr, offsets, rotary_dim, interleave, rope_scale, rope_theta) +# +# +# - ``q``, ``k``: input query and key tensors, shape ``[total_tokens, num_heads, head_dim]`` +# - ``q_rope``, ``k_rope``: pre-allocated output buffers - they appear **in the middle** of the argument list, after the inputs but before the index arrays +# - ``indptr``, ``offsets``: ragged-batch descriptors +# - ``rotary_dim``, ``interleave``, ``rope_scale``, ``rope_theta``: scalar parameters +# +# Because the outputs are interleaved with the inputs, the JAX convention (all outputs first) does not match the TVM signature. The wrapper must swap them back. +# +# +# .. note:: Concept: argument reordering +# +# +# TVM functions are compiled with whatever parameter order the kernel author chose. JAX's FFI convention always delivers ``(outputs, inputs, scalars)``. The wrapper bridges the two: +# +# .. code-block:: python +# +# # JAX delivers (via arg_spec): q_rope, k_rope, q, k, indptr, offsets, *scalars +# # TVM function expects: q, k, q_rope, k_rope, indptr, offsets, *scalars +# def _wrapper(q_rope, k_rope, q, k, indptr, offsets, *scalars): +# tvm_fn(q, k, q_rope, k_rope, indptr, offsets, *scalars) # reorder here +# +# +# Compile and load +# ~~~~~~~~~~~~~~~~ +# +# +# ``gen_rope_module()`` returns a JitSpec for the RoPE kernel. Unlike ``gen_act_and_mul_module``, this kernel dispatches over dtypes at runtime from a single binary - no dtype specialisation is needed at compile time. + +# %% +print("Compiling apply_rope...") +rope_module = gen_rope_module().build_and_load() +print(f" Function: {rope_module.apply_rope}") + +# %% +# Register as a JAX FFI target and validate +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# +# The ``apply_rope`` wrapper introduces the argument-reordering pattern described above. Because the TVM function places ``q_rope`` and ``k_rope`` after the input tensors - rather than at the beginning - the wrapper receives them first (via ``'rets'``) and must swap them back before calling the TVM function. The kernel also produces two output tensors, so ``jax.ffi.ffi_call`` receives a tuple of two ``ShapeDtypeStruct`` descriptors and returns a tuple of two JAX arrays. Because RoPE rotates each token independently, the operation decomposes cleanly over any outer batch dimension: ``vmap_method='broadcast_all'`` enables ``jax.vmap`` to map over it without any changes to the kernel. +# +# The validation uses two sequences with different starting positions - sequence 0 begins at position 0 and sequence 1 at position 100 - simulating a prefill where the second sequence already has 100 tokens in the KV-cache. Each token is rotated by angles matching its absolute position, so the test exercises both the CSR indexing and the offset arithmetic. + +# %% +# -- Register as a JAX FFI target ---------------------------------------------- +# +# TVM function: apply_rope(q, k, q_rope, k_rope, indptr, offsets, +# rotary_dim, interleave, rope_scale, rope_theta) +# +# JAX delivers: q_rope, k_rope first (rets), then q, k, indptr, offsets (args). +# TVM expects: q, k, q_rope, k_rope, indptr, offsets, ... - the wrapper swaps them. + + +def _rope_wrapper( + q_rope, + k_rope, + q, + k, + indptr, + offsets, + rotary_dim, + interleave, + rope_scale, + rope_theta, +): + rope_module.apply_rope( + q, + k, + q_rope, + k_rope, + indptr, + offsets, + rotary_dim, + interleave, + rope_scale, + rope_theta, + ) + + +jax_tvm_ffi.register_ffi_target( + "flashinfer.apply_rope", + _rope_wrapper, + arg_spec=[ + "rets", + "args", + "attrs.rotary_dim", + "attrs.interleave", + "attrs.rope_scale", + "attrs.rope_theta", + ], + platform="gpu", + allow_cuda_graph=True, + pass_owned_tensor=True, +) + +# -- JAX-facing function ------------------------------------------------------- + + +def apply_rope(q, k, indptr, offsets, *, rope_theta=1e4): + """Apply rotary positional embeddings to packed query and key tensors. + + q, k: [total_tokens, num_heads, head_dim] + indptr: [num_seqs + 1] CSR-style token range per sequence + offsets: [num_seqs] absolute position of the first token of each sequence + Returns: (q_rotated, k_rotated), same shapes as inputs + """ + head_dim = q.shape[-1] + return jax.ffi.ffi_call( + "flashinfer.apply_rope", + ( + jax.ShapeDtypeStruct(q.shape, q.dtype), + jax.ShapeDtypeStruct(k.shape, k.dtype), + ), + vmap_method="broadcast_all", # each packed batch is independent; safe to vmap + )( + q, + k, + indptr, + offsets, + rotary_dim=head_dim, + interleave=False, + rope_scale=1.0, + rope_theta=float(rope_theta), + ) + + +# -- Validate ------------------------------------------------------------------ + + +def _reference_rope(x, positions, theta=1e4): + """Non-interleaved RoPE reference. x: [tokens, heads, dim]""" + x32 = x.astype(jnp.float32) + d = x32.shape[-1] // 2 + freqs = 1.0 / (theta ** (2.0 * jnp.arange(d, dtype=jnp.float32) / x32.shape[-1])) + angles = positions[:, None].astype(jnp.float32) * freqs[None, :] # [T, d] + cos_a = jnp.cos(angles)[:, None, :] # [T, 1, d] + sin_a = jnp.sin(angles)[:, None, :] + x1, x2 = x32[..., :d], x32[..., d:] + return jnp.concatenate( + [x1 * cos_a - x2 * sin_a, x1 * sin_a + x2 * cos_a], axis=-1 + ).astype(x.dtype) + + +# Two sequences of 8 tokens: first starts at position 0, second at position 100 +NUM_HEADS, HEAD_DIM, SEQ_LEN, NUM_SEQ = 8, 64, 8, 2 +ROPE_THETA = 1e4 + +q_in = jax.random.normal( + jax.random.key(1), (NUM_SEQ * SEQ_LEN, NUM_HEADS, HEAD_DIM), dtype=jnp.bfloat16 +) +k_in = jax.random.normal( + jax.random.key(2), (NUM_SEQ * SEQ_LEN, NUM_HEADS, HEAD_DIM), dtype=jnp.bfloat16 +) + +indptr = jnp.array([0, SEQ_LEN, 2 * SEQ_LEN], dtype=jnp.int32) +offsets = jnp.array([0, 100], dtype=jnp.int32) + +q_rot, k_rot = apply_rope(q_in, k_in, indptr, offsets, rope_theta=ROPE_THETA) + +positions = jnp.concatenate( + [jnp.arange(SEQ_LEN, dtype=jnp.int32) + off for off in [0, 100]] +) +q_ref = _reference_rope(q_in, positions, theta=ROPE_THETA) +k_ref = _reference_rope(k_in, positions, theta=ROPE_THETA) + +for name, got, ref in [("q", q_rot, q_ref), ("k", k_rot, k_ref)]: + np.testing.assert_allclose( + np.array(got.astype(jnp.float32)), + np.array(ref.astype(jnp.float32)), + rtol=1e-2, + atol=1e-2, + ) + max_err = float(jnp.max(jnp.abs(got.astype(jnp.float32) - ref.astype(jnp.float32)))) + print(f"apply_rope {name}: PASSED max_err={max_err:.5f}") +print( + f" Input: {q_in.shape} ({NUM_SEQ} seqs x {SEQ_LEN} tokens, offsets {offsets.tolist()})" +) + +# %% +# Example 3: Single-request decode attention +# ------------------------------------------ +# +# +# In this example, you should learn about three patterns that appear together for the first time: a kernel **compiled separately per dtype and head dimension**, a **scratch buffer** the caller must allocate as an output, and **optional arguments** signalled as absent with empty tensors. +# +# In autoregressive generation, each new token attends over all previously generated tokens stored in the KV-cache: +# +# +# .. code-block:: text +# +# Attention(q, K, V) = softmax(q KT / sqrtd) V +# +# +# where ``q`` is a single query (or a small group in Grouped Query Attention), and ``K``, ``V`` are the full cached sequences. FlashInfer's single-request kernel is tuned for exactly this shape: one query token, many keys, memory-bound. +# +# +# Type-specialized compilation +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# +# Unlike ``silu_and_mul`` and ``apply_rope`` - which dispatch over dtypes at runtime from a single binary - the decode attention kernel is **compiled separately for each ``(dtype, head_dim)`` combination**. This lets the compiler choose tile sizes and memory access patterns for the exact configuration, with no runtime branching. +# +# FlashInfer's JIT system renders Jinja2 templates with concrete type names to produce configuration-specific CUDA code. A separate ``.so`` is compiled for each combination and identified by a URI that doubles as the on-disk cache key. ``gen_decode_jit_spec``, defined in the cell below, builds that URI and invokes the JIT system. +# +# +# The TVM function signature +# ~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# +# .. code-block:: python +# +# run(q, k, v, +# tmp, out, # outputs: scratch buffer first, result second +# maybe_lse, # optional output: log-sum-exp values +# kv_layout_code, window_left, +# maybe_alibi_slopes, # optional input: ALiBi position biases +# logits_soft_cap, sm_scale, rope_rcp_scale, rope_rcp_theta) +# +# +# - ``q``, ``k``, ``v``: query and KV-cache tensors +# - ``tmp``: a **scratch buffer** for split-K partial sums - must be provided by the caller as a pre-allocated output +# - ``out``: the attention result +# - ``maybe_lse``: optional log-sum-exp output; ``None`` to skip +# - ``kv_layout_code``: ``0`` for NHD layout, ``1`` for HND +# - ``window_left``: sliding-window cutoff; ``-1`` for full attention +# - ``maybe_alibi_slopes``: optional ALiBi position biases; ``None`` to skip +# - ``logits_soft_cap``, ``sm_scale``, ``rope_rcp_scale``, ``rope_rcp_theta``: scalar parameters +# +# Two patterns are new here: +# +# +# .. note:: Concept: scratch buffer +# +# +# The decode kernel uses an internal buffer for split-K partial results. We declare it as an output tensor so XLA pre-allocates it, then discard it after the call: +# +# .. code-block:: python +# +# out, _, _ = jax.ffi.ffi_call(target, (result_struct, tmp_struct, lse_struct))(...) +# # ^^^ ^^^ +# # 32 MB scratch LSE sentinel +# +# +# .. note:: Concept: empty-tensor sentinel +# +# +# Optional arguments are signalled as absent with a tensor whose first dimension is zero. Inside the wrapper, ``tensor.shape[0] == 0`` maps to ``None``: +# +# .. code-block:: python +# +# # As an output (not computed): +# jax.ShapeDtypeStruct((0,), jnp.float32) +# +# # As an input (not provided): +# jnp.empty((0,), dtype=jnp.float32) +# +# +# Compile and load +# ~~~~~~~~~~~~~~~~ +# +# +# ``gen_decode_jit_spec`` assembles the build recipe: it renders two Jinja2 templates with the concrete dtype and head-dimension values, copies the ``.cu`` source files into the generated directory, and returns a ``JitSpec`` that ``.build_and_load()`` compiles and caches. + +# %% +# -- Type mappings for Jinja template rendering -------------------------------- + +DTYPE_CPP = {"float16": "half", "bfloat16": "nv_bfloat16", "float32": "float"} +DTYPE_SAFE = {"float16": "f16", "bfloat16": "bf16", "float32": "f32"} +POS_ENC = { + 0: "PosEncodingMode::kNone", + 1: "PosEncodingMode::kRoPELlama", + 2: "PosEncodingMode::kALiBi", +} + + +def gen_decode_jit_spec(dtype: str = "float16", head_dim: int = 64): + """Return a JitSpec for type-specialized single-request decode attention.""" + s = DTYPE_SAFE[dtype] + uri = ( + f"single_decode_with_kv_cache_dtype_q_{s}_dtype_kv_{s}_dtype_o_{s}_" + f"head_dim_qk_{head_dim}_head_dim_vo_{head_dim}_" + f"posenc_0_use_swa_False_use_logits_cap_False" + ) + gen_dir = jit_env.FLASHINFER_GEN_SRC_DIR / uri + os.makedirs(gen_dir, exist_ok=True) + + # generate_additional_params produces the C++ boilerplate strings + # for the optional alibi-slopes tensor and four scalar parameters. + params_decl, func_params, params_setter = generate_additional_params( + additional_tensor_names=["maybe_alibi_slopes"], + additional_tensor_dtypes=["float"], + additional_scalar_names=[ + "logits_soft_cap", + "sm_scale", + "rope_rcp_scale", + "rope_rcp_theta", + ], + additional_scalar_dtypes=["double", "double", "double", "double"], + ) + + kwargs = dict( + additional_func_params=func_params, + additional_params_decl=params_decl, + additional_params_setter=params_setter, + variant_decl="#include", + variant_name="DefaultAttention", + dtype_q=DTYPE_CPP[dtype], + dtype_kv=DTYPE_CPP[dtype], + dtype_o=DTYPE_CPP[dtype], + head_dim_qk=head_dim, + head_dim_vo=head_dim, + pos_encoding_mode=POS_ENC[0], + use_sliding_window="false", + use_logits_soft_cap="false", + ) + + csrc = jit_env.FLASHINFER_CSRC_DIR + + # Render Jinja2 templates with the type-specific values + for tmpl, out in [ + ("single_decode_customize_config.jinja", "single_decode_config.inc"), + ("single_decode_kernel_inst.jinja", "single_decode_kernel.cu"), + ]: + rendered = jinja2.Template((csrc / tmpl).read_text()).render(**kwargs) + write_if_different(gen_dir / out, rendered) + + # Copy the .cu source files that #include the rendered headers + sources = [gen_dir / "single_decode_kernel.cu"] + for fname in ["single_decode.cu", "single_decode_jit_binding.cu"]: + dest = gen_dir / fname + write_if_different(dest, (csrc / fname).read_text()) + sources.append(dest) + + return gen_jit_spec(uri, sources) + + +# -- Compile and load ---------------------------------------------------------- + +DTYPE, HEAD_DIM = "float16", 64 +print(f"Compiling decode attention ({DTYPE}, head_dim={HEAD_DIM})...") +decode_module = gen_decode_jit_spec(DTYPE, HEAD_DIM).build_and_load() +print(f" run function: {decode_module.run}") + +# %% +# Register as a JAX FFI target +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# +# The ``decode_attention`` wrapper handles all three patterns from this example. The optional ``lse`` and ``alibi_slopes`` arguments use the empty-tensor sentinel: when ``tensor.shape[0] == 0``, the wrapper passes ``None`` to the TVM function, telling the kernel to skip that output or ignore that input. The scratch buffer ``tmp`` is declared as a pre-allocated output - XLA reserves 32 MB for it - and the caller discards it after the call. Finally, ``vmap_method`` is intentionally omitted: the scratch buffer is a flat array with no batch dimension, and GQA head-grouping does not decompose cleanly over an added outer batch axis. Callers that need batching should loop explicitly or use a batch-aware kernel variant. + +# %% +# -- Register ------------------------------------------------------------------ +# +# TVM function signature: +# run(q, k, v, tmp, out, maybe_lse, kv_layout_code, window_left, +# maybe_alibi_slopes, logits_soft_cap, sm_scale, rope_rcp_scale, rope_rcp_theta) +# +# Sentinel rule: tensor.shape[0] == 0 => pass None (argument is absent) + +_run = decode_module.run + + +def _decode_wrapper( + out, + tmp, + lse_or_empty, # <- rets + q, + k, + v, + alibi_or_empty, # <- args + kv_layout_code, + window_left, # <- attrs + logits_soft_cap, + sm_scale, + rope_scale, + rope_theta, +): + lse = None if lse_or_empty.shape[0] == 0 else lse_or_empty + alibi = None if alibi_or_empty.shape[0] == 0 else alibi_or_empty + # Reorder to match TVM function signature: + # kv_layout_code and window_left come before maybe_alibi_slopes + _run( + q, + k, + v, + tmp, + out, + lse, + kv_layout_code, + window_left, + alibi, + logits_soft_cap, + sm_scale, + rope_scale, + rope_theta, + ) + + +# Embed dtype + head_dim in the name: each compiled binary is a distinct target +DECODE_TARGET = f"flashinfer.single_decode_{DTYPE}_h{HEAD_DIM}" + +jax_tvm_ffi.register_ffi_target( + DECODE_TARGET, + _decode_wrapper, + arg_spec=[ + "rets", + "args", + "attrs.kv_layout_code", + "attrs.window_left", + "attrs.logits_soft_cap", + "attrs.sm_scale", + "attrs.rope_scale", + "attrs.rope_theta", + ], + platform="gpu", + allow_cuda_graph=True, + pass_owned_tensor=True, +) + +# -- JAX-facing function ------------------------------------------------------- + + +def decode_attention(q, k, v): + """Single-request GQA decode attention. + + q: [num_qo_heads, head_dim] float16 (query for one new token) + k: [kv_len, num_kv_heads, head_dim] float16 (full KV-cache, NHD layout) + v: [kv_len, num_kv_heads, head_dim] float16 + Returns: [num_qo_heads, head_dim] + """ + sm_scale = 1.0 / math.sqrt(q.shape[-1]) + tmp_elems = 32 * 1024 * 1024 // np.dtype(q.dtype).itemsize + + out, _, _ = jax.ffi.ffi_call( + DECODE_TARGET, + ( + jax.ShapeDtypeStruct(q.shape, q.dtype), # out + jax.ShapeDtypeStruct((tmp_elems,), q.dtype), # tmp (scratch, discarded) + jax.ShapeDtypeStruct((0,), jnp.float32), # lse (sentinel: not computed) + ), + )( + q, + k, + v, + jnp.empty((0,), dtype=jnp.float32), # alibi slopes: sentinel (not used) + kv_layout_code=0, # NHD=0, HND=1 + window_left=-1, # full attention window + logits_soft_cap=0.0, + sm_scale=sm_scale, + rope_scale=1.0, + rope_theta=1e4, + ) + return out + + +print(f"Registered '{DECODE_TARGET}'.") + +# %% +# The validation uses a GQA configuration: 16 query heads and 4 KV heads (4x grouping) attending over a 512-token KV-cache. A pure-JAX reference computes the grouped softmax attention in float32 for comparison. + +# %% +# -- Validate ------------------------------------------------------------------ + + +def _reference_gqa_decode(q, k, v): + """Reference GQA decode. q: [H_q, D] k, v: [S, H_kv, D] (NHD)""" + H_q, H_kv = q.shape[0], k.shape[1] + scale = q.shape[-1] ** -0.5 + q32 = q.astype(jnp.float32).reshape(H_kv, H_q // H_kv, -1) # [H_kv, group, D] + scores = jnp.einsum("hgd,shd->hgs", q32, k.astype(jnp.float32)) * scale + weights = jax.nn.softmax(scores, axis=-1) + out = jnp.einsum("hgs,shd->hgd", weights, v.astype(jnp.float32)) + return out.reshape(H_q, -1) + + +# GQA: 16 query heads, 4 KV heads (4x grouping), 512-token KV-cache +NUM_QO, NUM_KV, KV_LEN = 16, 4, 512 + +q = jax.random.normal(jax.random.key(10), (NUM_QO, HEAD_DIM), dtype=jnp.float16) +k = jax.random.normal(jax.random.key(11), (KV_LEN, NUM_KV, HEAD_DIM), dtype=jnp.float16) +v = jax.random.normal(jax.random.key(12), (KV_LEN, NUM_KV, HEAD_DIM), dtype=jnp.float16) + +out_raw = decode_attention(q, k, v) +out_ref = _reference_gqa_decode(q, k, v) + +np.testing.assert_allclose( + np.array(out_raw.astype(jnp.float32)), + np.array(out_ref), + rtol=1e-2, + atol=1e-2, +) +print("decode_attention: PASSED") +print( + f" GQA: {NUM_QO} query / {NUM_KV} KV heads ({NUM_QO // NUM_KV}x groups), kv_len={KV_LEN}" +) +print(f" Output: {out_raw.shape}") +print( + f" Max error: {float(jnp.max(jnp.abs(out_raw.astype(jnp.float32) - out_ref))):.4f}" +) + +# %% +# Composing kernels in ``@jax.jit`` +# --------------------------------- +# +# +# In this final section, you should learn that registered TVM FFI targets are plain XLA custom call nodes and compose naturally inside a single ``@jax.jit``-decorated function alongside any other JAX operations - no special handling is needed. +# +# A realistic LLM decode step brings all three kernels together: +# +# 1. ``silu_and_mul`` - gated FFN activation +# 2. ``apply_rope`` - rotary embeddings applied to the new query and key tokens +# 3. ``decode_attention`` - cross-attention over the full KV-cache with the RoPE'd query +# +# XLA compiles the entire function once. All three kernels become custom call nodes in the same HLO computation graph. +# +# .. note:: Where to put ``@jax.jit`` +# +# ``jax.ffi.ffi_call`` works with or without ``@jax.jit``. Two patterns are common: +# +# **Per-kernel JIT** - decorate each helper individually: +# +# .. code-block:: python +# +# @jax.jit +# def silu_and_mul(x): ... +# +# @jax.jit +# def decode_attention(q, k, v): ... +# +# Each kernel compiles independently and runs fast when called in isolation. Good for standalone use, exploratory work, and benchmarking individual kernels. +# +# **Outer JIT only** - helpers are plain functions; only the composition function is decorated: +# +# .. code-block:: python +# +# def silu_and_mul(x): ... # plain Python +# def decode_attention(q, k, v): ... +# +# @jax.jit +# def decode_step(...): +# ffn_out = silu_and_mul(...) +# attn_out = decode_attention(...) +# return ffn_out, attn_out +# +# XLA traces the entire decode step as a single computation graph, seeing all custom call nodes at once. This is the right choice for production LLM inference, where all kernels run together every step anyway. +# +# JAX handles nested ``@jax.jit`` correctly - inner jits are inlined during tracing - so there is no penalty to mixing patterns later. This tutorial uses **outer JIT only** to keep the focus on the composition. + +# %% +# -- Inputs for the decode step ----------------------------------------------- + +# FFN input: 4 tokens, each with a 2H fused gate+up projection +gate_up = jax.random.normal(jax.random.key(20), (4, 2 * HEAD_DIM), dtype=jnp.float16) + +# GQA dimensions and KV-cache (same config as the validate cell) +NUM_QO, NUM_KV, KV_LEN = 16, 4, 512 + +q = jax.random.normal(jax.random.key(10), (NUM_QO, HEAD_DIM), dtype=jnp.float16) +k = jax.random.normal(jax.random.key(11), (KV_LEN, NUM_KV, HEAD_DIM), dtype=jnp.float16) +v = jax.random.normal(jax.random.key(12), (KV_LEN, NUM_KV, HEAD_DIM), dtype=jnp.float16) + +# New-token Q and K as packed batches for apply_rope: [tokens, heads, dim] +q_new = q.reshape(1, NUM_QO, HEAD_DIM) # [1, 16, 64] <- new query token +k_new = k[:1] # [1, 4, 64] <- new key token +indptr = jnp.array([0, 1], dtype=jnp.int32) # one sequence of length 1 +offsets = jnp.array([KV_LEN], dtype=jnp.int32) # new token sits at position KV_LEN + + +# -- @jax.jit composition ------------------------------------------------------ + + +@jax.jit +def decode_step(gate_up, q_new, k_new, k_cache, v_cache, indptr, offsets): + """One LLM decode step compiled into a single XLA computation.""" + # 1. Gated FFN activation + ffn_out = silu_and_mul(gate_up) + + # 2. Rotary embeddings for the new query and key tokens + q_r, k_r = apply_rope(q_new, k_new, indptr, offsets) + + # 3. Decode attention over the full KV-cache with the RoPE'd query + attn_out = decode_attention(q_r.reshape(NUM_QO, HEAD_DIM), k_cache, v_cache) + + return ffn_out, attn_out + + +ffn_out, attn_out = decode_step(gate_up, q_new, k_new, k, v, indptr, offsets) + +# Validate against calling each kernel individually (outside @jax.jit) +ffn_ref = silu_and_mul(gate_up) +q_r, k_r = apply_rope(q_new, k_new, indptr, offsets) +attn_ref = decode_attention(q_r.reshape(NUM_QO, HEAD_DIM), k, v) + +np.testing.assert_allclose( + np.array(ffn_out.astype(jnp.float32)), + np.array(ffn_ref.astype(jnp.float32)), + rtol=1e-2, + atol=1e-2, +) +np.testing.assert_allclose( + np.array(attn_out.astype(jnp.float32)), + np.array(attn_ref.astype(jnp.float32)), + rtol=1e-2, + atol=1e-2, +) + +print("@jax.jit composition: PASSED") +print(f" gate_up {gate_up.shape} -> ffn_out {ffn_out.shape}") +print(f" q_new {q_new.shape} -> attn_out {attn_out.shape}") + +# -- Latency benchmark ---------------------------------------------------------- + +_ = decode_attention(q, k, v).block_until_ready() # warm-up (triggers XLA compilation) +N = 100 +t0 = time.perf_counter() +for _ in range(N): + decode_attention(q, k, v).block_until_ready() +us = (time.perf_counter() - t0) / N * 1e6 +print( + f"\ndecode_attention kv_len={KV_LEN}, {NUM_QO}/{NUM_KV} GQA heads -> {us:.1f} us" +) + +# %% +# Summary +# ------- +# +# +# You have applied the JAX-TVM FFI bridge to three real LLM inference kernels, each revealing a new layer of the pattern. +# +# +# .. code-block:: text +# +# Step 1 BUILD & LOAD jit_spec.build_and_load() -> tvm_ffi.Module +# Step 2 REGISTER jax_tvm_ffi.register_ffi_target(name, wrapper, arg_spec) +# Step 3 CALL jax.ffi.ffi_call(name, output_shapes)(*inputs, **scalar_attrs) +# +# +# FlashInfer provides the same bridge for batch decode with paged KV-cache, variable-length prefill attention, fused mixture-of-experts, quantized GEMM, and more. Every kernel uses the same three-step recipe summarized above. The main variables are: +# +# - How many output tensors to declare, and whether any are scratch buffers to discard after the call +# - Whether any inputs or outputs are optional (use the empty-tensor sentinel) +# - Whether the kernel needs type-specialized compilation (Jinja template rendering) or dispatches at runtime over dtypes +# +# +# Beyond the examples in this tutorial +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# +# This tutorial demonstrated ``silu_and_mul``, ``apply_rope``, and ``single_decode`` as representative examples, but FlashInfer's strength lies in its broader library of high-performance kernels - including Multi-head Latent Attention (MLA), sparse attention, TensorRT-LLM generative batch attention, and fused Mixture-of-Experts (MoE). The same three-step ``jax-tvm-ffi`` pattern shown here applies directly to all of them: compile the kernel, register the wrapper, and call it from JAX. No changes to the bridge are needed - only the TVM function signature and ``arg_spec`` vary per kernel. +# +# For more details on ``jax-tvm-ffi`` itself - including how to wrap your own C++ or Python TVM functions - see the `jax-tvm-ffi documentation and examples `_. diff --git a/docs/tutorials/jax_tvm_ffi/gemma3_flashinfer_jax.py b/docs/tutorials/jax_tvm_ffi/gemma3_flashinfer_jax.py new file mode 100644 index 0000000000..41f57242b9 --- /dev/null +++ b/docs/tutorials/jax_tvm_ffi/gemma3_flashinfer_jax.py @@ -0,0 +1,1162 @@ +# This file is the canonical Sphinx-Gallery source for the tutorial. +# The Sphinx docs build generates HTML, a downloadable .py file, and a +# downloadable .ipynb notebook from this single source. + +""" +Gemma 3 on JAX with FlashInfer and the JAX TVM FFI Bridge +========================================================= + + +Overview +-------- + + +:doc:`flashinfer_jax_tvm_ffi` built three FlashInfer kernels from scratch and wired them into JAX as XLA custom calls. This tutorial connects those same primitives to a real language model: **Gemma 3 1B Instruct**, Google's open-weight instruction-tuned LLM. + +Every Gemma 3 transformer layer uses the following kernels: + + +.. list-table:: + :header-rows: 1 + + * - Part 1 kernel + - Role in Gemma 3 + * - ``gelu_tanh_and_mul`` + - Gated FFN activation (GeGLU variant of SiLU-GLU) + * - ``apply_rope`` + - Query and key positional embeddings - with two different theta values + * - ``decode_attention`` + - Attention over the growing KV-cache - two compiled variants + + +Three things are new compared to Part 1: + +- **GeGLU instead of SiLU-GLU** - Gemma 3 uses ``gelu_tanh`` for its gated FFN; FlashInfer ships this as a one-word change from ``silu``. +- **QK-norm** - per-head RMSNorm applied to Q and K before computing dot products, replacing the logit soft-capping that Gemma 2 used. +- **Dual RoPE theta** - local-attention layers use theta = 10 000; global-attention layers use theta = 1 000 000. We select the right value per layer and pass it to ``apply_rope``. + + +Preliminaries +------------- + + +Hardware and software requirements +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + +.. list-table:: + :header-rows: 0 + + * - GPU + - NVIDIA, SM 7.5+ (Turing or later) + * - Python packages + - ``jax``, ``jax-tvm-ffi``, ``flashinfer``, ``torch``, ``safetensors``, ``huggingface_hub``, ``transformers`` + * - HuggingFace + - Account with Gemma 3 licence accepted - `request access `_ +""" + +# %% +# Setting the environment +# ~~~~~~~~~~~~~~~~~~~~~~~ +# +# +# If you haven't gone through :doc:`flashinfer_jax_tvm_ffi`, refer to it for the JAX and FlashInfer installation instructions. +# +# Four additional packages are required: +# +# +# .. list-table:: +# :header-rows: 1 +# +# * - Package +# - Role +# * - ``torch`` (CPU) +# - Provides ``torch.dtype`` literals used by FlashInfer's JIT API +# * - ``safetensors`` +# - Efficient loading of model weights from the HuggingFace format +# * - ``huggingface_hub`` +# - Model download from the HuggingFace Hub +# * - ``transformers`` +# - Tokenizer and chat-template formatting +# +# +# Run the cell below only once in your environment. + +# %% +# Install the tutorial dependencies before running the notebook or script:: +# +# pip install torch --index-url https://download.pytorch.org/whl/cpu +# pip install safetensors huggingface_hub transformers + +# %% +# Loading dependencies +# ~~~~~~~~~~~~~~~~~~~~ + +# %% +# Run the cell below to load the dependencies. + +# %% +import json +import math +import os +import time +import subprocess + +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # suppress TF/XLA info & warnings + +import importlib.util + +IN_COLAB = importlib.util.find_spec("google.colab") is not None + +if "CUDA_HOME" not in os.environ: + try: + nvcc = subprocess.check_output(["which", "nvcc"], text=True).strip() + os.environ["CUDA_HOME"] = str(os.path.dirname(os.path.dirname(nvcc))) + except subprocess.CalledProcessError: + os.environ["CUDA_HOME"] = "/usr/local/cuda" + +if "--xla_gpu_cuda_data_dir=" not in os.environ.get("XLA_FLAGS", ""): + os.environ["XLA_FLAGS"] = ( + f"{os.environ.get('XLA_FLAGS', '')} " + f"--xla_gpu_cuda_data_dir={os.environ['CUDA_HOME']}" + ).strip() + +import jax +import jax.numpy as jnp + +from transformers import AutoTokenizer +from huggingface_hub import snapshot_download, HfApi +from safetensors import safe_open +import jax_tvm_ffi + +print(f"JAX: {jax.__version__}") +print(f"Devices: {jax.devices()}") +print(f"CUDA home: {os.environ['CUDA_HOME']}") + +# %% +# HuggingFace Authentication +# ~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# +# Gemma 3 is a gated model. Before downloading the weights, you need to accept the licence on the HuggingFace model page - visit `google/gemma-3-1b-it `_ and click *Request access*. +# +# The cell below reads your token from the ``HF_TOKEN`` environment variable, falls back to the Colab Secrets API if running on Colab, or prompts interactively. + +# %% +# -- HuggingFace Authentication -------------------------------------------------- +# Accept the Gemma 3 license at: https://huggingface.co/google/gemma-3-1b + +if IN_COLAB: + from google.colab import userdata + + HF_TOKEN = userdata.get("HF_TOKEN") +else: + HF_TOKEN = os.environ.get("HF_TOKEN", "") + +if not HF_TOKEN: + from getpass import getpass + + HF_TOKEN = getpass( + "Hugging Face token not found in environment. Please enter it here: " + ) + +if not HF_TOKEN: + raise RuntimeError("Authentication failed: Hugging Face token is not set.") + +# Ensure token is set in this process +os.environ["HF_TOKEN"] = HF_TOKEN + +# Verify identity +api = HfApi() +user_info = api.whoami(token=HF_TOKEN) +username = user_info.get("name") or "Unknown user" +print(f"Authenticated with Hugging Face successfully as: {username}") + +# %% +# Downloading the model weights +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# +# The cell below downloads all model shards (~2 GB on first run) from the HuggingFace Hub, loads them as ``bfloat16`` JAX arrays, and instantiates the tokenizer. Weights are cached locally; subsequent runs skip the download. + +# %% +MODEL_ID = "google/gemma-3-1b-it" +HF_CACHE = os.environ.get("HF_HOME", os.path.expanduser("~/.cache/huggingface")) + +# -- Tokenizer ------------------------------------------------------------------ +print(f"Loading tokenizer from {MODEL_ID}...") +tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN, cache_dir=HF_CACHE) + +# -- Weights -------------------------------------------------------------------- +print("Downloading model weights (~2 GB on first run)...") +model_dir = snapshot_download(MODEL_ID, token=HF_TOKEN, cache_dir=f"{HF_CACHE}/hub") + +# Weights are split across shards - discover the full list from the index file +index_path = os.path.join(model_dir, "model.safetensors.index.json") +if os.path.exists(index_path): + with open(index_path) as f: + shard_files = sorted(set(json.load(f)["weight_map"].values())) +else: + shard_files = ["model.safetensors"] + +print(f"Loading {len(shard_files)} shard(s) as JAX bfloat16 arrays...") +weights = {} +for shard in shard_files: + with safe_open(os.path.join(model_dir, shard), framework="numpy") as f: + for key in f.keys(): + # jnp.array handles any numpy dtype (float32, bfloat16, ...) -> bfloat16 + weights[key] = jnp.array(f.get_tensor(key), dtype=jnp.bfloat16) + +n_params = sum(int(w.size) for w in weights.values()) +print(f"Loaded {len(weights)} tensors ({n_params / 1e9:.2f} B parameters)") +print(f" embed_tokens: {weights['model.embed_tokens.weight'].shape}") +print( + f" layer 0 q_proj: {weights['model.layers.0.self_attn.q_proj.weight'].shape}" +) +print( + f" layer 0 q_norm: {weights['model.layers.0.self_attn.q_norm.weight'].shape}" +) +print(f" layer 0 gate_proj: {weights['model.layers.0.mlp.gate_proj.weight'].shape}") + +# %% +# Gemma 3 Transformer Layer +# ------------------------- +# +# +# Each Gemma 3 1B layer has a **sandwich-norm** structure: RMSNorm before *and* after each sub-layer. +# +# +# .. code-block:: text +# +# -- Prefill (prompt, T tokens in parallel) ------------------------------------ +# x --+-- RMSNorm (input_layernorm) ------------------------------------------+ +# | Q, K, V <- linear projections | +# | Q, K <- QK-norm (per-head RMSNorm, new in Gemma 3) | +# | Q, K <- `apply_rope` (theta = local / global, per layer) | +# | out <- `prefill_attention` (causal, local or global) | +# | RMSNorm (post_attention_layernorm) | +# +-- + -------------------------------------------------------------------+ +# +# -- Decode (one new token at a time) ------------------------------------------ +# x --+-- RMSNorm (input_layernorm) ------------------------------------------+ +# | Q, K, V <- linear projections | +# | Q, K <- QK-norm (per-head RMSNorm, new in Gemma 3) | +# | Q, K <- `apply_rope` (theta = local / global, per layer) | +# | K, V -> KV-cache append | +# | out <- `decode_attention` (local sliding-window or global) | +# | RMSNorm (post_attention_layernorm) | +# +-- + -------------------------------------------------------------------+ +# +# -- Shared FFN (same code for prefill and decode) ----------------------------- +# x --+-- RMSNorm (pre_feedforward_layernorm) ----------------------------------+ +# | gate, up <- separate linear projections | +# | hidden <- `gelu_tanh_and_mul`( concat(gate, up) ) | +# | out <- down_proj(hidden) | +# | RMSNorm (post_feedforward_layernorm) | +# +-- + --------------------------------------------------------------------+ +# +# +# Local vs global attention - the 5:1 pattern +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# +# Gemma 3 alternates attention span in a repeating 5:1 pattern: +# +# +# .. list-table:: +# :header-rows: 1 +# +# * - Layer type +# - Frequency +# - Attends to +# - RoPE theta +# * - **Local** +# - 5 of every 6 +# - Last ``sliding_window`` tokens +# - ``rope_local_base_freq`` +# * - **Global** +# - 1 of every 6 +# - Full KV-cache +# - ``rope_theta`` +# +# +# For Gemma 3 1B (26 layers): layers **5, 11, 17, 23** are global; the remaining 22 are local. Exact values for window size and theta are read from ``config.json`` in the next cell. + +# %% +with open(os.path.join(model_dir, "config.json")) as _f: + _raw = json.load(_f) + +# Gemma 3 wraps the language model config under "text_config" in its multimodal JSON +cfg = _raw.get("text_config", _raw) + +HIDDEN = cfg["hidden_size"] +INTERMEDIATE = cfg["intermediate_size"] +N_LAYERS = cfg["num_hidden_layers"] +N_Q = cfg["num_attention_heads"] +N_KV = cfg["num_key_value_heads"] +HEAD_DIM = cfg.get("head_dim", HIDDEN // N_Q) +VOCAB = cfg["vocab_size"] +RMS_EPS = cfg.get("rms_norm_eps", 1e-6) +SLIDING_WINDOW = cfg.get("sliding_window", 1024) +SM_SCALE = 1.0 / math.sqrt(HEAD_DIM) + +# Dual RoPE theta: local layers use a small base, global layers use a large base. +# Gemma 3 stores these as rope_local_base_freq (local) and rope_theta (global). +ROPE_THETA_LOCAL = int(cfg.get("rope_local_base_freq", 10_000)) +ROPE_THETA_GLOBAL = int(cfg.get("rope_theta", 1_000_000)) + + +def is_global(layer_idx: int) -> bool: + """True for global (full-attention) layers (Gemma 3 1B: 5, 11, 17, 23).""" + return (layer_idx + 1) % 6 == 0 + + +print("Architecture loaded from config.json:") +print(f" hidden={HIDDEN}, intermediate={INTERMEDIATE}, layers={N_LAYERS}") +print(f" N_Q={N_Q}, N_KV={N_KV}, head_dim={HEAD_DIM} (GQA ratio {N_Q // N_KV}x)") +print(f" vocab={VOCAB}, rms_eps={RMS_EPS}") +print(f" sliding_window={SLIDING_WINDOW}") +print( + f" rope_theta_local={ROPE_THETA_LOCAL:,}, rope_theta_global={ROPE_THETA_GLOBAL:,}" +) +print() +print(f"{'Layer':>5} {'Type':>8} {'RoPE theta':>12} {'Window':>8}") +print("-" * 42) +for i in range(N_LAYERS): + kind = "global" if is_global(i) else "local" + theta = ROPE_THETA_GLOBAL if is_global(i) else ROPE_THETA_LOCAL + window = "full" if is_global(i) else f"{SLIDING_WINDOW:,}" + print(f"{i:>5} {kind:>8} {theta:>12,} {window:>8}") + +# %% +# Concept 1: GeGLU - ``gelu_tanh`` replaces ``silu`` +# -------------------------------------------------- +# +# +# Part 1 used FlashInfer's ``silu_and_mul`` kernel. Gemma 3 swaps the gate activation: +# +# +# .. code-block:: text +# +# SiLU-GLU (Llama, Gemma 2): out = silu(gate) * up +# GeGLU (Gemma 3): out = gelu_tanh(gate) * up +# +# +# where ``gelu_tanh`` is the tanh-approximated GELU, matching ``torch.nn.functional.gelu(x, approximate="tanh")``. +# +# FlashInfer ships all three variants - ``silu``, ``gelu``, ``gelu_tanh`` - through the same ``gen_act_and_mul_module`` interface. Switching from Part 1 is a one-word change: +# +# +# .. code-block:: python +# +# # Part 1 +# silu_module = gen_act_and_mul_module('silu').build_and_load() +# +# # Part 2 - Gemma 3 +# gelu_module = gen_act_and_mul_module('gelu_tanh').build_and_load() +# +# +# Everything else - the three-step bridge pattern, the wrapper, the ``ffi_call`` shape declaration - is identical. + +# %% +# Compile and register all kernels +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# +# All four kernel pairs - ``gelu_tanh_and_mul``, ``apply_rope``, and local/global variants of both decode and prefill attention - are compiled and registered in a single cell below. Concepts 2 (QK-norm) and 3 (dual RoPE theta) are explained in the sections that follow; they require no additional kernels beyond what Part 1 introduced. + +# %% +import torch as _torch # used only for dtype spec in gen_single_*_module +from flashinfer.jit import ( + gen_act_and_mul_module, + gen_single_decode_module, + gen_single_prefill_module, +) +from flashinfer.jit.rope import gen_rope_module + +# -- 1. gelu_tanh_and_mul ------------------------------------------------------ +print("Compiling gelu_tanh_and_mul...") +_gelu_mod = gen_act_and_mul_module("gelu_tanh").build_and_load() + + +def _gelu_wrapper(out, x, enable_pdl): + _gelu_mod.gelu_tanh_and_mul(out, x, enable_pdl) + + +jax_tvm_ffi.register_ffi_target( + "flashinfer.gelu_tanh_and_mul", + _gelu_wrapper, + arg_spec=["rets", "args", "attrs.enable_pdl"], + platform="gpu", + allow_cuda_graph=True, + pass_owned_tensor=True, +) + + +def gelu_and_mul(x: jax.Array) -> jax.Array: + """Fused gelu_tanh(gate) * up. Input: [..., 2H] Output: [..., H]""" + out_shape = x.shape[:-1] + (x.shape[-1] // 2,) + return jax.ffi.ffi_call( + "flashinfer.gelu_tanh_and_mul", + jax.ShapeDtypeStruct(out_shape, x.dtype), + vmap_method="broadcast_all", # element-wise op: independent across any batch dim + )(x, enable_pdl=False) + + +# -- 2. apply_rope ------------------------------------------------------------- +print("Compiling apply_rope...") +_rope_mod = gen_rope_module().build_and_load() + + +def _rope_wrapper( + q_rope, + k_rope, + q, + k, + indptr, + offsets, + rotary_dim, + interleave, + rope_scale, + rope_theta, +): + _rope_mod.apply_rope( + q, + k, + q_rope, + k_rope, + indptr, + offsets, + rotary_dim, + interleave, + rope_scale, + rope_theta, + ) + + +jax_tvm_ffi.register_ffi_target( + "flashinfer.apply_rope", + _rope_wrapper, + arg_spec=[ + "rets", + "args", + "attrs.rotary_dim", + "attrs.interleave", + "attrs.rope_scale", + "attrs.rope_theta", + ], + platform="gpu", + allow_cuda_graph=True, + pass_owned_tensor=True, +) + + +def apply_rope(q, k, indptr, offsets, rope_theta=1e4): + """Apply RoPE to packed batches. Returns (q_rope, k_rope).""" + return jax.ffi.ffi_call( + "flashinfer.apply_rope", + ( + jax.ShapeDtypeStruct(q.shape, q.dtype), + jax.ShapeDtypeStruct(k.shape, k.dtype), + ), + vmap_method="broadcast_all", # each packed batch is independent; safe to vmap + )( + q, + k, + indptr, + offsets, + rotary_dim=q.shape[-1], + interleave=False, + rope_scale=1.0, + rope_theta=float(rope_theta), + ) + + +# -- 3. decode_attention: local + global variants ------------------------------- +_TMP_ELEMS = 32 * 1024 * 1024 // 2 # 32 MB scratch buffer in bfloat16 elements + +print(f"Compiling decode attention (local, sliding-window={SLIDING_WINDOW})...") +_local_dec_mod = gen_single_decode_module( + _torch.bfloat16, + _torch.bfloat16, + _torch.bfloat16, + HEAD_DIM, + HEAD_DIM, + 0, + True, + False, +).build_and_load() +print("Compiling decode attention (global, full attention)...") +_global_dec_mod = gen_single_decode_module( + _torch.bfloat16, + _torch.bfloat16, + _torch.bfloat16, + HEAD_DIM, + HEAD_DIM, + 0, + False, + False, +).build_and_load() + + +def _make_decode_wrapper(run_fn): + def _w( + out, + tmp, + lse_or_empty, + q, + k, + v, + alibi_or_empty, + layout, + window_left, + logits_soft_cap, + sm_scale, + rope_rcp_scale, + rope_rcp_theta, + ): + lse = None if lse_or_empty.shape[0] == 0 else lse_or_empty + alibi = None if alibi_or_empty.shape[0] == 0 else alibi_or_empty + run_fn( + q, + k, + v, + tmp, + out, + lse, + layout, + window_left, + alibi, + logits_soft_cap, + sm_scale, + rope_rcp_scale, + rope_rcp_theta, + ) + + return _w + + +_DEC_ARG_SPEC = [ + "rets", + "args", + "attrs.layout", + "attrs.window_left", + "attrs.logits_soft_cap", + "attrs.sm_scale", + "attrs.rope_rcp_scale", + "attrs.rope_rcp_theta", +] +_KW = dict(platform="gpu", allow_cuda_graph=True, pass_owned_tensor=True) + +jax_tvm_ffi.register_ffi_target( + "flashinfer.decode_local", + _make_decode_wrapper(_local_dec_mod.run), + _DEC_ARG_SPEC, + **_KW, +) +jax_tvm_ffi.register_ffi_target( + "flashinfer.decode_global", + _make_decode_wrapper(_global_dec_mod.run), + _DEC_ARG_SPEC, + **_KW, +) + + +def decode_attention(q, k_cache, v_cache, global_layer=False): + """Single-request GQA decode attention. + + q: [N_Q, HEAD_DIM] bfloat16 + k_cache: [seq_len, N_KV, HEAD_DIM] bfloat16 + v_cache: [seq_len, N_KV, HEAD_DIM] bfloat16 + Returns: [N_Q, HEAD_DIM] + """ + target = "flashinfer.decode_global" if global_layer else "flashinfer.decode_local" + window = -1 if global_layer else SLIDING_WINDOW + out, _, _ = jax.ffi.ffi_call( + target, + ( + jax.ShapeDtypeStruct(q.shape, jnp.bfloat16), + jax.ShapeDtypeStruct((_TMP_ELEMS,), jnp.bfloat16), + jax.ShapeDtypeStruct((0,), jnp.float32), + ), + # vmap_method intentionally omitted: the scratch buffer has no batch + # dimension, and GQA head-grouping does not decompose over an outer batch axis. + )( + q, + k_cache, + v_cache, + jnp.empty((0,), dtype=jnp.float32), + layout=0, + window_left=window, + logits_soft_cap=0.0, + sm_scale=SM_SCALE, + rope_rcp_scale=1.0, + rope_rcp_theta=1.0, + ) + return out + + +# -- 4. prefill_attention: local + global variants ------------------------------ +print(f"Compiling prefill attention (local, sliding-window={SLIDING_WINDOW})...") +_local_pre_mod = gen_single_prefill_module( + "fa2", + _torch.bfloat16, + _torch.bfloat16, + _torch.bfloat16, + HEAD_DIM, + HEAD_DIM, + 0, + True, + False, + False, +).build_and_load() +print("Compiling prefill attention (global, full attention)...") +_global_pre_mod = gen_single_prefill_module( + "fa2", + _torch.bfloat16, + _torch.bfloat16, + _torch.bfloat16, + HEAD_DIM, + HEAD_DIM, + 0, + False, + False, + False, +).build_and_load() + + +def _make_prefill_wrapper(run_fn): + def _w( + out, + tmp, + lse_or_empty, + q, + k, + v, + alibi_or_empty, + mask_mode_code, + layout, + window_left, + logits_soft_cap, + sm_scale, + rope_rcp_scale, + rope_rcp_theta, + ): + lse = None if lse_or_empty.shape[0] == 0 else lse_or_empty + alibi = None if alibi_or_empty.shape[0] == 0 else alibi_or_empty + run_fn( + q, + k, + v, + tmp, + out, + lse, + mask_mode_code, + layout, + window_left, + None, + alibi, + logits_soft_cap, + sm_scale, + rope_rcp_scale, + rope_rcp_theta, + ) + + return _w + + +_PRE_ARG_SPEC = [ + "rets", + "args", + "attrs.mask_mode_code", + "attrs.layout", + "attrs.window_left", + "attrs.logits_soft_cap", + "attrs.sm_scale", + "attrs.rope_rcp_scale", + "attrs.rope_rcp_theta", +] + +jax_tvm_ffi.register_ffi_target( + "flashinfer.prefill_local", + _make_prefill_wrapper(_local_pre_mod.run), + _PRE_ARG_SPEC, + **_KW, +) +jax_tvm_ffi.register_ffi_target( + "flashinfer.prefill_global", + _make_prefill_wrapper(_global_pre_mod.run), + _PRE_ARG_SPEC, + **_KW, +) + + +def prefill_attention(q, k, v, layer_i): + """FlashInfer causal GQA attention for processing a multi-token prompt. + + Uses the same kernel-bridge pattern as decode_attention: mask_mode_code=1 + selects causal masking; window_left controls the sliding-window cut-off. + + q, k, v: [T, n_heads, HEAD_DIM] bfloat16 + Returns: [T, N_Q, HEAD_DIM] bfloat16 + """ + glob = is_global(layer_i) + target = "flashinfer.prefill_global" if glob else "flashinfer.prefill_local" + window = -1 if glob else SLIDING_WINDOW + out, _, _ = jax.ffi.ffi_call( + target, + ( + jax.ShapeDtypeStruct(q.shape, jnp.bfloat16), # output + jax.ShapeDtypeStruct((_TMP_ELEMS,), jnp.bfloat16), # tmp scratch + jax.ShapeDtypeStruct((0,), jnp.float32), + ), # lse (discard) + # vmap_method intentionally omitted: scratch buffer + GQA head-grouping + # do not decompose over an outer batch axis. + )( + q, + k, + v, + jnp.empty((0,), dtype=jnp.float32), # alibi = empty sentinel + mask_mode_code=1, + layout=0, + window_left=window, + logits_soft_cap=0.0, + sm_scale=SM_SCALE, + rope_rcp_scale=1.0, + rope_rcp_theta=1.0, + ) + return out + + +print("All kernels compiled and registered.") + +# %% +# Concept 2: QK-norm - per-head normalization +# ------------------------------------------- +# +# +# Gemma 2 bounded attention score magnitudes with logit soft-capping: ``scores = tanh(scores / 50) x 50``. Gemma 3 replaces this with **QK-norm**: an RMSNorm applied independently to each query and key head *after* the linear projection and *before* the dot product. +# +# +# .. code-block:: python +# +# # Gemma 2 (inside the attention kernel, via logits_soft_cap parameter) +# scores = tanh(q @ k.T / sqrt(d) / 50) * 50 +# +# # Gemma 3 (in JAX, before calling decode_attention) +# q = rms_norm_per_head(q, q_norm_weight) # [N_Q, head_dim] +# k = rms_norm_per_head(k, k_norm_weight) # [N_KV, head_dim] +# scores = q @ k.T / sqrt(d) # bounded by weight norms +# +# +# The norm weights ``q_norm.weight`` and ``k_norm.weight`` have shape ``[head_dim]`` - the same weight is shared across all heads. In the model state dict they are ``model.layers.{i}.self_attn.q_norm.weight``. +# +# +# Concept 3: Dual RoPE theta - one theta per attention scope +# ---------------------------------------------------------- +# +# +# Standard RoPE uses a single base frequency theta. Gemma 3 uses two: +# +# +# .. list-table:: +# :header-rows: 1 +# +# * - Layer type +# - theta +# - Why +# * - Local (5/6 of layers) +# - 10 000 +# - Standard positional bias for the 1 024-token window +# * - Global (1/6 of layers) +# - 1 000 000 +# - Slower-decaying frequencies for long-range context +# +# +# In code this is a single ``if`` in the layer function - we select the right theta and pass it to ``apply_rope`` as a scalar attribute. The kernel is compiled once; the theta value is a runtime parameter. + +# %% +# -- Pure-JAX building blocks --------------------------------------------------- + + +@jax.jit +def rms_norm(x, weight, eps=RMS_EPS): + """Gemma-style RMSNorm: normalise then scale by (1 + weight).""" + x32 = x.astype(jnp.float32) + y = x32 * jax.lax.rsqrt(jnp.mean(x32**2, axis=-1, keepdims=True) + eps) + return y.astype(x.dtype) * (1.0 + weight) + + +@jax.jit +def qk_norm(x, weight): + """Per-head RMSNorm for Q or K vectors. x: [..., head_dim].""" + return rms_norm(x, weight) + + +def embed(token_ids): + """Embedding lookup. Gemma multiplies by sqrt(hidden_size) to keep + hidden-state norms stable through the first RMSNorm. + + token_ids: [T] -> [T, HIDDEN] + """ + return weights["model.embed_tokens.weight"][token_ids] * math.sqrt(HIDDEN) + + +def lm_head(h): + """Project hidden state to vocabulary logits. h: [HIDDEN] -> [VOCAB] float32.""" + # Gemma 3 ties the LM head to the embedding matrix + lm_w = weights.get("lm_head.weight", weights["model.embed_tokens.weight"]) + return h.astype(jnp.float32) @ lm_w.astype(jnp.float32).T + + +def ffn(h, layer_i): + """GeGLU feed-forward block. h: [..., HIDDEN] -> [..., HIDDEN]. + + Handles both single-token decode (h: [HIDDEN]) and + full-sequence prefill (h: [T, HIDDEN]) with the same code. + """ + pre = rms_norm( + h, weights[f"model.layers.{layer_i}.pre_feedforward_layernorm.weight"] + ) + gate = ( + pre @ weights[f"model.layers.{layer_i}.mlp.gate_proj.weight"].T + ) # [..., INTER] + up = pre @ weights[f"model.layers.{layer_i}.mlp.up_proj.weight"].T # [..., INTER] + + # Concatenate along the last axis: gelu_and_mul splits it back in two + gate_up = jnp.concatenate([gate, up], axis=-1) # [..., 2*INTER] + hidden = gelu_and_mul(gate_up) # [..., INTER] <- FlashInfer kernel + + out = hidden @ weights[f"model.layers.{layer_i}.mlp.down_proj.weight"].T + out = rms_norm( + out, weights[f"model.layers.{layer_i}.post_feedforward_layernorm.weight"] + ) + return out + + +# Quick sanity check on the FFN +_x_test = jax.random.normal(jax.random.key(0), (HIDDEN,), dtype=jnp.bfloat16) +_out = ffn(_x_test, 0) +print( + f"FFN layer 0: {_x_test.shape} -> {_out.shape} dtype={_out.dtype} ok={not jnp.any(jnp.isnan(_out))}" +) + +# %% +# Prefill layer and full forward pass +# ----------------------------------- +# +# +# ``prefill_layer`` processes all prompt tokens in parallel through one transformer layer and builds the initial KV-cache. ``prefill`` chains it across all 26 layers, then applies ``rms_norm`` to the last token's hidden state and returns the per-layer KV-caches that the decode loop will update. + +# %% +# -- Prefill layer (processes all T prompt tokens in parallel) ------------------ + + +def prefill_layer(h, layer_i): + """Run one transformer layer over the full prompt. + + h: [T, HIDDEN] bfloat16 + Returns: (h: [T, HIDDEN], kv_cache: (k: [T, N_KV, D], v: [T, N_KV, D])) + """ + T = h.shape[0] + glob = is_global(layer_i) + rope_theta = ROPE_THETA_GLOBAL if glob else ROPE_THETA_LOCAL + + # -- Attention ------------------------------------------------------------- + ln = rms_norm(h, weights[f"model.layers.{layer_i}.input_layernorm.weight"]) + + q = (ln @ weights[f"model.layers.{layer_i}.self_attn.q_proj.weight"].T).reshape( + T, N_Q, HEAD_DIM + ) + k = (ln @ weights[f"model.layers.{layer_i}.self_attn.k_proj.weight"].T).reshape( + T, N_KV, HEAD_DIM + ) + v = (ln @ weights[f"model.layers.{layer_i}.self_attn.v_proj.weight"].T).reshape( + T, N_KV, HEAD_DIM + ) + + # QK-norm (per head, same weight across all token positions) + q = qk_norm(q, weights[f"model.layers.{layer_i}.self_attn.q_norm.weight"]) + k = qk_norm(k, weights[f"model.layers.{layer_i}.self_attn.k_norm.weight"]) + + # RoPE over all T tokens at once: one sequence starting at offset 0 + indptr = jnp.array([0, T], dtype=jnp.int32) + offsets = jnp.array([0], dtype=jnp.int32) + q, k = apply_rope(q, k, indptr, offsets, rope_theta=rope_theta) + + # FlashInfer causal attention + attn_out = prefill_attention(q, k, v, layer_i) + + attn_out = attn_out.reshape(T, N_Q * HEAD_DIM) + attn_out = attn_out @ weights[f"model.layers.{layer_i}.self_attn.o_proj.weight"].T + attn_out = rms_norm( + attn_out, weights[f"model.layers.{layer_i}.post_attention_layernorm.weight"] + ) + h = h + attn_out + + # -- FFN (works naturally for [T, HIDDEN]) -------------------------------- + h = h + ffn(h, layer_i) + + # KV-cache: store the RoPE-applied K and raw V for all prompt positions + return h, (k, v) + + +# -- Full prefill pass --------------------------------------------------------- + + +def prefill(prompt_ids): + """Process the full prompt. Returns (h_last: [HIDDEN], kv_caches).""" + h = embed(jnp.array(prompt_ids)) # [T, HIDDEN] + + kv_caches = [] + for i in range(N_LAYERS): + h, kv_cache = prefill_layer(h, i) + kv_caches.append(kv_cache) + + # Final norm applied to the last token's hidden state + h_last = rms_norm(h[-1], weights["model.norm.weight"]) # [HIDDEN] + return h_last, kv_caches + + +# %% +# Decode attention layer +# ---------------------- +# +# +# ``decode_layer`` processes one newly generated token through a full transformer layer. It applies QK-norm, selects the layer's RoPE theta, calls ``apply_rope`` and ``decode_attention``, appends to the KV-cache, and returns the updated hidden state. + +# %% +# -- Decode attention layer (one new token, growing KV-cache) ------------------- + + +def decode_layer(h, layer_i, kv_cache, pos): + """Process a single new token through one transformer layer. + + h: [HIDDEN] bfloat16 + kv_cache: (k: [pos, N_KV, D], v: [pos, N_KV, D]) + pos: current token's position in the full sequence (Python int) + Returns: (h: [HIDDEN], updated_kv_cache) + """ + glob = is_global(layer_i) + rope_theta = ROPE_THETA_GLOBAL if glob else ROPE_THETA_LOCAL + + # -- Attention ------------------------------------------------------------- + ln = rms_norm(h, weights[f"model.layers.{layer_i}.input_layernorm.weight"]) + + q = (ln @ weights[f"model.layers.{layer_i}.self_attn.q_proj.weight"].T).reshape( + N_Q, HEAD_DIM + ) + k = (ln @ weights[f"model.layers.{layer_i}.self_attn.k_proj.weight"].T).reshape( + N_KV, HEAD_DIM + ) + v = (ln @ weights[f"model.layers.{layer_i}.self_attn.v_proj.weight"].T).reshape( + N_KV, HEAD_DIM + ) + + # QK-norm (Concept 2: Gemma 3 replaces soft-capping with per-head RMSNorm) + q = qk_norm(q, weights[f"model.layers.{layer_i}.self_attn.q_norm.weight"]) + k = qk_norm(k, weights[f"model.layers.{layer_i}.self_attn.k_norm.weight"]) + + # apply_rope with the layer's theta (Concept 3: different theta for local vs global) + q_pack, k_pack = q[None], k[None] # [1, heads, D] packed batch of 1 token + indptr = jnp.array([0, 1], dtype=jnp.int32) + offsets = jnp.array([pos], dtype=jnp.int32) + q_r, k_r = apply_rope(q_pack, k_pack, indptr, offsets, rope_theta=rope_theta) + q_r = q_r.squeeze(0) # [N_Q, D] + k_r = k_r.squeeze(0) # [N_KV, D] + + # Append RoPE'd K and raw V to the KV-cache + # NOTE: Using jnp.concatenate to grow KV cache is intentional. + # In standard JAX this is inefficient (O(N^2)) and you'd normally preallocate + # and use lax.dynamic_update_slice. However, FlashInfer's single-request + # decode kernel infers sequence length from k_cache/v_cache.shape. + # Therefore we must keep the cache length equal to the actual number of tokens. + # Switching to a fixed-size buffer would require a different FlashInfer API + # (e.g. paged KV cache) or an explicit length/mask. + k_cache, v_cache = kv_cache + k_cache = jnp.concatenate([k_cache, k_r[None]], axis=0) # [pos+1, N_KV, D] + v_cache = jnp.concatenate([v_cache, v[None]], axis=0) # [pos+1, N_KV, D] + + # Decode attention over the full KV-cache (FlashInfer kernel) + attn_out = decode_attention(q_r, k_cache, v_cache, global_layer=glob) # [N_Q, D] + + attn_out = attn_out.reshape(N_Q * HEAD_DIM) + attn_out = attn_out @ weights[f"model.layers.{layer_i}.self_attn.o_proj.weight"].T + attn_out = rms_norm( + attn_out, weights[f"model.layers.{layer_i}.post_attention_layernorm.weight"] + ) + h = h + attn_out + + # -- FFN ------------------------------------------------------------------- + h = h + ffn(h, layer_i) + + return h, (k_cache, v_cache) + + +# %% +# Decode step +# ----------- +# +# +# **Why there is no ``@jax.jit`` here** +# +# The FlashInfer kernels (``decode_attention``, ``apply_rope``, ...) are fully ``@jax.jit``-compatible XLA custom calls. The obstacle is the KV-cache. Each decode step appends one new row: +# +# ```python +# k_cache = jnp.concatenate([k_cache, k_r[None]], axis=0) # shape grows every step +# ``` +# +# ``@jax.jit`` requires statically known output shapes. Because ``k_cache.shape[0]`` increments at every step, XLA would have to recompile ``decode_step`` on each call - far more expensive than running eagerly. +# +# A production system fixes this by pre-allocating a maximum-length cache and writing into it with ``jax.lax.dynamic_update_slice``, which keeps shapes static and allows the entire decode loop to be compiled with ``jax.lax.scan``. That is the paged KV-cache direction described in the Summary. + +# %% +# -- One decode step ----------------------------------------------------------- + + +def decode_step(token_id, kv_caches, pos): + """Process one newly generated token and predict the next. + + token_id: int the most recently produced token + kv_caches: list one (k, v) tuple per layer + pos: int this token's position in the full sequence + Returns: (logits: [VOCAB] float32, updated_kv_caches) + """ + h = embed(jnp.array([token_id])).squeeze(0) # [HIDDEN] + + new_kv = [] + for i in range(N_LAYERS): + h, kv = decode_layer(h, i, kv_caches[i], pos) + new_kv.append(kv) + + h = rms_norm(h, weights["model.norm.weight"]) + logits = lm_head(h) + return logits, new_kv + + +# -- Stop tokens --------------------------------------------------------------- +# Gemma instruct ends its turn with , not the generic . +# Collect all token IDs that should halt generation. +_STOP_IDS = {tokenizer.eos_token_id} if tokenizer.eos_token_id is not None else set() +for _tok in ["", ""]: + _id = tokenizer.convert_tokens_to_ids(_tok) + if _id is not None and _id != tokenizer.unk_token_id: + _STOP_IDS.add(_id) + + +# -- Text generation ----------------------------------------------------------- + + +def generate(prompt, max_new_tokens=200, temperature=0.7, seed=0): + """Autoregressive generation with the Gemma 3 instruct chat template.""" + messages = [{"role": "user", "content": prompt}] + + # Render chat template to plain text first. + rendered = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + ) + + # Then tokenize explicitly and extract only input_ids. + enc = tokenizer(rendered, add_special_tokens=False) + prompt_ids = enc["input_ids"] + + # Flatten batch dimension if present. + if len(prompt_ids) > 0 and isinstance(prompt_ids[0], list): + prompt_ids = prompt_ids[0] + + T = len(prompt_ids) + key = jax.random.key(seed) + + print(f"Prompt ({T} tokens): {prompt!r}") + print(f"Rendered prompt preview: {rendered[:120]!r}") + + print("Prefilling...", end=" ", flush=True) + t0 = time.perf_counter() + h_last, kv_caches = prefill(prompt_ids) + jax.block_until_ready(h_last) + print(f"{time.perf_counter() - t0:.1f}s") + + def _sample(logits, key): + if temperature == 0.0: + return int(jnp.argmax(logits)), key + key, subkey = jax.random.split(key) + return int(jax.random.categorical(subkey, logits / temperature)), key + + print("Response: ", end="", flush=True) + + generated = [] + for step in range(max_new_tokens): + if step == 0: + logits = lm_head(h_last) + else: + logits, kv_caches = decode_step(generated[-1], kv_caches, T + step - 1) + next_tok, key = _sample(logits, key) + generated.append(next_tok) + if next_tok in _STOP_IDS: + break + print(tokenizer.decode([next_tok]), end="", flush=True) + + print() + return tokenizer.decode(generated, skip_special_tokens=True) + + +# %% +# Running inference +# ----------------- +# +# +# The cell below runs the model on three sample questions using the Gemma 3 instruct chat template. XLA compiles the kernels on the first call; subsequent prompts reuse the cached compilation. + +# %% +questions = [ + "What is the capital of Germany", + "How does rotary positional embedding differ from learned positional embedding", + "What is grouped-query attention and why is it useful", +] + +for q in questions: + generate(q, max_new_tokens=150, temperature=0.7, seed=0) + print() + +# %% +# Summary +# ------- +# +# +# We have implemented end-to-end autoregressive inference for Gemma 3 1B Instruct using four FlashInfer kernels as the computational backbone - covering both the prompt (prefill) and generation (decode) phases. +# +# +# The complete inference recipe +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# +# .. code-block:: python +# +# # 1. Compile kernels (once) +# gelu_module = gen_act_and_mul_module('gelu_tanh').build_and_load() +# rope_module = gen_rope_module().build_and_load() +# decode_local_module = gen_single_decode_module(..., use_sliding_window=True, ...).build_and_load() +# decode_global_module = gen_single_decode_module(..., use_sliding_window=False, ...).build_and_load() +# prefill_local = gen_single_prefill_module('fa2', ..., use_sliding_window=True).build_and_load() +# prefill_global = gen_single_prefill_module('fa2', ..., use_sliding_window=False).build_and_load() +# +# # 2. Prefill: FlashInfer causal attention over all prompt tokens -> KV-cache +# h_last, kv_caches = prefill(prompt_ids) +# +# # 3. Decode: FlashInfer decode_attention, one token at a time +# for step in range(max_new_tokens): +# if step == 0: +# logits = lm_head(h_last) +# else: +# logits, kv_caches = decode_step(generated[-1], kv_caches, T + step - 1) +# +# +# Going further +# ~~~~~~~~~~~~~ +# +# +# - **Paged KV-cache**: Replace the growing list with a fixed-size paged cache and use FlashInfer's ``BatchDecodeWithPagedKVCacheWrapper`` for batch inference with mixed sequence lengths. +# - **Sampling**: Extend the sampler with top-p nucleus sampling or top-k filtering on the logits. +# - **Continuous batching**: Process multiple requests simultaneously, filling the decode kernel's batch dimension.