Add examples of calling FlashInfer from JAX via jax-tvm-ffi#3092
Add examples of calling FlashInfer from JAX via jax-tvm-ffi#3092yongwww merged 1 commit intoflashinfer-ai:mainfrom
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds a new examples directory with READMEs, notebooks, and standalone scripts demonstrating JAX ↔ TVM FFI integration that registers and calls FlashInfer CUDA kernels, plus an end-to-end Gemma 3 inference example with environment, installation, HF auth, and troubleshooting guidance. Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant JAX
participant TVM as TVM_FFI
participant FlashInfer as FlashInfer_JIT
participant CUDA as CUDA_Kernel
User->>JAX: Call JAX wrapper (e.g., silu_and_mul)
JAX->>TVM: jax.ffi.ffi_call (args, shape/dtype spec)
TVM->>FlashInfer: lookup/load compiled module and invoke target
FlashInfer->>CUDA: Launch CUDA kernel
CUDA-->>FlashInfer: Kernel result buffer (rgba(0,128,0,0.5))
FlashInfer-->>TVM: Return result buffer
TVM-->>JAX: Produce JAX Array
JAX-->>User: Return result
sequenceDiagram
participant User
participant App as Gemma3_App
participant HF as HuggingFace
participant JAX
participant FlashInfer as FlashInfer_Kernels
User->>App: generate(prompt)
App->>HF: Authenticate & download weights
HF-->>App: Model shards (safetensors)
App->>JAX: Load shards → bfloat16 arrays, tokenize
App->>JAX: prefill(token_ids)
JAX->>FlashInfer: prefill_attention + ffn calls
FlashInfer-->>JAX: KV caches, logits
loop decode until EOS
App->>JAX: decode_step(token, caches)
JAX->>FlashInfer: decode_attention + ffn
FlashInfer-->>JAX: logits
JAX->>App: next token (sampling)
end
App-->>User: Generated text
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request introduces a suite of examples and tutorials for using FlashInfer GPU kernels within JAX via the TVM FFI bridge, including an end-to-end inference implementation for Gemma 3 1B Instruct. The review feedback highlights opportunities to enhance code robustness and performance, specifically by avoiding hardcoded data types in the decode_attention function and replacing inefficient jnp.concatenate operations for KV-cache growth with pre-allocated buffers and jax.lax.dynamic_update_slice.
There was a problem hiding this comment.
Actionable comments posted: 5
🧹 Nitpick comments (3)
examples/jax_tvm_ffi/flashinfer_jax_tvm_ffi.ipynb (1)
115-148: Self-containedness nit:subprocessis only imported in the earlier compute-capability cell.The
CUDA_HOMEauto-detection block at lines 126–131 referencessubprocess, but this cell does not re-import it. If a user re-runs this cell after a kernel restart without running the earlier cell, it willNameError. Addingimport subprocessalongside the other imports in this cell (or merging the two setup cells) makes it robust to out-of-order execution.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/jax_tvm_ffi/flashinfer_jax_tvm_ffi.ipynb` around lines 115 - 148, The CUDA_HOME auto-detection block in this cell uses subprocess but subprocess is not imported here, causing a NameError if this cell runs standalone; to fix it, add an import subprocess alongside the other imports at the top of the cell (near import os, time, math, jinja2, numpy) or merge this cell with the earlier setup cell so that subprocess is guaranteed to be defined before the CUDA_HOME detection code in the block that references nvcc via subprocess.check_output.examples/jax_tvm_ffi/gemma3_flashinfer_jax.py (1)
38-194: All setup/download/compile work executes at module import time.Everything from compute-capability detection (line 46) through HF authentication (line 124), weight download/load (lines 135–156) and kernel compilation/registration (lines 209–511) runs as top-level module code. Only the
for q in questions:loop (lines 755–762) is guarded byif __name__ == "__main__":. Consequences:
- Importing this file as a module in another script triggers ~2 GB of HF downloads and multiple nvcc builds.
getpass(...)at line 114 will block on import ifHF_TOKENisn't set.- Static analyzers / IDEs that import the file to collect symbols will also trigger the side effects.
Not required for a demo, but moving the download/compile/register blocks into
main()(or into lazily-invoked helpers) would make the example safer to re-use.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/jax_tvm_ffi/gemma3_flashinfer_jax.py` around lines 38 - 194, The module performs heavy side-effect work at import (compute-cap detection via get_compute_capability, HF cache setup via _ensure_writable_hf_cache, HF_TOKEN prompting, snapshot_download and weights loading into weights, and kernel compilation/registration) which should be deferred; refactor by moving all top-level side-effect logic into a guarded main() (or into explicit functions like prepare_environment(), authenticate_hf(), download_and_load_weights(), compile_and_register_kernels()) and ensure only lightweight definitions (functions/classes/constants) remain at import time, then call main() under if __name__ == "__main__": so importing the module no longer triggers downloads, getpass prompts, or nvcc builds.examples/jax_tvm_ffi/README.md (1)
89-93: Add a language tag to the fenced code block (optional).markdownlint MD040 flags this block. Using
textornonekeeps the ASCII diagram untouched while silencing the lint.♻️ Proposed diff
-``` +```text Step 1 BUILD & LOAD jit_spec.build_and_load() -> tvm_ffi.Module Step 2 REGISTER jax_tvm_ffi.register_ffi_target(name, wrapper, arg_spec) Step 3 CALL jax.ffi.ffi_call(name, output_shapes)(*inputs, **scalar_attrs)</details> <details> <summary>🤖 Prompt for AI Agents</summary>Verify each finding against the current code and only fix it if needed.
In
@examples/jax_tvm_ffi/README.mdaround lines 89 - 93, The fenced ASCII
diagram block containing the lines starting with "Step 1 BUILD & LOAD", "Step 2
REGISTER", and "Step 3 CALL" should include a language tag to satisfy
markdownlint MD040; change the opening fence fromtotext (or ```none) so
the block becomes a labeled code fence and the diagram remains unchanged.</details> </blockquote></details> </blockquote></details> <details> <summary>🤖 Prompt for all review comments with AI agents</summary>Verify each finding against the current code and only fix it if needed.
Inline comments:
In@examples/jax_tvm_ffi/gemma3_flashinfer_jax.ipynb:
- Around line 362-364: The docstring for is_global(layer_idx: int) is stale — it
says "18-layer model" and lists global layers as 5, 11, 17; update the docstring
in the is_global function to reflect the 26-layer Gemma 3 1B model and list the
four global full-attention layers as 5, 11, 17, 23 (keeping the implementation
return (layer_idx + 1) % 6 == 0 unchanged).- Around line 119-145: The cell uses subprocess.check_output (referenced as
subprocess and check_output) but never imports subprocess, causing a NameError
if run after a kernel restart; add an explicit top-level import for subprocess
in this notebook cell (alongside json, math, os, time, jax, jax.numpy) so
subprocess is available before calling subprocess.check_output when setting
CUDA_HOME and ensure this import appears before any use of SM_MAJOR/SM_MINOR
environment logic.- Around line 478-486: The TVM FFI expects a double but ROPE_THETA_LOCAL /
ROPE_THETA_GLOBAL are ints, so cast rope_theta to a float when calling the FFI:
update apply_rope to pass rope_theta=float(rope_theta) into the jax.ffi.ffi_call
invocation and also ensure any callers that forward rope_theta (e.g.,
prefill_layer and decode_layer) forward float(rope_theta) instead of the raw
int; locate these by the symbols apply_rope, prefill_layer, decode_layer and the
constants ROPE_THETA_LOCAL / ROPE_THETA_GLOBAL and replace the direct int usage
with float(...) to avoid scalar-type errors.In
@examples/jax_tvm_ffi/gemma3_flashinfer_jax.py:
- Around line 284-301: The apply_rope wrapper should ensure rope_theta is
keyword-only and passed to the TVM FFI as a Python float; update the function
signature to def apply_rope(q, k, indptr, offsets, *, rope_theta=1e4) and
convert rope_theta to float (e.g., rope_theta = float(rope_theta)) before
calling jax.ffi.ffi_call so the value arriving at the "flashinfer.apply_rope"
FFI matches the expected C double type.In
@examples/jax_tvm_ffi/README.md:
- Around line 14-34: Update the README installation commands so the CUDA
variants are consistent with the "CUDA 11.8+" requirement: show paired
substitutions for both JAX and the FlashInfer wheel (e.g., replace 'jax[cuda13]'
with the matching 'jax[cuda12]' or 'jax[cuda11]' and change --extra-index-url
https://flashinfer.ai/whl/cu130/ to the corresponding wheel suffix like cu124 or
cu118). Modify the install block that currently uses 'pip install
'jax[cuda13]'' and the FlashInfer URL '--extra-index-url
https://flashinfer.ai/whl/cu130/' to include explicit examples for CUDA 11.8,
12.x, and 13 (or state a template such as jax[cuda] and
/whl/cu/) so users know to change both places together.
Nitpick comments:
In@examples/jax_tvm_ffi/flashinfer_jax_tvm_ffi.ipynb:
- Around line 115-148: The CUDA_HOME auto-detection block in this cell uses
subprocess but subprocess is not imported here, causing a NameError if this cell
runs standalone; to fix it, add an import subprocess alongside the other imports
at the top of the cell (near import os, time, math, jinja2, numpy) or merge this
cell with the earlier setup cell so that subprocess is guaranteed to be defined
before the CUDA_HOME detection code in the block that references nvcc via
subprocess.check_output.In
@examples/jax_tvm_ffi/gemma3_flashinfer_jax.py:
- Around line 38-194: The module performs heavy side-effect work at import
(compute-cap detection via get_compute_capability, HF cache setup via
_ensure_writable_hf_cache, HF_TOKEN prompting, snapshot_download and weights
loading into weights, and kernel compilation/registration) which should be
deferred; refactor by moving all top-level side-effect logic into a guarded
main() (or into explicit functions like prepare_environment(),
authenticate_hf(), download_and_load_weights(), compile_and_register_kernels())
and ensure only lightweight definitions (functions/classes/constants) remain at
import time, then call main() under if name == "main": so importing the
module no longer triggers downloads, getpass prompts, or nvcc builds.In
@examples/jax_tvm_ffi/README.md:
- Around line 89-93: The fenced ASCII diagram block containing the lines
starting with "Step 1 BUILD & LOAD", "Step 2 REGISTER", and "Step 3 CALL"
should include a language tag to satisfy markdownlint MD040; change the opening
fence fromtotext (or ```none) so the block becomes a labeled code fence
and the diagram remains unchanged.</details> <details> <summary>🪄 Autofix (Beta)</summary> Fix all unresolved CodeRabbit comments on this PR: - [ ] <!-- {"checkboxId": "4b0d0e0a-96d7-4f10-b296-3a18ea78f0b9"} --> Push a commit to this branch (recommended) - [ ] <!-- {"checkboxId": "ff5b1114-7d8c-49e6-8ac1-43f82af23a33"} --> Create a new PR with the fixes </details> --- <details> <summary>ℹ️ Review info</summary> <details> <summary>⚙️ Run configuration</summary> **Configuration used**: defaults **Review profile**: CHILL **Plan**: Pro **Run ID**: `134ea15d-524d-4714-a813-27ee015b2b8b` </details> <details> <summary>📥 Commits</summary> Reviewing files that changed from the base of the PR and between a99ee72fcdf7f2d8579877ef4ea7c5b2da34f29d and eeab46a4ad0f00f5f88aa944abf874f8acd4431c. </details> <details> <summary>📒 Files selected for processing (6)</summary> * `examples/README.md` * `examples/jax_tvm_ffi/README.md` * `examples/jax_tvm_ffi/flashinfer_jax_tvm_ffi.ipynb` * `examples/jax_tvm_ffi/flashinfer_jax_tvm_ffi.py` * `examples/jax_tvm_ffi/gemma3_flashinfer_jax.ipynb` * `examples/jax_tvm_ffi/gemma3_flashinfer_jax.py` </details> </details> <!-- This is an auto-generated comment by CodeRabbit for review status -->
There was a problem hiding this comment.
Actionable comments posted: 3
🧹 Nitpick comments (6)
examples/jax_tvm_ffi/flashinfer_jax_tvm_ffi.ipynb (1)
800-803: Consistency nit:np.dtypevsjnp.dtype.The sibling script
flashinfer_jax_tvm_ffi.pyat line 383 usesjnp.dtype(q.dtype).itemsizewhile this notebook usesnp.dtype(...). Both work for standard float dtypes, but aligning the two keeps the tutorial and script consistent for readers switching between them.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/jax_tvm_ffi/flashinfer_jax_tvm_ffi.ipynb` around lines 800 - 803, Consistency nit: replace np.dtype usage with jnp.dtype to match the sibling script. In the block where sm_scale and tmp_elems are computed (variables sm_scale, tmp_elems, and q), change np.dtype(q.dtype).itemsize to jnp.dtype(q.dtype).itemsize so the notebook uses JAX's dtype utility like flashinfer_jax_tvm_ffi.py and remains consistent for readers switching between the script and notebook.examples/jax_tvm_ffi/gemma3_flashinfer_jax.py (2)
305-305: Hardcoded bfloat16 byte-size assumption in_TMP_ELEMS.
_TMP_ELEMS = 32 * 1024 * 1024 // 2hardcodes the 2-byte element size. The script is bfloat16-only today, but this will silently produce a wrong-sized scratch buffer if anyone retargets it to float32. A self-documenting computation would avoid the footgun:♻️ Proposed diff
-_TMP_ELEMS = 32 * 1024 * 1024 // 2 +_TMP_ELEMS = 32 * 1024 * 1024 // jnp.dtype(jnp.bfloat16).itemsize🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/jax_tvm_ffi/gemma3_flashinfer_jax.py` at line 305, The constant _TMP_ELEMS currently assumes 2 bytes per element (bfloat16) causing incorrect scratch buffer sizing if element dtype changes; update the computation to derive element size dynamically (e.g., from the array dtype or a bytes-per-element constant) and compute _TMP_ELEMS = (32 * 1024 * 1024) // bytes_per_element so it self-documents and works for bfloat16, float32, etc.; locate and replace the literal expression where _TMP_ELEMS is defined in gemma3_flashinfer_jax.py to use the dynamic bytes_per_element value (or a named helper) rather than hardcoded “// 2”.
102-122: Swallowed exception on HF token lookup loses diagnostic info.Ruff flags S110/BLE001 here.
huggingface_hub.get_token()will raise on rare cases (corrupted token file, unexpected permissions), and silently falling through to the interactive prompt hides the root cause. Consider at least narrowing the catch and surfacing it to the user:♻️ Proposed diff
try: from huggingface_hub import get_token HF_TOKEN = get_token() or "" - except Exception: - pass + except Exception as e: + print(f"Warning: failed to read cached HF token ({e!r}); falling back to prompt.")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/jax_tvm_ffi/gemma3_flashinfer_jax.py` around lines 102 - 122, The current try/except around huggingface_hub.get_token() swallows all exceptions and hides diagnostics; update the block that calls get_token() (related symbols: HF_TOKEN and get_token) to catch only expected exceptions or catch Exception as e and surface the error (e.g., log or re-raise a RuntimeError including the original exception message) instead of silently passing, so failures like corrupted token files or permission errors are visible to the user and preserved as the cause of the failure.examples/jax_tvm_ffi/flashinfer_jax_tvm_ffi.py (3)
479-487: Benchmark measures wall time including Python/dispatch overhead.
time.perf_counter()aroundblock_until_readycaptures host-side dispatch and JAX sync overhead in addition to kernel time. For a single µs-scale decode kernel this is usually dominated by dispatch latency, which can mislead readers comparing it to FlashInfer's PyTorch benchmarks (which typically use CUDA events). Consider either:
- Running more iterations (e.g. N=1000) and noting the measurement includes Python dispatch, or
- Adding a brief comment that this is end-to-end latency including host overhead.
Not blocking for an example.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/jax_tvm_ffi/flashinfer_jax_tvm_ffi.py` around lines 479 - 487, The benchmark currently measures decode_attention latency using time.perf_counter() around block_until_ready which includes Python host/dispatch overhead; either increase the iteration count (e.g. bump N from 100 to 1000) to amortize dispatch or add a short comment next to the decode_attention(...).block_until_ready() loop stating that the reported us includes host-side dispatch and JAX sync overhead (so results are end-to-end, not pure kernel time); update the N variable and/or the print comment in the same block where decode_attention, block_until_ready, time.perf_counter, and N are used.
31-37: Minor:nvidia-smiinvoked via partial executable path.Ruff flags S607 at lines 33 and 50. For example scripts this is acceptable, but the script will fail with a somewhat confusing
FileNotFoundErrorifnvidia-smi/whichare not onPATH. Consider wrapping thesubprocess.check_outputcall in a clearer error message (e.g. "nvidia-smi not found on PATH; is a CUDA driver installed?") so that users running in a minimal container get actionable feedback rather than a raw stack trace.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/jax_tvm_ffi/flashinfer_jax_tvm_ffi.py` around lines 31 - 37, The call to subprocess.check_output in get_compute_capability can raise a FileNotFoundError with an opaque stack trace if nvidia-smi is missing; catch FileNotFoundError around the subprocess.check_output call in get_compute_capability and raise (or log) a clearer RuntimeError with a message like "nvidia-smi not found on PATH; is the CUDA driver installed?" so users get actionable feedback instead of the raw FileNotFoundError; reference the get_compute_capability function and the subprocess.check_output invocation when applying the change.
438-459:k_rfrom the outerapply_ropecall is unused.Line 458 unpacks
q_r, k_r = apply_rope(...)but onlyq_ris consumed on line 459. Rename to_k_r(Ruff RUF059) to make the intent explicit, mirroring the unused unpack insidedecode_stepat line 449.♻️ Proposed diff
-q_r, k_r = apply_rope(q_new, k_new, indptr, offsets) +q_r, _k_r = apply_rope(q_new, k_new, indptr, offsets) attn_ref = decode_attention(q_r.reshape(NUM_QO, HEAD_DIM), k, v)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/jax_tvm_ffi/flashinfer_jax_tvm_ffi.py` around lines 438 - 459, The outer call to apply_rope unpacks q_r, k_r but k_r is unused—rename k_r to _k_r in the outer scope to indicate it's intentionally unused (mirror the unused k_r inside decode_step); update the line that currently reads "q_r, k_r = apply_rope(...)" to "q_r, _k_r = apply_rope(...)" and keep all other uses of q_r the same (no changes to decode_step, apply_rope, silu_and_mul, or decode_attention).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/jax_tvm_ffi/flashinfer_jax_tvm_ffi.ipynb`:
- Around line 113-148: The setup cell is missing "import subprocess", causing a
NameError when subprocess.check_output() is called; add "import subprocess" near
the top imports (alongside import os, time, math, jinja2, numpy) so the
subprocess symbol is always defined before the CUDA_HOME detection block that
calls subprocess.check_output; ensure the import appears before the if
'CUDA_HOME' not in os.environ: block that uses subprocess.
In `@examples/jax_tvm_ffi/gemma3_flashinfer_jax.ipynb`:
- Around line 924-928: The current initialization of _STOP_IDS can include None
if tokenizer.eos_token_id is None; change the logic in the block that builds
_STOP_IDS to only add tokenizer.eos_token_id when it is not None and not equal
to tokenizer.unk_token_id, and keep the existing guard when converting special
tokens via tokenizer.convert_tokens_to_ids (i.e., only add _id when _id is not
None and _id != tokenizer.unk_token_id). Update the code that references
_STOP_IDS, tokenizer.eos_token_id, tokenizer.unk_token_id, convert_tokens_to_ids
and the loop over _tok to follow this guarded-add pattern so next_tok in
_STOP_IDS won't see a None entry.
In `@examples/jax_tvm_ffi/gemma3_flashinfer_jax.py`:
- Around line 697-702: The _STOP_IDS set currently seeds with
tokenizer.eos_token_id which may be None; change the construction so you start
with an empty set, then add tokenizer.eos_token_id only if it is not None (and
optionally cast to int), and retain the existing logic for converted token IDs
from tokenizer.convert_tokens_to_ids ensuring _id is not None and _id !=
tokenizer.unk_token_id before adding; this prevents None being in _STOP_IDS and
fixes the subsequent next_tok in _STOP_IDS check used elsewhere.
---
Nitpick comments:
In `@examples/jax_tvm_ffi/flashinfer_jax_tvm_ffi.ipynb`:
- Around line 800-803: Consistency nit: replace np.dtype usage with jnp.dtype to
match the sibling script. In the block where sm_scale and tmp_elems are computed
(variables sm_scale, tmp_elems, and q), change np.dtype(q.dtype).itemsize to
jnp.dtype(q.dtype).itemsize so the notebook uses JAX's dtype utility like
flashinfer_jax_tvm_ffi.py and remains consistent for readers switching between
the script and notebook.
In `@examples/jax_tvm_ffi/flashinfer_jax_tvm_ffi.py`:
- Around line 479-487: The benchmark currently measures decode_attention latency
using time.perf_counter() around block_until_ready which includes Python
host/dispatch overhead; either increase the iteration count (e.g. bump N from
100 to 1000) to amortize dispatch or add a short comment next to the
decode_attention(...).block_until_ready() loop stating that the reported us
includes host-side dispatch and JAX sync overhead (so results are end-to-end,
not pure kernel time); update the N variable and/or the print comment in the
same block where decode_attention, block_until_ready, time.perf_counter, and N
are used.
- Around line 31-37: The call to subprocess.check_output in
get_compute_capability can raise a FileNotFoundError with an opaque stack trace
if nvidia-smi is missing; catch FileNotFoundError around the
subprocess.check_output call in get_compute_capability and raise (or log) a
clearer RuntimeError with a message like "nvidia-smi not found on PATH; is the
CUDA driver installed?" so users get actionable feedback instead of the raw
FileNotFoundError; reference the get_compute_capability function and the
subprocess.check_output invocation when applying the change.
- Around line 438-459: The outer call to apply_rope unpacks q_r, k_r but k_r is
unused—rename k_r to _k_r in the outer scope to indicate it's intentionally
unused (mirror the unused k_r inside decode_step); update the line that
currently reads "q_r, k_r = apply_rope(...)" to "q_r, _k_r = apply_rope(...)"
and keep all other uses of q_r the same (no changes to decode_step, apply_rope,
silu_and_mul, or decode_attention).
In `@examples/jax_tvm_ffi/gemma3_flashinfer_jax.py`:
- Line 305: The constant _TMP_ELEMS currently assumes 2 bytes per element
(bfloat16) causing incorrect scratch buffer sizing if element dtype changes;
update the computation to derive element size dynamically (e.g., from the array
dtype or a bytes-per-element constant) and compute _TMP_ELEMS = (32 * 1024 *
1024) // bytes_per_element so it self-documents and works for bfloat16, float32,
etc.; locate and replace the literal expression where _TMP_ELEMS is defined in
gemma3_flashinfer_jax.py to use the dynamic bytes_per_element value (or a named
helper) rather than hardcoded “// 2”.
- Around line 102-122: The current try/except around huggingface_hub.get_token()
swallows all exceptions and hides diagnostics; update the block that calls
get_token() (related symbols: HF_TOKEN and get_token) to catch only expected
exceptions or catch Exception as e and surface the error (e.g., log or re-raise
a RuntimeError including the original exception message) instead of silently
passing, so failures like corrupted token files or permission errors are visible
to the user and preserved as the cause of the failure.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: f0942cbb-ac2d-41e9-acd4-e560fe59b092
📒 Files selected for processing (4)
examples/jax_tvm_ffi/flashinfer_jax_tvm_ffi.ipynbexamples/jax_tvm_ffi/flashinfer_jax_tvm_ffi.pyexamples/jax_tvm_ffi/gemma3_flashinfer_jax.ipynbexamples/jax_tvm_ffi/gemma3_flashinfer_jax.py
e4b5da0 to
1a4c0e7
Compare
There was a problem hiding this comment.
🧹 Nitpick comments (6)
examples/jax_tvm_ffi/flashinfer_jax_tvm_ffi.py (2)
40-42:assertis skipped underpython -O.This SM-version gate is a hard prerequisite. Since the script will be run with
python …, it almost always holds — but if anyone ever invokespython -O flashinfer_jax_tvm_ffi.py, the assertion is elided and the subsequent kernels will surface far less helpful errors. Araise RuntimeError(...)would be more robust.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/jax_tvm_ffi/flashinfer_jax_tvm_ffi.py` around lines 40 - 42, Replace the assert check on SM_MAJOR/SM_MINOR with an explicit runtime error so the guard is not skipped under python -O: locate the SM version gate using the symbols SM_MAJOR and SM_MINOR (the current lines using assert (SM_MAJOR, SM_MINOR) >= (7, 5)), and change it to raise a RuntimeError with the same message (e.g., "SM X.Y is below the minimum SM 7.5 (Turing)") so the script always fails fast when the GPU SM is too low.
445-454:k_ris discarded — consider renaming or adding a comment.In a real decode step, the RoPE'd
k_rwould be appended to the KV-cache and then attended over, not the raw pre-RoPEk_cachethat's currently used. For the demo this is fine (the composition still validates), but to avoid misleading readers following this as a template, either renamek_rto_or add a one-line comment that this is a demonstrative composition (not a faithful decode step).Proposed diff
- q_r, k_r = apply_rope(q_new, k_new, indptr, offsets) + # Note: in a real decode step, k_r would be appended to k_cache before attention; + # here we reuse the raw cache for brevity to demonstrate XLA composition only. + q_r, _k_r = apply_rope(q_new, k_new, indptr, offsets) attn_out = decode_attention(q_r.reshape(NUM_QO, HEAD_DIM), k_cache, v_cache)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/jax_tvm_ffi/flashinfer_jax_tvm_ffi.py` around lines 445 - 454, The variable k_r produced by apply_rope in decode_step is computed but never used (the function still attends over the original k_cache), which can mislead readers; either rename k_r to _ to indicate it's intentionally unused, or update decode_step to use the RoPE'd keys when calling decode_attention (or add a one-line comment inside decode_step clarifying this is a demonstrative composition and not a faithful KV-cache update). Locate the apply_rope call and the decode_step function (symbols: decode_step, apply_rope, k_r, k_cache, decode_attention) and make the rename or comment change (or adjust the call to pass k_r into decode_attention) accordingly.examples/jax_tvm_ffi/flashinfer_jax_tvm_ffi.ipynb (1)
121-122: Mirrorsetdefault(...)from the standalone script.
os.environ["XLA_FLAGS"] = "..."overwrites any pre-existing XLA flags a reader may have set. The.pycounterpart usessetdefault. Same pattern is worth applying here.Proposed diff
-"os.environ[\"TF_CPP_MIN_LOG_LEVEL\"] = \"2\" # suppress TF/XLA info & warnings\n", -"os.environ[\"XLA_FLAGS\"] = \"--xla_gpu_cuda_data_dir=/usr/local/cuda\"\n", +"os.environ.setdefault(\"TF_CPP_MIN_LOG_LEVEL\", \"2\") # suppress TF/XLA info & warnings\n", +"os.environ.setdefault(\"XLA_FLAGS\", \"--xla_gpu_cuda_data_dir=/usr/local/cuda\")\n",🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/jax_tvm_ffi/flashinfer_jax_tvm_ffi.ipynb` around lines 121 - 122, The notebook currently overwrites any existing XLA flags by setting os.environ["XLA_FLAGS"] directly; change this to use os.environ.setdefault("XLA_FLAGS", "--xla_gpu_cuda_data_dir=/usr/local/cuda") (mirror the standalone script pattern) so pre-existing XLA_FLAGS are preserved; update the cell containing os.environ["XLA_FLAGS"] = "..." and replace it with a setdefault call referencing the same key ("XLA_FLAGS").examples/jax_tvm_ffi/gemma3_flashinfer_jax.py (1)
105-110: Silentexcept Exception: passhides HF-token errors.If
get_token()(or its import) fails for any non-missing reason, the error is silently swallowed and the user ends up at the interactive prompt with no hint why. Narrowing to the expectedImportError(and optionally logging anything else) improves diagnosability.Proposed diff
try: from huggingface_hub import get_token HF_TOKEN = get_token() or "" - except Exception: - pass + except ImportError: + pass🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/jax_tvm_ffi/gemma3_flashinfer_jax.py` around lines 105 - 110, The current try/except around importing and calling huggingface_hub.get_token() swallows all exceptions; change it to catch ImportError for the missing module and set HF_TOKEN = "" in that branch, and add a separate except Exception as e branch that logs or warns about the unexpected failure (including the exception) rather than silently passing; adjust the block referencing get_token and HF_TOKEN so ImportError is handled silently but other errors are surfaced via logging.warn/logging.exception or re-raising as appropriate.examples/jax_tvm_ffi/gemma3_flashinfer_jax.ipynb (1)
127-128: Preferos.environ.setdefault(...)for consistency with the.pyversion.The standalone script (
gemma3_flashinfer_jax.pylines 52–53) usessetdefaultso user-configured values are preserved. The notebook unconditionally overwritesTF_CPP_MIN_LOG_LEVELandXLA_FLAGS, which can silently clobber flags a reader has pre-set (e.g.--xla_gpu_deterministic_ops).Proposed diff
-"os.environ[\"TF_CPP_MIN_LOG_LEVEL\"] = \"2\" # suppress TF/XLA info & warnings\n", -"os.environ[\"XLA_FLAGS\"] = \"--xla_gpu_cuda_data_dir=/usr/local/cuda\"\n", +"os.environ.setdefault(\"TF_CPP_MIN_LOG_LEVEL\", \"2\") # suppress TF/XLA info & warnings\n", +"os.environ.setdefault(\"XLA_FLAGS\", \"--xla_gpu_cuda_data_dir=/usr/local/cuda\")\n",🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/jax_tvm_ffi/gemma3_flashinfer_jax.ipynb` around lines 127 - 128, The notebook currently unconditionally overwrites environment variables TF_CPP_MIN_LOG_LEVEL and XLA_FLAGS; change those assignments to use os.environ.setdefault(...) (matching the approach used in gemma3_flashinfer_jax.py) so any user-preconfigured values are preserved — locate the lines that set TF_CPP_MIN_LOG_LEVEL and XLA_FLAGS and replace the direct assignments with setdefault calls.examples/jax_tvm_ffi/README.md (1)
89-93: Add a language to the fenced block forMD040.markdownlint flags the fenced block at line 89. Tagging it as
text(or similar) silences the warning without changing rendering.Proposed diff
-``` +```text Step 1 BUILD & LOAD jit_spec.build_and_load() -> tvm_ffi.Module Step 2 REGISTER jax_tvm_ffi.register_ffi_target(name, wrapper, arg_spec) Step 3 CALL jax.ffi.ffi_call(name, output_shapes)(*inputs, **scalar_attrs)</details> <details> <summary>🤖 Prompt for AI Agents</summary>Verify each finding against the current code and only fix it if needed.
In
@examples/jax_tvm_ffi/README.mdaround lines 89 - 93, The fenced code block
containing the three steps ("Step 1 BUILD & LOAD jit_spec.build_and_load()
-> tvm_ffi.Module", "Step 2 REGISTER
jax_tvm_ffi.register_ffi_target(name, wrapper, arg_spec)", "Step 3 CALL
jax.ffi.ffi_call(name, output_shapes)(*inputs, **scalar_attrs)") should include
a language tag to satisfy markdownlint MD040; update the fence from ``` tolabeled and the linter warning is silenced.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@examples/jax_tvm_ffi/flashinfer_jax_tvm_ffi.ipynb`:
- Around line 121-122: The notebook currently overwrites any existing XLA flags
by setting os.environ["XLA_FLAGS"] directly; change this to use
os.environ.setdefault("XLA_FLAGS", "--xla_gpu_cuda_data_dir=/usr/local/cuda")
(mirror the standalone script pattern) so pre-existing XLA_FLAGS are preserved;
update the cell containing os.environ["XLA_FLAGS"] = "..." and replace it with a
setdefault call referencing the same key ("XLA_FLAGS").
In `@examples/jax_tvm_ffi/flashinfer_jax_tvm_ffi.py`:
- Around line 40-42: Replace the assert check on SM_MAJOR/SM_MINOR with an
explicit runtime error so the guard is not skipped under python -O: locate the
SM version gate using the symbols SM_MAJOR and SM_MINOR (the current lines using
assert (SM_MAJOR, SM_MINOR) >= (7, 5)), and change it to raise a RuntimeError
with the same message (e.g., "SM X.Y is below the minimum SM 7.5 (Turing)") so
the script always fails fast when the GPU SM is too low.
- Around line 445-454: The variable k_r produced by apply_rope in decode_step is
computed but never used (the function still attends over the original k_cache),
which can mislead readers; either rename k_r to _ to indicate it's intentionally
unused, or update decode_step to use the RoPE'd keys when calling
decode_attention (or add a one-line comment inside decode_step clarifying this
is a demonstrative composition and not a faithful KV-cache update). Locate the
apply_rope call and the decode_step function (symbols: decode_step, apply_rope,
k_r, k_cache, decode_attention) and make the rename or comment change (or adjust
the call to pass k_r into decode_attention) accordingly.
In `@examples/jax_tvm_ffi/gemma3_flashinfer_jax.ipynb`:
- Around line 127-128: The notebook currently unconditionally overwrites
environment variables TF_CPP_MIN_LOG_LEVEL and XLA_FLAGS; change those
assignments to use os.environ.setdefault(...) (matching the approach used in
gemma3_flashinfer_jax.py) so any user-preconfigured values are preserved —
locate the lines that set TF_CPP_MIN_LOG_LEVEL and XLA_FLAGS and replace the
direct assignments with setdefault calls.
In `@examples/jax_tvm_ffi/gemma3_flashinfer_jax.py`:
- Around line 105-110: The current try/except around importing and calling
huggingface_hub.get_token() swallows all exceptions; change it to catch
ImportError for the missing module and set HF_TOKEN = "" in that branch, and add
a separate except Exception as e branch that logs or warns about the unexpected
failure (including the exception) rather than silently passing; adjust the block
referencing get_token and HF_TOKEN so ImportError is handled silently but other
errors are surfaced via logging.warn/logging.exception or re-raising as
appropriate.
In `@examples/jax_tvm_ffi/README.md`:
- Around line 89-93: The fenced code block containing the three steps ("Step 1
BUILD & LOAD jit_spec.build_and_load() -> tvm_ffi.Module", "Step 2 REGISTER
jax_tvm_ffi.register_ffi_target(name, wrapper, arg_spec)", "Step 3 CALL
jax.ffi.ffi_call(name, output_shapes)(*inputs, **scalar_attrs)") should include
a language tag to satisfy markdownlint MD040; update the fence from ``` to
```text (or another appropriate tag) in README.md so the block is explicitly
labeled and the linter warning is silenced.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 7a795534-2179-467f-84ec-ba3d898c53bf
📒 Files selected for processing (6)
examples/README.mdexamples/jax_tvm_ffi/README.mdexamples/jax_tvm_ffi/flashinfer_jax_tvm_ffi.ipynbexamples/jax_tvm_ffi/flashinfer_jax_tvm_ffi.pyexamples/jax_tvm_ffi/gemma3_flashinfer_jax.ipynbexamples/jax_tvm_ffi/gemma3_flashinfer_jax.py
✅ Files skipped from review due to trivial changes (1)
- examples/README.md
There was a problem hiding this comment.
Actionable comments posted: 3
♻️ Duplicate comments (2)
examples/jax_tvm_ffi/gemma3_flashinfer_jax.ipynb (2)
899-906:⚠️ Potential issue | 🟡 MinorGuard
eos_token_idbefore adding it to_STOP_IDS.Line 902 can add
Noneto_STOP_IDSwhen the tokenizer has no EOS token. Keep the same guarded-add pattern used for converted special tokens.Proposed fix
-_STOP_IDS = {tokenizer.eos_token_id} +_STOP_IDS = set() +if tokenizer.eos_token_id is not None and tokenizer.eos_token_id != tokenizer.unk_token_id: + _STOP_IDS.add(tokenizer.eos_token_id) for _tok in ['<end_of_turn>', '<eos>']: _id = tokenizer.convert_tokens_to_ids(_tok) if _id is not None and _id != tokenizer.unk_token_id: _STOP_IDS.add(_id)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/jax_tvm_ffi/gemma3_flashinfer_jax.ipynb` around lines 899 - 906, The code adds tokenizer.eos_token_id to _STOP_IDS without guarding for None; change the logic in the stop-token block so you only add tokenizer.eos_token_id if it is not None and not equal to tokenizer.unk_token_id (same guard used for tokens converted via tokenizer.convert_tokens_to_ids), i.e., check the value before calling _STOP_IDS.add(...) in the section that initializes _STOP_IDS and before the loop that converts '<end_of_turn>' and '<eos>'.
334-337:⚠️ Potential issue | 🟡 MinorNormalize RoPE theta to
floatbefore passing it as an FFI attribute.
ROPE_THETA_LOCAL/ROPE_THETA_GLOBALare currentlyints, andapply_ropeforwardsrope_thetaunchanged. Cast at the wrapper boundary so both prefill and decode pass the scalar type expected by the TVM FFI path.Proposed fix
-ROPE_THETA_LOCAL = int(cfg.get('rope_local_base_freq', 10_000)) -ROPE_THETA_GLOBAL = int(cfg.get('rope_theta', 1_000_000)) +ROPE_THETA_LOCAL = float(cfg.get('rope_local_base_freq', 10_000)) +ROPE_THETA_GLOBAL = float(cfg.get('rope_theta', 1_000_000)) @@ )(q, k, indptr, offsets, - rotary_dim=q.shape[-1], interleave=False, rope_scale=1.0, rope_theta=rope_theta) + rotary_dim=q.shape[-1], interleave=False, rope_scale=1.0, + rope_theta=float(rope_theta))What scalar Python type should jax-tvm-ffi / TVM FFI use for a double-valued attribute passed through jax.ffi.ffi_call?Also applies to: 456-464, 716-738, 798-820
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/jax_tvm_ffi/gemma3_flashinfer_jax.ipynb` around lines 334 - 337, ROPE_THETA_LOCAL and ROPE_THETA_GLOBAL are defined as ints but must be converted to Python floats before being sent across the TVM FFI boundary; update the wrapper(s) that call apply_rope / the jax.ffi.ffi_call so they cast rope_theta (and any rope_local_base_freq-derived value) to float at the boundary for both prefill and decode paths (e.g., where apply_rope is invoked and where jax.ffi.ffi_call constructs the attribute dict), ensuring all FFI attributes use float scalar types.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/jax_tvm_ffi/gemma3_flashinfer_jax.ipynb`:
- Around line 1006-1013: The summary snippet uses a non-public symbol
gen_decode_jit_spec; replace it with the publicly exported
gen_single_decode_module used elsewhere (see earlier imports/usage) and pass the
same argument pattern as the working examples (e.g., dtype 'bfloat16',
max_batch/sequence size 256, and use_sliding_window flag). Update both
occurrences of gen_decode_jit_spec to gen_single_decode_module and ensure the
call signature and flags match the earlier prefill/decode examples so the
notebook uses the public flashinfer.jit API.
- Around line 948-960: The first sampled token is chosen unconditionally
(first_token = _sample(...)) which causes an extra decode step and generates a
token even when max_new_tokens == 0; fix by guarding and early-exiting: only
sample/print the first token if max_new_tokens > 0, and after sampling check if
first_token is in _STOP_IDS — if so, print/return immediately and do not enter
the decode loop or call decode_step; otherwise proceed to append to generated
and run the for-loop as before (referencing first_token, _sample, _STOP_IDS,
generated, decode_step, and max_new_tokens).
- Around line 125-146: Move the XLA environment setup and CUDA discovery so it
runs before importing jax: ensure subprocess is imported, detect CUDA_HOME
(using subprocess.check_output as currently written), set
os.environ['CUDA_HOME'] to the discovered path (fallback '/usr/local/cuda'),
then set os.environ['XLA_FLAGS'] to include the discovered CUDA path, and only
after that import jax and jax.numpy; update the cell to import subprocess at top
and reorder the CUDA detection/XLA_FLAGS assignment to precede the import of
jax.
---
Duplicate comments:
In `@examples/jax_tvm_ffi/gemma3_flashinfer_jax.ipynb`:
- Around line 899-906: The code adds tokenizer.eos_token_id to _STOP_IDS without
guarding for None; change the logic in the stop-token block so you only add
tokenizer.eos_token_id if it is not None and not equal to tokenizer.unk_token_id
(same guard used for tokens converted via tokenizer.convert_tokens_to_ids),
i.e., check the value before calling _STOP_IDS.add(...) in the section that
initializes _STOP_IDS and before the loop that converts '<end_of_turn>' and
'<eos>'.
- Around line 334-337: ROPE_THETA_LOCAL and ROPE_THETA_GLOBAL are defined as
ints but must be converted to Python floats before being sent across the TVM FFI
boundary; update the wrapper(s) that call apply_rope / the jax.ffi.ffi_call so
they cast rope_theta (and any rope_local_base_freq-derived value) to float at
the boundary for both prefill and decode paths (e.g., where apply_rope is
invoked and where jax.ffi.ffi_call constructs the attribute dict), ensuring all
FFI attributes use float scalar types.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: bd13bf32-0800-49af-bc7f-0382380fb7ce
📒 Files selected for processing (1)
examples/jax_tvm_ffi/gemma3_flashinfer_jax.ipynb
04345ce to
48c404a
Compare
|
Thanks for the contribution! since the repo already uses Sphinx for documentation, would it make sense to integrate this content into the Sphinx docs (could under the tutorials: https://github.com/flashinfer-ai/flashinfer/tree/main/docs/tutorials) instead of keeping it as a standalone Jupyter notebook? |
|
@katjasrz this is failing a pre-merge formatting check: https://github.com/flashinfer-ai/flashinfer/actions/runs/24689882948/job/72521293385?pr=3092 could you please take a look? You may have to rerun pre-commit run --all-files (in the PR description instructions "Pull Request Checklist" ) |
b8e6347 to
f5d397a
Compare
|
The formatting should pass now |
f5d397a to
e4204f8
Compare
e4204f8 to
9c3ce4d
Compare
|
Thanks everyone for the guidance! I updated the PR to integrate the tutorials into the existing Sphinx docs using Sphinx-Gallery. The canonical sources now live under docs/tutorials/jax_tvm_ffi/, and the docs build generates the rendered HTML pages plus downloadable .py and .ipynb versions from those same sources. I also removed the standalone examples/ copy to avoid having two sources of truth. I verified the docs build locally with: Please let me know if this structure looks aligned with the repo’s documentation conventions. |
📌 Description
This PR adds a new example under examples/jax_tvm_ffi/ showing how to call FlashInfer from JAX via jax-tvm-ffi. It also adds examples/README.md to document the examples directory and make the new example easier to discover.
The goal is to provide a minimal reference for users interested in integrating FlashInfer outside of PyTorch, especially in JAX-based workflows.
🔍 Related Issues
N/A
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).This PR only adds example code and documentation; no changes to core functionality, so no additional tests were added. Examples run successfully end-to-end.
Summary by CodeRabbit
Documentation
New Features