Skip to content

feat(dllm): add Block Extend Attention for Diffusion LLM#2722

Open
fdz-1999 wants to merge 12 commits intoflashinfer-ai:mainfrom
fdz-1999:feature/block-extend
Open

feat(dllm): add Block Extend Attention for Diffusion LLM#2722
fdz-1999 wants to merge 12 commits intoflashinfer-ai:mainfrom
fdz-1999:feature/block-extend

Conversation

@fdz-1999
Copy link
Copy Markdown

@fdz-1999 fdz-1999 commented Mar 8, 2026

Motivation

What is Diffusion LLM (DLLM)?

Diffusion LLM (DLLM) is an emerging text generation paradigm. Unlike traditional Auto-Regressive LLMs that generate one token at a time, DLLM generates multiple tokens in parallel at the block level. In each iteration, all tokens within the current block are produced simultaneously through multi-step denoising. As a result, tokens within the same block require bidirectional visibility, while tokens in subsequent blocks must be completely invisible — this is the semantic origin of the Block Extend Mask.

Similarity Between Block Diffusion and Chunked Prefill

The execution flow of Block Diffusion closely resembles SGLang's existing Chunked Prefill — both split the full sequence into chunks and process them step by step:

Computation Phase Chunked Prefill Block Diffusion
Context query (Q_curr × KV_prev) Full Attention Full Attention (identical)
Intra-block query (Q_curr × KV_curr) Causal Attention Mask Full Attention (bidirectional)

The only difference lies in the intra-block query: Chunked Prefill uses a causal mask (earlier tokens cannot attend to later ones), whereas Block Diffusion requires bidirectional full attention since tokens within the same block are generated in parallel.

Why Native Kernel Support is Needed

Chunked Prefill is SGLang's default and well-established execution path. Given the strong similarity between Block Diffusion and Chunked Prefill, the initial approach was to reuse the Chunked Prefill path — splitting by DLLM block size and using causal=False for the current chunk to approximate the Block Extend mask.

However, this indirect Cascade Attention-based approach has fundamental limitations:

  1. chunk_size is locked to dllm_block_size: causal=False is only correct when the chunk exactly equals one complete DLLM block; larger chunk sizes that would reduce the number of iteration steps cannot be used
  2. 2–3 kernel launches per step (current chunk + prefix + merge_state), with CPU overhead growing linearly with the number of steps
  3. Unable to exploit the structural sparsity of the Block Extend mask: invisible KV tiles are still loaded and computed

An alternative path is to use Custom Mask (MaskMode::kCustom), but it requires O(qo_len × kv_len) of GPU memory to store the 2D mask tensor, which is completely infeasible for long sequences (a single request at seq_len=32K needs 1GB of mask memory).

Therefore, we implemented the native MaskMode::kBlockExpanding mask mode in FlashInfer, embedding the Block Extend mask semantics directly into the CUDA kernel:

  • Zero extra memory: the mask is computed via integer division in registers
  • Tile-level skip: invisible KV tiles are skipped at CTA granularity — zero load, zero compute
  • Single kernel: no need for Cascade's multiple kernel launches + state merge
  • Flexible chunk_size: chunk_size can be much larger than dllm_block_size, significantly reducing the number of iteration steps

📌 Description

Add DLLM Block Extend Attention with native MaskMode::kBlockExpanding tile-level skip optimization for Diffusion LLM inference.

Block Extend Mask Rule

image

mask[q, k] = floor(q_global / B) >= floor(k_global / B)

Same DLLM block tokens are bidirectionally visible; can see previous blocks but not subsequent blocks. This is the core mask mode for Diffusion LLM (DLLM).

Core API

Layer API Use Case
Single Prefill block_extend_attention_with_offset() Single request, incremental chunk prefill
Batch Ragged BatchBlockExtendRaggedOffsetWrapper Multi-request parallel, ragged KV
Batch Paged BatchBlockExtendPagedOffsetWrapper Multi-request parallel, paged KV cache
Cascade block_extend_cascade() / batch_block_extend_cascade() Current chunk + prefix + merge state

Why faster than Custom Mask

Dimension Custom Mask Block Extend
Memory O(qo_len × kv_len) per request 0 extra memory
Invisible tiles Load + compute + mask discard Skip entirely (0 overhead)
Fully visible tiles Load mask + element-wise check Direct MMA (0 mask check)
Mask source Global memory read Register computation (integer division)

Why faster than Cascade Attention (SGLang-style)

image
Dimension Cascade Block Extend
Kernel launches/step 2-3 (current + prefix + merge) 1
chunk_size constraint Must = dllm_block_size Any ≥ dllm_block_size
Steps (tokens=256, B=32) 256/32 = 8 steps chunk=256 → 1 step
SM utilization Small chunk → low occupancy Larger chunk → better SM saturation

Why faster than PyTorch Flex Attention

Dimension Flex Attention Block Extend
Tile granularity Fixed 128×128 Adaptive CTA tile
block_size < 128 Many PARTIAL tiles Precise block-level skip
KV Cache Dense BHSD only Ragged + Paged
Compilation torch.compile required AOT pre-compiled / JIT once
First-call latency Triton compile seconds~tens of seconds AOT zero / JIT ~1-2s

Compilation

  • AOT pre-compilation for all head_dim × dtype × backend combinations
  • JIT fallback when AOT not available

🔍 Related Issues

None

🚀 Pull Request Checklist

- [ ] Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit
  • I have installed the hooks with pre-commit install
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues

🧪 Tests

  • Correctness: block extend vs custom_mask reference, cascade vs blockwise extend validation
  • Performance: FlashInfer Block Extend vs PyTorch Flex Attention benchmark (1K-32K context sweep, memory comparison)

📊 Benchmark Results

⏳ Benchmark data to be added later.

1. Block Extend vs Custom Mask

Baseline: single_prefill_with_kv_cache(custom_mask=...) (FA2 only, FA3 not supported)

Case 1: Single-Request Incremental Prefill (tokens=8192, CUDA Graph)

Method chunk_size steps CG(ms) vs Custom Mask
Custom Mask (FA2) 32 256 14.171 1.00x
block_extend_attention_with_offset 32 256 5.995 2.48x
block_extend_attention_with_offset 64 128 3.484 4.07x
block_extend_attention_with_offset 128 64 2.376 5.96x
block_extend_attention_with_offset 256 32 1.888 7.50x

Case 2: Multi-Request BatchPrefill (256 reqs × 512 tokens, CUDA Graph)

Method chunk_size steps CG(ms) vs Custom Mask
Custom Mask (FA2) 32 16 6.399 1.00x
BatchBlockExtendRaggedOffsetWrapper 32 16 3.696 1.73x
BatchBlockExtendRaggedOffsetWrapper 64 8 3.290 1.95x
BatchBlockExtendRaggedOffsetWrapper 128 4 3.178 2.01x
BatchBlockExtendRaggedOffsetWrapper 256 2 3.179 2.01x

2. Block Extend vs Cascade Attention

Baseline: SGLang-style 3-stage Cascade (FA3)

Case 1: Single-Request (tokens=8192, CUDA Graph)

Method chunk_size steps CG(ms) vs Cascade
SGLang Cascade (FA3) 32 256 18.764 1.00x
block_extend_attention_with_offset 32 256 5.995 3.13x
block_extend_attention_with_offset 64 128 3.484 5.39x
block_extend_attention_with_offset 128 64 2.376 7.90x
block_extend_attention_with_offset 256 32 1.888 9.94x

Case 2: Multi-Request BatchPrefill (256 reqs × 512 tokens, CUDA Graph)

Method chunk_size steps CG(ms) vs Cascade
SGLang Cascade (FA3) 32 16 15.808 1.00x
BatchBlockExtendRaggedOffsetWrapper 32 16 3.696 4.28x
BatchBlockExtendRaggedOffsetWrapper 64 8 3.290 4.81x
BatchBlockExtendRaggedOffsetWrapper 128 4 3.178 4.98x
BatchBlockExtendRaggedOffsetWrapper 256 2 3.179 4.97x

Why Block Extend is faster: SGLang Cascade is locked to chunk_size = dllm_block_size = 32, because it uses causal=True to approximate the block extend mask — this equivalence only holds when chunk_size == dllm_block_size. Block Extend uses the true blockwise mask (q_block >= k_block), so it can use larger chunk sizes. Doubling the chunk_size halves the number of steps, which halves the total kernel launch count.


3. Block Extend vs PyTorch Flex Attention

Abbreviations:

  • FI: FlashInfer Block Extend Attention
  • FI CG: FlashInfer Block Extend Attention with CUDA Graph
  • Flex: Flex Attention with torch.compile()
  • Speedup: FI latency vs Flex latency
  • Mem Saved: Peak memory reduction vs Flex

Results are averaged across dllm_block_size ∈ {32, 64, 128, 256} as it has negligible impact on performance.

seq_len batch FI (ms) FI CG (ms) FI peak (MB) Flex (ms) Flex peak (MB) Speedup CG Speedup Mem Saved
2048 4 0.351 0.351 169 0.418 244 1.19x 1.19x +31%
4096 4 1.117 1.152 329 1.447 784 1.30x 1.26x +58%
8192 4 4.360 4.493 649 5.580 2752 1.28x 1.24x +76%
16384 1 4.526 4.561 329 5.462 2752 1.21x 1.20x +88%
24576 1 9.378 9.729 489 12.232 6048 1.31x 1.26x +92%
32768 1 17.276 17.318 649 21.972 10625 1.27x 1.27x +94%

Key Takeaways:

  • FlashInfer Block Extend achieves 1.19x–1.31x latency speedup over Flex Attention consistently.
  • Memory savings scale with sequence length: from +31% at 2K to +94% at 32K.
  • CUDA Graph adds minimal benefit here since the kernel itself is already efficient.

4. LLaDA2.0-flash-CAP End-to-End Benchmark Results

Environment: Integrated with SGLang on internal build, tested with LLaDA2.0-flash-CAP model.

  • Short prompt: 4K input tokens, 1.5K output tokens
  • Long prompt: 24K input tokens, 1.5K output tokens
  • Throughput (tokens/s): higher is better
  • TTFT (Time To First Token, ms): lower is better
  • TPOT (Time Per Output Token, ms): lower is better
Version Metric Short Prompt / Concurrency 1 Short Prompt / Concurrency 4 Long Prompt / Concurrency 1 Long Prompt / Concurrency 4
260209 (no cache) Throughput 460.60 1241.64 2789.68 5939.58
TTFT 269.55 818.44 794.71 2704.73
TPOT 8.01 11.62 5.80 9.39
251225 (no cache) Throughput 448.03 911.07 1396.97 2197.70
TTFT 1162.47 5826.76 6341.80 22274.96
TPOT 7.66 12.48 8.48 16.70

TTFT Optimization Highlights

The most significant improvement is in TTFT (Time To First Token), which directly reflects the prefill-stage kernel efficiency:

Scenario 251225 (no cache) 260209 (no cache) Speedup
Short Prompt / Concurrency 1 1162.47 ms 269.55 ms 4.31x
Short Prompt / Concurrency 4 5826.76 ms 818.44 ms 7.12x
Long Prompt / Concurrency 1 6341.80 ms 794.71 ms 7.98x
Long Prompt / Concurrency 4 22274.96 ms 2704.73 ms 8.23x

With cache-aware scheduling enabled in version 260209, TTFT is reduced by 4.3x–8.2x across all scenarios. The benefit scales with both prompt length and concurrency — under the most demanding setting (long prompt, concurrency 4), TTFT drops from 22.3s to 2.7s.

refs : sgl-project/sglang#12766

Summary by CodeRabbit

  • New Features

    • Block-expanding attention with per-request Q/K offsets and DLLM block-size support for tile-aware inference
    • Cascade helpers for multi-stage prefix + current-chunk processing and merged softmax states
    • Runtime mask-mode override and CUDA-graph–friendly planning for batch wrappers
    • New Python dllm package exports and high-level paged/ragged wrappers with backend selection
  • Tests

    • New benchmarks and correctness tests comparing block-extend attention against Flex Attention

yzh119 and others added 10 commits November 20, 2025 00:57
…lease CI

Merge branch fa2-fa3-opt of git@code.alipay.com:deep-xpu/flashinfer.git into main
https://code.alipay.com/deep-xpu/flashinfer/pull_requests/6

Reviewed-by: 明泓 <mingliang.gml@antgroup.com>


* feat(dllm,ci): add Block Expanding Attention & PyPI release CI
* build: add date suffix to ant-deepxpu-flashinfer-python version (0.5.3.20260202)
* fix(jit): compare base version only to allow date/cuda suffix
* drop clone api
* build flashinfer-jit-cache
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Mar 8, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds block‑expanding mask mode and per‑batch DLLM block offsets across CUDA/C++ kernels, parameter structs, JIT/AOT generation, and Python DLLM wrappers; implements offset‑aware block‑extend attention (single/batch, ragged/paged), cascade helpers, and a comprehensive FlashInfer vs Flex benchmark.

Changes

Cohort / File(s) Summary
C++ Jinja Parameter Templates
csrc/single_prefill_customize_config.jinja, csrc/single_prefill_sm90_customize_config.jinja, csrc/batch_prefill_customize_config.jinja, csrc/batch_prefill_sm90_customize_config.jinja
Add helper accessors for block‑expanding offsets and length getters (get_q*_len, get_q_block_expanding_offset, get_kv_block_expanding_offset) with conditional template blocks for optional params.
Core CUDA headers — block expanding
include/flashinfer/attention/block_expanding_prefill.cuh
New BlockExpandingTileSkipController and free helpers to compute KV valid end, iteration counts, mask iteration, and tile visibility/masking utilities.
CUDA headers — params & mask mode
include/flashinfer/attention/default_prefill_params.cuh, include/flashinfer/attention/mask.cuh, include/flashinfer/utils.cuh
Add dllm_block_size, q/kv block‑expanding offsets to Single/Batch params with safe accessors; introduce kBlockExpanding/BLOCK_EXPANDING mask mode and add traits to detect new AdditionalParams fields; extend DISPATCH_MASK_MODE.
CUDA headers — mainloops / MMA / prefill
include/flashinfer/attention/hopper/mainloop.cuh, include/flashinfer/attention/hopper/mainloop_mma.cuh, include/flashinfer/attention/hopper/sparse_mainloop.cuh, include/flashinfer/attention/hopper/prefill_sm90.cuh, include/flashinfer/attention/prefill.cuh
Introduce BLOCK_EXPANDING template flag, propagate batch_idx to tile-counting and MMA entrypoints, and integrate per‑batch offsets into num_kv_tiles and masking logic; wire BLOCK_EXPANDING through dispatch/instantiation paths.
Python — DLLM block‑extend modules
flashinfer/dllm/block_extend.py, flashinfer/dllm/batch_block_extend.py
New modules implementing offset‑aware block‑extend attention: backend selection (FA2/FA3), AOT/JIT checks and module URI helpers, paged/ragged Batch wrappers (plan/run, CUDA Graph handling), and cascade composition helpers.
Python — package & exports
flashinfer/__init__.py, flashinfer/dllm/__init__.py
Expose new dllm submodule and re‑export block‑extend APIs, wrappers, and variant declaration symbols at package level.
Python — JIT / prefill plumbing & utils
flashinfer/jit/attention/modules.py, flashinfer/jit/utils.py, flashinfer/prefill.py, flashinfer/utils.py
Add optional mask_modes to codegen, map new mask mode literal, thread optional mask_mode through planner/run paths, and add MaskMode.BLOCK_EXPANDING enum/value.
Tests — benchmarks & correctness
tests/attention/test_dllm_vs_flex_attention.py
Large new benchmark/test comparing FlashInfer block‑extend (ragged/paged, with/without CUDA Graph) to PyTorch flex_attention; includes reference compute, memory/timing helpers, multiple sweep modes, verification, and reporting.

Sequence Diagram(s)

sequenceDiagram
    participant User as User Code
    participant Cascade as block_extend_cascade
    participant Ragged as Ragged Kernel (current chunk)
    participant Paged as Paged Kernel (prefix)
    participant Device as CUDA Device

    User->>Cascade: q, k_current, v_current, k_prefix?, v_prefix?, offsets
    Cascade->>Ragged: Stage1: run block‑extend (q_offsets, kv_offsets)
    Ragged->>Device: execute kernel (block‑expanding mask)
    Device-->>Ragged: O1, LSE1
    Ragged-->>Cascade: return O1, LSE1

    Cascade->>Paged: Stage2: run prefix attention (if present)
    Paged->>Device: execute kernel (fully visible / paged)
    Device-->>Paged: O2, LSE2
    Paged-->>Cascade: return O2, LSE2

    Cascade->>Device: Stage3: merge O1 + O2 using LSEs
    Device-->>Cascade: merged output
    Cascade-->>User: final attention output
Loading
sequenceDiagram
    participant User as User Code
    participant Wrapper as BatchBlockExtendPagedOffsetWrapper
    participant Backend as Backend Selector
    participant Cache as AOT/JIT Module Cache
    participant Kernel as Compiled Kernel / CUDA

    User->>Wrapper: plan(indptrs..., backend="auto", mask_mode?)
    Wrapper->>Backend: select_best_backend(head_dim, dtype)
    Backend-->>Wrapper: backend (FA2/FA3)
    Wrapper->>Cache: get_or_build_module(backend, head_dim, dtype, mask_modes)
    Cache->>Cache: check AOT / generate JIT module
    Cache-->>Wrapper: compiled module
    User->>Wrapper: run(q, kv_cache, offsets, sm_scale)
    Wrapper->>Kernel: invoke compiled kernel with per‑batch offsets
    Kernel->>Kernel: execute on device
    Kernel-->>Wrapper: results (o, lse?)
    Wrapper-->>User: outputs
Loading

Estimated code review effort

🎯 5 (Critical) | ⏱️ ~120 minutes

Possibly related PRs

Suggested labels

run-ci

Suggested reviewers

  • yzh119
  • aleozlx
  • bkryu
  • cyx-6
  • jimmyzho
  • kahyunnam
  • nvmbreughe
  • jiahanc

Poem

🐇 I hopped through tiles and offset rows,
Masks unfurled where the block‑wind blows,
Kernels stitched currents and prefix streams,
Cascades merged softmax dreams,
A rabbit cheers: "Fast attention — let’s go!"

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 53.45% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title 'feat(dllm): add Block Extend Attention for Diffusion LLM' is concise, clearly describes the main feature addition, and directly aligns with the changeset which implements Block Extend Attention support across multiple kernel files and public APIs.
Description check ✅ Passed The PR description is comprehensive and well-structured, covering motivation (DLLM overview and mask rule), implementation details, performance benchmarks, and addressing reviewer concerns about binary bloat. It includes the required sections: motivation, description, related issues, and test results.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a specialized Block Extend Attention mechanism tailored for Diffusion LLMs, significantly improving attention computation efficiency through tile-level skip optimizations. It provides comprehensive support for both single-request and batched operations, accommodating ragged and paged KV cache layouts, and is designed for seamless integration with JIT and AOT compilation workflows. The changes enhance the system's ability to handle complex attention patterns with improved performance.

Highlights

  • New Attention Mechanism: Introduced Block Extend Attention specifically designed for Diffusion LLMs, enabling efficient processing of attention masks.
  • Performance Optimization: Implemented tile-level skip optimization within the Block Extend Attention kernel to enhance performance by avoiding unnecessary computations.
  • Flexible API Support: Provided both single-request (block_extend_attention_with_offset) and batch processing APIs (BatchBlockExtendRaggedOffsetWrapper, BatchBlockExtendPagedOffsetWrapper) for various KV cache layouts.
  • Compilation Support: Ensured compatibility with both JIT (Just-In-Time) and AOT (Ahead-Of-Time) compilation for deployment flexibility.
  • Cascade Attention Integration: Integrated Block Extend Attention into cascade attention flows, allowing for efficient handling of current chunks and prefixes.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • csrc/batch_prefill_customize_config.jinja
    • Added get_q_block_expanding_offset and get_kv_block_expanding_offset methods to RaggedParams and PagedParams structs.
  • csrc/batch_prefill_sm90_customize_config.jinja
    • Added get_qo_len, get_kv_len, get_q_block_expanding_offset, and get_kv_block_expanding_offset methods to RaggedParams and PagedParams structs for SM90 architecture.
  • csrc/single_prefill_customize_config.jinja
    • Added get_q_block_expanding_offset and get_kv_block_expanding_offset methods to the Params struct.
  • csrc/single_prefill_sm90_customize_config.jinja
    • Added get_qo_len, get_kv_len, get_q_block_expanding_offset, and get_kv_block_expanding_offset methods to the Params struct for SM90 architecture.
  • flashinfer/init.py
    • Imported the new dllm module to expose Diffusion LLM functionalities.
  • flashinfer/dllm/init.py
    • Added a new module to define the public API for DLLM attention, including single-request and batch block extend functions and variant declarations.
  • flashinfer/dllm/batch_block_extend.py
    • Added implementation for batch block extend attention wrappers (BatchBlockExtendPagedOffsetWrapper, BatchBlockExtendRaggedOffsetWrapper) and cascade functions.
    • Included logic for dynamic backend selection (FA2/FA3) and JIT/AOT compilation.
  • flashinfer/dllm/block_extend.py
    • Added implementation for single-request block extend attention with offset support (block_extend_attention_with_offset).
    • Included a cascade version and variant declarations for FA2 and FA3 backends.
  • flashinfer/jit/attention/modules.py
    • Modified gen_customize_single_prefill_module and gen_customize_batch_prefill_module to accept an optional mask_modes parameter.
    • Added has_q_block_expanding_offset to the rendering context for Jinja templates.
  • flashinfer/jit/utils.py
    • Added MaskMode::kBlockExpanding to the MASK_MODE_TO_STRING mapping.
  • flashinfer/prefill.py
    • Modified get_customize_batch_prefill_module to pass mask_modes to the JIT compilation process.
    • Updated BatchPrefillWithPagedKVCacheWrapper and BatchPrefillWithRaggedKVCacheWrapper to store and utilize an optional _mask_mode.
  • flashinfer/utils.py
    • Added BLOCK_EXPANDING to the MaskMode enum.
    • Defined new DEFINE_HAS_MEMBER macros for DLLM block expanding related type traits.
  • include/flashinfer/attention/block_expanding_prefill.cuh
    • Added a new header defining CUDA device functions and a BlockExpandingTileSkipController struct for block expanding mask logic and tile-level skip optimization.
  • include/flashinfer/attention/default_prefill_params.cuh
    • Added dllm_block_size, q_block_expanding_offset, and kv_block_expanding_offset fields to SinglePrefillParams, BatchPrefillRaggedParams, and BatchPrefillPagedParams structs.
  • include/flashinfer/attention/hopper/mainloop.cuh
    • Modified CollectiveMainloop to include a BLOCK_EXPANDING template parameter.
    • Updated get_num_kv_tiles to incorporate block expanding logic with q_offset and kv_offset.
  • include/flashinfer/attention/hopper/mainloop_mma.cuh
    • Modified mma_f16 to include a BLOCK_EXPANDING template parameter.
    • Added block expanding mask helper functions and updated masking logic for FA3.
  • include/flashinfer/attention/hopper/prefill_sm90.cuh
    • Updated kernel dispatch functions to include the BLOCK_EXPANDING template parameter.
    • Adjusted getCTATileSize to consider CAUSAL_OR_BLOCK_EXPANDING for tile size selection.
  • include/flashinfer/attention/hopper/sparse_mainloop.cuh
    • Modified SparseCollectiveMainloop to include a BLOCK_EXPANDING template parameter.
    • Updated get_num_kv_tiles with block expanding logic.
  • include/flashinfer/attention/mask.cuh
    • Added kBlockExpanding to the MaskMode enum.
  • include/flashinfer/attention/prefill.cuh
    • Included block_expanding_prefill.cuh for new mask logic.
    • Updated logits_mask function to handle MaskMode::kBlockExpanding with q_offset and kv_offset.
    • Updated SinglePrefillWithKVCacheDevice, BatchPrefillWithRaggedKVCacheDevice, and BatchPrefillWithPagedKVCacheDevice to incorporate block expanding logic for num_iterations and mask_iteration.
  • include/flashinfer/utils.cuh
    • Added DEFINE_HAS_MEMBER macros for DLLM block expanding related type traits.
    • Extended FLASHINFER_DISPATCH_MASK_MODE macro to include kBlockExpanding.
  • tests/attention/test_dllm_vs_flex_attention.py
    • Added a new test file to benchmark FlashInfer's Block Extend Attention against PyTorch's Flex Attention.
    • Included correctness validation, performance benchmarks (with and without CUDA graphs), memory usage comparisons, and sweeps across sequence lengths and block sizes.
Activity
  • The author has performed correctness tests, comparing Block Extend against custom_mask reference and validating cascade vs blockwise extend.
  • Performance benchmarks were conducted, comparing FlashInfer Block Extend against PyTorch Flex Attention across various context lengths (1K-32K) and analyzing memory consumption.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces Block Extend Attention for Diffusion LLMs, a significant new feature. The changes are comprehensive, adding support for single-request and batch processing with both ragged and paged KV caches, and including JIT/AOT compilation capabilities. The implementation is well-structured, with clear separation of concerns between Python wrappers, JIT logic, and CUDA kernels for different architectures. The inclusion of a thorough benchmark against PyTorch's flex_attention is also a great addition. My review focuses on a potential correctness issue in a helper function and a small opportunity to improve code clarity in one of the CUDA kernels.

ragged_wrapper.plan(
qo_indptr=qo_indptr, kv_indptr=kv_curr_indptr,
num_qo_heads=num_qo_heads, num_kv_heads=num_kv_heads,
head_dim_qk=head_dim, head_dim_vo=head_dim, q_data_type=q.dtype, causal=False,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There is a discrepancy between the comment on line 638, which states causal=True, and the code here, which passes causal=False. For a cascade attention pattern, the attention for the current chunk is typically causal. If the intention is to have causal attention for this stage, this should be corrected. If non-causal attention is intended, the comment on line 638 should be updated to avoid confusion.

Suggested change
head_dim_qk=head_dim, head_dim_vo=head_dim, q_data_type=q.dtype, causal=False,
head_dim_qk=head_dim, head_dim_vo=head_dim, q_data_type=q.dtype, causal=True,

Comment on lines +1542 to +1544
bool needs_mask = (MASK_MODE == MaskMode::kCustom) ||
(MASK_MODE == MaskMode::kBlockExpanding && iter >= mask_iteration) ||
(iter >= mask_iteration || iter < window_iteration);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The condition for needs_mask is slightly redundant. The term (MASK_MODE == MaskMode::kBlockExpanding && iter >= mask_iteration) is already covered by the subsequent (iter >= mask_iteration) in the OR chain. Simplifying this expression will improve code readability without changing the logic.

      bool needs_mask = (MASK_MODE == MaskMode::kCustom) || (iter >= mask_iteration || iter < window_iteration);

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 15

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
include/flashinfer/attention/hopper/sparse_mainloop.cuh (1)

159-208: ⚠️ Potential issue | 🔴 Critical

Handle num_kv_tiles == 0 before the sparse prefetch path runs.

The new block-extend bound can shrink num_kv_tiles to 0. Here that becomes kv_tile_idx = -1, and the very first prefetch_kv_offset(kv_tile_idx, true) / load_kv_with_gather(..., kv_tile_idx, ...) sequence will read the page table with a negative tile index.

Also applies to: 242-243, 376-382

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@include/flashinfer/attention/hopper/sparse_mainloop.cuh` around lines 159 -
208, get_num_kv_tiles can return 0 under BLOCK_EXPANDING, which leads to
kv_tile_idx == -1 and then a negative index passed into prefetch_kv_offset and
load_kv_with_gather; update the caller(s) that compute kv_tile_idx from
get_num_kv_tiles (and any code paths around lines referenced) to check for
num_kv_tiles == 0 and early-skip the sparse prefetch/load sequence, or
clamp/skip any prefetch_kv_offset(kv_tile_idx, ...) and load_kv_with_gather(...,
kv_tile_idx, ...) calls when kv_tile_idx < 0; locate calls by the symbols
get_num_kv_tiles, kv_tile_idx, prefetch_kv_offset, and load_kv_with_gather and
add a guard that prevents negative tile indices from reaching the page-table
access.
🧹 Nitpick comments (1)
flashinfer/__init__.py (1)

23-23: Re-export the new DLLM entry points at the package root.

Importing the submodule here makes flashinfer.dllm.* available, but the new public wrappers/functions still are not top-level flashinfer.* exports like the rest of the package surface. Please add explicit imports for the public DLLM APIs in this file as well. As per coding guidelines "flashinfer/__init__.py: Export all public operations in flashinfer/__init__.py after implementing."

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/__init__.py` at line 23, Replace the current submodule-only import
with explicit re-exports of the DLLM public API: import the public symbols from
.dllm (e.g., the public classes/functions in that module such as DLLMClient,
run_inference, load_model — or the actual names defined in dllm) via "from .dllm
import <PublicName1>, <PublicName2>, ..." and add those names to the package
__all__ list so they become top-level flashinfer.* exports; keep the
module-level alias (dllm) if you still need it, but ensure all public DLLM
symbols are explicitly imported and included in __all__ in
flashinfer/__init__.py.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/dllm/batch_block_extend.py`:
- Around line 263-264: The current assertion allows dllm_block_size==0 because
(0 & -1)==0; update the check to explicitly reject zero and require a positive
power of two: replace the existing assert that tests (dllm_block_size &
(dllm_block_size - 1)) == 0 with a compound check that dllm_block_size > 0 and
(dllm_block_size & (dllm_block_size - 1)) == 0, and raise an informative
AssertionError/ValueError if it fails; apply the same change to the other
occurrence in this file (the second assertion at the later block).
- Around line 611-613: The helper is constructing stage-1 ragged attention with
causal=False and drops the logits_soft_cap parameter; update the ragged-wrapper
construction(s) so stage 1 uses causal=True to enforce current-chunk (causal +
merge) attention, and thread the logits_soft_cap argument through to the helper
calls instead of discarding it; specifically, in batch_block_extend.py adjust
the stage-1 ragged/planning call(s) that currently pass causal=False to
causal=True and ensure the helper invocations (the ones near the signature
containing logits_soft_cap, return_lse, backend and the similar block at lines
~638-659) forward logits_soft_cap into the downstream helper/function that
applies the soft cap.
- Around line 179-181: The URI generation in _get_batch_be_module_uri only
encodes head_dim and a coarse dtype string, causing different kernel ABIs (e.g.,
idtype variants like int64 or fp8) to collide; update _get_batch_be_module_uri
to include both the element dtype and the index/id dtype (idtype) in the
returned string (e.g., use explicit mappings for
torch.float16/torch.bfloat16/torch.int64/FP8 aliases or use dtype.name and
idtype.name), and ensure any other URI/cache identity builders and the wrapper
recreation check also incorporate idtype (or explicitly reject unsupported
dtypes up front) so each specialization yields a unique flashinfer::{uri}_* name
and the recreation logic compares both dtype and idtype.
- Around line 560-563: In batch_block_extend_cascade(), when q_offsets or
kv_offsets are None the code currently defaults both to zero which is incorrect
when has_prefix is true; change this to derive per-request global offsets from
the paged-prefix metadata (the per-request prefix length stored in the paged
prefix structure used by the function) instead of using torch.zeros, so that
q_offsets and kv_offsets reflect each request's prefix length (and block-aligned
adjustments) before the two-stage extension; alternatively validate and require
the caller to supply q_offsets/kv_offsets and raise a clear error if they are
omitted when has_prefix is true. Ensure you update the logic referencing
q_offsets, kv_offsets, has_prefix, and the paged prefix metadata in
batch_block_extend_cascade to use the computed per-request offsets.

In `@flashinfer/dllm/block_extend.py`:
- Around line 221-245: The FA3 capability check in
get_block_extend_module_with_offset uses a hardcoded torch.device("cuda") which
breaks mixed-arch multi-GPU setups; update get_block_extend_module_with_offset
to accept a device (torch.device or device-like) parameter and use that device
when calling is_sm90a_supported, and update any callers (notably
block_extend_attention_with_offset) to pass the q/kv tensor's device through
into get_block_extend_module_with_offset; also apply the same
device-threading/fix at the other occurrence referenced (around the second call
at line ~360) so all FA3 checks use the intended device rather than the default
CUDA device.
- Around line 131-136: The function _get_dtype_str currently maps unknown dtypes
to "fp16", causing wrong module URIs; update _get_dtype_str to explicitly map
all supported dtypes (e.g., torch.float16 -> "fp16", torch.bfloat16 -> "bf16",
torch.float32 -> "fp32", and any FP8 dtype used in this project -> the correct
string such as "fp8") and make the default case raise a ValueError (or return a
distinct sentinel like "unknown") instead of aliasing to "fp16"; apply the same
explicit mapping/failure behavior to the other similar helper usages referenced
in the diff (the other _get_dtype_str-like mappings at the other locations) so
that module names are unique per actual dtype and FP8 does not resolve to the
FP16 specialization.

In `@flashinfer/prefill.py`:
- Line 1651: The code accepts mask_mode in plan/run but doesn't enforce backend
support: detect and reject unsupported MaskMode values early (in plan and run)
or filter backends that cannot handle them; specifically, when mask_mode ==
MaskMode.BLOCK_EXPANDING.value (or any non-default), ensure you do not route to
the cudnn path that uses self._causal nor to the trtllm-gen paged path that
ignores mask_mode—validate mask_mode against the selected backend and either
remove unsupported backends from backend selection or raise an error before
continuing (update the functions that accept the mask_mode parameter and the
backend selection logic to perform this check).

In `@include/flashinfer/attention/hopper/mainloop.cuh`:
- Around line 141-188: get_num_kv_tiles can return 0 when
kv_block_expanding_offset pushes the valid KV range past the query block,
causing later code in load() to compute kv_tile_idx = -1 and perform invalid
tile reads via tKgK/tVgV; fix by adding a guard after computing num_kv_tiles (or
before the first copy() in load()) that checks for zero and early-returns or
skips scheduling that tile so kv_tile_idx is never -1. Specifically, in the
caller load() (or right after calling get_num_kv_tiles) detect num_kv_tiles == 0
and do an early return/no-op for that q_tile_idx, or in the scheduling loop
ensure you don't iterate when get_num_kv_tiles(...) == 0; this prevents tKgK(_,
kv_tile_idx) and tVgV(_, kv_tile_idx) from being invoked with an invalid index
and avoids the invalid tile access.

In `@include/flashinfer/attention/prefill.cuh`:
- Around line 1451-1457: The block-expanding iteration bound computed in the
MASK_MODE == MaskMode::kBlockExpanding branches only passes q_offset into
block_expanding_num_iterations but the later legality checks use both q_offset
and kv_offset; update the call to compute kv_offset (e.g., call
params.get_kv_block_expanding_offset(batch_idx) or the appropriate getter) and
pass that kv_offset into block_expanding_num_iterations so the loop bounds
exclude fully-masked KV tiles and avoid poisoning update_mdo_states(); make the
same change for the other block-expanding call sites that mirror this logic (the
other MaskMode::kBlockExpanding branches).

In `@tests/attention/test_dllm_vs_flex_attention.py`:
- Around line 61-104: compute_block_extend_reference and
make_block_extend_mask_mod currently never exercise per-request offsets or
kv_block_expanding_offset because compute_block_extend_reference hardcodes
q_offset (kv_offset) to zero, make_block_extend_mask_mod ignores the batch index
b, and tests use torch.full for q_offsets; update the test helpers so
compute_block_extend_reference accepts and uses per-request q_offset values (and
propagate kv_offset if applicable) and make_block_extend_mask_mod's inner
block_extend_mask uses the batch index b to look up per-sample offsets, then
modify the batch construction in tests to pass heterogeneous q_offsets (not
torch.full) and add cases exercising nonzero kv_block_expanding_offset and
cascade/current-chunk paths so the new plumbing is validated against
single_prefill_with_kv_cache and block_extend_mask behavior.
- Around line 166-210: The benchmarks (functions benchmark_fn and
benchmark_with_cuda_graph) and the direct perf_counter() timing in
test_total_memory_comparison should be replaced to use the repo timing harness
flashinfer.testing.bench_gpu_time() so results use CUPTI with CUDA-event
fallback and remain comparable across the suite; locate usages of benchmark_fn,
benchmark_with_cuda_graph, and the perf_counter() blocks in
test_total_memory_comparison and call bench_gpu_time() (passing the callable and
warmup/bench iteration params) instead of manual perf_counter/CUDAGraph timing,
ensuring any CUDA Graph replay loops are wrapped or adapted to the
bench_gpu_time() callable interface.
- Around line 20-56: The module currently performs CUDA/Hopper-specific work at
import time and must early-skip unsupported GPUs: query
flashinfer.utils.get_compute_capability(), is_sm90a_supported(), and
is_sm100a_supported() at module scope (and check
torch.cuda.is_available()/device count) and if the current GPU is unsupported,
set HAS_FLASHINFER = HAS_FLEX_ATTENTION = False and print a skip message before
attempting any flashinfer or flex_attention imports or CUDA allocations; wrap
the existing flashinfer imports and the flex_attention import logic behind this
guard so functions like single_prefill_with_kv_cache,
BatchBlockExtendPagedOffsetWrapper, BatchBlockExtendRaggedOffsetWrapper,
flex_attention and create_block_mask are only imported when the architecture
checks pass.
- Around line 355-399: The test currently only prints PASS/FAIL for
ragged/paged/flex comparisons (using ragged_pass, paged_pass, flex_pass) which
means failures don't fail the pytest; change each check to assert the diff is
below tol (e.g., assert ragged_diff < tol, assert paged_diff < tol, and assert
flex_diff < tol) or call pytest.fail with a clear message including the diff
when the condition is false so CI fails on regressions; keep the existing diff
variables (ragged_diff, paged_diff, flex_diff) and messages but replace the
print-only behavior with assertions/pytest.fail in the test function.
- Around line 216-229: The benchmark driver functions named with the
pytest-discovered prefix (e.g., test_flashinfer_vs_flex_attention and the other
top-level functions in this file referenced at lines 216–229, 664–672, 752–759,
832–841, 959–968) must be renamed so they do not start with "test_" (or moved
into a non-test harness/module); change their names (for example to
flashinfer_vs_flex_attention_bench or
run_flashinfer_vs_flex_attention_benchmark) or relocate them into a dedicated
benchmarks file to prevent pytest from collecting and executing the heavy
benchmark sweeps during CI, and update any callers/imports accordingly.

---

Outside diff comments:
In `@include/flashinfer/attention/hopper/sparse_mainloop.cuh`:
- Around line 159-208: get_num_kv_tiles can return 0 under BLOCK_EXPANDING,
which leads to kv_tile_idx == -1 and then a negative index passed into
prefetch_kv_offset and load_kv_with_gather; update the caller(s) that compute
kv_tile_idx from get_num_kv_tiles (and any code paths around lines referenced)
to check for num_kv_tiles == 0 and early-skip the sparse prefetch/load sequence,
or clamp/skip any prefetch_kv_offset(kv_tile_idx, ...) and
load_kv_with_gather(..., kv_tile_idx, ...) calls when kv_tile_idx < 0; locate
calls by the symbols get_num_kv_tiles, kv_tile_idx, prefetch_kv_offset, and
load_kv_with_gather and add a guard that prevents negative tile indices from
reaching the page-table access.

---

Nitpick comments:
In `@flashinfer/__init__.py`:
- Line 23: Replace the current submodule-only import with explicit re-exports of
the DLLM public API: import the public symbols from .dllm (e.g., the public
classes/functions in that module such as DLLMClient, run_inference, load_model —
or the actual names defined in dllm) via "from .dllm import <PublicName1>,
<PublicName2>, ..." and add those names to the package __all__ list so they
become top-level flashinfer.* exports; keep the module-level alias (dllm) if you
still need it, but ensure all public DLLM symbols are explicitly imported and
included in __all__ in flashinfer/__init__.py.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: b97072d3-ff07-4a26-b15a-3f3936d9e25f

📥 Commits

Reviewing files that changed from the base of the PR and between 65d6e4a and 33b195e.

📒 Files selected for processing (23)
  • csrc/batch_prefill_customize_config.jinja
  • csrc/batch_prefill_sm90_customize_config.jinja
  • csrc/single_prefill_customize_config.jinja
  • csrc/single_prefill_sm90_customize_config.jinja
  • flashinfer/__init__.py
  • flashinfer/dllm/__init__.py
  • flashinfer/dllm/batch_block_extend.py
  • flashinfer/dllm/block_extend.py
  • flashinfer/jit/attention/modules.py
  • flashinfer/jit/utils.py
  • flashinfer/prefill.py
  • flashinfer/utils.py
  • include/flashinfer/attention/block_expanding_prefill.cuh
  • include/flashinfer/attention/default_prefill_params.cuh
  • include/flashinfer/attention/hopper/mainloop.cuh
  • include/flashinfer/attention/hopper/mainloop_mma.cuh
  • include/flashinfer/attention/hopper/prefill_sm90.cuh
  • include/flashinfer/attention/hopper/sparse_mainloop.cuh
  • include/flashinfer/attention/mask.cuh
  • include/flashinfer/attention/prefill.cuh
  • include/flashinfer/utils.cuh
  • tests/attention/test_dllm_cascade_vs_blockwise_extend_attention.py
  • tests/attention/test_dllm_vs_flex_attention.py

Comment thread flashinfer/dllm/batch_block_extend.py Outdated
Comment on lines +179 to +181
def _get_batch_be_module_uri(head_dim: int, dtype: torch.dtype) -> str:
dtype_str = {torch.float16: "fp16", torch.bfloat16: "bf16"}.get(dtype, "fp16")
return f"batch_prefill_block_expanding_hd{head_dim}_{dtype_str}"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Don't let distinct kernel ABIs share the same URI.

_get_batch_be_module_uri() only keys on head_dim plus a coarse dtype string, but the generated module also depends on idtype, and every other dtype currently aliases to fp16. An int64 indptr variant or an FP8 variant can therefore load/register the same flashinfer::{uri}_* name as a different specialization, and the wrapper recreation check also ignores idtype. Please bake both dtype and idtype into the URI/cache identity, or reject unsupported dtypes up front.

Also applies to: 300-305, 331-332, 435-440, 463-464

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/dllm/batch_block_extend.py` around lines 179 - 181, The URI
generation in _get_batch_be_module_uri only encodes head_dim and a coarse dtype
string, causing different kernel ABIs (e.g., idtype variants like int64 or fp8)
to collide; update _get_batch_be_module_uri to include both the element dtype
and the index/id dtype (idtype) in the returned string (e.g., use explicit
mappings for torch.float16/torch.bfloat16/torch.int64/FP8 aliases or use
dtype.name and idtype.name), and ensure any other URI/cache identity builders
and the wrapper recreation check also incorporate idtype (or explicitly reject
unsupported dtypes up front) so each specialization yields a unique
flashinfer::{uri}_* name and the recreation logic compares both dtype and
idtype.

Comment thread flashinfer/dllm/batch_block_extend.py Outdated
Comment thread flashinfer/dllm/batch_block_extend.py Outdated
Comment on lines +560 to +563
if q_offsets is None:
q_offsets = torch.zeros(batch_size, dtype=torch.int32, device=device)
if kv_offsets is None:
kv_offsets = q_offsets
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

batch_block_extend_cascade() defaults to the wrong global offsets when a prefix exists.

When has_prefix is true and the caller omits q_offsets/kv_offsets, both stages run as if the current chunk starts at position 0. That changes the block mask whenever the prefix length is nonzero, especially when it is not block-aligned. Derive the per-request prefix lengths from the paged prefix metadata here, or require the caller to pass them explicitly.

Suggested fix
-    if q_offsets is None:
-        q_offsets = torch.zeros(batch_size, dtype=torch.int32, device=device)
-    if kv_offsets is None:
-        kv_offsets = q_offsets
+    if q_offsets is None:
+        if has_prefix:
+            q_offsets = (
+                page_size * (paged_kv_indptr[1:] - paged_kv_indptr[:-1] - 1)
+                + paged_kv_last_page_len
+            ).to(device=device, dtype=qo_indptr.dtype)
+        else:
+            q_offsets = torch.zeros(batch_size, dtype=qo_indptr.dtype, device=device)
+    if kv_offsets is None:
+        kv_offsets = q_offsets
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if q_offsets is None:
q_offsets = torch.zeros(batch_size, dtype=torch.int32, device=device)
if kv_offsets is None:
kv_offsets = q_offsets
if q_offsets is None:
if has_prefix:
q_offsets = (
page_size * (paged_kv_indptr[1:] - paged_kv_indptr[:-1] - 1)
paged_kv_last_page_len
).to(device=device, dtype=qo_indptr.dtype)
else:
q_offsets = torch.zeros(batch_size, dtype=qo_indptr.dtype, device=device)
if kv_offsets is None:
kv_offsets = q_offsets
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/dllm/batch_block_extend.py` around lines 560 - 563, In
batch_block_extend_cascade(), when q_offsets or kv_offsets are None the code
currently defaults both to zero which is incorrect when has_prefix is true;
change this to derive per-request global offsets from the paged-prefix metadata
(the per-request prefix length stored in the paged prefix structure used by the
function) instead of using torch.zeros, so that q_offsets and kv_offsets reflect
each request's prefix length (and block-aligned adjustments) before the
two-stage extension; alternatively validate and require the caller to supply
q_offsets/kv_offsets and raise a clear error if they are omitted when has_prefix
is true. Ensure you update the logic referencing q_offsets, kv_offsets,
has_prefix, and the paged prefix metadata in batch_block_extend_cascade to use
the computed per-request offsets.

Comment thread flashinfer/dllm/batch_block_extend.py Outdated
Comment on lines +611 to +613
logits_soft_cap: float = 0.0,
return_lse: bool = False,
backend: str = "fa2",
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

The SGLang helper is not actually running the advertised current-chunk attention.

The docstring says "causal + merge", but stage 1 plans the ragged wrapper with causal=False, so current-chunk tokens can see future chunk tokens before the merge. The helper also drops the caller's logits_soft_cap entirely.

Suggested fix
     ragged_wrapper.plan(
         qo_indptr=qo_indptr, kv_indptr=kv_curr_indptr,
         num_qo_heads=num_qo_heads, num_kv_heads=num_kv_heads,
-        head_dim_qk=head_dim, head_dim_vo=head_dim, q_data_type=q.dtype, causal=False,
+        head_dim_qk=head_dim,
+        head_dim_vo=head_dim,
+        q_data_type=q.dtype,
+        causal=True,
+        logits_soft_cap=logits_soft_cap,
     )
@@
     paged_wrapper.plan(
         qo_indptr=qo_indptr, paged_kv_indptr=paged_kv_indptr,
         paged_kv_indices=paged_kv_indices, paged_kv_last_page_len=paged_kv_last_page_len,
         num_qo_heads=num_qo_heads, num_kv_heads=num_kv_heads,
         head_dim_qk=head_dim, head_dim_vo=head_dim, page_size=page_size,
-        q_data_type=q.dtype, causal=False,
+        q_data_type=q.dtype, causal=False, logits_soft_cap=logits_soft_cap,
     )

Also applies to: 638-659

🧰 Tools
🪛 Ruff (0.15.4)

[warning] 611-611: Unused function argument: logits_soft_cap

(ARG001)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/dllm/batch_block_extend.py` around lines 611 - 613, The helper is
constructing stage-1 ragged attention with causal=False and drops the
logits_soft_cap parameter; update the ragged-wrapper construction(s) so stage 1
uses causal=True to enforce current-chunk (causal + merge) attention, and thread
the logits_soft_cap argument through to the helper calls instead of discarding
it; specifically, in batch_block_extend.py adjust the stage-1 ragged/planning
call(s) that currently pass causal=False to causal=True and ensure the helper
invocations (the ones near the signature containing logits_soft_cap, return_lse,
backend and the similar block at lines ~638-659) forward logits_soft_cap into
the downstream helper/function that applies the soft cap.

Comment thread flashinfer/dllm/block_extend.py Outdated
Comment on lines +20 to +56
import torch
import time
import math
import sys

# ============================================================
# FlashInfer imports
# ============================================================
try:
from flashinfer import single_prefill_with_kv_cache
from flashinfer.dllm import (
BatchBlockExtendPagedOffsetWrapper,
BatchBlockExtendRaggedOffsetWrapper,
)
HAS_FLASHINFER = True
except ImportError as e:
HAS_FLASHINFER = False
print(f"[WARN] flashinfer not available: {e}")
print(" Will skip FlashInfer benchmarks")
except Exception as e:
HAS_FLASHINFER = False
print(f"[ERROR] flashinfer import failed with unexpected error: {e}")
print(" Will skip FlashInfer benchmarks")

# ============================================================
# Flex Attention imports (requires PyTorch >= 2.5)
# ============================================================
try:
from torch.nn.attention.flex_attention import (
flex_attention,
create_block_mask,
)
HAS_FLEX_ATTENTION = True
except ImportError:
HAS_FLEX_ATTENTION = False
print("[WARN] flex_attention not available (requires PyTorch >= 2.5)")

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Skip unsupported GPU architectures at module scope.

This module allocates on cuda:0 and exercises Hopper-specific paths, but it never gates execution with the repo's architecture helpers. Unsupported runners will fail instead of skipping cleanly.

As per coding guidelines, tests/**/*.py: Use flashinfer.utils functions (get_compute_capability(), is_sm90a_supported(), is_sm100a_supported()) to skip tests on unsupported GPU architectures.

🧰 Tools
🪛 Ruff (0.15.4)

[warning] 39-39: Do not catch blind exception: Exception

(BLE001)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/attention/test_dllm_vs_flex_attention.py` around lines 20 - 56, The
module currently performs CUDA/Hopper-specific work at import time and must
early-skip unsupported GPUs: query flashinfer.utils.get_compute_capability(),
is_sm90a_supported(), and is_sm100a_supported() at module scope (and check
torch.cuda.is_available()/device count) and if the current GPU is unsupported,
set HAS_FLASHINFER = HAS_FLEX_ATTENTION = False and print a skip message before
attempting any flashinfer or flex_attention imports or CUDA allocations; wrap
the existing flashinfer imports and the flex_attention import logic behind this
guard so functions like single_prefill_with_kv_cache,
BatchBlockExtendPagedOffsetWrapper, BatchBlockExtendRaggedOffsetWrapper,
flex_attention and create_block_mask are only imported when the architecture
checks pass.

Comment on lines +61 to +104
def compute_block_extend_reference(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
dllm_block_size: int,
q_offset: int = 0,
sm_scale: float = None,
) -> torch.Tensor:
"""Reference: single_prefill_with_kv_cache + custom_mask"""
qo_len = q.shape[0]
kv_len = k.shape[0]
head_dim = q.shape[-1]
device = q.device

if sm_scale is None:
sm_scale = 1.0 / math.sqrt(head_dim)

q_pos = torch.arange(qo_len, device=device) + q_offset
k_pos = torch.arange(kv_len, device=device)
q_block = q_pos.unsqueeze(1) // dllm_block_size
k_block = k_pos.unsqueeze(0) // dllm_block_size
mask_2d = (q_block >= k_block).to(torch.uint8)

return single_prefill_with_kv_cache(
q, k, v, custom_mask=mask_2d, sm_scale=sm_scale,
)


# ============================================================
# Flex Attention helper: build block_extend mask_mod
# ============================================================
def make_block_extend_mask_mod(dllm_block_size: int, q_offset: int = 0):
"""
返回 flex_attention 使用的 mask_mod 函数

mask_mod(b, h, q_idx, kv_idx) -> bool
True = 允许 attend, False = 屏蔽
"""
def block_extend_mask(b, h, q_idx, kv_idx):
q_global = q_idx + q_offset
q_blk = q_global // dllm_block_size
kv_blk = kv_idx // dllm_block_size
return q_blk >= kv_blk
return block_extend_mask
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

The new per-request and kv_block_expanding_offset paths still are not validated.

compute_block_extend_reference() hardcodes kv_offset=0, make_block_extend_mask_mod() ignores b, and the batch setup uses torch.full(...) for q_offsets. That means the batch-indexed offset plumbing added in this PR never sees a heterogeneous batch, and the cascade/current-chunk kv_block_expanding_offset behavior is still untested.

Also applies to: 321-321, 381-384, 516-520

🧰 Tools
🪛 Ruff (0.15.4)

[warning] 67-67: PEP 484 prohibits implicit Optional

Convert to T | None

(RUF013)


[warning] 99-99: Unused function argument: b

(ARG001)


[warning] 99-99: Unused function argument: h

(ARG001)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/attention/test_dllm_vs_flex_attention.py` around lines 61 - 104,
compute_block_extend_reference and make_block_extend_mask_mod currently never
exercise per-request offsets or kv_block_expanding_offset because
compute_block_extend_reference hardcodes q_offset (kv_offset) to zero,
make_block_extend_mask_mod ignores the batch index b, and tests use torch.full
for q_offsets; update the test helpers so compute_block_extend_reference accepts
and uses per-request q_offset values (and propagate kv_offset if applicable) and
make_block_extend_mask_mod's inner block_extend_mask uses the batch index b to
look up per-sample offsets, then modify the batch construction in tests to pass
heterogeneous q_offsets (not torch.full) and add cases exercising nonzero
kv_block_expanding_offset and cascade/current-chunk paths so the new plumbing is
validated against single_prefill_with_kv_cache and block_extend_mask behavior.

Comment on lines +166 to +210
def benchmark_fn(fn, warmup_iters=20, bench_iters=100, label=""):
"""Benchmark a callable, return average time in ms."""
for _ in range(warmup_iters):
fn()
torch.cuda.synchronize()

start = time.perf_counter()
for _ in range(bench_iters):
fn()
torch.cuda.synchronize()
elapsed_ms = (time.perf_counter() - start) / bench_iters * 1000
return elapsed_ms


def benchmark_with_cuda_graph(fn, warmup_iters=20, bench_iters=100, label=""):
"""Benchmark with CUDA Graph capture, return average time in ms."""
# warmup
for _ in range(warmup_iters):
fn()
torch.cuda.synchronize()

# capture
stream = torch.cuda.Stream()
with torch.cuda.stream(stream):
fn()
stream.synchronize()

graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, stream=stream):
fn()

# warmup cuda_graph
for _ in range(warmup_iters):
graph.replay()
torch.cuda.synchronize()

# bench
start = time.perf_counter()
for _ in range(bench_iters):
graph.replay()
torch.cuda.synchronize()
elapsed_ms = (time.perf_counter() - start) / bench_iters * 1000

del graph
return elapsed_ms
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Route these timings through bench_gpu_time().

The benchmark helpers and the direct perf_counter() paths in test_total_memory_comparison() bypass the repo timing harness, so these numbers will not use the standard CUPTI/CUDA-event fallback and will not be comparable with the rest of the benchmark suite.

As per coding guidelines, tests/**/*.py: Use flashinfer.testing.bench_gpu_time() for benchmarking kernels, preferring CUPTI timing with auto-fallback to CUDA events.

Also applies to: 1078-1083, 1119-1123

🧰 Tools
🪛 Ruff (0.15.4)

[warning] 166-166: Unused function argument: label

(ARG001)


[warning] 180-180: Unused function argument: label

(ARG001)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/attention/test_dllm_vs_flex_attention.py` around lines 166 - 210, The
benchmarks (functions benchmark_fn and benchmark_with_cuda_graph) and the direct
perf_counter() timing in test_total_memory_comparison should be replaced to use
the repo timing harness flashinfer.testing.bench_gpu_time() so results use CUPTI
with CUDA-event fallback and remain comparable across the suite; locate usages
of benchmark_fn, benchmark_with_cuda_graph, and the perf_counter() blocks in
test_total_memory_comparison and call bench_gpu_time() (passing the callable and
warmup/bench iteration params) instead of manual perf_counter/CUDAGraph timing,
ensuring any CUDA Graph replay loops are wrapped or adapted to the
bench_gpu_time() callable interface.

Comment thread tests/attention/test_dllm_vs_flex_attention.py Outdated
Comment thread tests/attention/test_dllm_vs_flex_attention.py
Add DLLM Block Extend Attention feature with tile-level skip
optimization
using native MaskMode::kBlockExpanding.

Core API:
- Single-request: block_extend_attention_with_offset() with q/kv offset
  support
- Batch: BatchBlockExtendRaggedOffsetWrapper,
  BatchBlockExtendPagedOffsetWrapper
- Cascade: 3-stage attention (current chunk + prefix + merge state)
- Support both JIT and AOT compilation

Tests:
- Precision: block extend vs custom_mask reference, cascade vs blockwise
  correctness
- Performance: FlashInfer Block Extend vs PyTorch Flex Attention
  benchmark
  - Context length sweep (1K-32K), block size alignment analysis
  - Significant speedup over Flex Attention with lower memory usage
@fdz-1999 fdz-1999 force-pushed the feature/block-extend branch from 33b195e to 5e07b58 Compare March 9, 2026 10:06
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

♻️ Duplicate comments (10)
include/flashinfer/attention/hopper/mainloop.cuh (1)

230-243: ⚠️ Potential issue | 🔴 Critical

Handle the zero-visible-KV case before the first TMA load.

get_num_kv_tiles() can now return 0 when block expansion plus kv_offset removes the whole visible range. In that case kv_tile_idx becomes -1, and Line 243 immediately indexes tKgK(_, kv_tile_idx), which is an invalid tile read. Please skip scheduling those work tiles or add a no-op path before any K/V load when num_kv_tiles == 0.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@include/flashinfer/attention/hopper/mainloop.cuh` around lines 230 - 243, The
code assumes get_num_kv_tiles(...) > 0 but it can return 0, which makes
kv_tile_idx = num_kv_tiles - 1 negative and leads to an invalid read from
tKgK(_, kv_tile_idx); add a guard for num_kv_tiles == 0 before any K/V TMA
scheduling: check num_kv_tiles and if it is 0 skip the
pipeline_k.producer_acquire(...) and the subsequent copy(...) that references
tKgK and smem_pipe_write_k (i.e., short-circuit the path that schedules the
first K load using mainloop_params.tma_load_K and tKgK), or route to a no-op
branch so no tile is indexed when num_kv_tiles == 0.
flashinfer/dllm/block_extend.py (2)

146-170: ⚠️ Potential issue | 🟠 Major

Use the target tensor device for the FA3 capability gate.

block_extend_attention_with_offset() selects fa3 from q.device, but get_block_extend_module_with_offset() re-checks FA3 against torch.device("cuda"). On mixed-architecture multi-GPU systems, a valid call on one device can fail just because the default CUDA device is older.

Proposed fix
 def get_block_extend_module_with_offset(
     head_dim: int = 128,
     dtype: torch.dtype = torch.float16,
     backend: str = "fa2",
+    device: Optional[torch.device] = None,
 ):
@@
-    if backend == "fa3" and not is_sm90a_supported(torch.device("cuda")):
+    device = device or torch.device("cuda")
+    if backend == "fa3" and not is_sm90a_supported(device):
         raise RuntimeError(
             "FA3 backend requires SM90 (Hopper) architecture. "
             "Use backend='fa2' for older architectures."
         )
@@
-    module = get_block_extend_module_with_offset(head_dim=head_dim, dtype=dtype, backend=backend)
+    module = get_block_extend_module_with_offset(
+        head_dim=head_dim,
+        dtype=dtype,
+        backend=backend,
+        device=q.device,
+    )

Also applies to: 281-285

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/dllm/block_extend.py` around lines 146 - 170, The capability check
for FA3 uses torch.device("cuda") which can be wrong on mixed-GPU systems;
update get_block_extend_module_with_offset to accept or derive the target device
(use the device of the input/target tensor or the same device used by
block_extend_attention_with_offset) and pass that device into is_sm90a_supported
instead of torch.device("cuda"); also apply the same change to the other FA3
check near the block_extend_attention_with_offset-related code (the second
occurrence) so the SM90 gate queries the actual tensor/device being compiled
for.

90-95: ⚠️ Potential issue | 🟠 Major

Don't alias unsupported dtypes to the FP16 URI.

_get_dtype_str() currently collapses every non-fp16/bf16 dtype to "fp16". That makes FP8/FP32 call paths reuse the FP16 module URI and cache entry, which can load the wrong specialization.

Proposed fix
 def _get_dtype_str(dtype: torch.dtype) -> str:
     """Get dtype string representation (unified interface)"""
-    return {
-        torch.float16: "fp16",
-        torch.bfloat16: "bf16",
-    }.get(dtype, "fp16")
+    mapping = {
+        torch.float16: "fp16",
+        torch.bfloat16: "bf16",
+        torch.float32: "fp32",
+        torch.float8_e4m3fn: "fp8e4m3",
+        torch.float8_e5m2: "fp8e5m2",
+    }
+    try:
+        return mapping[dtype]
+    except KeyError as exc:
+        raise ValueError(f"Unsupported block-extend dtype: {dtype}") from exc
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/dllm/block_extend.py` around lines 90 - 95, The helper
_get_dtype_str currently maps any non-fp16/bf16 dtype to "fp16", causing
FP8/FP32 to reuse FP16 URIs; update _get_dtype_str to explicitly map known
dtypes (at minimum torch.float16 -> "fp16", torch.bfloat16 -> "bf16",
torch.float32 -> "fp32") and do not alias unknown types to "fp16" — instead
either return a distinct string (e.g., dtype.name) or raise a clear ValueError
for unsupported dtypes so wrong specializations are never cached; edit the
_get_dtype_str function accordingly.
flashinfer/prefill.py (1)

1651-1651: ⚠️ Potential issue | 🟠 Major

Reject unsupported mask_modes before dispatch.

Both wrappers now persist arbitrary mask_mode overrides, but the cudnn path still derives masking from self._causal and the paged trtllm-gen path never consumes mask_mode at all. A BLOCK_EXPANDING plan can therefore silently execute with causal/non-causal semantics instead of failing fast.

Also applies to: 2010-2010, 2205-2206, 2632-2632, 2938-2938, 3159-3160

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/prefill.py` at line 1651, The function accepting mask_mode should
reject unsupported or conflicting overrides before dispatch: at the start of the
routine that declares the mask_mode parameter, validate mask_mode against
supported values and fail fast (raise ValueError) instead of persisting
arbitrary overrides; specifically, if dispatch will use the cudnn path (which
derives masking from self._causal) and mask_mode is provided but inconsistent
with self._causal, raise an error, and if dispatch will use the paged trtllm-gen
path (which never consumes mask_mode) reject any non-None mask_mode; ensure
checks reference mask_mode, self._causal, and the BLOCK_EXPANDING plan name so
callers cannot silently run with wrong causal/non-causal semantics.
tests/attention/test_dllm_vs_flex_attention.py (4)

61-104: ⚠️ Potential issue | 🟠 Major

The new per-batch and kv_offset plumbing still isn't being validated.

compute_block_extend_reference() only shifts Q, make_block_extend_mask_mod() ignores b, and the batch setup uses a uniform torch.full(...) offset. The Hopper block-expanding path uses both per-request Q offsets and per-request KV offsets, so these checks can pass without touching the new code path at all.

Also applies to: 321-321, 381-384, 516-520

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/attention/test_dllm_vs_flex_attention.py` around lines 61 - 104, The
tests currently only shift Q globally and ignore per-batch and per-request KV
offsets, so update compute_block_extend_reference and make_block_extend_mask_mod
to accept and use per-batch q_offset and kv_offset (arrays/tensors indexed by b)
and ensure the returned mask_mod callback uses the b parameter to read
kv_offset[b] and q_offset[b] when computing q_global and kv_blk; also change the
test setup (replace torch.full(...) offsets) to supply non-uniform per-batch q
and kv offsets so the flex-attention path actually exercises per-request
plumbing (check functions compute_block_extend_reference,
make_block_extend_mask_mod, and the test offset construction).

216-229: ⚠️ Potential issue | 🟠 Major

These benchmark drivers will be collected and run as ordinary tests.

All of these top-level test_* functions have only defaulted parameters, so pytest will execute the full sweeps during normal test runs. Please move them to a benchmark module or rename them to a non-test_ prefix.

Also applies to: 664-672, 752-759, 832-841, 959-968

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/attention/test_dllm_vs_flex_attention.py` around lines 216 - 229, The
function named test_flashinfer_vs_flex_attention (and other top-level functions
at the noted ranges) are declared as pytest tests but are benchmark drivers with
only defaulted parameters, so rename each function to a non-test name (e.g.,
benchmark_flashinfer_vs_flex_attention) or move them into a dedicated benchmark
module so pytest won't auto-run them; update any references/calls to the
original function names accordingly and ensure imports/exports reflect the new
names (target symbols: test_flashinfer_vs_flex_attention and the other top-level
test_* functions referenced).

355-399: ⚠️ Potential issue | 🟠 Major

Make correctness mismatches fail the test.

ragged_pass, paged_pass, and flex_pass only affect logging right now. If any backend diverges from the reference, pytest still reports success and CI misses the regression.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/attention/test_dllm_vs_flex_attention.py` around lines 355 - 399, The
test currently only prints ragged_pass, paged_pass, and flex_pass but does not
assert them, so divergences are not failing CI; update the test to assert these
boolean checks (or assert the numeric diffs are below tol) after computing
ragged_diff, paged_diff, and flex_diff so failures raise in pytest.
Specifically, add assertions for ragged_pass and paged_pass (and for flex_pass
only if HAS_FLEX_ATTENTION) with informative messages including the
corresponding diff values to help debugging; reference the variables
ragged_pass, paged_pass, flex_pass and the diff variables ragged_diff,
paged_diff, flex_diff and the tolerance tol when adding the assertions.

28-56: ⚠️ Potential issue | 🟠 Major

Skip unsupported GPU architectures before running this module.

This file exercises architecture-specific attention paths but only soft-fails import errors. Unsupported runners will still execute the test functions and fail later instead of skipping cleanly. As per coding guidelines, tests/**/*.py: Use flashinfer.utils functions (get_compute_capability(), is_sm90a_supported(), is_sm100a_supported()) to skip tests on unsupported GPU architectures.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/attention/test_dllm_vs_flex_attention.py` around lines 28 - 56, Add an
early GPU-architecture guard at module import time before the
FlashInfer/Flex-Attention import blocks: call
flashinfer.utils.get_compute_capability() and use
flashinfer.utils.is_sm90a_supported() and is_sm100a_supported() (or import them
from flashinfer.utils) to detect unsupported GPUs and call pytest.skip(...) to
skip the whole module when the architecture is not supported; ensure this check
runs before or around the try/except blocks that set HAS_FLASHINFER and
HAS_FLEX_ATTENTION so tests are skipped cleanly instead of failing later (refer
to the module-level import area and the HAS_FLASHINFER / HAS_FLEX_ATTENTION
logic).
flashinfer/dllm/batch_block_extend.py (2)

611-613: ⚠️ Potential issue | 🟠 Major

Stage 1 is still non-causal, and logits_soft_cap is still a no-op.

The helper advertises “causal + merge”, but the current-chunk ragged plan still uses causal=False, so tokens can see later positions inside the chunk before the merge. logits_soft_cap is also accepted by the public API and then dropped in both planning calls.

Also applies to: 638-659

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/dllm/batch_block_extend.py` around lines 611 - 613, The helper
currently advertises "causal + merge" but leaves stage 1 non-causal and drops
logits_soft_cap; update the code so the current-chunk ragged plan is created
with causal=True (so tokens cannot see later positions inside the chunk prior to
merge) and pass the logits_soft_cap parameter through to both planning calls
instead of ignoring it; locate the helper that builds the current-chunk ragged
plan and the two planning call sites in batch_block_extend.py and modify their
call signatures/arguments to include logits_soft_cap and set causal=True for the
stage-1 plan.

560-563: ⚠️ Potential issue | 🟠 Major

Default offsets are still wrong when a paged prefix exists.

With a nonzero prefix, zeroing q_offsets/kv_offsets makes stage 1 run as if the current chunk started at global position 0. That changes the block-expanding mask whenever the prefix length is not block-aligned. Derive per-request offsets from the paged-prefix metadata or require the caller to pass them explicitly.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/dllm/batch_block_extend.py` around lines 560 - 563, The defaulting
of q_offsets/kv_offsets to zeros is incorrect when a paged prefix exists because
it makes stage 1 assume the chunk starts at global position 0; update the logic
in batch_block_extend.py where q_offsets and kv_offsets are initialized so that
when they are None you derive per-request offsets from the paged-prefix metadata
(use the paged-prefix length/offset fields for each request) and compute int32
offsets on the device instead of zeroing, or alternatively make the function
require the caller to pass explicit q_offsets/kv_offsets; ensure you propagate
these computed per-request offsets into the downstream stage 1 block-expanding
mask calculations so block alignment is correct for non-block-aligned prefix
lengths.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/dllm/batch_block_extend.py`:
- Around line 287-299: The code mutates the wrapper's backend preference by
assigning effective_backend back to self._backend, breaking future
auto-selection; instead, keep self._backend unchanged and use a local variable
(effective_backend) only to select the URI/variant (via
select_best_backend_paged, _get_batch_be_module_uri, variant_name, variant_decl)
so that backend="auto" continues to work across re-plans—apply the same
non-mutating pattern used in
BatchBlockExtendRaggedOffsetWrapper._create_inner_wrapper and also fix the
analogous block at the other occurrence (around the 422-433 region).
- Around line 545-549: In batch_block_extend_cascade(), don't pre-resolve
backend=="auto" using is_sm90a_supported/device; remove the logic that sets
actual_backend = "fa3" if is_sm90a_supported(device) else "fa2" and instead
preserve "auto" (i.e., set actual_backend = backend) so the wrapper layers can
resolve it, or replace that branch with a call to the centralized availability
selector helper if you prefer availability-based resolution; update references
to actual_backend accordingly.

In `@flashinfer/jit/attention/modules.py`:
- Line 1283: The JIT cache key must include the mask_modes so different mask
sets produce distinct compiled artifacts; update the code that constructs the
JitSpec (the unique name/URI hash and/or sources list) to incorporate the
mask_modes value(s) alongside existing fields (uri, sources,
extra_cuda_cflags/extra_cflags/extra_ldflags). Specifically, when creating the
JitSpec for functions that accept mask_modes (parameter mask_modes), append or
mix the mask_modes representation into the unique name/hash used as the URI
and/or into the sources list so the compiled directory/shared object differs for
different mask_modes and prevents loading stale modules.

---

Duplicate comments:
In `@flashinfer/dllm/batch_block_extend.py`:
- Around line 611-613: The helper currently advertises "causal + merge" but
leaves stage 1 non-causal and drops logits_soft_cap; update the code so the
current-chunk ragged plan is created with causal=True (so tokens cannot see
later positions inside the chunk prior to merge) and pass the logits_soft_cap
parameter through to both planning calls instead of ignoring it; locate the
helper that builds the current-chunk ragged plan and the two planning call sites
in batch_block_extend.py and modify their call signatures/arguments to include
logits_soft_cap and set causal=True for the stage-1 plan.
- Around line 560-563: The defaulting of q_offsets/kv_offsets to zeros is
incorrect when a paged prefix exists because it makes stage 1 assume the chunk
starts at global position 0; update the logic in batch_block_extend.py where
q_offsets and kv_offsets are initialized so that when they are None you derive
per-request offsets from the paged-prefix metadata (use the paged-prefix
length/offset fields for each request) and compute int32 offsets on the device
instead of zeroing, or alternatively make the function require the caller to
pass explicit q_offsets/kv_offsets; ensure you propagate these computed
per-request offsets into the downstream stage 1 block-expanding mask
calculations so block alignment is correct for non-block-aligned prefix lengths.

In `@flashinfer/dllm/block_extend.py`:
- Around line 146-170: The capability check for FA3 uses torch.device("cuda")
which can be wrong on mixed-GPU systems; update
get_block_extend_module_with_offset to accept or derive the target device (use
the device of the input/target tensor or the same device used by
block_extend_attention_with_offset) and pass that device into is_sm90a_supported
instead of torch.device("cuda"); also apply the same change to the other FA3
check near the block_extend_attention_with_offset-related code (the second
occurrence) so the SM90 gate queries the actual tensor/device being compiled
for.
- Around line 90-95: The helper _get_dtype_str currently maps any non-fp16/bf16
dtype to "fp16", causing FP8/FP32 to reuse FP16 URIs; update _get_dtype_str to
explicitly map known dtypes (at minimum torch.float16 -> "fp16", torch.bfloat16
-> "bf16", torch.float32 -> "fp32") and do not alias unknown types to "fp16" —
instead either return a distinct string (e.g., dtype.name) or raise a clear
ValueError for unsupported dtypes so wrong specializations are never cached;
edit the _get_dtype_str function accordingly.

In `@flashinfer/prefill.py`:
- Line 1651: The function accepting mask_mode should reject unsupported or
conflicting overrides before dispatch: at the start of the routine that declares
the mask_mode parameter, validate mask_mode against supported values and fail
fast (raise ValueError) instead of persisting arbitrary overrides; specifically,
if dispatch will use the cudnn path (which derives masking from self._causal)
and mask_mode is provided but inconsistent with self._causal, raise an error,
and if dispatch will use the paged trtllm-gen path (which never consumes
mask_mode) reject any non-None mask_mode; ensure checks reference mask_mode,
self._causal, and the BLOCK_EXPANDING plan name so callers cannot silently run
with wrong causal/non-causal semantics.

In `@include/flashinfer/attention/hopper/mainloop.cuh`:
- Around line 230-243: The code assumes get_num_kv_tiles(...) > 0 but it can
return 0, which makes kv_tile_idx = num_kv_tiles - 1 negative and leads to an
invalid read from tKgK(_, kv_tile_idx); add a guard for num_kv_tiles == 0 before
any K/V TMA scheduling: check num_kv_tiles and if it is 0 skip the
pipeline_k.producer_acquire(...) and the subsequent copy(...) that references
tKgK and smem_pipe_write_k (i.e., short-circuit the path that schedules the
first K load using mainloop_params.tma_load_K and tKgK), or route to a no-op
branch so no tile is indexed when num_kv_tiles == 0.

In `@tests/attention/test_dllm_vs_flex_attention.py`:
- Around line 61-104: The tests currently only shift Q globally and ignore
per-batch and per-request KV offsets, so update compute_block_extend_reference
and make_block_extend_mask_mod to accept and use per-batch q_offset and
kv_offset (arrays/tensors indexed by b) and ensure the returned mask_mod
callback uses the b parameter to read kv_offset[b] and q_offset[b] when
computing q_global and kv_blk; also change the test setup (replace
torch.full(...) offsets) to supply non-uniform per-batch q and kv offsets so the
flex-attention path actually exercises per-request plumbing (check functions
compute_block_extend_reference, make_block_extend_mask_mod, and the test offset
construction).
- Around line 216-229: The function named test_flashinfer_vs_flex_attention (and
other top-level functions at the noted ranges) are declared as pytest tests but
are benchmark drivers with only defaulted parameters, so rename each function to
a non-test name (e.g., benchmark_flashinfer_vs_flex_attention) or move them into
a dedicated benchmark module so pytest won't auto-run them; update any
references/calls to the original function names accordingly and ensure
imports/exports reflect the new names (target symbols:
test_flashinfer_vs_flex_attention and the other top-level test_* functions
referenced).
- Around line 355-399: The test currently only prints ragged_pass, paged_pass,
and flex_pass but does not assert them, so divergences are not failing CI;
update the test to assert these boolean checks (or assert the numeric diffs are
below tol) after computing ragged_diff, paged_diff, and flex_diff so failures
raise in pytest. Specifically, add assertions for ragged_pass and paged_pass
(and for flex_pass only if HAS_FLEX_ATTENTION) with informative messages
including the corresponding diff values to help debugging; reference the
variables ragged_pass, paged_pass, flex_pass and the diff variables ragged_diff,
paged_diff, flex_diff and the tolerance tol when adding the assertions.
- Around line 28-56: Add an early GPU-architecture guard at module import time
before the FlashInfer/Flex-Attention import blocks: call
flashinfer.utils.get_compute_capability() and use
flashinfer.utils.is_sm90a_supported() and is_sm100a_supported() (or import them
from flashinfer.utils) to detect unsupported GPUs and call pytest.skip(...) to
skip the whole module when the architecture is not supported; ensure this check
runs before or around the try/except blocks that set HAS_FLASHINFER and
HAS_FLEX_ATTENTION so tests are skipped cleanly instead of failing later (refer
to the module-level import area and the HAS_FLASHINFER / HAS_FLEX_ATTENTION
logic).

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 79a6e941-787c-4ea3-a005-e934cd5ac5b1

📥 Commits

Reviewing files that changed from the base of the PR and between 33b195e and 5e07b58.

📒 Files selected for processing (23)
  • csrc/batch_prefill_customize_config.jinja
  • csrc/batch_prefill_sm90_customize_config.jinja
  • csrc/single_prefill_customize_config.jinja
  • csrc/single_prefill_sm90_customize_config.jinja
  • flashinfer/__init__.py
  • flashinfer/dllm/__init__.py
  • flashinfer/dllm/batch_block_extend.py
  • flashinfer/dllm/block_extend.py
  • flashinfer/jit/attention/modules.py
  • flashinfer/jit/utils.py
  • flashinfer/prefill.py
  • flashinfer/utils.py
  • include/flashinfer/attention/block_expanding_prefill.cuh
  • include/flashinfer/attention/default_prefill_params.cuh
  • include/flashinfer/attention/hopper/mainloop.cuh
  • include/flashinfer/attention/hopper/mainloop_mma.cuh
  • include/flashinfer/attention/hopper/prefill_sm90.cuh
  • include/flashinfer/attention/hopper/sparse_mainloop.cuh
  • include/flashinfer/attention/mask.cuh
  • include/flashinfer/attention/prefill.cuh
  • include/flashinfer/utils.cuh
  • tests/attention/test_dllm_cascade_vs_blockwise_extend_attention.py
  • tests/attention/test_dllm_vs_flex_attention.py
🚧 Files skipped from review as they are similar to previous changes (6)
  • csrc/single_prefill_customize_config.jinja
  • flashinfer/init.py
  • csrc/single_prefill_sm90_customize_config.jinja
  • flashinfer/utils.py
  • include/flashinfer/attention/block_expanding_prefill.cuh
  • flashinfer/jit/utils.py

Comment thread flashinfer/dllm/batch_block_extend.py Outdated
Comment thread flashinfer/dllm/batch_block_extend.py Outdated
use_logits_soft_cap: bool = False,
use_fp16_qk_reduction: bool = False,
fp8_enabled: bool = False,
mask_modes: Optional[List[int]] = None,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Include mask_modes in the JIT cache key.

mask_modes now changes which kernel sources get generated, but both functions still compile under the caller-supplied uri unchanged. Reusing the same uri with a different mask set can therefore hit the wrong generated directory / shared object and silently load a stale module.

💡 Proposed fix
 def gen_customize_single_prefill_module(
     backend: str,
     uri: str,
@@
-    mask_modes: Optional[List[int]] = None,
+    mask_modes: Optional[List[int]] = None,
 ) -> JitSpec:
+    normalized_mask_modes = tuple(sorted(set(mask_modes))) if mask_modes is not None else None
+    if normalized_mask_modes is not None:
+        uri = f"{uri}_mask_modes_{'_'.join(map(str, normalized_mask_modes))}"
     kwargs = {
         "variant_decl": variant_decl,
@@
 def gen_customize_batch_prefill_module(
     backend: str,
     uri: str,
@@
-    mask_modes: Optional[List[int]] = None,
+    mask_modes: Optional[List[int]] = None,
 ) -> JitSpec:
+    normalized_mask_modes = tuple(sorted(set(mask_modes))) if mask_modes is not None else None
+    if normalized_mask_modes is not None:
+        uri = f"{uri}_mask_modes_{'_'.join(map(str, normalized_mask_modes))}"
     kwargs = {
         "variant_decl": variant_decl,

As per coding guidelines, "Structure JIT compilation parameters in JitSpec with unique name (URI hash), sources list, and compiler flags (extra_cuda_cflags, extra_cflags, extra_ldflags)".

Also applies to: 1532-1532

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/jit/attention/modules.py` at line 1283, The JIT cache key must
include the mask_modes so different mask sets produce distinct compiled
artifacts; update the code that constructs the JitSpec (the unique name/URI hash
and/or sources list) to incorporate the mask_modes value(s) alongside existing
fields (uri, sources, extra_cuda_cflags/extra_cflags/extra_ldflags).
Specifically, when creating the JitSpec for functions that accept mask_modes
(parameter mask_modes), append or mix the mask_modes representation into the
unique name/hash used as the URI and/or into the sources list so the compiled
directory/shared object differs for different mask_modes and prevents loading
stale modules.

@ClawSeven
Copy link
Copy Markdown

Hi, @yzh119, @bkryu,

This PR is quite impactful for block diffusion LLM — we’ve seen several times improvement in TTFT in our production environment at AntGroup when using SGLang.

By the way, I’m the code owner of SGLang-dLLM. I’d expect to collaborate with you to make FlashInfer as the official recommended backend for SGLang-dLLM. I think it would be helpful for both projects and the broader community.
Looking forward to working together!

SGLang dLLM roadmap: sgl-project/sglang#14199

Copy link
Copy Markdown
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @fdz-1999 thanks for the great work and it make sense to me in general.

Adding a new mask type would significantly bloat binary size (in jit-cache, etc.), can we make it a standalone class instead of changing csrc/batch_prefill_customize_config.jinja etc?

} // namespace flashinfer

#endif // FLASHINFER_DECODE_PARAMS_CUH_
#endif // FLASHINFER_DECODE_PARAMS_CUH_ No newline at end of file
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please fix the lint issue (new line in the end).

Comment thread include/flashinfer/utils.cuh Outdated
} // namespace flashinfer

#endif // FLASHINFER_UTILS_CUH_
#endif // FLASHINFER_UTILS_CUH_ No newline at end of file
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto.

@fdz-1999
Copy link
Copy Markdown
Author

Hi @fdz-1999 thanks for the great work and it make sense to me in general.

Adding a new mask type would significantly bloat binary size (in jit-cache, etc.), can we make it a standalone class instead of changing csrc/batch_prefill_customize_config.jinja etc?

Thanks for the review and the concern about binary size — it's a really important consideration for a JIT-heavy project like FlashInfer, and I want to make sure the impact is clear.

I looked into this carefully, and I believe the current approach should not cause binary bloat for non-DLLM users. Here's the reasoning:

1) Mode 4 is not compiled by default

The JIT default in modules.py is mask_modes=[0, 1, 2, 3] across all 4 code paths (single prefill, batch ragged, batch paged, SM90 variants). Mode 4 (kBlockExpanding) is only compiled when DLLM wrappers explicitly request it via mask_modes=[4]. So for existing users, there are no extra .cu files generated and no extra JIT cache entries.

2) Jinja template changes are conditionally compiled away

The get_q/kv_block_expanding_offset() methods added to RaggedParams / PagedParams are guarded by Jinja {% if %} conditions:

{% if 'maybe_q_block_expanding_offset' in additional_params_decl %}
    return maybe_q_block_expanding_offset[batch_idx];
{% else %}
    return 0;
{% endif %}

Non-DLLM callers never pass maybe_q_block_expanding_offset, so these methods render as return 0 and get inlined away by the compiler.

3) The kernel-level changes follow the existing MaskMode pattern

The additions in prefill.cuh and the Hopper files are all if constexpr (MASK_MODE == MaskMode::kBlockExpanding) branches — the same mechanism used by kCausal, kCustom, and kMultiItemScoring. When mode 4 is not instantiated, these branches are eliminated at compile time with zero code generation.

The only "global" changes are:

  • mask.cuh: one enum member (kBlockExpanding = 4U) — no code generation impact
  • utils.cuh: one case in DISPATCH_MASK_MODE — only reached at runtime when mode 4 is explicitly requested, and non-DLLM users never pass mode 4

4) Regarding the standalone class alternative

I did consider this direction. The challenge is that block expanding only differs from the existing kernel at a few specific points (mask predicate, iteration count, mask_iteration boundary — ~50 lines of if constexpr branches in prefill.cuh, with similar patterns in the 4 Hopper files). The rest of the kernel — QKV matmul, online softmax, tiling, shared memory management, warp-level MMA — is identical.

A standalone implementation would need to duplicate the full prefill kernel stack (prefill.cuh + 4 Hopper files, ~4500 lines total), plus maintain a separate set of Jinja templates and JIT pipeline. This would create a significant ongoing maintenance burden, especially when upstream kernel optimizations or bugfixes need to be synced.

Given that the current approach has zero impact on non-DLLM users (no extra compilation, no binary size increase, no runtime overhead), I think extending the existing MaskMode infrastructure is the more sustainable path. But I'm happy to discuss further if there are specific scenarios you're concerned about!

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

♻️ Duplicate comments (3)
flashinfer/dllm/batch_block_extend.py (2)

293-296: ⚠️ Potential issue | 🟠 Major

Don't overwrite the wrapper's backend preference on first plan.

Both _create_inner_wrapper implementations still assign self._backend = effective_backend. When the wrapper was constructed with backend="auto", the first plan locks the instance into whatever was picked for that (head_dim, dtype) pair, so a subsequent re-plan for a different shape no longer auto-falls-back and can fail even when the alternative backend is available. self._preferred_backend is already stored for this purpose — use the local effective_backend for URI/variant/inner construction and leave self._backend unchanged.

Also applies to: 430-433

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/dllm/batch_block_extend.py` around lines 293 - 296, The
_create_inner_wrapper method currently assigns self._backend = effective_backend
which locks the instance to the first chosen backend; instead, stop mutating
self._backend and only use the local effective_backend variable when
constructing the inner wrapper/URI/variant/inner objects (leave
self._preferred_backend and self._backend unchanged so future re-plans can
re-select backends). Apply the same change in the other _create_inner_wrapper
implementation as well (remove any assignment to self._backend and rely on
effective_backend locally).

179-186: ⚠️ Potential issue | 🟠 Major

URI still omits idtype — silent kernel-mismatch risk when indptr dtype changes.

The recreation check at line 338 / 472 now correctly includes self._idtype, but _get_batch_be_module_uri() does not encode idtype into the URI string. When a user re-plans with a different indptr dtype (e.g. int32int64), _create_inner_wrapper produces the same URI, and the downstream JIT module cache may return the previously-specialized kernel, leading to silent miscompute or undefined behavior. Bake idtype into the URI (and reject unsupported dtypes explicitly as done for the element dtype).

🛠️ Proposed fix
-def _get_batch_be_module_uri(head_dim: int, dtype: torch.dtype) -> str:
-    _dtype_map = {torch.float16: "fp16", torch.bfloat16: "bf16"}
+def _get_batch_be_module_uri(head_dim: int, dtype: torch.dtype, idtype: torch.dtype = torch.int32) -> str:
+    _dtype_map = {torch.float16: "fp16", torch.bfloat16: "bf16"}
+    _idtype_map = {torch.int32: "i32", torch.int64: "i64"}
     if dtype not in _dtype_map:
         raise ValueError(
             f"Unsupported dtype {dtype} for Block Extend Attention. "
             f"Supported: {list(_dtype_map.keys())}"
         )
-    return f"batch_prefill_block_expanding_hd{head_dim}_{_dtype_map[dtype]}"
+    if idtype not in _idtype_map:
+        raise ValueError(
+            f"Unsupported idtype {idtype}. Supported: {list(_idtype_map.keys())}"
+        )
+    return f"batch_prefill_block_expanding_hd{head_dim}_{_dtype_map[dtype]}_{_idtype_map[idtype]}"

Thread idtype into every call site (lines 86-88, 134-136, 298/302, 435/439).

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/dllm/batch_block_extend.py` around lines 179 - 186,
_get_batch_be_module_uri currently omits the indptr idtype which can cause
silent kernel-mismatch; change its signature to accept an idtype parameter (e.g.
def _get_batch_be_module_uri(head_dim: int, dtype: torch.dtype, idtype:
torch.dtype) -> str), add an _idtype_map (e.g. {torch.int32: "i32", torch.int64:
"i64"}) and raise ValueError for unsupported idtypes mirroring the element-dtype
check, and include the mapped idtype token in the returned URI string; then
update all callers (notably places that call _get_batch_be_module_uri and
_create_inner_wrapper and any other call sites that construct the module URI) to
pass self._idtype so the URI uniquely encodes indptr dtype and avoids returning
incorrectly specialized kernels.
tests/attention/test_dllm_vs_flex_attention.py (1)

61-104: ⚠️ Potential issue | 🟠 Major

Per-request offset plumbing is still untested.

compute_block_extend_reference hardcodes kv_offset=0, make_block_extend_mask_mod ignores b, and the batch driver at line 321 constructs q_offsets via torch.full(...), so every request gets the same offset. The per-batch maybe_q_block_expanding_offset / maybe_kv_block_expanding_offset arrays added in this PR — plus the cascade current-chunk kv_block_expanding_offset path — never see heterogeneous inputs in these tests. Please add at least one case with distinct per-request q_offsets and a non-zero kv_offset to exercise the batch-indexed array reads.

Also applies to: 321-321

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/attention/test_dllm_vs_flex_attention.py` around lines 61 - 104, The
tests never exercise per-request offsets: update compute_block_extend_reference
to accept and use a kv_offset parameter (and propagate it into q_pos/k_pos
calculation), modify make_block_extend_mask_mod so the returned
block_extend_mask reads per-request q_offset (use the b argument) and
accepts/uses a kv_offset path, and change the test batch setup that currently
uses torch.full(...) for q_offsets to create heterogeneous per-request q_offsets
and a non-zero kv_offset (so maybe_q_block_expanding_offset,
maybe_kv_block_expanding_offset and the kv_block_expanding_offset cascade
actually read varied entries); run the updated test to ensure the batch-indexed
reads are exercised.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/dllm/batch_block_extend.py`:
- Line 177: The module-scope variable _BATCH_BE_MODULE_CACHE is unused; either
remove it or wire it into the wrapper memoization: if removing, delete the
_BATCH_BE_MODULE_CACHE definition and any unused references; if implementing
memoization, update the _create_inner_wrapper methods on
BatchBlockExtendPagedOffsetWrapper and BatchBlockExtendRaggedOffsetWrapper to
check _BATCH_BE_MODULE_CACHE for an existing wrapper keyed by the unique
parameters (e.g., backend module name and config), return the cached instance
when present, and store newly created wrappers in _BATCH_BE_MODULE_CACHE to
avoid recreating them.

In `@flashinfer/dllm/block_extend.py`:
- Around line 200-204: The condition that checks FLASHINFER_DISABLE_JIT is
incorrectly treating the string "0" as truthy; update the check in
block_extend.py so it uses an explicit comparison (e.g.,
os.environ.get("FLASHINFER_DISABLE_JIT") == "1" or the same boolean parsing used
for FLASHINFER_FORCE_JIT) before raising the RuntimeError, and keep the same
error message referencing _get_aot_path(uri).

In `@tests/attention/test_dllm_vs_flex_attention.py`:
- Around line 1089-1093: The try/except around benchmark_with_cuda_graph
silently swallows errors; change the except block in the section calling
benchmark_with_cuda_graph(_run_fi, warmup_iters, bench_iters) so it captures the
exception as e, logs or prints the exception (including a clear message that
CUDA-graph capture failed) and sets a sentinel in entry (e.g., entry["fi_cg_ms"]
= None or an explicit failure string) so the result row shows the failure;
update the except clause that currently just does `pass` to record the error and
preserve diagnostic visibility.

---

Duplicate comments:
In `@flashinfer/dllm/batch_block_extend.py`:
- Around line 293-296: The _create_inner_wrapper method currently assigns
self._backend = effective_backend which locks the instance to the first chosen
backend; instead, stop mutating self._backend and only use the local
effective_backend variable when constructing the inner wrapper/URI/variant/inner
objects (leave self._preferred_backend and self._backend unchanged so future
re-plans can re-select backends). Apply the same change in the other
_create_inner_wrapper implementation as well (remove any assignment to
self._backend and rely on effective_backend locally).
- Around line 179-186: _get_batch_be_module_uri currently omits the indptr
idtype which can cause silent kernel-mismatch; change its signature to accept an
idtype parameter (e.g. def _get_batch_be_module_uri(head_dim: int, dtype:
torch.dtype, idtype: torch.dtype) -> str), add an _idtype_map (e.g.
{torch.int32: "i32", torch.int64: "i64"}) and raise ValueError for unsupported
idtypes mirroring the element-dtype check, and include the mapped idtype token
in the returned URI string; then update all callers (notably places that call
_get_batch_be_module_uri and _create_inner_wrapper and any other call sites that
construct the module URI) to pass self._idtype so the URI uniquely encodes
indptr dtype and avoids returning incorrectly specialized kernels.

In `@tests/attention/test_dllm_vs_flex_attention.py`:
- Around line 61-104: The tests never exercise per-request offsets: update
compute_block_extend_reference to accept and use a kv_offset parameter (and
propagate it into q_pos/k_pos calculation), modify make_block_extend_mask_mod so
the returned block_extend_mask reads per-request q_offset (use the b argument)
and accepts/uses a kv_offset path, and change the test batch setup that
currently uses torch.full(...) for q_offsets to create heterogeneous per-request
q_offsets and a non-zero kv_offset (so maybe_q_block_expanding_offset,
maybe_kv_block_expanding_offset and the kv_block_expanding_offset cascade
actually read varied entries); run the updated test to ensure the batch-indexed
reads are exercised.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 2afdf1fd-f0a9-41ad-84de-e36b89558f65

📥 Commits

Reviewing files that changed from the base of the PR and between 5e07b58 and 2bcbb68.

📒 Files selected for processing (6)
  • csrc/single_prefill_sm90_customize_config.jinja
  • flashinfer/dllm/batch_block_extend.py
  • flashinfer/dllm/block_extend.py
  • include/flashinfer/attention/default_prefill_params.cuh
  • include/flashinfer/utils.cuh
  • tests/attention/test_dllm_vs_flex_attention.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • csrc/single_prefill_sm90_customize_config.jinja
  • include/flashinfer/attention/default_prefill_params.cuh

Comment thread flashinfer/dllm/batch_block_extend.py Outdated
Comment thread flashinfer/dllm/block_extend.py Outdated
Comment on lines +1089 to +1093
try:
fi_cg_ms = benchmark_with_cuda_graph(_run_fi, warmup_iters, bench_iters)
entry["fi_cg_ms"] = fi_cg_ms
except Exception:
pass
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Don't silently swallow CUDA-graph capture failures.

The bare try: … except Exception: pass drops any exception from benchmark_with_cuda_graph, so a real capture failure (e.g., dynamic shapes, stream issues) silently produces a result row with fi_cg_ms missing and no diagnostic. At minimum, print the exception so the benchmark output makes the failure visible.

🛠️ Proposed fix
                 try:
                     fi_cg_ms = benchmark_with_cuda_graph(_run_fi, warmup_iters, bench_iters)
                     entry["fi_cg_ms"] = fi_cg_ms
-                except Exception:
-                    pass
+                except Exception as e:
+                    print(f"    [dllm_bs={dbs}, seq={seq_len}] FI CUDA Graph failed: {e}")
🧰 Tools
🪛 Ruff (0.15.10)

[error] 1092-1093: try-except-pass detected, consider logging the exception

(S110)


[warning] 1092-1092: Do not catch blind exception: Exception

(BLE001)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/attention/test_dllm_vs_flex_attention.py` around lines 1089 - 1093, The
try/except around benchmark_with_cuda_graph silently swallows errors; change the
except block in the section calling benchmark_with_cuda_graph(_run_fi,
warmup_iters, bench_iters) so it captures the exception as e, logs or prints the
exception (including a clear message that CUDA-graph capture failed) and sets a
sentinel in entry (e.g., entry["fi_cg_ms"] = None or an explicit failure string)
so the result row shows the failure; update the except clause that currently
just does `pass` to record the error and preserve diagnostic visibility.

@fdz-1999 fdz-1999 force-pushed the feature/block-extend branch 2 times, most recently from 5ec571a to 93965f5 Compare April 21, 2026 12:56
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

♻️ Duplicate comments (5)
flashinfer/dllm/batch_block_extend.py (1)

179-186: ⚠️ Potential issue | 🟠 Major

URI still ignores idtype (and coerces FP8 dtypes into fp16).

_get_batch_be_module_uri only encodes head_dim and the element dtype. Different idtype variants (torch.int32 vs torch.int64) will collide under the same flashinfer::{uri}_* name, and the wrapper-recreation check at Lines 338/472 tracks idtype but the URI does not. Please bake idtype (and any supported dtype explicitly) into the URI.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/dllm/batch_block_extend.py` around lines 179 - 186, The URI
builder _get_batch_be_module_uri currently only encodes head_dim and a limited
dtype map, causing collisions across different idtype values and misrepresenting
FP8 types; update _get_batch_be_module_uri to include the idtype (e.g.,
torch.int32 vs torch.int64) in the returned string and expand the dtype mapping
to explicitly handle all supported dtypes (including FP8 variants) so each
unique combination of head_dim, element dtype, and idtype produces a distinct
URI; reference the function name _get_batch_be_module_uri and ensure the
returned string format embeds both idtype and the normalized dtype token.
flashinfer/dllm/block_extend.py (1)

200-204: ⚠️ Potential issue | 🟠 Major

FLASHINFER_DISABLE_JIT=0 still disables JIT.

os.environ.get("FLASHINFER_DISABLE_JIT") returns the literal string, and "0" is truthy in Python — so users setting this to the natural "off" value get JIT disabled. Line 85 in the same file correctly uses == "1" for FLASHINFER_FORCE_JIT; please mirror that here.

🛠️ Proposed fix
-    if os.environ.get("FLASHINFER_DISABLE_JIT"):
+    if os.environ.get("FLASHINFER_DISABLE_JIT", "0") == "1":
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/dllm/block_extend.py` around lines 200 - 204, The environment
check currently uses os.environ.get("FLASHINFER_DISABLE_JIT") which treats "0"
as truthy and incorrectly disables JIT; update the condition to explicitly
compare the env value to "1" (i.e., FLASHINFER_DISABLE_JIT == "1") so only an
explicit "1" disables JIT, keeping the existing RuntimeError message that
references _get_aot_path(uri) unchanged.
tests/attention/test_dllm_vs_flex_attention.py (3)

99-104: ⚠️ Potential issue | 🟡 Minor

Batch-indexed kv_block_expanding_offset path still has no heterogeneous-batch test.

make_block_extend_mask_mod ignores b, the reference uses a scalar q_offset, and the batch setup uses torch.full((num_requests,), q_offset, ...) — so every request in the batch gets the same offset and kv_offsets is never exercised end-to-end against a per-request reference. The PR introduces per-batch q_block_expanding_offset/kv_block_expanding_offset accessors; please add at least one correctness case with heterogeneous q_offsets and a non-zero kv_offsets to cover the new plumbing.

Also applies to: 321-321

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/attention/test_dllm_vs_flex_attention.py` around lines 99 - 104, The
batch-indexed path isn't tested with heterogeneous per-request offsets: update
the test in tests/attention/test_dllm_vs_flex_attention.py to exercise per-batch
q_offsets and non-zero kv_offsets by fixing make_block_extend_mask_mod to
actually use the batch index parameter b (i.e., compute q_global from
q_offset[b] rather than a scalar), create a heterogeneous torch tensor for
q_block_expanding_offset (not torch.full) and set a non-zero
kv_block_expanding_offset/kv_offsets, then add an assertion comparing
block_extend_mask results (or end-to-end attention outputs) between the dllm
implementation and the flex/reference implementation so the new
q_block_expanding_offset and kv_block_expanding_offset plumbing is covered.

20-56: ⚠️ Potential issue | 🟡 Minor

Module-level import-time prints on non-CUDA/unsupported hosts.

Since this file is named test_dllm_vs_flex_attention.py, pytest will still import it during collection. On machines without CUDA, the top-of-file prints fire but nothing gets skipped cleanly. The driver entry points have been renamed to bench_* (good), but please gate the module with an explicit pytest.importorskip/arch check (e.g., is_sm90a_supported/is_sm80a_supported from flashinfer.utils) so collection is quiet and deterministic on unsupported runners. As per coding guidelines: "Skip test execution on unsupported GPU architectures using flashinfer.utils check functions".

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/attention/test_dllm_vs_flex_attention.py` around lines 20 - 56,
Top-level imports in test_dllm_vs_flex_attention.py cause prints during pytest
collection on unsupported/non-CUDA hosts; wrap/gate module import with
pytest.importorskip or an explicit arch check using flashinfer.utils functions
to prevent collection printing. Concretely, at the top of the module call
pytest.importorskip("flashinfer") or call
flashinfer.utils.is_sm90a_supported()/is_sm80a_supported() (and skip via
pytest.skip if neither is true) before importing single_prefill_with_kv_cache,
BatchBlockExtendPagedOffsetWrapper, BatchBlockExtendRaggedOffsetWrapper,
flex_attention, create_block_mask so the file is skipped quietly on unsupported
runners and no module-level prints occur.

1089-1093: ⚠️ Potential issue | 🟡 Minor

Bare except: pass still silently drops CUDA-graph capture failures.

entry["fi_cg_ms"] is omitted with no diagnostic, making real capture failures invisible. At minimum print the exception; the same pattern also appears at Line 1029 in the workspace probe and Line 1130 in the Flex path (the Flex path already prints, which is the right shape).

🛠️ Proposed fix
-                except Exception:
-                    pass
+                except Exception as e:
+                    print(f"    [dllm_bs={dbs}, seq={seq_len}] FI CUDA Graph failed: {e}")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/attention/test_dllm_vs_flex_attention.py` around lines 1089 - 1093,
Replace the bare "except: pass" around the CUDA-graph capture so failures are
not silently dropped: wrap the benchmark_with_cuda_graph(_run_fi, ...) call in
"except Exception as e:" and print or log the exception (including context) and
set entry["fi_cg_ms"] to a sentinel (e.g., None or an "error" value) so the
failure is recorded; apply the same change to the other identical patterns (the
workspace probe capture and the Flex path) to ensure all capture exceptions are
surfaced rather than ignored.
🧹 Nitpick comments (3)
flashinfer/dllm/__init__.py (1)

16-19: Unused private imports in package __init__.

_check_batch_be_aot_available, _get_batch_be_aot_path, and _get_batch_be_module_uri are imported but not added to __all__ and not otherwise referenced here, so they're dead imports at the package level. Either drop them from the import list or promote them to the public surface consciously.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/dllm/__init__.py` around lines 16 - 19, The three private symbols
_check_batch_be_aot_available, _get_batch_be_aot_path, and
_get_batch_be_module_uri are imported into flashinfer.dllm.__init__ but never
exported or used; either remove them from the import list to eliminate dead
imports or intentionally expose them by adding their names to __all__ (or
re-exporting under public names) so they become part of the package surface;
update the import statement and/or the __all__ list in __init__.py accordingly
to keep imports and public API consistent.
tests/attention/test_dllm_vs_flex_attention.py (1)

1073-1074: Loop-variable closure in _run_fi (Ruff B023).

_run_fi captures wrapper/q/k/v by reference from the enclosing loop. It happens to work today because the closure is invoked inside the same iteration, but a future refactor (e.g., deferring the callable past the del ... wrapper at Line 1097) would silently pick up a later iteration's bindings. Bind explicitly:

def _run_fi(_w=wrapper, _q=q, _k=k, _v=v):
    _w.run(_q, _k, _v)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/attention/test_dllm_vs_flex_attention.py` around lines 1073 - 1074, The
nested function _run_fi closes over loop variables (wrapper, q, k, v) which can
change across iterations; update the definition of _run_fi to bind those values
as default parameters (e.g., def _run_fi(_w=wrapper, _q=q, _k=k, _v=v): ...) and
call _w.run(_q, _k, _v) so the callable captures the current iteration's
bindings instead of referencing them by closure.
flashinfer/dllm/block_extend.py (1)

292-296: backend="auto" here does not check actual kernel availability.

Unlike batch_block_extend.py's select_best_backend[_paged], this helper picks purely on is_sm90a_supported(q.device) and then unconditionally calls get_block_extend_module_with_offset(..., backend=backend). On a Hopper box where only the FA2 variant is AOT-compiled and JIT is disabled, auto mode will hard-fail instead of falling back. Consider reusing/consolidating the availability-aware selector used in the batch module.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/dllm/block_extend.py` around lines 292 - 296, The current auto
backend selection uses is_sm90a_supported(q.device) and then directly calls
get_block_extend_module_with_offset, which can hard-fail if the chosen kernel
isn't actually available; change the logic to consult the availability-aware
selector used in batch_block_extend.py (e.g., call select_best_backend or its
_paged variant) instead of the simple is_sm90a_supported check, or wrap
get_block_extend_module_with_offset in a try/fallback that falls back to the
other backend on failure; update the backend local variable from that selector
and then call get_block_extend_module_with_offset(head_dim=head_dim,
dtype=dtype, backend=backend, device=q.device).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/dllm/batch_block_extend.py`:
- Around line 313-317: The jit_kwargs currently lists mask_modes=[0,1,2,3,4]
which forces compilation of all mask-mode kernels even though the DLLM wrappers
call the inner prefill with mask_mode=MaskMode.BLOCK_EXPANDING.value; change the
mask_modes entry in jit_kwargs to only include [MaskMode.BLOCK_EXPANDING.value]
(or the literal 4) so only mode 4 is JIT/AOT-compiled; update both occurrences
around the jit_kwargs definitions (the one at lines ~313 and the similar block
at ~450-454) to avoid inflating build/AOT size and cache footprint.
- Around line 34-64: check_jit_environment currently shells out to a
PATH-resolved "nvcc" and uses a broad except that swallows real errors; update
it to locate nvcc via shutil.which and fallback to CUDA_HOME (e.g., check
os.environ["CUDA_HOME"] + "/bin/nvcc") before declaring nvcc missing and include
a clear issue message when nvcc is found/not found, and replace the broad
"except Exception as e" around the tvm_ffi probe with a narrower handler (or at
minimum log/append the full exception details to results["issues"] instead of
swallowing) so callers can see the real failure; refer to the function name
check_jit_environment, the results dict keys ("nvcc_ok", "issues"), and the
tvm_ffi probe block to make the changes.
- Around line 251-267: Add the required decorators to the public DLLM APIs:
import and apply `@flashinfer_api` to BatchBlockExtendPagedOffsetWrapper,
BatchBlockExtendRaggedOffsetWrapper, batch_block_extend_cascade, and
sglang_style_cascade_attention so these high-level Python APIs have crash-safe
logging; for the FA3/architecture-gated code paths (the functions/constructors
that require FA3), also import and apply `@backend_requirement` with the
appropriate backend constraint (e.g., the FA3 identifier used elsewhere in the
repo) to those symbols; ensure imports for flashinfer_api and
backend_requirement are added at the top if missing and keep decorator placement
immediately above the class or def declarations for the listed symbols.

In `@flashinfer/dllm/block_extend.py`:
- Around line 235-245: Add the missing decorators to make these public APIs
crash-safe and declare architecture requirements: annotate
block_extend_attention_with_offset and block_extend_cascade with `@flashinfer_api`
to enable crash-safe logging, and add `@backend_requirement` with the appropriate
backend tag(s) used for FA3 dispatch (e.g., `@backend_requirement`("fa3") or the
specific backend string your dispatch expects) to the same functions; ensure the
decorators are imported at the top of flashinfer.dllm (or block_extend.py) if
not already present and keep any existing signature and defaults intact so only
decorators are added.

---

Duplicate comments:
In `@flashinfer/dllm/batch_block_extend.py`:
- Around line 179-186: The URI builder _get_batch_be_module_uri currently only
encodes head_dim and a limited dtype map, causing collisions across different
idtype values and misrepresenting FP8 types; update _get_batch_be_module_uri to
include the idtype (e.g., torch.int32 vs torch.int64) in the returned string and
expand the dtype mapping to explicitly handle all supported dtypes (including
FP8 variants) so each unique combination of head_dim, element dtype, and idtype
produces a distinct URI; reference the function name _get_batch_be_module_uri
and ensure the returned string format embeds both idtype and the normalized
dtype token.

In `@flashinfer/dllm/block_extend.py`:
- Around line 200-204: The environment check currently uses
os.environ.get("FLASHINFER_DISABLE_JIT") which treats "0" as truthy and
incorrectly disables JIT; update the condition to explicitly compare the env
value to "1" (i.e., FLASHINFER_DISABLE_JIT == "1") so only an explicit "1"
disables JIT, keeping the existing RuntimeError message that references
_get_aot_path(uri) unchanged.

In `@tests/attention/test_dllm_vs_flex_attention.py`:
- Around line 99-104: The batch-indexed path isn't tested with heterogeneous
per-request offsets: update the test in
tests/attention/test_dllm_vs_flex_attention.py to exercise per-batch q_offsets
and non-zero kv_offsets by fixing make_block_extend_mask_mod to actually use the
batch index parameter b (i.e., compute q_global from q_offset[b] rather than a
scalar), create a heterogeneous torch tensor for q_block_expanding_offset (not
torch.full) and set a non-zero kv_block_expanding_offset/kv_offsets, then add an
assertion comparing block_extend_mask results (or end-to-end attention outputs)
between the dllm implementation and the flex/reference implementation so the new
q_block_expanding_offset and kv_block_expanding_offset plumbing is covered.
- Around line 20-56: Top-level imports in test_dllm_vs_flex_attention.py cause
prints during pytest collection on unsupported/non-CUDA hosts; wrap/gate module
import with pytest.importorskip or an explicit arch check using flashinfer.utils
functions to prevent collection printing. Concretely, at the top of the module
call pytest.importorskip("flashinfer") or call
flashinfer.utils.is_sm90a_supported()/is_sm80a_supported() (and skip via
pytest.skip if neither is true) before importing single_prefill_with_kv_cache,
BatchBlockExtendPagedOffsetWrapper, BatchBlockExtendRaggedOffsetWrapper,
flex_attention, create_block_mask so the file is skipped quietly on unsupported
runners and no module-level prints occur.
- Around line 1089-1093: Replace the bare "except: pass" around the CUDA-graph
capture so failures are not silently dropped: wrap the
benchmark_with_cuda_graph(_run_fi, ...) call in "except Exception as e:" and
print or log the exception (including context) and set entry["fi_cg_ms"] to a
sentinel (e.g., None or an "error" value) so the failure is recorded; apply the
same change to the other identical patterns (the workspace probe capture and the
Flex path) to ensure all capture exceptions are surfaced rather than ignored.

---

Nitpick comments:
In `@flashinfer/dllm/__init__.py`:
- Around line 16-19: The three private symbols _check_batch_be_aot_available,
_get_batch_be_aot_path, and _get_batch_be_module_uri are imported into
flashinfer.dllm.__init__ but never exported or used; either remove them from the
import list to eliminate dead imports or intentionally expose them by adding
their names to __all__ (or re-exporting under public names) so they become part
of the package surface; update the import statement and/or the __all__ list in
__init__.py accordingly to keep imports and public API consistent.

In `@flashinfer/dllm/block_extend.py`:
- Around line 292-296: The current auto backend selection uses
is_sm90a_supported(q.device) and then directly calls
get_block_extend_module_with_offset, which can hard-fail if the chosen kernel
isn't actually available; change the logic to consult the availability-aware
selector used in batch_block_extend.py (e.g., call select_best_backend or its
_paged variant) instead of the simple is_sm90a_supported check, or wrap
get_block_extend_module_with_offset in a try/fallback that falls back to the
other backend on failure; update the backend local variable from that selector
and then call get_block_extend_module_with_offset(head_dim=head_dim,
dtype=dtype, backend=backend, device=q.device).

In `@tests/attention/test_dllm_vs_flex_attention.py`:
- Around line 1073-1074: The nested function _run_fi closes over loop variables
(wrapper, q, k, v) which can change across iterations; update the definition of
_run_fi to bind those values as default parameters (e.g., def
_run_fi(_w=wrapper, _q=q, _k=k, _v=v): ...) and call _w.run(_q, _k, _v) so the
callable captures the current iteration's bindings instead of referencing them
by closure.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: dd7c2332-0df7-4061-8e4f-2809936d2756

📥 Commits

Reviewing files that changed from the base of the PR and between 2bcbb68 and 5ec571a.

📒 Files selected for processing (7)
  • csrc/single_prefill_sm90_customize_config.jinja
  • flashinfer/dllm/__init__.py
  • flashinfer/dllm/batch_block_extend.py
  • flashinfer/dllm/block_extend.py
  • include/flashinfer/attention/default_prefill_params.cuh
  • include/flashinfer/utils.cuh
  • tests/attention/test_dllm_vs_flex_attention.py
🚧 Files skipped from review as they are similar to previous changes (3)
  • csrc/single_prefill_sm90_customize_config.jinja
  • include/flashinfer/utils.cuh
  • include/flashinfer/attention/default_prefill_params.cuh

Comment on lines +34 to +64
def check_jit_environment() -> dict:
"""Check if JIT compilation environment is working properly"""
results = {
"tvm_ffi_ok": False,
"device_guard_ok": False,
"nvcc_ok": False,
"issues": [],
}

try:
import tvm_ffi
results["tvm_ffi_ok"] = True
include_path = tvm_ffi.libinfo.find_include_path()
device_guard_path = Path(include_path) / "tvm" / "ffi" / "extra" / "cuda" / "device_guard.h"
results["device_guard_ok"] = device_guard_path.exists()
if not results["device_guard_ok"]:
results["issues"].append(f"Missing TVM header: {device_guard_path}")
except ImportError:
results["issues"].append("tvm_ffi package not installed")
except Exception as e:
results["issues"].append(f"Error checking tvm_ffi: {e}")

import subprocess
try:
result = subprocess.run(["nvcc", "--version"], capture_output=True, text=True)
results["nvcc_ok"] = result.returncode == 0
except FileNotFoundError:
results["nvcc_ok"] = False
results["issues"].append("nvcc not found in PATH")

return results
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

check_jit_environment relies on PATH-resolved nvcc and a blind except Exception.

Calling subprocess.run(["nvcc", "--version"]) is a partial-executable-path invocation (S607); on systems where CUDA is present but nvcc is not in PATH (common in production containers that ship only runtime libs) this will return jit_available=False and drop the wrapper into a confusing "no kernel available" error even when AOT would have worked. Also, except Exception around the tvm_ffi probe (line 53) swallows real errors. Consider: (1) resolving nvcc via CUDA_HOME/shutil.which with a helpful message, and (2) narrowing or at least logging the exception.

🧰 Tools
🪛 Ruff (0.15.10)

[warning] 53-53: Do not catch blind exception: Exception

(BLE001)


[error] 58-58: Starting a process with a partial executable path

(S607)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/dllm/batch_block_extend.py` around lines 34 - 64,
check_jit_environment currently shells out to a PATH-resolved "nvcc" and uses a
broad except that swallows real errors; update it to locate nvcc via
shutil.which and fallback to CUDA_HOME (e.g., check os.environ["CUDA_HOME"] +
"/bin/nvcc") before declaring nvcc missing and include a clear issue message
when nvcc is found/not found, and replace the broad "except Exception as e"
around the tvm_ffi probe with a narrower handler (or at minimum log/append the
full exception details to results["issues"] instead of swallowing) so callers
can see the real failure; refer to the function name check_jit_environment, the
results dict keys ("nvcc_ok", "issues"), and the tvm_ffi probe block to make the
changes.

Comment on lines +251 to +267
class BatchBlockExtendPagedOffsetWrapper:
"""Batch Block Extend Paged Attention with Offset Support"""

def __init__(
self,
float_workspace_buffer: torch.Tensor,
kv_layout: str = "NHD",
dllm_block_size: int = 256,
use_cuda_graph: bool = False,
qo_indptr_buf: Optional[torch.Tensor] = None,
paged_kv_indptr_buf: Optional[torch.Tensor] = None,
paged_kv_indices_buf: Optional[torch.Tensor] = None,
paged_kv_last_page_len_buf: Optional[torch.Tensor] = None,
q_offsets_buf: Optional[torch.Tensor] = None,
kv_offsets_buf: Optional[torch.Tensor] = None,
backend: str = "auto",
) -> None:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Missing @flashinfer_api / @backend_requirement decorators on public DLLM APIs.

BatchBlockExtendPagedOffsetWrapper, BatchBlockExtendRaggedOffsetWrapper, batch_block_extend_cascade, and sglang_style_cascade_attention are high-level Python APIs exposed from flashinfer.dllm but none carry the @flashinfer_api decorator for crash-safe logging, and the FA3 path is architecture-gated without @backend_requirement. As per coding guidelines: "Use @flashinfer_api decorator on high-level Python APIs" and "Use @backend_requirement decorator on APIs with architecture-specific requirements".

Also applies to: 392-406, 521-539, 608-623

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/dllm/batch_block_extend.py` around lines 251 - 267, Add the
required decorators to the public DLLM APIs: import and apply `@flashinfer_api` to
BatchBlockExtendPagedOffsetWrapper, BatchBlockExtendRaggedOffsetWrapper,
batch_block_extend_cascade, and sglang_style_cascade_attention so these
high-level Python APIs have crash-safe logging; for the FA3/architecture-gated
code paths (the functions/constructors that require FA3), also import and apply
`@backend_requirement` with the appropriate backend constraint (e.g., the FA3
identifier used elsewhere in the repo) to those symbols; ensure imports for
flashinfer_api and backend_requirement are added at the top if missing and keep
decorator placement immediately above the class or def declarations for the
listed symbols.

Comment on lines +313 to +317
jit_kwargs = {
"pos_encoding_mode": 0, "use_sliding_window": False,
"use_logits_soft_cap": False, "use_fp16_qk_reduction": False,
"mask_modes": [0, 1, 2, 3, 4],
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

mask_modes=[0, 1, 2, 3, 4] compiles every mask variant for a wrapper that only uses mode 4.

Both DLLM wrappers always call the inner prefill with mask_mode=MaskMode.BLOCK_EXPANDING.value (Lines 369, 502), but the JIT request instantiates all five mask-mode kernel specializations. This defeats the PR's stated design (mode 4 should be compiled only when DLLM wrappers request it) and inflates build/AOT size and cache footprint for each (head_dim, dtype, backend) combo. Restrict to mask_modes=[4] unless you deliberately want the other modes here.

🛠️ Proposed fix
-            "mask_modes": [0, 1, 2, 3, 4],
+            "mask_modes": [MaskMode.BLOCK_EXPANDING.value],

Also applies to: 450-454

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/dllm/batch_block_extend.py` around lines 313 - 317, The jit_kwargs
currently lists mask_modes=[0,1,2,3,4] which forces compilation of all mask-mode
kernels even though the DLLM wrappers call the inner prefill with
mask_mode=MaskMode.BLOCK_EXPANDING.value; change the mask_modes entry in
jit_kwargs to only include [MaskMode.BLOCK_EXPANDING.value] (or the literal 4)
so only mode 4 is JIT/AOT-compiled; update both occurrences around the
jit_kwargs definitions (the one at lines ~313 and the similar block at ~450-454)
to avoid inflating build/AOT size and cache footprint.

Comment on lines +235 to +245
def block_extend_attention_with_offset(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
dllm_block_size: int,
q_offset: int = 0,
kv_offset: int = 0,
sm_scale: Optional[float] = None,
return_lse: bool = False,
backend: str = "auto",
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Public APIs lack @flashinfer_api decorator.

block_extend_attention_with_offset and block_extend_cascade are exported from flashinfer.dllm as top-level APIs but don't carry @flashinfer_api for crash-safe logging, and FA3 dispatch is architecture-gated without @backend_requirement. As per coding guidelines: "Use @flashinfer_api decorator on high-level Python APIs for crash-safe logging" and "Use @backend_requirement decorator on APIs with architecture-specific requirements".

Also applies to: 314-324

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/dllm/block_extend.py` around lines 235 - 245, Add the missing
decorators to make these public APIs crash-safe and declare architecture
requirements: annotate block_extend_attention_with_offset and
block_extend_cascade with `@flashinfer_api` to enable crash-safe logging, and add
`@backend_requirement` with the appropriate backend tag(s) used for FA3 dispatch
(e.g., `@backend_requirement`("fa3") or the specific backend string your dispatch
expects) to the same functions; ensure the decorators are imported at the top of
flashinfer.dllm (or block_extend.py) if not already present and keep any
existing signature and defaults intact so only decorators are added.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

♻️ Duplicate comments (6)
flashinfer/dllm/batch_block_extend.py (4)

311-315: ⚠️ Potential issue | 🟠 Major

Compile only the block-expanding mask specialization here.

Both wrappers always run with MaskMode.BLOCK_EXPANDING.value, but the JIT request still instantiates all five mask modes. This directly expands JIT/AOT cache size for DLLM wrappers.

🛠️ Proposed fix
         jit_kwargs = {
             "pos_encoding_mode": 0, "use_sliding_window": False,
             "use_logits_soft_cap": False, "use_fp16_qk_reduction": False,
-            "mask_modes": [0, 1, 2, 3, 4],
+            "mask_modes": [MaskMode.BLOCK_EXPANDING.value],
         }

Apply this to both paged and ragged wrapper JIT kwargs.

Also applies to: 448-452

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/dllm/batch_block_extend.py` around lines 311 - 315, The jit_kwargs
currently lists all mask modes which forces JIT/AOT to compile five
specializations; change the jit_kwargs in both wrapper configurations (the
jit_kwargs dict used for the paged and ragged wrappers — the occurrences around
the current jit_kwargs and the other block at lines ~448-452) to only include
MaskMode.BLOCK_EXPANDING.value for the "mask_modes" key so the JIT compiles only
the block-expanding specialization; ensure you reference the same jit_kwargs
variable names and import/usage of MaskMode present in this module.

249-265: ⚠️ Potential issue | 🟡 Minor

Add the required decorators to exported DLLM APIs.

These wrapper classes and helper functions are public high-level APIs exported from flashinfer.dllm, but they do not carry @flashinfer_api; the backend-dispatching APIs should also declare backend requirements for architecture tracking.

As per coding guidelines, flashinfer/**/*.py: Use @flashinfer_api decorator on high-level Python APIs and @backend_requirement decorator on APIs with architecture-specific requirements.

Also applies to: 390-404, 519-537, 606-621

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/dllm/batch_block_extend.py` around lines 249 - 265, The exported
wrapper class BatchBlockExtendPagedOffsetWrapper (and the other public wrapper
classes and helper functions at the noted locations) are missing the required
decorators; add the `@flashinfer_api` decorator to each high-level API (e.g.,
class BatchBlockExtendPagedOffsetWrapper) and add `@backend_requirement` where the
API enforces architecture-specific backends (APIs that accept a backend: str
parameter or perform backend dispatching) so the backend-dispatching is tracked;
ensure you import these decorators and apply `@backend_requirement` to the
constructors or functions that accept the backend argument (e.g., any __init__
or factory functions with backend: str) while leaving non-backend-specific
helpers only with `@flashinfer_api`.

561-571: ⚠️ Potential issue | 🟠 Major

Do not default prefix requests to zero offsets.

When a prefix exists, q_offsets=None makes the current chunk start at global position 0, so the block-expanding mask is wrong for any nonzero prefix length. Derive per-request prefix lengths from paged metadata or require explicit offsets.

🛠️ Proposed fix
     if q_offsets is None:
         if has_prefix:
-            import warnings
-            warnings.warn(
-                "q_offsets is None but prefix exists. Block extend mask may be incorrect "
-                "if prefix length is nonzero. Consider passing explicit q_offsets.",
-                stacklevel=2,
-            )
-        q_offsets = torch.zeros(batch_size, dtype=torch.int32, device=device)
+            q_offsets = (
+                (paged_kv_indptr[1:] - paged_kv_indptr[:-1] - 1) * page_size
+                + paged_kv_last_page_len
+            ).to(device=device, dtype=qo_indptr.dtype)
+        else:
+            q_offsets = torch.zeros(batch_size, dtype=qo_indptr.dtype, device=device)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/dllm/batch_block_extend.py` around lines 561 - 571, The current
code defaults q_offsets to zeros when q_offsets is None (and aliases kv_offsets
to q_offsets), which breaks the block-extend mask whenever has_prefix is True;
instead, when has_prefix is True derive per-request prefix lengths from the
paged metadata and populate q_offsets/kv_offsets accordingly (or raise/require
explicit offsets) rather than assigning zeros; update the logic around
q_offsets, kv_offsets, and has_prefix in batch_block_extend.py so q_offsets is
computed from the request/page metadata for each of the batch_size entries on
the given device and dtype, and only fall back to a true zero-offset alias when
you have verified there is no prefix for all requests.

177-184: ⚠️ Potential issue | 🟠 Major

Include idtype in the module URI and reject unsupported index dtypes.

The generated ABI depends on idtype, but the URI only includes head_dim and data dtype. Replanning the same (head_dim, dtype) with int32 vs int64 can collide on the same registered module name; dtype_map_for_idtype() also aliases unknown dtypes to int32_t.

🛠️ Proposed fix
-def _get_batch_be_module_uri(head_dim: int, dtype: torch.dtype) -> str:
+def _get_batch_be_module_uri(
+    head_dim: int,
+    dtype: torch.dtype,
+    idtype: torch.dtype = torch.int32,
+) -> str:
     _dtype_map = {torch.float16: "fp16", torch.bfloat16: "bf16"}
+    _idtype_map = {torch.int32: "i32", torch.int64: "i64"}
     if dtype not in _dtype_map:
         raise ValueError(
             f"Unsupported dtype {dtype} for Block Extend Attention. "
             f"Supported: {list(_dtype_map.keys())}"
         )
-    return f"batch_prefill_block_expanding_hd{head_dim}_{_dtype_map[dtype]}"
+    if idtype not in _idtype_map:
+        raise ValueError(
+            f"Unsupported idtype {idtype} for Block Extend Attention. "
+            f"Supported: {list(_idtype_map.keys())}"
+        )
+    return (
+        f"batch_prefill_block_expanding_hd{head_dim}_"
+        f"{_dtype_map[dtype]}_{_idtype_map[idtype]}"
+    )
@@
-def dtype_map_for_idtype(idtype: torch.dtype) -> str:
-    return {torch.int32: "int32_t", torch.int64: "int64_t"}.get(idtype, "int32_t")
+def dtype_map_for_idtype(idtype: torch.dtype) -> str:
+    _idtype_map = {torch.int32: "int32_t", torch.int64: "int64_t"}
+    if idtype not in _idtype_map:
+        raise ValueError(f"Unsupported idtype {idtype}")
+    return _idtype_map[idtype]

Then pass idtype at each _get_batch_be_module_uri(...) call site.

Also applies to: 304-308, 386-387, 441-445

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/dllm/batch_block_extend.py` around lines 177 - 184, The module URI
builder _get_batch_be_module_uri currently only includes head_dim and tensor
dtype, which can collide across different index dtypes; update
_get_batch_be_module_uri to accept an idtype (torch.dtype for indices), extend
the internal mapping (use dtype_map_for_idtype or same mapping logic) to map
supported index dtypes (e.g., torch.int32->"i32", torch.int64->"i64") and raise
ValueError for unsupported index dtypes, then include the idtype token in the
returned URI string (e.g.,
..._hd{head_dim}_{_dtype_map[input_dtype]}_{idtype_token}). Finally, update all
call sites of _get_batch_be_module_uri (and any other variants at the other
locations mentioned) to pass the idtype argument so the generated ABI name is
unique per (head_dim, dtype, idtype).
flashinfer/dllm/block_extend.py (1)

235-245: ⚠️ Potential issue | 🟡 Minor

Add the required public API decorators.

block_extend_attention_with_offset and block_extend_cascade are exported high-level APIs but still lack crash-safe API logging, and their FA2/FA3 backend dispatch should declare the architecture-gated backend requirement.

As per coding guidelines, flashinfer/**/*.py: Use @flashinfer_api decorator on high-level Python APIs and @backend_requirement decorator on APIs with architecture-specific requirements.

Also applies to: 314-324

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/dllm/block_extend.py` around lines 235 - 245, The high-level APIs
block_extend_attention_with_offset and block_extend_cascade need the public API
and architecture-gated backend decorators: add `@flashinfer_api` above each
function and add `@backend_requirement`("fa2","fa3") (or the project-specific
backend_requirement form used elsewhere) to declare FA2/FA3 architecture
requirements for the backend-dispatching paths; ensure the decorators are
imported from the flashinfer decorator module if not already, and place them
immediately above the def for block_extend_attention_with_offset and the
block_extend_cascade function (the other exported API at the 314-324 region) so
crash-safe API logging and backend gating are applied.
tests/attention/test_dllm_vs_flex_attention.py (1)

61-104: ⚠️ Potential issue | 🟠 Major

Exercise heterogeneous q_offsets and nonzero kv_offset in the reference path.

The reference and Flex mask still model only a scalar q_offset and implicit kv_offset=0, while the batch benchmark uses identical offsets for every request. That leaves the new per-request offset and current-chunk kv_block_expanding_offset plumbing largely unvalidated.

🛠️ Proposed direction
 def compute_block_extend_reference(
     q: torch.Tensor,
     k: torch.Tensor,
     v: torch.Tensor,
     dllm_block_size: int,
     q_offset: int = 0,
+    kv_offset: int = 0,
     sm_scale: float = None,
 ) -> torch.Tensor:
@@
     q_pos = torch.arange(qo_len, device=device) + q_offset
-    k_pos = torch.arange(kv_len, device=device)
+    k_pos = torch.arange(kv_len, device=device) + kv_offset
@@
-def make_block_extend_mask_mod(dllm_block_size: int, q_offset: int = 0):
+def make_block_extend_mask_mod(dllm_block_size: int, q_offsets, kv_offsets=None):
@@
     def block_extend_mask(b, h, q_idx, kv_idx):
-        q_global = q_idx + q_offset
+        q_global = q_idx + q_offsets[b]
+        kv_global = kv_idx + (0 if kv_offsets is None else kv_offsets[b])
         q_blk = q_global // dllm_block_size
-        kv_blk = kv_idx // dllm_block_size
+        kv_blk = kv_global // dllm_block_size
         return q_blk >= kv_blk

Also update the batch setup to use non-uniform offsets and add at least one nonzero kv_offset case.

Also applies to: 321-321, 381-384, 518-523

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/attention/test_dllm_vs_flex_attention.py` around lines 61 - 104, The
reference functions only handle a scalar q_offset and assume kv_offset=0, so
update compute_block_extend_reference to accept per-request heterogeneous
q_offsets (e.g., a 1D tensor/array matching batch or sequence of queries) and an
explicit kv_offset parameter, compute q_pos using per-request q_offsets and
k_pos using kv_offset, and build mask_2d accordingly so it mirrors
batched/variable offsets; likewise update make_block_extend_mask_mod (and its
returned block_extend_mask) to accept/close over a kv_offset and support
per-request q_offset (or an indexable q_offset source) so the Flex mask and the
reference use identical offset semantics, and adjust the batch setup to include
non-uniform q_offsets and at least one nonzero kv_offset test case.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/dllm/__init__.py`:
- Around line 21-37: Ruff flags the __all__ list ordering (RUF022); sort the
entries in the __all__ list alphabetically according to the project's convention
so the exported names (e.g., "BatchBlockExtendPagedOffsetWrapper",
"BatchBlockExtendRaggedOffsetWrapper",
"BLOCK_EXTEND_V2_WITH_OFFSET_VARIANT_DECL",
"BLOCK_EXTEND_V3_WITH_OFFSET_VARIANT_DECL",
"BLOCK_EXTEND_V3_WITH_OFFSET_VARIANT_DECL",
"BLOCK_EXTEND_V2_WITH_OFFSET_VARIANT_DECL",
"BLOCK_EXTEND_V3_WITH_OFFSET_VARIANT_DECL",
"BLOCK_EXTEND_V2_WITH_OFFSET_VARIANT_DECL",
"BatchBlockExtendPagedOffsetWrapper", "block_extend_attention_with_offset",
"block_extend_cascade", "batch_block_extend_cascade",
"get_block_extend_module_with_offset", "sglang_style_cascade_attention",
"_BATCH_BE_OFFSET_VARIANT_DECL", "_BATCH_BE_OFFSET_VARIANT_DECL_FA3") are in the
required sorted order; update the __all__ declaration in
flashinfer/dllm/__init__.py to the sorted list so the linter RUF022 is
satisfied.

In `@flashinfer/dllm/batch_block_extend.py`:
- Around line 82-127: select_best_backend currently considers fa3_available even
on non-Hopper GPUs; update it to only allow returning "fa3" when the device
supports SM90A (use is_sm90a_supported(device))—i.e., set fa3_effective =
fa3_available and then if not is_hopper set fa3_effective = False, use
fa3_effective for all selection logic (including the auto path and when
preferred_backend == "fa3" raise a RuntimeError if device is non-Hopper). Apply
the identical change to select_best_backend_paged so FA3 is never chosen or
accepted on non-SM90 devices; ensure device is defaulted to torch.device("cuda")
when None and reuse the same is_hopper check.

In `@flashinfer/dllm/block_extend.py`:
- Around line 368-370: The code currently treats a partially provided prefix
(only k_prefix or only v_prefix) as absent; update the logic in block_extend.py
around the has_prefix computation to validate that either both k_prefix and
v_prefix are provided or neither are—if one is None and the other is not, raise
a clear ValueError (or custom exception) indicating mismatched prefix arguments;
keep the existing prefix_len computation (prefix_len = k_prefix.size(0)) when
both are present. Ensure you reference and modify the has_prefix, k_prefix,
v_prefix, and prefix_len handling so callers fail fast instead of silently
dropping a single-side prefix.

In `@tests/attention/test_dllm_vs_flex_attention.py`:
- Line 241: Multiple plain string print statements in
tests/attention/test_dllm_vs_flex_attention.py were written as f-strings without
placeholders (e.g., print(f"FlashInfer Block Extend vs PyTorch Flex
Attention")); remove the unnecessary 'f' prefix on those 16 instances so they
become regular string literals. Search for all print/assignment/logging lines in
that file that start with f" or f' but contain no braces/format placeholders and
change them to plain "..." or '...' (example symbol to locate: the print call
containing "FlashInfer Block Extend vs PyTorch Flex Attention"). Ensure no
formatting behavior is altered and run linters to confirm Ruff F541 is resolved.

---

Duplicate comments:
In `@flashinfer/dllm/batch_block_extend.py`:
- Around line 311-315: The jit_kwargs currently lists all mask modes which
forces JIT/AOT to compile five specializations; change the jit_kwargs in both
wrapper configurations (the jit_kwargs dict used for the paged and ragged
wrappers — the occurrences around the current jit_kwargs and the other block at
lines ~448-452) to only include MaskMode.BLOCK_EXPANDING.value for the
"mask_modes" key so the JIT compiles only the block-expanding specialization;
ensure you reference the same jit_kwargs variable names and import/usage of
MaskMode present in this module.
- Around line 249-265: The exported wrapper class
BatchBlockExtendPagedOffsetWrapper (and the other public wrapper classes and
helper functions at the noted locations) are missing the required decorators;
add the `@flashinfer_api` decorator to each high-level API (e.g., class
BatchBlockExtendPagedOffsetWrapper) and add `@backend_requirement` where the API
enforces architecture-specific backends (APIs that accept a backend: str
parameter or perform backend dispatching) so the backend-dispatching is tracked;
ensure you import these decorators and apply `@backend_requirement` to the
constructors or functions that accept the backend argument (e.g., any __init__
or factory functions with backend: str) while leaving non-backend-specific
helpers only with `@flashinfer_api`.
- Around line 561-571: The current code defaults q_offsets to zeros when
q_offsets is None (and aliases kv_offsets to q_offsets), which breaks the
block-extend mask whenever has_prefix is True; instead, when has_prefix is True
derive per-request prefix lengths from the paged metadata and populate
q_offsets/kv_offsets accordingly (or raise/require explicit offsets) rather than
assigning zeros; update the logic around q_offsets, kv_offsets, and has_prefix
in batch_block_extend.py so q_offsets is computed from the request/page metadata
for each of the batch_size entries on the given device and dtype, and only fall
back to a true zero-offset alias when you have verified there is no prefix for
all requests.
- Around line 177-184: The module URI builder _get_batch_be_module_uri currently
only includes head_dim and tensor dtype, which can collide across different
index dtypes; update _get_batch_be_module_uri to accept an idtype (torch.dtype
for indices), extend the internal mapping (use dtype_map_for_idtype or same
mapping logic) to map supported index dtypes (e.g., torch.int32->"i32",
torch.int64->"i64") and raise ValueError for unsupported index dtypes, then
include the idtype token in the returned URI string (e.g.,
..._hd{head_dim}_{_dtype_map[input_dtype]}_{idtype_token}). Finally, update all
call sites of _get_batch_be_module_uri (and any other variants at the other
locations mentioned) to pass the idtype argument so the generated ABI name is
unique per (head_dim, dtype, idtype).

In `@flashinfer/dllm/block_extend.py`:
- Around line 235-245: The high-level APIs block_extend_attention_with_offset
and block_extend_cascade need the public API and architecture-gated backend
decorators: add `@flashinfer_api` above each function and add
`@backend_requirement`("fa2","fa3") (or the project-specific backend_requirement
form used elsewhere) to declare FA2/FA3 architecture requirements for the
backend-dispatching paths; ensure the decorators are imported from the
flashinfer decorator module if not already, and place them immediately above the
def for block_extend_attention_with_offset and the block_extend_cascade function
(the other exported API at the 314-324 region) so crash-safe API logging and
backend gating are applied.

In `@tests/attention/test_dllm_vs_flex_attention.py`:
- Around line 61-104: The reference functions only handle a scalar q_offset and
assume kv_offset=0, so update compute_block_extend_reference to accept
per-request heterogeneous q_offsets (e.g., a 1D tensor/array matching batch or
sequence of queries) and an explicit kv_offset parameter, compute q_pos using
per-request q_offsets and k_pos using kv_offset, and build mask_2d accordingly
so it mirrors batched/variable offsets; likewise update
make_block_extend_mask_mod (and its returned block_extend_mask) to accept/close
over a kv_offset and support per-request q_offset (or an indexable q_offset
source) so the Flex mask and the reference use identical offset semantics, and
adjust the batch setup to include non-uniform q_offsets and at least one nonzero
kv_offset test case.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: c8dc3936-3642-48fa-b28c-0da34a217dee

📥 Commits

Reviewing files that changed from the base of the PR and between 5ec571a and 93965f5.

📒 Files selected for processing (7)
  • csrc/single_prefill_sm90_customize_config.jinja
  • flashinfer/dllm/__init__.py
  • flashinfer/dllm/batch_block_extend.py
  • flashinfer/dllm/block_extend.py
  • include/flashinfer/attention/default_prefill_params.cuh
  • include/flashinfer/utils.cuh
  • tests/attention/test_dllm_vs_flex_attention.py
🚧 Files skipped from review as they are similar to previous changes (3)
  • csrc/single_prefill_sm90_customize_config.jinja
  • include/flashinfer/utils.cuh
  • include/flashinfer/attention/default_prefill_params.cuh

Comment on lines +21 to +37
__all__ = [
# Single Prefill with offset (FA2/FA3 auto-select)
"block_extend_attention_with_offset",
"get_block_extend_module_with_offset",
"BLOCK_EXTEND_V2_WITH_OFFSET_VARIANT_DECL",
"BLOCK_EXTEND_V3_WITH_OFFSET_VARIANT_DECL",
# Cascade + block extend (SGLang style: causal + merge_state)
"block_extend_cascade",
"batch_block_extend_cascade",
"sglang_style_cascade_attention",
# Batch Prefill with offset versions
"BatchBlockExtendPagedOffsetWrapper",
"BatchBlockExtendRaggedOffsetWrapper",
# Batch Offset variant declarations
"_BATCH_BE_OFFSET_VARIANT_DECL",
"_BATCH_BE_OFFSET_VARIANT_DECL_FA3",
]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Sort __all__ to satisfy Ruff.

Ruff flags this list with RUF022; please apply the project’s __all__ sorting convention or the lint step may fail.

🛠️ Proposed fix
 __all__ = [
-    # Single Prefill with offset (FA2/FA3 auto-select)
-    "block_extend_attention_with_offset",
-    "get_block_extend_module_with_offset",
     "BLOCK_EXTEND_V2_WITH_OFFSET_VARIANT_DECL",
     "BLOCK_EXTEND_V3_WITH_OFFSET_VARIANT_DECL",
-    # Cascade + block extend (SGLang style: causal + merge_state)
-    "block_extend_cascade",
-    "batch_block_extend_cascade",
-    "sglang_style_cascade_attention",
-    # Batch Prefill with offset versions
     "BatchBlockExtendPagedOffsetWrapper",
     "BatchBlockExtendRaggedOffsetWrapper",
-    # Batch Offset variant declarations
     "_BATCH_BE_OFFSET_VARIANT_DECL",
     "_BATCH_BE_OFFSET_VARIANT_DECL_FA3",
+    "batch_block_extend_cascade",
+    "block_extend_attention_with_offset",
+    "block_extend_cascade",
+    "get_block_extend_module_with_offset",
+    "sglang_style_cascade_attention",
 ]
🧰 Tools
🪛 Ruff (0.15.10)

[warning] 21-37: __all__ is not sorted

Apply an isort-style sorting to __all__

(RUF022)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/dllm/__init__.py` around lines 21 - 37, Ruff flags the __all__
list ordering (RUF022); sort the entries in the __all__ list alphabetically
according to the project's convention so the exported names (e.g.,
"BatchBlockExtendPagedOffsetWrapper", "BatchBlockExtendRaggedOffsetWrapper",
"BLOCK_EXTEND_V2_WITH_OFFSET_VARIANT_DECL",
"BLOCK_EXTEND_V3_WITH_OFFSET_VARIANT_DECL",
"BLOCK_EXTEND_V3_WITH_OFFSET_VARIANT_DECL",
"BLOCK_EXTEND_V2_WITH_OFFSET_VARIANT_DECL",
"BLOCK_EXTEND_V3_WITH_OFFSET_VARIANT_DECL",
"BLOCK_EXTEND_V2_WITH_OFFSET_VARIANT_DECL",
"BatchBlockExtendPagedOffsetWrapper", "block_extend_attention_with_offset",
"block_extend_cascade", "batch_block_extend_cascade",
"get_block_extend_module_with_offset", "sglang_style_cascade_attention",
"_BATCH_BE_OFFSET_VARIANT_DECL", "_BATCH_BE_OFFSET_VARIANT_DECL_FA3") are in the
required sorted order; update the __all__ declaration in
flashinfer/dllm/__init__.py to the sorted list so the linter RUF022 is
satisfied.

Comment on lines +82 to +127
def select_best_backend(head_dim: int, dtype: torch.dtype, preferred_backend: str = "auto", device: torch.device = None) -> str:
"""Select backend based on kernel availability and compute capability"""
from ..utils import is_sm90a_supported

base_uri = _get_batch_be_module_uri(head_dim, dtype)
fa2_uri = base_uri + "_ragged_offset"
fa3_uri = base_uri + "_ragged_offset_fa3"

fa2_aot, fa2_jit, _ = check_kernel_availability(fa2_uri)
fa3_aot, fa3_jit, _ = check_kernel_availability(fa3_uri)

fa2_available = fa2_aot or fa2_jit
fa3_available = fa3_aot or fa3_jit

if preferred_backend == "auto":
if device is None:
device = torch.device("cuda")
is_hopper = is_sm90a_supported(device)

if is_hopper:
if fa3_available:
return "fa3"
if fa2_available:
return "fa2"
else:
if fa2_available:
return "fa2"
if fa3_available:
return "fa3"

raise RuntimeError(
f"No Block Extend kernel available for head_dim={head_dim}, dtype={dtype}. "
f"FA2: AOT={fa2_aot}, JIT={fa2_jit}; FA3: AOT={fa3_aot}, JIT={fa3_jit}"
)

if preferred_backend == "fa2":
if fa2_available:
return "fa2"
raise RuntimeError(f"FA2 kernel '{fa2_uri}' not available")

if preferred_backend == "fa3":
if fa3_available:
return "fa3"
raise RuntimeError(f"FA3 kernel '{fa3_uri}' not available")

raise ValueError(f"Unknown backend: {preferred_backend}")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Do not select FA3 on non-Hopper devices.

fa3_available only means AOT/JIT exists; it does not mean the current GPU can run FA3. On non-SM90 devices, these selectors can still return "fa3" when FA2 is unavailable or explicitly requested, causing a later Hopper-kernel failure.

🛠️ Proposed fix
     if preferred_backend == "auto":
         if device is None:
             device = torch.device("cuda")
         is_hopper = is_sm90a_supported(device)
         
         if is_hopper:
             if fa3_available:
                 return "fa3"
             if fa2_available:
                 return "fa2"
         else:
             if fa2_available:
                 return "fa2"
-            if fa3_available:
-                return "fa3"
         
         raise RuntimeError(
@@
     if preferred_backend == "fa3":
+        if device is None:
+            device = torch.device("cuda")
+        if not is_sm90a_supported(device):
+            raise RuntimeError("FA3 backend requires SM90/Hopper architecture")
         if fa3_available:
             return "fa3"
         raise RuntimeError(f"FA3 kernel '{fa3_uri}' not available")

Apply the same guard in select_best_backend_paged.

Also applies to: 130-174

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/dllm/batch_block_extend.py` around lines 82 - 127,
select_best_backend currently considers fa3_available even on non-Hopper GPUs;
update it to only allow returning "fa3" when the device supports SM90A (use
is_sm90a_supported(device))—i.e., set fa3_effective = fa3_available and then if
not is_hopper set fa3_effective = False, use fa3_effective for all selection
logic (including the auto path and when preferred_backend == "fa3" raise a
RuntimeError if device is non-Hopper). Apply the identical change to
select_best_backend_paged so FA3 is never chosen or accepted on non-SM90
devices; ensure device is defaulted to torch.device("cuda") when None and reuse
the same is_hopper check.

Comment on lines +368 to +370
has_prefix = k_prefix is not None and v_prefix is not None
prefix_len = k_prefix.size(0) if has_prefix else 0

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Reject partially specified prefix KV pairs.

If callers pass only k_prefix or only v_prefix, has_prefix becomes false and the provided prefix is silently ignored, returning current-chunk-only attention. Fail fast instead.

🛠️ Proposed fix
-    has_prefix = k_prefix is not None and v_prefix is not None
+    if (k_prefix is None) != (v_prefix is None):
+        raise ValueError("k_prefix and v_prefix must be provided together")
+    has_prefix = k_prefix is not None
     prefix_len = k_prefix.size(0) if has_prefix else 0
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/dllm/block_extend.py` around lines 368 - 370, The code currently
treats a partially provided prefix (only k_prefix or only v_prefix) as absent;
update the logic in block_extend.py around the has_prefix computation to
validate that either both k_prefix and v_prefix are provided or neither are—if
one is None and the other is not, raise a clear ValueError (or custom exception)
indicating mismatched prefix arguments; keep the existing prefix_len computation
(prefix_len = k_prefix.size(0)) when both are present. Ensure you reference and
modify the has_prefix, k_prefix, v_prefix, and prefix_len handling so callers
fail fast instead of silently dropping a single-side prefix.

q_offset = total_kv_len - qo_len

print(f"\n{'='*80}")
print(f"FlashInfer Block Extend vs PyTorch Flex Attention")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Verify remaining placeholder-free f-strings in this file.
ruff check tests/attention/test_dllm_vs_flex_attention.py --select F541

Repository: flashinfer-ai/flashinfer

Length of output: 10824


Convert placeholder-free f-strings to plain string literals.

These are Ruff F541 lint failures. Remove the f prefix from 16 f-string instances that contain no placeholders across lines 241, 328, 557, 647, 656, 745–747, 827, 860–862, 995–1002, 1145, and 1190.

🧰 Tools
🪛 Ruff (0.15.10)

[error] 241-241: f-string without any placeholders

Remove extraneous f prefix

(F541)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/attention/test_dllm_vs_flex_attention.py` at line 241, Multiple plain
string print statements in tests/attention/test_dllm_vs_flex_attention.py were
written as f-strings without placeholders (e.g., print(f"FlashInfer Block Extend
vs PyTorch Flex Attention")); remove the unnecessary 'f' prefix on those 16
instances so they become regular string literals. Search for all
print/assignment/logging lines in that file that start with f" or f' but contain
no braces/format placeholders and change them to plain "..." or '...' (example
symbol to locate: the print call containing "FlashInfer Block Extend vs PyTorch
Flex Attention"). Ensure no formatting behavior is altered and run linters to
confirm Ruff F541 is resolved.

- Validate dllm_block_size > 0 to reject zero and negative values
- Raise ValueError on unsupported dtype instead of silent fallback to fp16
- Preserve user's preferred backend across wrapper re-creation
- Track idtype to correctly invalidate plan when index dtype changes
- Defer backend auto-selection to wrappers instead of pre-resolving in cascade
- Warn when q_offsets is None but prefix exists in cascade attention
- Pass device to FA3 SM90 check and include device in module cache key
- Remove unused logits_soft_cap parameter from sglang_style_cascade_attention
- Fix causal=True comment to causal=False in sglang_style_cascade_attention
- Fix docstring function names (block_expanding_* -> block_extend_*)
- Add assert statements to test correctness checks
- Rename benchmark functions from test_* to bench_* to avoid pytest collection
- Fix missing trailing newline in .cuh and .jinja files
@fdz-1999 fdz-1999 force-pushed the feature/block-extend branch from 93965f5 to 4284113 Compare April 21, 2026 13:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants