Skip to content

Add examples of calling FlashInfer from JAX via jax-tvm-ffi#3092

Merged
yongwww merged 1 commit intoflashinfer-ai:mainfrom
katjasrz:flashinfer-jax-demo
Apr 24, 2026
Merged

Add examples of calling FlashInfer from JAX via jax-tvm-ffi#3092
yongwww merged 1 commit intoflashinfer-ai:mainfrom
katjasrz:flashinfer-jax-demo

Conversation

@katjasrz
Copy link
Copy Markdown
Contributor

@katjasrz katjasrz commented Apr 16, 2026

📌 Description

This PR adds a new example under examples/jax_tvm_ffi/ showing how to call FlashInfer from JAX via jax-tvm-ffi. It also adds examples/README.md to document the examples directory and make the new example easier to discover.

The goal is to provide a minimal reference for users interested in integrating FlashInfer outside of PyTorch, especially in JAX-based workflows.

🔍 Related Issues

N/A

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

This PR only adds example code and documentation; no changes to core functionality, so no additional tests were added. Examples run successfully end-to-end.

Summary by CodeRabbit

  • Documentation

    • Added an Examples overview and detailed per-example guides covering setup, installation, GPU/CUDA prerequisites, compilation/caching behavior, Hugging Face gated-model steps, authentication flows, and troubleshooting for JAX↔TVM FFI workflows.
  • New Features

    • Added runnable JAX↔TVM FFI examples (notebooks and standalone scripts) demonstrating fused activations/FFN, RoPE, and attention kernels, end-to-end Gemma 3 inference, correctness validations, and latency micro-benchmarks.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 16, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds a new examples directory with READMEs, notebooks, and standalone scripts demonstrating JAX ↔ TVM FFI integration that registers and calls FlashInfer CUDA kernels, plus an end-to-end Gemma 3 inference example with environment, installation, HF auth, and troubleshooting guidance.

Changes

Cohort / File(s) Summary
Top-level & example READMEs
examples/README.md, examples/jax_tvm_ffi/README.md
New documentation introducing the examples directory, describing the JAX+TVM FFI example, prerequisites, installation variants, running instructions, compilation/caching notes, Gemma gating steps, and troubleshooting tips.
JAX↔TVM FFI demo — notebook
examples/jax_tvm_ffi/flashinfer_jax_tvm_ffi.ipynb
New notebook demonstrating build→load → register → call pattern via jax-tvm-ffi; implements and validates silu_and_mul, apply_rope, decode_attention, composed decode_step, helper wrappers, reference implementations, and micro-benchmarks.
JAX↔TVM FFI demo — script
examples/jax_tvm_ffi/flashinfer_jax_tvm_ffi.py
New standalone script mirroring the notebook: GPU detection and env setup, build/load/register kernels, JAX-callable wrappers, reference assertions, an @jax.jit composed decode_step, and a latency benchmark.
Gemma 3 end-to-end — notebook
examples/jax_tvm_ffi/gemma3_flashinfer_jax.ipynb
New notebook that loads Gemma 3 weights, sets up HF auth and safetensors, converts tensors to JAX bfloat16, compiles/registers FlashInfer kernels (gelu_tanh_and_mul, apply_rope, prefill/decode local+global), and implements prefill/decode control flow plus sampling generation.
Gemma 3 end-to-end — script
examples/jax_tvm_ffi/gemma3_flashinfer_jax.py
New standalone script for Gemma 3 inference: GPU/env checks, HF authentication and download, safetensors→JAX arrays, kernel registration via jax_tvm_ffi, JAX model scaffolding (RMSNorm, embeddings, gated FFN), prefill + iterative decode, and generate(prompt) entrypoint.

Sequence Diagram(s)

sequenceDiagram
    participant User
    participant JAX
    participant TVM as TVM_FFI
    participant FlashInfer as FlashInfer_JIT
    participant CUDA as CUDA_Kernel

    User->>JAX: Call JAX wrapper (e.g., silu_and_mul)
    JAX->>TVM: jax.ffi.ffi_call (args, shape/dtype spec)
    TVM->>FlashInfer: lookup/load compiled module and invoke target
    FlashInfer->>CUDA: Launch CUDA kernel
    CUDA-->>FlashInfer: Kernel result buffer (rgba(0,128,0,0.5))
    FlashInfer-->>TVM: Return result buffer
    TVM-->>JAX: Produce JAX Array
    JAX-->>User: Return result
Loading
sequenceDiagram
    participant User
    participant App as Gemma3_App
    participant HF as HuggingFace
    participant JAX
    participant FlashInfer as FlashInfer_Kernels

    User->>App: generate(prompt)
    App->>HF: Authenticate & download weights
    HF-->>App: Model shards (safetensors)
    App->>JAX: Load shards → bfloat16 arrays, tokenize
    App->>JAX: prefill(token_ids)
    JAX->>FlashInfer: prefill_attention + ffn calls
    FlashInfer-->>JAX: KV caches, logits
    loop decode until EOS
        App->>JAX: decode_step(token, caches)
        JAX->>FlashInfer: decode_attention + ffn
        FlashInfer-->>JAX: logits
        JAX->>App: next token (sampling)
    end
    App-->>User: Generated text
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Suggested labels

ready

Suggested reviewers

  • aleozlx
  • yzh119
  • cyx-6
  • yongwww
  • jimmyzho
  • bkryu
  • nv-yunzheq

Poem

🐰 I hopped through kernels, docs, and code tonight,
From JAX to TVM the calls took flight,
Gemma learned to chat with FlashInfer's tune,
Notebooks, scripts — examples glowing soon,
Carrots compiled, and benchmarks felt right!

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 17.14% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The pull request title directly and clearly summarizes the main change: adding JAX integration examples via jax-tvm-ffi, which matches the core objective of the changeset.
Description check ✅ Passed The pull request description addresses the required sections from the template (Description, Related Issues) and includes the checklist with pre-commit verification and test status explanations, though test-related checkbox items differ from template defaults due to documentation-only nature.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a suite of examples and tutorials for using FlashInfer GPU kernels within JAX via the TVM FFI bridge, including an end-to-end inference implementation for Gemma 3 1B Instruct. The review feedback highlights opportunities to enhance code robustness and performance, specifically by avoiding hardcoded data types in the decode_attention function and replacing inefficient jnp.concatenate operations for KV-cache growth with pre-allocated buffers and jax.lax.dynamic_update_slice.

Comment thread examples/jax_tvm_ffi/flashinfer_jax_tvm_ffi.py Outdated
Comment thread examples/jax_tvm_ffi/gemma3_flashinfer_jax.py Outdated
Comment thread examples/jax_tvm_ffi/flashinfer_jax_tvm_ffi.ipynb Outdated
Comment thread examples/jax_tvm_ffi/gemma3_flashinfer_jax.ipynb Outdated
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

🧹 Nitpick comments (3)
examples/jax_tvm_ffi/flashinfer_jax_tvm_ffi.ipynb (1)

115-148: Self-containedness nit: subprocess is only imported in the earlier compute-capability cell.

The CUDA_HOME auto-detection block at lines 126–131 references subprocess, but this cell does not re-import it. If a user re-runs this cell after a kernel restart without running the earlier cell, it will NameError. Adding import subprocess alongside the other imports in this cell (or merging the two setup cells) makes it robust to out-of-order execution.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/jax_tvm_ffi/flashinfer_jax_tvm_ffi.ipynb` around lines 115 - 148,
The CUDA_HOME auto-detection block in this cell uses subprocess but subprocess
is not imported here, causing a NameError if this cell runs standalone; to fix
it, add an import subprocess alongside the other imports at the top of the cell
(near import os, time, math, jinja2, numpy) or merge this cell with the earlier
setup cell so that subprocess is guaranteed to be defined before the CUDA_HOME
detection code in the block that references nvcc via subprocess.check_output.
examples/jax_tvm_ffi/gemma3_flashinfer_jax.py (1)

38-194: All setup/download/compile work executes at module import time.

Everything from compute-capability detection (line 46) through HF authentication (line 124), weight download/load (lines 135–156) and kernel compilation/registration (lines 209–511) runs as top-level module code. Only the for q in questions: loop (lines 755–762) is guarded by if __name__ == "__main__":. Consequences:

  • Importing this file as a module in another script triggers ~2 GB of HF downloads and multiple nvcc builds.
  • getpass(...) at line 114 will block on import if HF_TOKEN isn't set.
  • Static analyzers / IDEs that import the file to collect symbols will also trigger the side effects.

Not required for a demo, but moving the download/compile/register blocks into main() (or into lazily-invoked helpers) would make the example safer to re-use.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/jax_tvm_ffi/gemma3_flashinfer_jax.py` around lines 38 - 194, The
module performs heavy side-effect work at import (compute-cap detection via
get_compute_capability, HF cache setup via _ensure_writable_hf_cache, HF_TOKEN
prompting, snapshot_download and weights loading into weights, and kernel
compilation/registration) which should be deferred; refactor by moving all
top-level side-effect logic into a guarded main() (or into explicit functions
like prepare_environment(), authenticate_hf(), download_and_load_weights(),
compile_and_register_kernels()) and ensure only lightweight definitions
(functions/classes/constants) remain at import time, then call main() under if
__name__ == "__main__": so importing the module no longer triggers downloads,
getpass prompts, or nvcc builds.
examples/jax_tvm_ffi/README.md (1)

89-93: Add a language tag to the fenced code block (optional).

markdownlint MD040 flags this block. Using text or none keeps the ASCII diagram untouched while silencing the lint.

♻️ Proposed diff
-```
+```text
   Step 1  BUILD & LOAD   jit_spec.build_and_load()  ->  tvm_ffi.Module
   Step 2  REGISTER       jax_tvm_ffi.register_ffi_target(name, wrapper, arg_spec)
   Step 3  CALL           jax.ffi.ffi_call(name, output_shapes)(*inputs, **scalar_attrs)
</details>

<details>
<summary>🤖 Prompt for AI Agents</summary>

Verify each finding against the current code and only fix it if needed.

In @examples/jax_tvm_ffi/README.md around lines 89 - 93, The fenced ASCII
diagram block containing the lines starting with "Step 1 BUILD & LOAD", "Step 2
REGISTER", and "Step 3 CALL" should include a language tag to satisfy
markdownlint MD040; change the opening fence from totext (or ```none) so
the block becomes a labeled code fence and the diagram remains unchanged.


</details>

</blockquote></details>

</blockquote></details>

<details>
<summary>🤖 Prompt for all review comments with AI agents</summary>

Verify each finding against the current code and only fix it if needed.

Inline comments:
In @examples/jax_tvm_ffi/gemma3_flashinfer_jax.ipynb:

  • Around line 362-364: The docstring for is_global(layer_idx: int) is stale — it
    says "18-layer model" and lists global layers as 5, 11, 17; update the docstring
    in the is_global function to reflect the 26-layer Gemma 3 1B model and list the
    four global full-attention layers as 5, 11, 17, 23 (keeping the implementation
    return (layer_idx + 1) % 6 == 0 unchanged).
  • Around line 119-145: The cell uses subprocess.check_output (referenced as
    subprocess and check_output) but never imports subprocess, causing a NameError
    if run after a kernel restart; add an explicit top-level import for subprocess
    in this notebook cell (alongside json, math, os, time, jax, jax.numpy) so
    subprocess is available before calling subprocess.check_output when setting
    CUDA_HOME and ensure this import appears before any use of SM_MAJOR/SM_MINOR
    environment logic.
  • Around line 478-486: The TVM FFI expects a double but ROPE_THETA_LOCAL /
    ROPE_THETA_GLOBAL are ints, so cast rope_theta to a float when calling the FFI:
    update apply_rope to pass rope_theta=float(rope_theta) into the jax.ffi.ffi_call
    invocation and also ensure any callers that forward rope_theta (e.g.,
    prefill_layer and decode_layer) forward float(rope_theta) instead of the raw
    int; locate these by the symbols apply_rope, prefill_layer, decode_layer and the
    constants ROPE_THETA_LOCAL / ROPE_THETA_GLOBAL and replace the direct int usage
    with float(...) to avoid scalar-type errors.

In @examples/jax_tvm_ffi/gemma3_flashinfer_jax.py:

  • Around line 284-301: The apply_rope wrapper should ensure rope_theta is
    keyword-only and passed to the TVM FFI as a Python float; update the function
    signature to def apply_rope(q, k, indptr, offsets, *, rope_theta=1e4) and
    convert rope_theta to float (e.g., rope_theta = float(rope_theta)) before
    calling jax.ffi.ffi_call so the value arriving at the "flashinfer.apply_rope"
    FFI matches the expected C double type.

In @examples/jax_tvm_ffi/README.md:

  • Around line 14-34: Update the README installation commands so the CUDA
    variants are consistent with the "CUDA 11.8+" requirement: show paired
    substitutions for both JAX and the FlashInfer wheel (e.g., replace 'jax[cuda13]'
    with the matching 'jax[cuda12]' or 'jax[cuda11]' and change --extra-index-url
    https://flashinfer.ai/whl/cu130/ to the corresponding wheel suffix like cu124 or
    cu118). Modify the install block that currently uses 'pip install
    'jax[cuda13]'' and the FlashInfer URL '--extra-index-url
    https://flashinfer.ai/whl/cu130/' to include explicit examples for CUDA 11.8,
    12.x, and 13 (or state a template such as jax[cuda] and
    /whl/cu/) so users know to change both places together.

Nitpick comments:
In @examples/jax_tvm_ffi/flashinfer_jax_tvm_ffi.ipynb:

  • Around line 115-148: The CUDA_HOME auto-detection block in this cell uses
    subprocess but subprocess is not imported here, causing a NameError if this cell
    runs standalone; to fix it, add an import subprocess alongside the other imports
    at the top of the cell (near import os, time, math, jinja2, numpy) or merge this
    cell with the earlier setup cell so that subprocess is guaranteed to be defined
    before the CUDA_HOME detection code in the block that references nvcc via
    subprocess.check_output.

In @examples/jax_tvm_ffi/gemma3_flashinfer_jax.py:

  • Around line 38-194: The module performs heavy side-effect work at import
    (compute-cap detection via get_compute_capability, HF cache setup via
    _ensure_writable_hf_cache, HF_TOKEN prompting, snapshot_download and weights
    loading into weights, and kernel compilation/registration) which should be
    deferred; refactor by moving all top-level side-effect logic into a guarded
    main() (or into explicit functions like prepare_environment(),
    authenticate_hf(), download_and_load_weights(), compile_and_register_kernels())
    and ensure only lightweight definitions (functions/classes/constants) remain at
    import time, then call main() under if name == "main": so importing the
    module no longer triggers downloads, getpass prompts, or nvcc builds.

In @examples/jax_tvm_ffi/README.md:

  • Around line 89-93: The fenced ASCII diagram block containing the lines
    starting with "Step 1 BUILD & LOAD", "Step 2 REGISTER", and "Step 3 CALL"
    should include a language tag to satisfy markdownlint MD040; change the opening
    fence from totext (or ```none) so the block becomes a labeled code fence
    and the diagram remains unchanged.

</details>

<details>
<summary>🪄 Autofix (Beta)</summary>

Fix all unresolved CodeRabbit comments on this PR:

- [ ] <!-- {"checkboxId": "4b0d0e0a-96d7-4f10-b296-3a18ea78f0b9"} --> Push a commit to this branch (recommended)
- [ ] <!-- {"checkboxId": "ff5b1114-7d8c-49e6-8ac1-43f82af23a33"} --> Create a new PR with the fixes

</details>

---

<details>
<summary>ℹ️ Review info</summary>

<details>
<summary>⚙️ Run configuration</summary>

**Configuration used**: defaults

**Review profile**: CHILL

**Plan**: Pro

**Run ID**: `134ea15d-524d-4714-a813-27ee015b2b8b`

</details>

<details>
<summary>📥 Commits</summary>

Reviewing files that changed from the base of the PR and between a99ee72fcdf7f2d8579877ef4ea7c5b2da34f29d and eeab46a4ad0f00f5f88aa944abf874f8acd4431c.

</details>

<details>
<summary>📒 Files selected for processing (6)</summary>

* `examples/README.md`
* `examples/jax_tvm_ffi/README.md`
* `examples/jax_tvm_ffi/flashinfer_jax_tvm_ffi.ipynb`
* `examples/jax_tvm_ffi/flashinfer_jax_tvm_ffi.py`
* `examples/jax_tvm_ffi/gemma3_flashinfer_jax.ipynb`
* `examples/jax_tvm_ffi/gemma3_flashinfer_jax.py`

</details>

</details>

<!-- This is an auto-generated comment by CodeRabbit for review status -->

Comment thread examples/jax_tvm_ffi/gemma3_flashinfer_jax.ipynb Outdated
Comment thread examples/jax_tvm_ffi/gemma3_flashinfer_jax.ipynb Outdated
Comment thread examples/jax_tvm_ffi/gemma3_flashinfer_jax.ipynb Outdated
Comment thread examples/jax_tvm_ffi/gemma3_flashinfer_jax.py Outdated
Comment thread examples/jax_tvm_ffi/README.md Outdated
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (6)
examples/jax_tvm_ffi/flashinfer_jax_tvm_ffi.ipynb (1)

800-803: Consistency nit: np.dtype vs jnp.dtype.

The sibling script flashinfer_jax_tvm_ffi.py at line 383 uses jnp.dtype(q.dtype).itemsize while this notebook uses np.dtype(...). Both work for standard float dtypes, but aligning the two keeps the tutorial and script consistent for readers switching between them.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/jax_tvm_ffi/flashinfer_jax_tvm_ffi.ipynb` around lines 800 - 803,
Consistency nit: replace np.dtype usage with jnp.dtype to match the sibling
script. In the block where sm_scale and tmp_elems are computed (variables
sm_scale, tmp_elems, and q), change np.dtype(q.dtype).itemsize to
jnp.dtype(q.dtype).itemsize so the notebook uses JAX's dtype utility like
flashinfer_jax_tvm_ffi.py and remains consistent for readers switching between
the script and notebook.
examples/jax_tvm_ffi/gemma3_flashinfer_jax.py (2)

305-305: Hardcoded bfloat16 byte-size assumption in _TMP_ELEMS.

_TMP_ELEMS = 32 * 1024 * 1024 // 2 hardcodes the 2-byte element size. The script is bfloat16-only today, but this will silently produce a wrong-sized scratch buffer if anyone retargets it to float32. A self-documenting computation would avoid the footgun:

♻️ Proposed diff
-_TMP_ELEMS = 32 * 1024 * 1024 // 2
+_TMP_ELEMS = 32 * 1024 * 1024 // jnp.dtype(jnp.bfloat16).itemsize
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/jax_tvm_ffi/gemma3_flashinfer_jax.py` at line 305, The constant
_TMP_ELEMS currently assumes 2 bytes per element (bfloat16) causing incorrect
scratch buffer sizing if element dtype changes; update the computation to derive
element size dynamically (e.g., from the array dtype or a bytes-per-element
constant) and compute _TMP_ELEMS = (32 * 1024 * 1024) // bytes_per_element so it
self-documents and works for bfloat16, float32, etc.; locate and replace the
literal expression where _TMP_ELEMS is defined in gemma3_flashinfer_jax.py to
use the dynamic bytes_per_element value (or a named helper) rather than
hardcoded “// 2”.

102-122: Swallowed exception on HF token lookup loses diagnostic info.

Ruff flags S110/BLE001 here. huggingface_hub.get_token() will raise on rare cases (corrupted token file, unexpected permissions), and silently falling through to the interactive prompt hides the root cause. Consider at least narrowing the catch and surfacing it to the user:

♻️ Proposed diff
     try:
         from huggingface_hub import get_token

         HF_TOKEN = get_token() or ""
-    except Exception:
-        pass
+    except Exception as e:
+        print(f"Warning: failed to read cached HF token ({e!r}); falling back to prompt.")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/jax_tvm_ffi/gemma3_flashinfer_jax.py` around lines 102 - 122, The
current try/except around huggingface_hub.get_token() swallows all exceptions
and hides diagnostics; update the block that calls get_token() (related symbols:
HF_TOKEN and get_token) to catch only expected exceptions or catch Exception as
e and surface the error (e.g., log or re-raise a RuntimeError including the
original exception message) instead of silently passing, so failures like
corrupted token files or permission errors are visible to the user and preserved
as the cause of the failure.
examples/jax_tvm_ffi/flashinfer_jax_tvm_ffi.py (3)

479-487: Benchmark measures wall time including Python/dispatch overhead.

time.perf_counter() around block_until_ready captures host-side dispatch and JAX sync overhead in addition to kernel time. For a single µs-scale decode kernel this is usually dominated by dispatch latency, which can mislead readers comparing it to FlashInfer's PyTorch benchmarks (which typically use CUDA events). Consider either:

  • Running more iterations (e.g. N=1000) and noting the measurement includes Python dispatch, or
  • Adding a brief comment that this is end-to-end latency including host overhead.

Not blocking for an example.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/jax_tvm_ffi/flashinfer_jax_tvm_ffi.py` around lines 479 - 487, The
benchmark currently measures decode_attention latency using time.perf_counter()
around block_until_ready which includes Python host/dispatch overhead; either
increase the iteration count (e.g. bump N from 100 to 1000) to amortize dispatch
or add a short comment next to the decode_attention(...).block_until_ready()
loop stating that the reported us includes host-side dispatch and JAX sync
overhead (so results are end-to-end, not pure kernel time); update the N
variable and/or the print comment in the same block where decode_attention,
block_until_ready, time.perf_counter, and N are used.

31-37: Minor: nvidia-smi invoked via partial executable path.

Ruff flags S607 at lines 33 and 50. For example scripts this is acceptable, but the script will fail with a somewhat confusing FileNotFoundError if nvidia-smi/which are not on PATH. Consider wrapping the subprocess.check_output call in a clearer error message (e.g. "nvidia-smi not found on PATH; is a CUDA driver installed?") so that users running in a minimal container get actionable feedback rather than a raw stack trace.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/jax_tvm_ffi/flashinfer_jax_tvm_ffi.py` around lines 31 - 37, The
call to subprocess.check_output in get_compute_capability can raise a
FileNotFoundError with an opaque stack trace if nvidia-smi is missing; catch
FileNotFoundError around the subprocess.check_output call in
get_compute_capability and raise (or log) a clearer RuntimeError with a message
like "nvidia-smi not found on PATH; is the CUDA driver installed?" so users get
actionable feedback instead of the raw FileNotFoundError; reference the
get_compute_capability function and the subprocess.check_output invocation when
applying the change.

438-459: k_r from the outer apply_rope call is unused.

Line 458 unpacks q_r, k_r = apply_rope(...) but only q_r is consumed on line 459. Rename to _k_r (Ruff RUF059) to make the intent explicit, mirroring the unused unpack inside decode_step at line 449.

♻️ Proposed diff
-q_r, k_r = apply_rope(q_new, k_new, indptr, offsets)
+q_r, _k_r = apply_rope(q_new, k_new, indptr, offsets)
 attn_ref = decode_attention(q_r.reshape(NUM_QO, HEAD_DIM), k, v)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/jax_tvm_ffi/flashinfer_jax_tvm_ffi.py` around lines 438 - 459, The
outer call to apply_rope unpacks q_r, k_r but k_r is unused—rename k_r to _k_r
in the outer scope to indicate it's intentionally unused (mirror the unused k_r
inside decode_step); update the line that currently reads "q_r, k_r =
apply_rope(...)" to "q_r, _k_r = apply_rope(...)" and keep all other uses of q_r
the same (no changes to decode_step, apply_rope, silu_and_mul, or
decode_attention).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@examples/jax_tvm_ffi/flashinfer_jax_tvm_ffi.ipynb`:
- Around line 113-148: The setup cell is missing "import subprocess", causing a
NameError when subprocess.check_output() is called; add "import subprocess" near
the top imports (alongside import os, time, math, jinja2, numpy) so the
subprocess symbol is always defined before the CUDA_HOME detection block that
calls subprocess.check_output; ensure the import appears before the if
'CUDA_HOME' not in os.environ: block that uses subprocess.

In `@examples/jax_tvm_ffi/gemma3_flashinfer_jax.ipynb`:
- Around line 924-928: The current initialization of _STOP_IDS can include None
if tokenizer.eos_token_id is None; change the logic in the block that builds
_STOP_IDS to only add tokenizer.eos_token_id when it is not None and not equal
to tokenizer.unk_token_id, and keep the existing guard when converting special
tokens via tokenizer.convert_tokens_to_ids (i.e., only add _id when _id is not
None and _id != tokenizer.unk_token_id). Update the code that references
_STOP_IDS, tokenizer.eos_token_id, tokenizer.unk_token_id, convert_tokens_to_ids
and the loop over _tok to follow this guarded-add pattern so next_tok in
_STOP_IDS won't see a None entry.

In `@examples/jax_tvm_ffi/gemma3_flashinfer_jax.py`:
- Around line 697-702: The _STOP_IDS set currently seeds with
tokenizer.eos_token_id which may be None; change the construction so you start
with an empty set, then add tokenizer.eos_token_id only if it is not None (and
optionally cast to int), and retain the existing logic for converted token IDs
from tokenizer.convert_tokens_to_ids ensuring _id is not None and _id !=
tokenizer.unk_token_id before adding; this prevents None being in _STOP_IDS and
fixes the subsequent next_tok in _STOP_IDS check used elsewhere.

---

Nitpick comments:
In `@examples/jax_tvm_ffi/flashinfer_jax_tvm_ffi.ipynb`:
- Around line 800-803: Consistency nit: replace np.dtype usage with jnp.dtype to
match the sibling script. In the block where sm_scale and tmp_elems are computed
(variables sm_scale, tmp_elems, and q), change np.dtype(q.dtype).itemsize to
jnp.dtype(q.dtype).itemsize so the notebook uses JAX's dtype utility like
flashinfer_jax_tvm_ffi.py and remains consistent for readers switching between
the script and notebook.

In `@examples/jax_tvm_ffi/flashinfer_jax_tvm_ffi.py`:
- Around line 479-487: The benchmark currently measures decode_attention latency
using time.perf_counter() around block_until_ready which includes Python
host/dispatch overhead; either increase the iteration count (e.g. bump N from
100 to 1000) to amortize dispatch or add a short comment next to the
decode_attention(...).block_until_ready() loop stating that the reported us
includes host-side dispatch and JAX sync overhead (so results are end-to-end,
not pure kernel time); update the N variable and/or the print comment in the
same block where decode_attention, block_until_ready, time.perf_counter, and N
are used.
- Around line 31-37: The call to subprocess.check_output in
get_compute_capability can raise a FileNotFoundError with an opaque stack trace
if nvidia-smi is missing; catch FileNotFoundError around the
subprocess.check_output call in get_compute_capability and raise (or log) a
clearer RuntimeError with a message like "nvidia-smi not found on PATH; is the
CUDA driver installed?" so users get actionable feedback instead of the raw
FileNotFoundError; reference the get_compute_capability function and the
subprocess.check_output invocation when applying the change.
- Around line 438-459: The outer call to apply_rope unpacks q_r, k_r but k_r is
unused—rename k_r to _k_r in the outer scope to indicate it's intentionally
unused (mirror the unused k_r inside decode_step); update the line that
currently reads "q_r, k_r = apply_rope(...)" to "q_r, _k_r = apply_rope(...)"
and keep all other uses of q_r the same (no changes to decode_step, apply_rope,
silu_and_mul, or decode_attention).

In `@examples/jax_tvm_ffi/gemma3_flashinfer_jax.py`:
- Line 305: The constant _TMP_ELEMS currently assumes 2 bytes per element
(bfloat16) causing incorrect scratch buffer sizing if element dtype changes;
update the computation to derive element size dynamically (e.g., from the array
dtype or a bytes-per-element constant) and compute _TMP_ELEMS = (32 * 1024 *
1024) // bytes_per_element so it self-documents and works for bfloat16, float32,
etc.; locate and replace the literal expression where _TMP_ELEMS is defined in
gemma3_flashinfer_jax.py to use the dynamic bytes_per_element value (or a named
helper) rather than hardcoded “// 2”.
- Around line 102-122: The current try/except around huggingface_hub.get_token()
swallows all exceptions and hides diagnostics; update the block that calls
get_token() (related symbols: HF_TOKEN and get_token) to catch only expected
exceptions or catch Exception as e and surface the error (e.g., log or re-raise
a RuntimeError including the original exception message) instead of silently
passing, so failures like corrupted token files or permission errors are visible
to the user and preserved as the cause of the failure.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: f0942cbb-ac2d-41e9-acd4-e560fe59b092

📥 Commits

Reviewing files that changed from the base of the PR and between eeab46a and e4b5da0.

📒 Files selected for processing (4)
  • examples/jax_tvm_ffi/flashinfer_jax_tvm_ffi.ipynb
  • examples/jax_tvm_ffi/flashinfer_jax_tvm_ffi.py
  • examples/jax_tvm_ffi/gemma3_flashinfer_jax.ipynb
  • examples/jax_tvm_ffi/gemma3_flashinfer_jax.py

Comment thread examples/jax_tvm_ffi/flashinfer_jax_tvm_ffi.ipynb Outdated
Comment thread examples/jax_tvm_ffi/gemma3_flashinfer_jax.ipynb Outdated
Comment thread examples/jax_tvm_ffi/gemma3_flashinfer_jax.py Outdated
@katjasrz katjasrz force-pushed the flashinfer-jax-demo branch from e4b5da0 to 1a4c0e7 Compare April 17, 2026 18:04
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (6)
examples/jax_tvm_ffi/flashinfer_jax_tvm_ffi.py (2)

40-42: assert is skipped under python -O.

This SM-version gate is a hard prerequisite. Since the script will be run with python …, it almost always holds — but if anyone ever invokes python -O flashinfer_jax_tvm_ffi.py, the assertion is elided and the subsequent kernels will surface far less helpful errors. A raise RuntimeError(...) would be more robust.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/jax_tvm_ffi/flashinfer_jax_tvm_ffi.py` around lines 40 - 42, Replace
the assert check on SM_MAJOR/SM_MINOR with an explicit runtime error so the
guard is not skipped under python -O: locate the SM version gate using the
symbols SM_MAJOR and SM_MINOR (the current lines using assert (SM_MAJOR,
SM_MINOR) >= (7, 5)), and change it to raise a RuntimeError with the same
message (e.g., "SM X.Y is below the minimum SM 7.5 (Turing)") so the script
always fails fast when the GPU SM is too low.

445-454: k_r is discarded — consider renaming or adding a comment.

In a real decode step, the RoPE'd k_r would be appended to the KV-cache and then attended over, not the raw pre-RoPE k_cache that's currently used. For the demo this is fine (the composition still validates), but to avoid misleading readers following this as a template, either rename k_r to _ or add a one-line comment that this is a demonstrative composition (not a faithful decode step).

Proposed diff
-    q_r, k_r = apply_rope(q_new, k_new, indptr, offsets)
+    # Note: in a real decode step, k_r would be appended to k_cache before attention;
+    # here we reuse the raw cache for brevity to demonstrate XLA composition only.
+    q_r, _k_r = apply_rope(q_new, k_new, indptr, offsets)
     attn_out = decode_attention(q_r.reshape(NUM_QO, HEAD_DIM), k_cache, v_cache)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/jax_tvm_ffi/flashinfer_jax_tvm_ffi.py` around lines 445 - 454, The
variable k_r produced by apply_rope in decode_step is computed but never used
(the function still attends over the original k_cache), which can mislead
readers; either rename k_r to _ to indicate it's intentionally unused, or update
decode_step to use the RoPE'd keys when calling decode_attention (or add a
one-line comment inside decode_step clarifying this is a demonstrative
composition and not a faithful KV-cache update). Locate the apply_rope call and
the decode_step function (symbols: decode_step, apply_rope, k_r, k_cache,
decode_attention) and make the rename or comment change (or adjust the call to
pass k_r into decode_attention) accordingly.
examples/jax_tvm_ffi/flashinfer_jax_tvm_ffi.ipynb (1)

121-122: Mirror setdefault(...) from the standalone script.

os.environ["XLA_FLAGS"] = "..." overwrites any pre-existing XLA flags a reader may have set. The .py counterpart uses setdefault. Same pattern is worth applying here.

Proposed diff
-"os.environ[\"TF_CPP_MIN_LOG_LEVEL\"] = \"2\"           # suppress TF/XLA info & warnings\n",
-"os.environ[\"XLA_FLAGS\"] = \"--xla_gpu_cuda_data_dir=/usr/local/cuda\"\n",
+"os.environ.setdefault(\"TF_CPP_MIN_LOG_LEVEL\", \"2\")  # suppress TF/XLA info & warnings\n",
+"os.environ.setdefault(\"XLA_FLAGS\", \"--xla_gpu_cuda_data_dir=/usr/local/cuda\")\n",
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/jax_tvm_ffi/flashinfer_jax_tvm_ffi.ipynb` around lines 121 - 122,
The notebook currently overwrites any existing XLA flags by setting
os.environ["XLA_FLAGS"] directly; change this to use
os.environ.setdefault("XLA_FLAGS", "--xla_gpu_cuda_data_dir=/usr/local/cuda")
(mirror the standalone script pattern) so pre-existing XLA_FLAGS are preserved;
update the cell containing os.environ["XLA_FLAGS"] = "..." and replace it with a
setdefault call referencing the same key ("XLA_FLAGS").
examples/jax_tvm_ffi/gemma3_flashinfer_jax.py (1)

105-110: Silent except Exception: pass hides HF-token errors.

If get_token() (or its import) fails for any non-missing reason, the error is silently swallowed and the user ends up at the interactive prompt with no hint why. Narrowing to the expected ImportError (and optionally logging anything else) improves diagnosability.

Proposed diff
     try:
         from huggingface_hub import get_token
 
         HF_TOKEN = get_token() or ""
-    except Exception:
-        pass
+    except ImportError:
+        pass
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/jax_tvm_ffi/gemma3_flashinfer_jax.py` around lines 105 - 110, The
current try/except around importing and calling huggingface_hub.get_token()
swallows all exceptions; change it to catch ImportError for the missing module
and set HF_TOKEN = "" in that branch, and add a separate except Exception as e
branch that logs or warns about the unexpected failure (including the exception)
rather than silently passing; adjust the block referencing get_token and
HF_TOKEN so ImportError is handled silently but other errors are surfaced via
logging.warn/logging.exception or re-raising as appropriate.
examples/jax_tvm_ffi/gemma3_flashinfer_jax.ipynb (1)

127-128: Prefer os.environ.setdefault(...) for consistency with the .py version.

The standalone script (gemma3_flashinfer_jax.py lines 52–53) uses setdefault so user-configured values are preserved. The notebook unconditionally overwrites TF_CPP_MIN_LOG_LEVEL and XLA_FLAGS, which can silently clobber flags a reader has pre-set (e.g. --xla_gpu_deterministic_ops).

Proposed diff
-"os.environ[\"TF_CPP_MIN_LOG_LEVEL\"] = \"2\"           # suppress TF/XLA info & warnings\n",
-"os.environ[\"XLA_FLAGS\"] = \"--xla_gpu_cuda_data_dir=/usr/local/cuda\"\n",
+"os.environ.setdefault(\"TF_CPP_MIN_LOG_LEVEL\", \"2\")  # suppress TF/XLA info & warnings\n",
+"os.environ.setdefault(\"XLA_FLAGS\", \"--xla_gpu_cuda_data_dir=/usr/local/cuda\")\n",
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/jax_tvm_ffi/gemma3_flashinfer_jax.ipynb` around lines 127 - 128, The
notebook currently unconditionally overwrites environment variables
TF_CPP_MIN_LOG_LEVEL and XLA_FLAGS; change those assignments to use
os.environ.setdefault(...) (matching the approach used in
gemma3_flashinfer_jax.py) so any user-preconfigured values are preserved —
locate the lines that set TF_CPP_MIN_LOG_LEVEL and XLA_FLAGS and replace the
direct assignments with setdefault calls.
examples/jax_tvm_ffi/README.md (1)

89-93: Add a language to the fenced block for MD040.

markdownlint flags the fenced block at line 89. Tagging it as text (or similar) silences the warning without changing rendering.

Proposed diff
-```
+```text
 Step 1  BUILD & LOAD   jit_spec.build_and_load()  ->  tvm_ffi.Module
 Step 2  REGISTER       jax_tvm_ffi.register_ffi_target(name, wrapper, arg_spec)
 Step 3  CALL           jax.ffi.ffi_call(name, output_shapes)(*inputs, **scalar_attrs)
</details>

<details>
<summary>🤖 Prompt for AI Agents</summary>

Verify each finding against the current code and only fix it if needed.

In @examples/jax_tvm_ffi/README.md around lines 89 - 93, The fenced code block
containing the three steps ("Step 1 BUILD & LOAD jit_spec.build_and_load()
-> tvm_ffi.Module", "Step 2 REGISTER
jax_tvm_ffi.register_ffi_target(name, wrapper, arg_spec)", "Step 3 CALL
jax.ffi.ffi_call(name, output_shapes)(*inputs, **scalar_attrs)") should include
a language tag to satisfy markdownlint MD040; update the fence from ``` to

labeled and the linter warning is silenced.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@examples/jax_tvm_ffi/flashinfer_jax_tvm_ffi.ipynb`:
- Around line 121-122: The notebook currently overwrites any existing XLA flags
by setting os.environ["XLA_FLAGS"] directly; change this to use
os.environ.setdefault("XLA_FLAGS", "--xla_gpu_cuda_data_dir=/usr/local/cuda")
(mirror the standalone script pattern) so pre-existing XLA_FLAGS are preserved;
update the cell containing os.environ["XLA_FLAGS"] = "..." and replace it with a
setdefault call referencing the same key ("XLA_FLAGS").

In `@examples/jax_tvm_ffi/flashinfer_jax_tvm_ffi.py`:
- Around line 40-42: Replace the assert check on SM_MAJOR/SM_MINOR with an
explicit runtime error so the guard is not skipped under python -O: locate the
SM version gate using the symbols SM_MAJOR and SM_MINOR (the current lines using
assert (SM_MAJOR, SM_MINOR) >= (7, 5)), and change it to raise a RuntimeError
with the same message (e.g., "SM X.Y is below the minimum SM 7.5 (Turing)") so
the script always fails fast when the GPU SM is too low.
- Around line 445-454: The variable k_r produced by apply_rope in decode_step is
computed but never used (the function still attends over the original k_cache),
which can mislead readers; either rename k_r to _ to indicate it's intentionally
unused, or update decode_step to use the RoPE'd keys when calling
decode_attention (or add a one-line comment inside decode_step clarifying this
is a demonstrative composition and not a faithful KV-cache update). Locate the
apply_rope call and the decode_step function (symbols: decode_step, apply_rope,
k_r, k_cache, decode_attention) and make the rename or comment change (or adjust
the call to pass k_r into decode_attention) accordingly.

In `@examples/jax_tvm_ffi/gemma3_flashinfer_jax.ipynb`:
- Around line 127-128: The notebook currently unconditionally overwrites
environment variables TF_CPP_MIN_LOG_LEVEL and XLA_FLAGS; change those
assignments to use os.environ.setdefault(...) (matching the approach used in
gemma3_flashinfer_jax.py) so any user-preconfigured values are preserved —
locate the lines that set TF_CPP_MIN_LOG_LEVEL and XLA_FLAGS and replace the
direct assignments with setdefault calls.

In `@examples/jax_tvm_ffi/gemma3_flashinfer_jax.py`:
- Around line 105-110: The current try/except around importing and calling
huggingface_hub.get_token() swallows all exceptions; change it to catch
ImportError for the missing module and set HF_TOKEN = "" in that branch, and add
a separate except Exception as e branch that logs or warns about the unexpected
failure (including the exception) rather than silently passing; adjust the block
referencing get_token and HF_TOKEN so ImportError is handled silently but other
errors are surfaced via logging.warn/logging.exception or re-raising as
appropriate.

In `@examples/jax_tvm_ffi/README.md`:
- Around line 89-93: The fenced code block containing the three steps ("Step 1 
BUILD & LOAD   jit_spec.build_and_load()  ->  tvm_ffi.Module", "Step 2  REGISTER
jax_tvm_ffi.register_ffi_target(name, wrapper, arg_spec)", "Step 3  CALL        
jax.ffi.ffi_call(name, output_shapes)(*inputs, **scalar_attrs)") should include
a language tag to satisfy markdownlint MD040; update the fence from ``` to
```text (or another appropriate tag) in README.md so the block is explicitly
labeled and the linter warning is silenced.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 7a795534-2179-467f-84ec-ba3d898c53bf

📥 Commits

Reviewing files that changed from the base of the PR and between e4b5da0 and 1a4c0e7.

📒 Files selected for processing (6)
  • examples/README.md
  • examples/jax_tvm_ffi/README.md
  • examples/jax_tvm_ffi/flashinfer_jax_tvm_ffi.ipynb
  • examples/jax_tvm_ffi/flashinfer_jax_tvm_ffi.py
  • examples/jax_tvm_ffi/gemma3_flashinfer_jax.ipynb
  • examples/jax_tvm_ffi/gemma3_flashinfer_jax.py
✅ Files skipped from review due to trivial changes (1)
  • examples/README.md

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

♻️ Duplicate comments (2)
examples/jax_tvm_ffi/gemma3_flashinfer_jax.ipynb (2)

899-906: ⚠️ Potential issue | 🟡 Minor

Guard eos_token_id before adding it to _STOP_IDS.

Line 902 can add None to _STOP_IDS when the tokenizer has no EOS token. Keep the same guarded-add pattern used for converted special tokens.

Proposed fix
-_STOP_IDS = {tokenizer.eos_token_id}
+_STOP_IDS = set()
+if tokenizer.eos_token_id is not None and tokenizer.eos_token_id != tokenizer.unk_token_id:
+    _STOP_IDS.add(tokenizer.eos_token_id)
 for _tok in ['<end_of_turn>', '<eos>']:
     _id = tokenizer.convert_tokens_to_ids(_tok)
     if _id is not None and _id != tokenizer.unk_token_id:
         _STOP_IDS.add(_id)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/jax_tvm_ffi/gemma3_flashinfer_jax.ipynb` around lines 899 - 906, The
code adds tokenizer.eos_token_id to _STOP_IDS without guarding for None; change
the logic in the stop-token block so you only add tokenizer.eos_token_id if it
is not None and not equal to tokenizer.unk_token_id (same guard used for tokens
converted via tokenizer.convert_tokens_to_ids), i.e., check the value before
calling _STOP_IDS.add(...) in the section that initializes _STOP_IDS and before
the loop that converts '<end_of_turn>' and '<eos>'.

334-337: ⚠️ Potential issue | 🟡 Minor

Normalize RoPE theta to float before passing it as an FFI attribute.

ROPE_THETA_LOCAL / ROPE_THETA_GLOBAL are currently ints, and apply_rope forwards rope_theta unchanged. Cast at the wrapper boundary so both prefill and decode pass the scalar type expected by the TVM FFI path.

Proposed fix
-ROPE_THETA_LOCAL  = int(cfg.get('rope_local_base_freq', 10_000))
-ROPE_THETA_GLOBAL = int(cfg.get('rope_theta', 1_000_000))
+ROPE_THETA_LOCAL  = float(cfg.get('rope_local_base_freq', 10_000))
+ROPE_THETA_GLOBAL = float(cfg.get('rope_theta', 1_000_000))
@@
     )(q, k, indptr, offsets,
-      rotary_dim=q.shape[-1], interleave=False, rope_scale=1.0, rope_theta=rope_theta)
+      rotary_dim=q.shape[-1], interleave=False, rope_scale=1.0,
+      rope_theta=float(rope_theta))
What scalar Python type should jax-tvm-ffi / TVM FFI use for a double-valued attribute passed through jax.ffi.ffi_call?

Also applies to: 456-464, 716-738, 798-820

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/jax_tvm_ffi/gemma3_flashinfer_jax.ipynb` around lines 334 - 337,
ROPE_THETA_LOCAL and ROPE_THETA_GLOBAL are defined as ints but must be converted
to Python floats before being sent across the TVM FFI boundary; update the
wrapper(s) that call apply_rope / the jax.ffi.ffi_call so they cast rope_theta
(and any rope_local_base_freq-derived value) to float at the boundary for both
prefill and decode paths (e.g., where apply_rope is invoked and where
jax.ffi.ffi_call constructs the attribute dict), ensuring all FFI attributes use
float scalar types.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@examples/jax_tvm_ffi/gemma3_flashinfer_jax.ipynb`:
- Around line 1006-1013: The summary snippet uses a non-public symbol
gen_decode_jit_spec; replace it with the publicly exported
gen_single_decode_module used elsewhere (see earlier imports/usage) and pass the
same argument pattern as the working examples (e.g., dtype 'bfloat16',
max_batch/sequence size 256, and use_sliding_window flag). Update both
occurrences of gen_decode_jit_spec to gen_single_decode_module and ensure the
call signature and flags match the earlier prefill/decode examples so the
notebook uses the public flashinfer.jit API.
- Around line 948-960: The first sampled token is chosen unconditionally
(first_token = _sample(...)) which causes an extra decode step and generates a
token even when max_new_tokens == 0; fix by guarding and early-exiting: only
sample/print the first token if max_new_tokens > 0, and after sampling check if
first_token is in _STOP_IDS — if so, print/return immediately and do not enter
the decode loop or call decode_step; otherwise proceed to append to generated
and run the for-loop as before (referencing first_token, _sample, _STOP_IDS,
generated, decode_step, and max_new_tokens).
- Around line 125-146: Move the XLA environment setup and CUDA discovery so it
runs before importing jax: ensure subprocess is imported, detect CUDA_HOME
(using subprocess.check_output as currently written), set
os.environ['CUDA_HOME'] to the discovered path (fallback '/usr/local/cuda'),
then set os.environ['XLA_FLAGS'] to include the discovered CUDA path, and only
after that import jax and jax.numpy; update the cell to import subprocess at top
and reorder the CUDA detection/XLA_FLAGS assignment to precede the import of
jax.

---

Duplicate comments:
In `@examples/jax_tvm_ffi/gemma3_flashinfer_jax.ipynb`:
- Around line 899-906: The code adds tokenizer.eos_token_id to _STOP_IDS without
guarding for None; change the logic in the stop-token block so you only add
tokenizer.eos_token_id if it is not None and not equal to tokenizer.unk_token_id
(same guard used for tokens converted via tokenizer.convert_tokens_to_ids),
i.e., check the value before calling _STOP_IDS.add(...) in the section that
initializes _STOP_IDS and before the loop that converts '<end_of_turn>' and
'<eos>'.
- Around line 334-337: ROPE_THETA_LOCAL and ROPE_THETA_GLOBAL are defined as
ints but must be converted to Python floats before being sent across the TVM FFI
boundary; update the wrapper(s) that call apply_rope / the jax.ffi.ffi_call so
they cast rope_theta (and any rope_local_base_freq-derived value) to float at
the boundary for both prefill and decode paths (e.g., where apply_rope is
invoked and where jax.ffi.ffi_call constructs the attribute dict), ensuring all
FFI attributes use float scalar types.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: bd13bf32-0800-49af-bc7f-0382380fb7ce

📥 Commits

Reviewing files that changed from the base of the PR and between 4d07790 and eefb1eb.

📒 Files selected for processing (1)
  • examples/jax_tvm_ffi/gemma3_flashinfer_jax.ipynb

Comment thread examples/jax_tvm_ffi/gemma3_flashinfer_jax.ipynb Outdated
Comment thread examples/jax_tvm_ffi/gemma3_flashinfer_jax.ipynb Outdated
Comment thread examples/jax_tvm_ffi/gemma3_flashinfer_jax.ipynb Outdated
Copy link
Copy Markdown
Member

@kahyunnam kahyunnam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM overall, just one nit

Comment thread examples/jax_tvm_ffi/flashinfer_jax_tvm_ffi.ipynb Outdated
Comment thread examples/jax_tvm_ffi/flashinfer_jax_tvm_ffi.py Outdated
@yongwww
Copy link
Copy Markdown
Member

yongwww commented Apr 22, 2026

Thanks for the contribution! since the repo already uses Sphinx for documentation, would it make sense to integrate this content into the Sphinx docs (could under the tutorials: https://github.com/flashinfer-ai/flashinfer/tree/main/docs/tutorials) instead of keeping it as a standalone Jupyter notebook?

@yongwww yongwww added the run-ci label Apr 22, 2026
@kahyunnam
Copy link
Copy Markdown
Member

@katjasrz this is failing a pre-merge formatting check: https://github.com/flashinfer-ai/flashinfer/actions/runs/24689882948/job/72521293385?pr=3092 could you please take a look? You may have to rerun pre-commit run --all-files (in the PR description instructions "Pull Request Checklist" )

@katjasrz katjasrz force-pushed the flashinfer-jax-demo branch 2 times, most recently from b8e6347 to f5d397a Compare April 22, 2026 19:48
@katjasrz
Copy link
Copy Markdown
Contributor Author

The formatting should pass now

@katjasrz katjasrz force-pushed the flashinfer-jax-demo branch from f5d397a to e4204f8 Compare April 23, 2026 19:54
@katjasrz katjasrz force-pushed the flashinfer-jax-demo branch from e4204f8 to 9c3ce4d Compare April 23, 2026 20:02
@katjasrz
Copy link
Copy Markdown
Contributor Author

Thanks everyone for the guidance!

I updated the PR to integrate the tutorials into the existing Sphinx docs using Sphinx-Gallery. The canonical sources now live under docs/tutorials/jax_tvm_ffi/, and the docs build generates the rendered HTML pages plus downloadable .py and .ipynb versions from those same sources.

I also removed the standalone examples/ copy to avoid having two sources of truth.

I verified the docs build locally with:

sphinx-build -b html docs docs/_build/html -j auto

Please let me know if this structure looks aligned with the repo’s documentation conventions.

@yongwww yongwww enabled auto-merge (squash) April 23, 2026 20:40
@yongwww yongwww merged commit 2764147 into flashinfer-ai:main Apr 24, 2026
30 of 35 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants