feat(dllm): add Block Extend Attention for Diffusion LLM#2722
feat(dllm): add Block Extend Attention for Diffusion LLM#2722fdz-1999 wants to merge 12 commits intoflashinfer-ai:mainfrom
Conversation
…er-ai#2086)" This reverts commit 9a79b78.
…lease CI Merge branch fa2-fa3-opt of git@code.alipay.com:deep-xpu/flashinfer.git into main https://code.alipay.com/deep-xpu/flashinfer/pull_requests/6 Reviewed-by: 明泓 <mingliang.gml@antgroup.com> * feat(dllm,ci): add Block Expanding Attention & PyPI release CI * build: add date suffix to ant-deepxpu-flashinfer-python version (0.5.3.20260202) * fix(jit): compare base version only to allow date/cuda suffix * drop clone api * build flashinfer-jit-cache
…ion blockwise mask benchmark
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds block‑expanding mask mode and per‑batch DLLM block offsets across CUDA/C++ kernels, parameter structs, JIT/AOT generation, and Python DLLM wrappers; implements offset‑aware block‑extend attention (single/batch, ragged/paged), cascade helpers, and a comprehensive FlashInfer vs Flex benchmark. Changes
Sequence Diagram(s)sequenceDiagram
participant User as User Code
participant Cascade as block_extend_cascade
participant Ragged as Ragged Kernel (current chunk)
participant Paged as Paged Kernel (prefix)
participant Device as CUDA Device
User->>Cascade: q, k_current, v_current, k_prefix?, v_prefix?, offsets
Cascade->>Ragged: Stage1: run block‑extend (q_offsets, kv_offsets)
Ragged->>Device: execute kernel (block‑expanding mask)
Device-->>Ragged: O1, LSE1
Ragged-->>Cascade: return O1, LSE1
Cascade->>Paged: Stage2: run prefix attention (if present)
Paged->>Device: execute kernel (fully visible / paged)
Device-->>Paged: O2, LSE2
Paged-->>Cascade: return O2, LSE2
Cascade->>Device: Stage3: merge O1 + O2 using LSEs
Device-->>Cascade: merged output
Cascade-->>User: final attention output
sequenceDiagram
participant User as User Code
participant Wrapper as BatchBlockExtendPagedOffsetWrapper
participant Backend as Backend Selector
participant Cache as AOT/JIT Module Cache
participant Kernel as Compiled Kernel / CUDA
User->>Wrapper: plan(indptrs..., backend="auto", mask_mode?)
Wrapper->>Backend: select_best_backend(head_dim, dtype)
Backend-->>Wrapper: backend (FA2/FA3)
Wrapper->>Cache: get_or_build_module(backend, head_dim, dtype, mask_modes)
Cache->>Cache: check AOT / generate JIT module
Cache-->>Wrapper: compiled module
User->>Wrapper: run(q, kv_cache, offsets, sm_scale)
Wrapper->>Kernel: invoke compiled kernel with per‑batch offsets
Kernel->>Kernel: execute on device
Kernel-->>Wrapper: results (o, lse?)
Wrapper-->>User: outputs
Estimated code review effort🎯 5 (Critical) | ⏱️ ~120 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a specialized Block Extend Attention mechanism tailored for Diffusion LLMs, significantly improving attention computation efficiency through tile-level skip optimizations. It provides comprehensive support for both single-request and batched operations, accommodating ragged and paged KV cache layouts, and is designed for seamless integration with JIT and AOT compilation workflows. The changes enhance the system's ability to handle complex attention patterns with improved performance. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces Block Extend Attention for Diffusion LLMs, a significant new feature. The changes are comprehensive, adding support for single-request and batch processing with both ragged and paged KV caches, and including JIT/AOT compilation capabilities. The implementation is well-structured, with clear separation of concerns between Python wrappers, JIT logic, and CUDA kernels for different architectures. The inclusion of a thorough benchmark against PyTorch's flex_attention is also a great addition. My review focuses on a potential correctness issue in a helper function and a small opportunity to improve code clarity in one of the CUDA kernels.
| ragged_wrapper.plan( | ||
| qo_indptr=qo_indptr, kv_indptr=kv_curr_indptr, | ||
| num_qo_heads=num_qo_heads, num_kv_heads=num_kv_heads, | ||
| head_dim_qk=head_dim, head_dim_vo=head_dim, q_data_type=q.dtype, causal=False, |
There was a problem hiding this comment.
There is a discrepancy between the comment on line 638, which states causal=True, and the code here, which passes causal=False. For a cascade attention pattern, the attention for the current chunk is typically causal. If the intention is to have causal attention for this stage, this should be corrected. If non-causal attention is intended, the comment on line 638 should be updated to avoid confusion.
| head_dim_qk=head_dim, head_dim_vo=head_dim, q_data_type=q.dtype, causal=False, | |
| head_dim_qk=head_dim, head_dim_vo=head_dim, q_data_type=q.dtype, causal=True, |
| bool needs_mask = (MASK_MODE == MaskMode::kCustom) || | ||
| (MASK_MODE == MaskMode::kBlockExpanding && iter >= mask_iteration) || | ||
| (iter >= mask_iteration || iter < window_iteration); |
There was a problem hiding this comment.
The condition for needs_mask is slightly redundant. The term (MASK_MODE == MaskMode::kBlockExpanding && iter >= mask_iteration) is already covered by the subsequent (iter >= mask_iteration) in the OR chain. Simplifying this expression will improve code readability without changing the logic.
bool needs_mask = (MASK_MODE == MaskMode::kCustom) || (iter >= mask_iteration || iter < window_iteration);
There was a problem hiding this comment.
Actionable comments posted: 15
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
include/flashinfer/attention/hopper/sparse_mainloop.cuh (1)
159-208:⚠️ Potential issue | 🔴 CriticalHandle
num_kv_tiles == 0before the sparse prefetch path runs.The new block-extend bound can shrink
num_kv_tilesto0. Here that becomeskv_tile_idx = -1, and the very firstprefetch_kv_offset(kv_tile_idx, true)/load_kv_with_gather(..., kv_tile_idx, ...)sequence will read the page table with a negative tile index.Also applies to: 242-243, 376-382
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@include/flashinfer/attention/hopper/sparse_mainloop.cuh` around lines 159 - 208, get_num_kv_tiles can return 0 under BLOCK_EXPANDING, which leads to kv_tile_idx == -1 and then a negative index passed into prefetch_kv_offset and load_kv_with_gather; update the caller(s) that compute kv_tile_idx from get_num_kv_tiles (and any code paths around lines referenced) to check for num_kv_tiles == 0 and early-skip the sparse prefetch/load sequence, or clamp/skip any prefetch_kv_offset(kv_tile_idx, ...) and load_kv_with_gather(..., kv_tile_idx, ...) calls when kv_tile_idx < 0; locate calls by the symbols get_num_kv_tiles, kv_tile_idx, prefetch_kv_offset, and load_kv_with_gather and add a guard that prevents negative tile indices from reaching the page-table access.
🧹 Nitpick comments (1)
flashinfer/__init__.py (1)
23-23: Re-export the new DLLM entry points at the package root.Importing the submodule here makes
flashinfer.dllm.*available, but the new public wrappers/functions still are not top-levelflashinfer.*exports like the rest of the package surface. Please add explicit imports for the public DLLM APIs in this file as well. As per coding guidelines "flashinfer/__init__.py: Export all public operations inflashinfer/__init__.pyafter implementing."🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/__init__.py` at line 23, Replace the current submodule-only import with explicit re-exports of the DLLM public API: import the public symbols from .dllm (e.g., the public classes/functions in that module such as DLLMClient, run_inference, load_model — or the actual names defined in dllm) via "from .dllm import <PublicName1>, <PublicName2>, ..." and add those names to the package __all__ list so they become top-level flashinfer.* exports; keep the module-level alias (dllm) if you still need it, but ensure all public DLLM symbols are explicitly imported and included in __all__ in flashinfer/__init__.py.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/dllm/batch_block_extend.py`:
- Around line 263-264: The current assertion allows dllm_block_size==0 because
(0 & -1)==0; update the check to explicitly reject zero and require a positive
power of two: replace the existing assert that tests (dllm_block_size &
(dllm_block_size - 1)) == 0 with a compound check that dllm_block_size > 0 and
(dllm_block_size & (dllm_block_size - 1)) == 0, and raise an informative
AssertionError/ValueError if it fails; apply the same change to the other
occurrence in this file (the second assertion at the later block).
- Around line 611-613: The helper is constructing stage-1 ragged attention with
causal=False and drops the logits_soft_cap parameter; update the ragged-wrapper
construction(s) so stage 1 uses causal=True to enforce current-chunk (causal +
merge) attention, and thread the logits_soft_cap argument through to the helper
calls instead of discarding it; specifically, in batch_block_extend.py adjust
the stage-1 ragged/planning call(s) that currently pass causal=False to
causal=True and ensure the helper invocations (the ones near the signature
containing logits_soft_cap, return_lse, backend and the similar block at lines
~638-659) forward logits_soft_cap into the downstream helper/function that
applies the soft cap.
- Around line 179-181: The URI generation in _get_batch_be_module_uri only
encodes head_dim and a coarse dtype string, causing different kernel ABIs (e.g.,
idtype variants like int64 or fp8) to collide; update _get_batch_be_module_uri
to include both the element dtype and the index/id dtype (idtype) in the
returned string (e.g., use explicit mappings for
torch.float16/torch.bfloat16/torch.int64/FP8 aliases or use dtype.name and
idtype.name), and ensure any other URI/cache identity builders and the wrapper
recreation check also incorporate idtype (or explicitly reject unsupported
dtypes up front) so each specialization yields a unique flashinfer::{uri}_* name
and the recreation logic compares both dtype and idtype.
- Around line 560-563: In batch_block_extend_cascade(), when q_offsets or
kv_offsets are None the code currently defaults both to zero which is incorrect
when has_prefix is true; change this to derive per-request global offsets from
the paged-prefix metadata (the per-request prefix length stored in the paged
prefix structure used by the function) instead of using torch.zeros, so that
q_offsets and kv_offsets reflect each request's prefix length (and block-aligned
adjustments) before the two-stage extension; alternatively validate and require
the caller to supply q_offsets/kv_offsets and raise a clear error if they are
omitted when has_prefix is true. Ensure you update the logic referencing
q_offsets, kv_offsets, has_prefix, and the paged prefix metadata in
batch_block_extend_cascade to use the computed per-request offsets.
In `@flashinfer/dllm/block_extend.py`:
- Around line 221-245: The FA3 capability check in
get_block_extend_module_with_offset uses a hardcoded torch.device("cuda") which
breaks mixed-arch multi-GPU setups; update get_block_extend_module_with_offset
to accept a device (torch.device or device-like) parameter and use that device
when calling is_sm90a_supported, and update any callers (notably
block_extend_attention_with_offset) to pass the q/kv tensor's device through
into get_block_extend_module_with_offset; also apply the same
device-threading/fix at the other occurrence referenced (around the second call
at line ~360) so all FA3 checks use the intended device rather than the default
CUDA device.
- Around line 131-136: The function _get_dtype_str currently maps unknown dtypes
to "fp16", causing wrong module URIs; update _get_dtype_str to explicitly map
all supported dtypes (e.g., torch.float16 -> "fp16", torch.bfloat16 -> "bf16",
torch.float32 -> "fp32", and any FP8 dtype used in this project -> the correct
string such as "fp8") and make the default case raise a ValueError (or return a
distinct sentinel like "unknown") instead of aliasing to "fp16"; apply the same
explicit mapping/failure behavior to the other similar helper usages referenced
in the diff (the other _get_dtype_str-like mappings at the other locations) so
that module names are unique per actual dtype and FP8 does not resolve to the
FP16 specialization.
In `@flashinfer/prefill.py`:
- Line 1651: The code accepts mask_mode in plan/run but doesn't enforce backend
support: detect and reject unsupported MaskMode values early (in plan and run)
or filter backends that cannot handle them; specifically, when mask_mode ==
MaskMode.BLOCK_EXPANDING.value (or any non-default), ensure you do not route to
the cudnn path that uses self._causal nor to the trtllm-gen paged path that
ignores mask_mode—validate mask_mode against the selected backend and either
remove unsupported backends from backend selection or raise an error before
continuing (update the functions that accept the mask_mode parameter and the
backend selection logic to perform this check).
In `@include/flashinfer/attention/hopper/mainloop.cuh`:
- Around line 141-188: get_num_kv_tiles can return 0 when
kv_block_expanding_offset pushes the valid KV range past the query block,
causing later code in load() to compute kv_tile_idx = -1 and perform invalid
tile reads via tKgK/tVgV; fix by adding a guard after computing num_kv_tiles (or
before the first copy() in load()) that checks for zero and early-returns or
skips scheduling that tile so kv_tile_idx is never -1. Specifically, in the
caller load() (or right after calling get_num_kv_tiles) detect num_kv_tiles == 0
and do an early return/no-op for that q_tile_idx, or in the scheduling loop
ensure you don't iterate when get_num_kv_tiles(...) == 0; this prevents tKgK(_,
kv_tile_idx) and tVgV(_, kv_tile_idx) from being invoked with an invalid index
and avoids the invalid tile access.
In `@include/flashinfer/attention/prefill.cuh`:
- Around line 1451-1457: The block-expanding iteration bound computed in the
MASK_MODE == MaskMode::kBlockExpanding branches only passes q_offset into
block_expanding_num_iterations but the later legality checks use both q_offset
and kv_offset; update the call to compute kv_offset (e.g., call
params.get_kv_block_expanding_offset(batch_idx) or the appropriate getter) and
pass that kv_offset into block_expanding_num_iterations so the loop bounds
exclude fully-masked KV tiles and avoid poisoning update_mdo_states(); make the
same change for the other block-expanding call sites that mirror this logic (the
other MaskMode::kBlockExpanding branches).
In `@tests/attention/test_dllm_vs_flex_attention.py`:
- Around line 61-104: compute_block_extend_reference and
make_block_extend_mask_mod currently never exercise per-request offsets or
kv_block_expanding_offset because compute_block_extend_reference hardcodes
q_offset (kv_offset) to zero, make_block_extend_mask_mod ignores the batch index
b, and tests use torch.full for q_offsets; update the test helpers so
compute_block_extend_reference accepts and uses per-request q_offset values (and
propagate kv_offset if applicable) and make_block_extend_mask_mod's inner
block_extend_mask uses the batch index b to look up per-sample offsets, then
modify the batch construction in tests to pass heterogeneous q_offsets (not
torch.full) and add cases exercising nonzero kv_block_expanding_offset and
cascade/current-chunk paths so the new plumbing is validated against
single_prefill_with_kv_cache and block_extend_mask behavior.
- Around line 166-210: The benchmarks (functions benchmark_fn and
benchmark_with_cuda_graph) and the direct perf_counter() timing in
test_total_memory_comparison should be replaced to use the repo timing harness
flashinfer.testing.bench_gpu_time() so results use CUPTI with CUDA-event
fallback and remain comparable across the suite; locate usages of benchmark_fn,
benchmark_with_cuda_graph, and the perf_counter() blocks in
test_total_memory_comparison and call bench_gpu_time() (passing the callable and
warmup/bench iteration params) instead of manual perf_counter/CUDAGraph timing,
ensuring any CUDA Graph replay loops are wrapped or adapted to the
bench_gpu_time() callable interface.
- Around line 20-56: The module currently performs CUDA/Hopper-specific work at
import time and must early-skip unsupported GPUs: query
flashinfer.utils.get_compute_capability(), is_sm90a_supported(), and
is_sm100a_supported() at module scope (and check
torch.cuda.is_available()/device count) and if the current GPU is unsupported,
set HAS_FLASHINFER = HAS_FLEX_ATTENTION = False and print a skip message before
attempting any flashinfer or flex_attention imports or CUDA allocations; wrap
the existing flashinfer imports and the flex_attention import logic behind this
guard so functions like single_prefill_with_kv_cache,
BatchBlockExtendPagedOffsetWrapper, BatchBlockExtendRaggedOffsetWrapper,
flex_attention and create_block_mask are only imported when the architecture
checks pass.
- Around line 355-399: The test currently only prints PASS/FAIL for
ragged/paged/flex comparisons (using ragged_pass, paged_pass, flex_pass) which
means failures don't fail the pytest; change each check to assert the diff is
below tol (e.g., assert ragged_diff < tol, assert paged_diff < tol, and assert
flex_diff < tol) or call pytest.fail with a clear message including the diff
when the condition is false so CI fails on regressions; keep the existing diff
variables (ragged_diff, paged_diff, flex_diff) and messages but replace the
print-only behavior with assertions/pytest.fail in the test function.
- Around line 216-229: The benchmark driver functions named with the
pytest-discovered prefix (e.g., test_flashinfer_vs_flex_attention and the other
top-level functions in this file referenced at lines 216–229, 664–672, 752–759,
832–841, 959–968) must be renamed so they do not start with "test_" (or moved
into a non-test harness/module); change their names (for example to
flashinfer_vs_flex_attention_bench or
run_flashinfer_vs_flex_attention_benchmark) or relocate them into a dedicated
benchmarks file to prevent pytest from collecting and executing the heavy
benchmark sweeps during CI, and update any callers/imports accordingly.
---
Outside diff comments:
In `@include/flashinfer/attention/hopper/sparse_mainloop.cuh`:
- Around line 159-208: get_num_kv_tiles can return 0 under BLOCK_EXPANDING,
which leads to kv_tile_idx == -1 and then a negative index passed into
prefetch_kv_offset and load_kv_with_gather; update the caller(s) that compute
kv_tile_idx from get_num_kv_tiles (and any code paths around lines referenced)
to check for num_kv_tiles == 0 and early-skip the sparse prefetch/load sequence,
or clamp/skip any prefetch_kv_offset(kv_tile_idx, ...) and
load_kv_with_gather(..., kv_tile_idx, ...) calls when kv_tile_idx < 0; locate
calls by the symbols get_num_kv_tiles, kv_tile_idx, prefetch_kv_offset, and
load_kv_with_gather and add a guard that prevents negative tile indices from
reaching the page-table access.
---
Nitpick comments:
In `@flashinfer/__init__.py`:
- Line 23: Replace the current submodule-only import with explicit re-exports of
the DLLM public API: import the public symbols from .dllm (e.g., the public
classes/functions in that module such as DLLMClient, run_inference, load_model —
or the actual names defined in dllm) via "from .dllm import <PublicName1>,
<PublicName2>, ..." and add those names to the package __all__ list so they
become top-level flashinfer.* exports; keep the module-level alias (dllm) if you
still need it, but ensure all public DLLM symbols are explicitly imported and
included in __all__ in flashinfer/__init__.py.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: b97072d3-ff07-4a26-b15a-3f3936d9e25f
📒 Files selected for processing (23)
csrc/batch_prefill_customize_config.jinjacsrc/batch_prefill_sm90_customize_config.jinjacsrc/single_prefill_customize_config.jinjacsrc/single_prefill_sm90_customize_config.jinjaflashinfer/__init__.pyflashinfer/dllm/__init__.pyflashinfer/dllm/batch_block_extend.pyflashinfer/dllm/block_extend.pyflashinfer/jit/attention/modules.pyflashinfer/jit/utils.pyflashinfer/prefill.pyflashinfer/utils.pyinclude/flashinfer/attention/block_expanding_prefill.cuhinclude/flashinfer/attention/default_prefill_params.cuhinclude/flashinfer/attention/hopper/mainloop.cuhinclude/flashinfer/attention/hopper/mainloop_mma.cuhinclude/flashinfer/attention/hopper/prefill_sm90.cuhinclude/flashinfer/attention/hopper/sparse_mainloop.cuhinclude/flashinfer/attention/mask.cuhinclude/flashinfer/attention/prefill.cuhinclude/flashinfer/utils.cuhtests/attention/test_dllm_cascade_vs_blockwise_extend_attention.pytests/attention/test_dllm_vs_flex_attention.py
| def _get_batch_be_module_uri(head_dim: int, dtype: torch.dtype) -> str: | ||
| dtype_str = {torch.float16: "fp16", torch.bfloat16: "bf16"}.get(dtype, "fp16") | ||
| return f"batch_prefill_block_expanding_hd{head_dim}_{dtype_str}" |
There was a problem hiding this comment.
Don't let distinct kernel ABIs share the same URI.
_get_batch_be_module_uri() only keys on head_dim plus a coarse dtype string, but the generated module also depends on idtype, and every other dtype currently aliases to fp16. An int64 indptr variant or an FP8 variant can therefore load/register the same flashinfer::{uri}_* name as a different specialization, and the wrapper recreation check also ignores idtype. Please bake both dtype and idtype into the URI/cache identity, or reject unsupported dtypes up front.
Also applies to: 300-305, 331-332, 435-440, 463-464
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/dllm/batch_block_extend.py` around lines 179 - 181, The URI
generation in _get_batch_be_module_uri only encodes head_dim and a coarse dtype
string, causing different kernel ABIs (e.g., idtype variants like int64 or fp8)
to collide; update _get_batch_be_module_uri to include both the element dtype
and the index/id dtype (idtype) in the returned string (e.g., use explicit
mappings for torch.float16/torch.bfloat16/torch.int64/FP8 aliases or use
dtype.name and idtype.name), and ensure any other URI/cache identity builders
and the wrapper recreation check also incorporate idtype (or explicitly reject
unsupported dtypes up front) so each specialization yields a unique
flashinfer::{uri}_* name and the recreation logic compares both dtype and
idtype.
| if q_offsets is None: | ||
| q_offsets = torch.zeros(batch_size, dtype=torch.int32, device=device) | ||
| if kv_offsets is None: | ||
| kv_offsets = q_offsets |
There was a problem hiding this comment.
batch_block_extend_cascade() defaults to the wrong global offsets when a prefix exists.
When has_prefix is true and the caller omits q_offsets/kv_offsets, both stages run as if the current chunk starts at position 0. That changes the block mask whenever the prefix length is nonzero, especially when it is not block-aligned. Derive the per-request prefix lengths from the paged prefix metadata here, or require the caller to pass them explicitly.
Suggested fix
- if q_offsets is None:
- q_offsets = torch.zeros(batch_size, dtype=torch.int32, device=device)
- if kv_offsets is None:
- kv_offsets = q_offsets
+ if q_offsets is None:
+ if has_prefix:
+ q_offsets = (
+ page_size * (paged_kv_indptr[1:] - paged_kv_indptr[:-1] - 1)
+ + paged_kv_last_page_len
+ ).to(device=device, dtype=qo_indptr.dtype)
+ else:
+ q_offsets = torch.zeros(batch_size, dtype=qo_indptr.dtype, device=device)
+ if kv_offsets is None:
+ kv_offsets = q_offsets📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if q_offsets is None: | |
| q_offsets = torch.zeros(batch_size, dtype=torch.int32, device=device) | |
| if kv_offsets is None: | |
| kv_offsets = q_offsets | |
| if q_offsets is None: | |
| if has_prefix: | |
| q_offsets = ( | |
| page_size * (paged_kv_indptr[1:] - paged_kv_indptr[:-1] - 1) | |
| paged_kv_last_page_len | |
| ).to(device=device, dtype=qo_indptr.dtype) | |
| else: | |
| q_offsets = torch.zeros(batch_size, dtype=qo_indptr.dtype, device=device) | |
| if kv_offsets is None: | |
| kv_offsets = q_offsets |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/dllm/batch_block_extend.py` around lines 560 - 563, In
batch_block_extend_cascade(), when q_offsets or kv_offsets are None the code
currently defaults both to zero which is incorrect when has_prefix is true;
change this to derive per-request global offsets from the paged-prefix metadata
(the per-request prefix length stored in the paged prefix structure used by the
function) instead of using torch.zeros, so that q_offsets and kv_offsets reflect
each request's prefix length (and block-aligned adjustments) before the
two-stage extension; alternatively validate and require the caller to supply
q_offsets/kv_offsets and raise a clear error if they are omitted when has_prefix
is true. Ensure you update the logic referencing q_offsets, kv_offsets,
has_prefix, and the paged prefix metadata in batch_block_extend_cascade to use
the computed per-request offsets.
| logits_soft_cap: float = 0.0, | ||
| return_lse: bool = False, | ||
| backend: str = "fa2", |
There was a problem hiding this comment.
The SGLang helper is not actually running the advertised current-chunk attention.
The docstring says "causal + merge", but stage 1 plans the ragged wrapper with causal=False, so current-chunk tokens can see future chunk tokens before the merge. The helper also drops the caller's logits_soft_cap entirely.
Suggested fix
ragged_wrapper.plan(
qo_indptr=qo_indptr, kv_indptr=kv_curr_indptr,
num_qo_heads=num_qo_heads, num_kv_heads=num_kv_heads,
- head_dim_qk=head_dim, head_dim_vo=head_dim, q_data_type=q.dtype, causal=False,
+ head_dim_qk=head_dim,
+ head_dim_vo=head_dim,
+ q_data_type=q.dtype,
+ causal=True,
+ logits_soft_cap=logits_soft_cap,
)
@@
paged_wrapper.plan(
qo_indptr=qo_indptr, paged_kv_indptr=paged_kv_indptr,
paged_kv_indices=paged_kv_indices, paged_kv_last_page_len=paged_kv_last_page_len,
num_qo_heads=num_qo_heads, num_kv_heads=num_kv_heads,
head_dim_qk=head_dim, head_dim_vo=head_dim, page_size=page_size,
- q_data_type=q.dtype, causal=False,
+ q_data_type=q.dtype, causal=False, logits_soft_cap=logits_soft_cap,
)Also applies to: 638-659
🧰 Tools
🪛 Ruff (0.15.4)
[warning] 611-611: Unused function argument: logits_soft_cap
(ARG001)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/dllm/batch_block_extend.py` around lines 611 - 613, The helper is
constructing stage-1 ragged attention with causal=False and drops the
logits_soft_cap parameter; update the ragged-wrapper construction(s) so stage 1
uses causal=True to enforce current-chunk (causal + merge) attention, and thread
the logits_soft_cap argument through to the helper calls instead of discarding
it; specifically, in batch_block_extend.py adjust the stage-1 ragged/planning
call(s) that currently pass causal=False to causal=True and ensure the helper
invocations (the ones near the signature containing logits_soft_cap, return_lse,
backend and the similar block at lines ~638-659) forward logits_soft_cap into
the downstream helper/function that applies the soft cap.
| import torch | ||
| import time | ||
| import math | ||
| import sys | ||
|
|
||
| # ============================================================ | ||
| # FlashInfer imports | ||
| # ============================================================ | ||
| try: | ||
| from flashinfer import single_prefill_with_kv_cache | ||
| from flashinfer.dllm import ( | ||
| BatchBlockExtendPagedOffsetWrapper, | ||
| BatchBlockExtendRaggedOffsetWrapper, | ||
| ) | ||
| HAS_FLASHINFER = True | ||
| except ImportError as e: | ||
| HAS_FLASHINFER = False | ||
| print(f"[WARN] flashinfer not available: {e}") | ||
| print(" Will skip FlashInfer benchmarks") | ||
| except Exception as e: | ||
| HAS_FLASHINFER = False | ||
| print(f"[ERROR] flashinfer import failed with unexpected error: {e}") | ||
| print(" Will skip FlashInfer benchmarks") | ||
|
|
||
| # ============================================================ | ||
| # Flex Attention imports (requires PyTorch >= 2.5) | ||
| # ============================================================ | ||
| try: | ||
| from torch.nn.attention.flex_attention import ( | ||
| flex_attention, | ||
| create_block_mask, | ||
| ) | ||
| HAS_FLEX_ATTENTION = True | ||
| except ImportError: | ||
| HAS_FLEX_ATTENTION = False | ||
| print("[WARN] flex_attention not available (requires PyTorch >= 2.5)") | ||
|
|
There was a problem hiding this comment.
Skip unsupported GPU architectures at module scope.
This module allocates on cuda:0 and exercises Hopper-specific paths, but it never gates execution with the repo's architecture helpers. Unsupported runners will fail instead of skipping cleanly.
As per coding guidelines, tests/**/*.py: Use flashinfer.utils functions (get_compute_capability(), is_sm90a_supported(), is_sm100a_supported()) to skip tests on unsupported GPU architectures.
🧰 Tools
🪛 Ruff (0.15.4)
[warning] 39-39: Do not catch blind exception: Exception
(BLE001)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/attention/test_dllm_vs_flex_attention.py` around lines 20 - 56, The
module currently performs CUDA/Hopper-specific work at import time and must
early-skip unsupported GPUs: query flashinfer.utils.get_compute_capability(),
is_sm90a_supported(), and is_sm100a_supported() at module scope (and check
torch.cuda.is_available()/device count) and if the current GPU is unsupported,
set HAS_FLASHINFER = HAS_FLEX_ATTENTION = False and print a skip message before
attempting any flashinfer or flex_attention imports or CUDA allocations; wrap
the existing flashinfer imports and the flex_attention import logic behind this
guard so functions like single_prefill_with_kv_cache,
BatchBlockExtendPagedOffsetWrapper, BatchBlockExtendRaggedOffsetWrapper,
flex_attention and create_block_mask are only imported when the architecture
checks pass.
| def compute_block_extend_reference( | ||
| q: torch.Tensor, | ||
| k: torch.Tensor, | ||
| v: torch.Tensor, | ||
| dllm_block_size: int, | ||
| q_offset: int = 0, | ||
| sm_scale: float = None, | ||
| ) -> torch.Tensor: | ||
| """Reference: single_prefill_with_kv_cache + custom_mask""" | ||
| qo_len = q.shape[0] | ||
| kv_len = k.shape[0] | ||
| head_dim = q.shape[-1] | ||
| device = q.device | ||
|
|
||
| if sm_scale is None: | ||
| sm_scale = 1.0 / math.sqrt(head_dim) | ||
|
|
||
| q_pos = torch.arange(qo_len, device=device) + q_offset | ||
| k_pos = torch.arange(kv_len, device=device) | ||
| q_block = q_pos.unsqueeze(1) // dllm_block_size | ||
| k_block = k_pos.unsqueeze(0) // dllm_block_size | ||
| mask_2d = (q_block >= k_block).to(torch.uint8) | ||
|
|
||
| return single_prefill_with_kv_cache( | ||
| q, k, v, custom_mask=mask_2d, sm_scale=sm_scale, | ||
| ) | ||
|
|
||
|
|
||
| # ============================================================ | ||
| # Flex Attention helper: build block_extend mask_mod | ||
| # ============================================================ | ||
| def make_block_extend_mask_mod(dllm_block_size: int, q_offset: int = 0): | ||
| """ | ||
| 返回 flex_attention 使用的 mask_mod 函数 | ||
|
|
||
| mask_mod(b, h, q_idx, kv_idx) -> bool | ||
| True = 允许 attend, False = 屏蔽 | ||
| """ | ||
| def block_extend_mask(b, h, q_idx, kv_idx): | ||
| q_global = q_idx + q_offset | ||
| q_blk = q_global // dllm_block_size | ||
| kv_blk = kv_idx // dllm_block_size | ||
| return q_blk >= kv_blk | ||
| return block_extend_mask |
There was a problem hiding this comment.
The new per-request and kv_block_expanding_offset paths still are not validated.
compute_block_extend_reference() hardcodes kv_offset=0, make_block_extend_mask_mod() ignores b, and the batch setup uses torch.full(...) for q_offsets. That means the batch-indexed offset plumbing added in this PR never sees a heterogeneous batch, and the cascade/current-chunk kv_block_expanding_offset behavior is still untested.
Also applies to: 321-321, 381-384, 516-520
🧰 Tools
🪛 Ruff (0.15.4)
[warning] 67-67: PEP 484 prohibits implicit Optional
Convert to T | None
(RUF013)
[warning] 99-99: Unused function argument: b
(ARG001)
[warning] 99-99: Unused function argument: h
(ARG001)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/attention/test_dllm_vs_flex_attention.py` around lines 61 - 104,
compute_block_extend_reference and make_block_extend_mask_mod currently never
exercise per-request offsets or kv_block_expanding_offset because
compute_block_extend_reference hardcodes q_offset (kv_offset) to zero,
make_block_extend_mask_mod ignores the batch index b, and tests use torch.full
for q_offsets; update the test helpers so compute_block_extend_reference accepts
and uses per-request q_offset values (and propagate kv_offset if applicable) and
make_block_extend_mask_mod's inner block_extend_mask uses the batch index b to
look up per-sample offsets, then modify the batch construction in tests to pass
heterogeneous q_offsets (not torch.full) and add cases exercising nonzero
kv_block_expanding_offset and cascade/current-chunk paths so the new plumbing is
validated against single_prefill_with_kv_cache and block_extend_mask behavior.
| def benchmark_fn(fn, warmup_iters=20, bench_iters=100, label=""): | ||
| """Benchmark a callable, return average time in ms.""" | ||
| for _ in range(warmup_iters): | ||
| fn() | ||
| torch.cuda.synchronize() | ||
|
|
||
| start = time.perf_counter() | ||
| for _ in range(bench_iters): | ||
| fn() | ||
| torch.cuda.synchronize() | ||
| elapsed_ms = (time.perf_counter() - start) / bench_iters * 1000 | ||
| return elapsed_ms | ||
|
|
||
|
|
||
| def benchmark_with_cuda_graph(fn, warmup_iters=20, bench_iters=100, label=""): | ||
| """Benchmark with CUDA Graph capture, return average time in ms.""" | ||
| # warmup | ||
| for _ in range(warmup_iters): | ||
| fn() | ||
| torch.cuda.synchronize() | ||
|
|
||
| # capture | ||
| stream = torch.cuda.Stream() | ||
| with torch.cuda.stream(stream): | ||
| fn() | ||
| stream.synchronize() | ||
|
|
||
| graph = torch.cuda.CUDAGraph() | ||
| with torch.cuda.graph(graph, stream=stream): | ||
| fn() | ||
|
|
||
| # warmup cuda_graph | ||
| for _ in range(warmup_iters): | ||
| graph.replay() | ||
| torch.cuda.synchronize() | ||
|
|
||
| # bench | ||
| start = time.perf_counter() | ||
| for _ in range(bench_iters): | ||
| graph.replay() | ||
| torch.cuda.synchronize() | ||
| elapsed_ms = (time.perf_counter() - start) / bench_iters * 1000 | ||
|
|
||
| del graph | ||
| return elapsed_ms |
There was a problem hiding this comment.
🛠️ Refactor suggestion | 🟠 Major
Route these timings through bench_gpu_time().
The benchmark helpers and the direct perf_counter() paths in test_total_memory_comparison() bypass the repo timing harness, so these numbers will not use the standard CUPTI/CUDA-event fallback and will not be comparable with the rest of the benchmark suite.
As per coding guidelines, tests/**/*.py: Use flashinfer.testing.bench_gpu_time() for benchmarking kernels, preferring CUPTI timing with auto-fallback to CUDA events.
Also applies to: 1078-1083, 1119-1123
🧰 Tools
🪛 Ruff (0.15.4)
[warning] 166-166: Unused function argument: label
(ARG001)
[warning] 180-180: Unused function argument: label
(ARG001)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/attention/test_dllm_vs_flex_attention.py` around lines 166 - 210, The
benchmarks (functions benchmark_fn and benchmark_with_cuda_graph) and the direct
perf_counter() timing in test_total_memory_comparison should be replaced to use
the repo timing harness flashinfer.testing.bench_gpu_time() so results use CUPTI
with CUDA-event fallback and remain comparable across the suite; locate usages
of benchmark_fn, benchmark_with_cuda_graph, and the perf_counter() blocks in
test_total_memory_comparison and call bench_gpu_time() (passing the callable and
warmup/bench iteration params) instead of manual perf_counter/CUDAGraph timing,
ensuring any CUDA Graph replay loops are wrapped or adapted to the
bench_gpu_time() callable interface.
Add DLLM Block Extend Attention feature with tile-level skip optimization using native MaskMode::kBlockExpanding. Core API: - Single-request: block_extend_attention_with_offset() with q/kv offset support - Batch: BatchBlockExtendRaggedOffsetWrapper, BatchBlockExtendPagedOffsetWrapper - Cascade: 3-stage attention (current chunk + prefix + merge state) - Support both JIT and AOT compilation Tests: - Precision: block extend vs custom_mask reference, cascade vs blockwise correctness - Performance: FlashInfer Block Extend vs PyTorch Flex Attention benchmark - Context length sweep (1K-32K), block size alignment analysis - Significant speedup over Flex Attention with lower memory usage
33b195e to
5e07b58
Compare
There was a problem hiding this comment.
Actionable comments posted: 3
♻️ Duplicate comments (10)
include/flashinfer/attention/hopper/mainloop.cuh (1)
230-243:⚠️ Potential issue | 🔴 CriticalHandle the zero-visible-KV case before the first TMA load.
get_num_kv_tiles()can now return0when block expansion pluskv_offsetremoves the whole visible range. In that casekv_tile_idxbecomes-1, and Line 243 immediately indexestKgK(_, kv_tile_idx), which is an invalid tile read. Please skip scheduling those work tiles or add a no-op path before any K/V load whennum_kv_tiles == 0.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@include/flashinfer/attention/hopper/mainloop.cuh` around lines 230 - 243, The code assumes get_num_kv_tiles(...) > 0 but it can return 0, which makes kv_tile_idx = num_kv_tiles - 1 negative and leads to an invalid read from tKgK(_, kv_tile_idx); add a guard for num_kv_tiles == 0 before any K/V TMA scheduling: check num_kv_tiles and if it is 0 skip the pipeline_k.producer_acquire(...) and the subsequent copy(...) that references tKgK and smem_pipe_write_k (i.e., short-circuit the path that schedules the first K load using mainloop_params.tma_load_K and tKgK), or route to a no-op branch so no tile is indexed when num_kv_tiles == 0.flashinfer/dllm/block_extend.py (2)
146-170:⚠️ Potential issue | 🟠 MajorUse the target tensor device for the FA3 capability gate.
block_extend_attention_with_offset()selectsfa3fromq.device, butget_block_extend_module_with_offset()re-checks FA3 againsttorch.device("cuda"). On mixed-architecture multi-GPU systems, a valid call on one device can fail just because the default CUDA device is older.Proposed fix
def get_block_extend_module_with_offset( head_dim: int = 128, dtype: torch.dtype = torch.float16, backend: str = "fa2", + device: Optional[torch.device] = None, ): @@ - if backend == "fa3" and not is_sm90a_supported(torch.device("cuda")): + device = device or torch.device("cuda") + if backend == "fa3" and not is_sm90a_supported(device): raise RuntimeError( "FA3 backend requires SM90 (Hopper) architecture. " "Use backend='fa2' for older architectures." ) @@ - module = get_block_extend_module_with_offset(head_dim=head_dim, dtype=dtype, backend=backend) + module = get_block_extend_module_with_offset( + head_dim=head_dim, + dtype=dtype, + backend=backend, + device=q.device, + )Also applies to: 281-285
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/dllm/block_extend.py` around lines 146 - 170, The capability check for FA3 uses torch.device("cuda") which can be wrong on mixed-GPU systems; update get_block_extend_module_with_offset to accept or derive the target device (use the device of the input/target tensor or the same device used by block_extend_attention_with_offset) and pass that device into is_sm90a_supported instead of torch.device("cuda"); also apply the same change to the other FA3 check near the block_extend_attention_with_offset-related code (the second occurrence) so the SM90 gate queries the actual tensor/device being compiled for.
90-95:⚠️ Potential issue | 🟠 MajorDon't alias unsupported dtypes to the FP16 URI.
_get_dtype_str()currently collapses every non-fp16/bf16 dtype to"fp16". That makes FP8/FP32 call paths reuse the FP16 module URI and cache entry, which can load the wrong specialization.Proposed fix
def _get_dtype_str(dtype: torch.dtype) -> str: """Get dtype string representation (unified interface)""" - return { - torch.float16: "fp16", - torch.bfloat16: "bf16", - }.get(dtype, "fp16") + mapping = { + torch.float16: "fp16", + torch.bfloat16: "bf16", + torch.float32: "fp32", + torch.float8_e4m3fn: "fp8e4m3", + torch.float8_e5m2: "fp8e5m2", + } + try: + return mapping[dtype] + except KeyError as exc: + raise ValueError(f"Unsupported block-extend dtype: {dtype}") from exc🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/dllm/block_extend.py` around lines 90 - 95, The helper _get_dtype_str currently maps any non-fp16/bf16 dtype to "fp16", causing FP8/FP32 to reuse FP16 URIs; update _get_dtype_str to explicitly map known dtypes (at minimum torch.float16 -> "fp16", torch.bfloat16 -> "bf16", torch.float32 -> "fp32") and do not alias unknown types to "fp16" — instead either return a distinct string (e.g., dtype.name) or raise a clear ValueError for unsupported dtypes so wrong specializations are never cached; edit the _get_dtype_str function accordingly.flashinfer/prefill.py (1)
1651-1651:⚠️ Potential issue | 🟠 MajorReject unsupported
mask_modes before dispatch.Both wrappers now persist arbitrary
mask_modeoverrides, but thecudnnpath still derives masking fromself._causaland the pagedtrtllm-genpath never consumesmask_modeat all. ABLOCK_EXPANDINGplan can therefore silently execute with causal/non-causal semantics instead of failing fast.Also applies to: 2010-2010, 2205-2206, 2632-2632, 2938-2938, 3159-3160
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/prefill.py` at line 1651, The function accepting mask_mode should reject unsupported or conflicting overrides before dispatch: at the start of the routine that declares the mask_mode parameter, validate mask_mode against supported values and fail fast (raise ValueError) instead of persisting arbitrary overrides; specifically, if dispatch will use the cudnn path (which derives masking from self._causal) and mask_mode is provided but inconsistent with self._causal, raise an error, and if dispatch will use the paged trtllm-gen path (which never consumes mask_mode) reject any non-None mask_mode; ensure checks reference mask_mode, self._causal, and the BLOCK_EXPANDING plan name so callers cannot silently run with wrong causal/non-causal semantics.tests/attention/test_dllm_vs_flex_attention.py (4)
61-104:⚠️ Potential issue | 🟠 MajorThe new per-batch and
kv_offsetplumbing still isn't being validated.
compute_block_extend_reference()only shifts Q,make_block_extend_mask_mod()ignoresb, and the batch setup uses a uniformtorch.full(...)offset. The Hopper block-expanding path uses both per-request Q offsets and per-request KV offsets, so these checks can pass without touching the new code path at all.Also applies to: 321-321, 381-384, 516-520
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/attention/test_dllm_vs_flex_attention.py` around lines 61 - 104, The tests currently only shift Q globally and ignore per-batch and per-request KV offsets, so update compute_block_extend_reference and make_block_extend_mask_mod to accept and use per-batch q_offset and kv_offset (arrays/tensors indexed by b) and ensure the returned mask_mod callback uses the b parameter to read kv_offset[b] and q_offset[b] when computing q_global and kv_blk; also change the test setup (replace torch.full(...) offsets) to supply non-uniform per-batch q and kv offsets so the flex-attention path actually exercises per-request plumbing (check functions compute_block_extend_reference, make_block_extend_mask_mod, and the test offset construction).
216-229:⚠️ Potential issue | 🟠 MajorThese benchmark drivers will be collected and run as ordinary tests.
All of these top-level
test_*functions have only defaulted parameters, so pytest will execute the full sweeps during normal test runs. Please move them to a benchmark module or rename them to a non-test_prefix.Also applies to: 664-672, 752-759, 832-841, 959-968
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/attention/test_dllm_vs_flex_attention.py` around lines 216 - 229, The function named test_flashinfer_vs_flex_attention (and other top-level functions at the noted ranges) are declared as pytest tests but are benchmark drivers with only defaulted parameters, so rename each function to a non-test name (e.g., benchmark_flashinfer_vs_flex_attention) or move them into a dedicated benchmark module so pytest won't auto-run them; update any references/calls to the original function names accordingly and ensure imports/exports reflect the new names (target symbols: test_flashinfer_vs_flex_attention and the other top-level test_* functions referenced).
355-399:⚠️ Potential issue | 🟠 MajorMake correctness mismatches fail the test.
ragged_pass,paged_pass, andflex_passonly affect logging right now. If any backend diverges from the reference, pytest still reports success and CI misses the regression.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/attention/test_dllm_vs_flex_attention.py` around lines 355 - 399, The test currently only prints ragged_pass, paged_pass, and flex_pass but does not assert them, so divergences are not failing CI; update the test to assert these boolean checks (or assert the numeric diffs are below tol) after computing ragged_diff, paged_diff, and flex_diff so failures raise in pytest. Specifically, add assertions for ragged_pass and paged_pass (and for flex_pass only if HAS_FLEX_ATTENTION) with informative messages including the corresponding diff values to help debugging; reference the variables ragged_pass, paged_pass, flex_pass and the diff variables ragged_diff, paged_diff, flex_diff and the tolerance tol when adding the assertions.
28-56:⚠️ Potential issue | 🟠 MajorSkip unsupported GPU architectures before running this module.
This file exercises architecture-specific attention paths but only soft-fails import errors. Unsupported runners will still execute the test functions and fail later instead of skipping cleanly. As per coding guidelines,
tests/**/*.py: Use flashinfer.utils functions (get_compute_capability(),is_sm90a_supported(),is_sm100a_supported()) to skip tests on unsupported GPU architectures.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/attention/test_dllm_vs_flex_attention.py` around lines 28 - 56, Add an early GPU-architecture guard at module import time before the FlashInfer/Flex-Attention import blocks: call flashinfer.utils.get_compute_capability() and use flashinfer.utils.is_sm90a_supported() and is_sm100a_supported() (or import them from flashinfer.utils) to detect unsupported GPUs and call pytest.skip(...) to skip the whole module when the architecture is not supported; ensure this check runs before or around the try/except blocks that set HAS_FLASHINFER and HAS_FLEX_ATTENTION so tests are skipped cleanly instead of failing later (refer to the module-level import area and the HAS_FLASHINFER / HAS_FLEX_ATTENTION logic).flashinfer/dllm/batch_block_extend.py (2)
611-613:⚠️ Potential issue | 🟠 MajorStage 1 is still non-causal, and
logits_soft_capis still a no-op.The helper advertises “causal + merge”, but the current-chunk ragged plan still uses
causal=False, so tokens can see later positions inside the chunk before the merge.logits_soft_capis also accepted by the public API and then dropped in both planning calls.Also applies to: 638-659
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/dllm/batch_block_extend.py` around lines 611 - 613, The helper currently advertises "causal + merge" but leaves stage 1 non-causal and drops logits_soft_cap; update the code so the current-chunk ragged plan is created with causal=True (so tokens cannot see later positions inside the chunk prior to merge) and pass the logits_soft_cap parameter through to both planning calls instead of ignoring it; locate the helper that builds the current-chunk ragged plan and the two planning call sites in batch_block_extend.py and modify their call signatures/arguments to include logits_soft_cap and set causal=True for the stage-1 plan.
560-563:⚠️ Potential issue | 🟠 MajorDefault offsets are still wrong when a paged prefix exists.
With a nonzero prefix, zeroing
q_offsets/kv_offsetsmakes stage 1 run as if the current chunk started at global position 0. That changes the block-expanding mask whenever the prefix length is not block-aligned. Derive per-request offsets from the paged-prefix metadata or require the caller to pass them explicitly.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/dllm/batch_block_extend.py` around lines 560 - 563, The defaulting of q_offsets/kv_offsets to zeros is incorrect when a paged prefix exists because it makes stage 1 assume the chunk starts at global position 0; update the logic in batch_block_extend.py where q_offsets and kv_offsets are initialized so that when they are None you derive per-request offsets from the paged-prefix metadata (use the paged-prefix length/offset fields for each request) and compute int32 offsets on the device instead of zeroing, or alternatively make the function require the caller to pass explicit q_offsets/kv_offsets; ensure you propagate these computed per-request offsets into the downstream stage 1 block-expanding mask calculations so block alignment is correct for non-block-aligned prefix lengths.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/dllm/batch_block_extend.py`:
- Around line 287-299: The code mutates the wrapper's backend preference by
assigning effective_backend back to self._backend, breaking future
auto-selection; instead, keep self._backend unchanged and use a local variable
(effective_backend) only to select the URI/variant (via
select_best_backend_paged, _get_batch_be_module_uri, variant_name, variant_decl)
so that backend="auto" continues to work across re-plans—apply the same
non-mutating pattern used in
BatchBlockExtendRaggedOffsetWrapper._create_inner_wrapper and also fix the
analogous block at the other occurrence (around the 422-433 region).
- Around line 545-549: In batch_block_extend_cascade(), don't pre-resolve
backend=="auto" using is_sm90a_supported/device; remove the logic that sets
actual_backend = "fa3" if is_sm90a_supported(device) else "fa2" and instead
preserve "auto" (i.e., set actual_backend = backend) so the wrapper layers can
resolve it, or replace that branch with a call to the centralized availability
selector helper if you prefer availability-based resolution; update references
to actual_backend accordingly.
In `@flashinfer/jit/attention/modules.py`:
- Line 1283: The JIT cache key must include the mask_modes so different mask
sets produce distinct compiled artifacts; update the code that constructs the
JitSpec (the unique name/URI hash and/or sources list) to incorporate the
mask_modes value(s) alongside existing fields (uri, sources,
extra_cuda_cflags/extra_cflags/extra_ldflags). Specifically, when creating the
JitSpec for functions that accept mask_modes (parameter mask_modes), append or
mix the mask_modes representation into the unique name/hash used as the URI
and/or into the sources list so the compiled directory/shared object differs for
different mask_modes and prevents loading stale modules.
---
Duplicate comments:
In `@flashinfer/dllm/batch_block_extend.py`:
- Around line 611-613: The helper currently advertises "causal + merge" but
leaves stage 1 non-causal and drops logits_soft_cap; update the code so the
current-chunk ragged plan is created with causal=True (so tokens cannot see
later positions inside the chunk prior to merge) and pass the logits_soft_cap
parameter through to both planning calls instead of ignoring it; locate the
helper that builds the current-chunk ragged plan and the two planning call sites
in batch_block_extend.py and modify their call signatures/arguments to include
logits_soft_cap and set causal=True for the stage-1 plan.
- Around line 560-563: The defaulting of q_offsets/kv_offsets to zeros is
incorrect when a paged prefix exists because it makes stage 1 assume the chunk
starts at global position 0; update the logic in batch_block_extend.py where
q_offsets and kv_offsets are initialized so that when they are None you derive
per-request offsets from the paged-prefix metadata (use the paged-prefix
length/offset fields for each request) and compute int32 offsets on the device
instead of zeroing, or alternatively make the function require the caller to
pass explicit q_offsets/kv_offsets; ensure you propagate these computed
per-request offsets into the downstream stage 1 block-expanding mask
calculations so block alignment is correct for non-block-aligned prefix lengths.
In `@flashinfer/dllm/block_extend.py`:
- Around line 146-170: The capability check for FA3 uses torch.device("cuda")
which can be wrong on mixed-GPU systems; update
get_block_extend_module_with_offset to accept or derive the target device (use
the device of the input/target tensor or the same device used by
block_extend_attention_with_offset) and pass that device into is_sm90a_supported
instead of torch.device("cuda"); also apply the same change to the other FA3
check near the block_extend_attention_with_offset-related code (the second
occurrence) so the SM90 gate queries the actual tensor/device being compiled
for.
- Around line 90-95: The helper _get_dtype_str currently maps any non-fp16/bf16
dtype to "fp16", causing FP8/FP32 to reuse FP16 URIs; update _get_dtype_str to
explicitly map known dtypes (at minimum torch.float16 -> "fp16", torch.bfloat16
-> "bf16", torch.float32 -> "fp32") and do not alias unknown types to "fp16" —
instead either return a distinct string (e.g., dtype.name) or raise a clear
ValueError for unsupported dtypes so wrong specializations are never cached;
edit the _get_dtype_str function accordingly.
In `@flashinfer/prefill.py`:
- Line 1651: The function accepting mask_mode should reject unsupported or
conflicting overrides before dispatch: at the start of the routine that declares
the mask_mode parameter, validate mask_mode against supported values and fail
fast (raise ValueError) instead of persisting arbitrary overrides; specifically,
if dispatch will use the cudnn path (which derives masking from self._causal)
and mask_mode is provided but inconsistent with self._causal, raise an error,
and if dispatch will use the paged trtllm-gen path (which never consumes
mask_mode) reject any non-None mask_mode; ensure checks reference mask_mode,
self._causal, and the BLOCK_EXPANDING plan name so callers cannot silently run
with wrong causal/non-causal semantics.
In `@include/flashinfer/attention/hopper/mainloop.cuh`:
- Around line 230-243: The code assumes get_num_kv_tiles(...) > 0 but it can
return 0, which makes kv_tile_idx = num_kv_tiles - 1 negative and leads to an
invalid read from tKgK(_, kv_tile_idx); add a guard for num_kv_tiles == 0 before
any K/V TMA scheduling: check num_kv_tiles and if it is 0 skip the
pipeline_k.producer_acquire(...) and the subsequent copy(...) that references
tKgK and smem_pipe_write_k (i.e., short-circuit the path that schedules the
first K load using mainloop_params.tma_load_K and tKgK), or route to a no-op
branch so no tile is indexed when num_kv_tiles == 0.
In `@tests/attention/test_dllm_vs_flex_attention.py`:
- Around line 61-104: The tests currently only shift Q globally and ignore
per-batch and per-request KV offsets, so update compute_block_extend_reference
and make_block_extend_mask_mod to accept and use per-batch q_offset and
kv_offset (arrays/tensors indexed by b) and ensure the returned mask_mod
callback uses the b parameter to read kv_offset[b] and q_offset[b] when
computing q_global and kv_blk; also change the test setup (replace
torch.full(...) offsets) to supply non-uniform per-batch q and kv offsets so the
flex-attention path actually exercises per-request plumbing (check functions
compute_block_extend_reference, make_block_extend_mask_mod, and the test offset
construction).
- Around line 216-229: The function named test_flashinfer_vs_flex_attention (and
other top-level functions at the noted ranges) are declared as pytest tests but
are benchmark drivers with only defaulted parameters, so rename each function to
a non-test name (e.g., benchmark_flashinfer_vs_flex_attention) or move them into
a dedicated benchmark module so pytest won't auto-run them; update any
references/calls to the original function names accordingly and ensure
imports/exports reflect the new names (target symbols:
test_flashinfer_vs_flex_attention and the other top-level test_* functions
referenced).
- Around line 355-399: The test currently only prints ragged_pass, paged_pass,
and flex_pass but does not assert them, so divergences are not failing CI;
update the test to assert these boolean checks (or assert the numeric diffs are
below tol) after computing ragged_diff, paged_diff, and flex_diff so failures
raise in pytest. Specifically, add assertions for ragged_pass and paged_pass
(and for flex_pass only if HAS_FLEX_ATTENTION) with informative messages
including the corresponding diff values to help debugging; reference the
variables ragged_pass, paged_pass, flex_pass and the diff variables ragged_diff,
paged_diff, flex_diff and the tolerance tol when adding the assertions.
- Around line 28-56: Add an early GPU-architecture guard at module import time
before the FlashInfer/Flex-Attention import blocks: call
flashinfer.utils.get_compute_capability() and use
flashinfer.utils.is_sm90a_supported() and is_sm100a_supported() (or import them
from flashinfer.utils) to detect unsupported GPUs and call pytest.skip(...) to
skip the whole module when the architecture is not supported; ensure this check
runs before or around the try/except blocks that set HAS_FLASHINFER and
HAS_FLEX_ATTENTION so tests are skipped cleanly instead of failing later (refer
to the module-level import area and the HAS_FLASHINFER / HAS_FLEX_ATTENTION
logic).
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 79a6e941-787c-4ea3-a005-e934cd5ac5b1
📒 Files selected for processing (23)
csrc/batch_prefill_customize_config.jinjacsrc/batch_prefill_sm90_customize_config.jinjacsrc/single_prefill_customize_config.jinjacsrc/single_prefill_sm90_customize_config.jinjaflashinfer/__init__.pyflashinfer/dllm/__init__.pyflashinfer/dllm/batch_block_extend.pyflashinfer/dllm/block_extend.pyflashinfer/jit/attention/modules.pyflashinfer/jit/utils.pyflashinfer/prefill.pyflashinfer/utils.pyinclude/flashinfer/attention/block_expanding_prefill.cuhinclude/flashinfer/attention/default_prefill_params.cuhinclude/flashinfer/attention/hopper/mainloop.cuhinclude/flashinfer/attention/hopper/mainloop_mma.cuhinclude/flashinfer/attention/hopper/prefill_sm90.cuhinclude/flashinfer/attention/hopper/sparse_mainloop.cuhinclude/flashinfer/attention/mask.cuhinclude/flashinfer/attention/prefill.cuhinclude/flashinfer/utils.cuhtests/attention/test_dllm_cascade_vs_blockwise_extend_attention.pytests/attention/test_dllm_vs_flex_attention.py
🚧 Files skipped from review as they are similar to previous changes (6)
- csrc/single_prefill_customize_config.jinja
- flashinfer/init.py
- csrc/single_prefill_sm90_customize_config.jinja
- flashinfer/utils.py
- include/flashinfer/attention/block_expanding_prefill.cuh
- flashinfer/jit/utils.py
| use_logits_soft_cap: bool = False, | ||
| use_fp16_qk_reduction: bool = False, | ||
| fp8_enabled: bool = False, | ||
| mask_modes: Optional[List[int]] = None, |
There was a problem hiding this comment.
Include mask_modes in the JIT cache key.
mask_modes now changes which kernel sources get generated, but both functions still compile under the caller-supplied uri unchanged. Reusing the same uri with a different mask set can therefore hit the wrong generated directory / shared object and silently load a stale module.
💡 Proposed fix
def gen_customize_single_prefill_module(
backend: str,
uri: str,
@@
- mask_modes: Optional[List[int]] = None,
+ mask_modes: Optional[List[int]] = None,
) -> JitSpec:
+ normalized_mask_modes = tuple(sorted(set(mask_modes))) if mask_modes is not None else None
+ if normalized_mask_modes is not None:
+ uri = f"{uri}_mask_modes_{'_'.join(map(str, normalized_mask_modes))}"
kwargs = {
"variant_decl": variant_decl,
@@
def gen_customize_batch_prefill_module(
backend: str,
uri: str,
@@
- mask_modes: Optional[List[int]] = None,
+ mask_modes: Optional[List[int]] = None,
) -> JitSpec:
+ normalized_mask_modes = tuple(sorted(set(mask_modes))) if mask_modes is not None else None
+ if normalized_mask_modes is not None:
+ uri = f"{uri}_mask_modes_{'_'.join(map(str, normalized_mask_modes))}"
kwargs = {
"variant_decl": variant_decl,As per coding guidelines, "Structure JIT compilation parameters in JitSpec with unique name (URI hash), sources list, and compiler flags (extra_cuda_cflags, extra_cflags, extra_ldflags)".
Also applies to: 1532-1532
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/jit/attention/modules.py` at line 1283, The JIT cache key must
include the mask_modes so different mask sets produce distinct compiled
artifacts; update the code that constructs the JitSpec (the unique name/URI hash
and/or sources list) to incorporate the mask_modes value(s) alongside existing
fields (uri, sources, extra_cuda_cflags/extra_cflags/extra_ldflags).
Specifically, when creating the JitSpec for functions that accept mask_modes
(parameter mask_modes), append or mix the mask_modes representation into the
unique name/hash used as the URI and/or into the sources list so the compiled
directory/shared object differs for different mask_modes and prevents loading
stale modules.
|
This PR is quite impactful for block diffusion LLM — we’ve seen several times improvement in TTFT in our production environment at AntGroup when using SGLang. By the way, I’m the code owner of SGLang-dLLM. I’d expect to collaborate with you to make FlashInfer as the official recommended backend for SGLang-dLLM. I think it would be helpful for both projects and the broader community. SGLang dLLM roadmap: sgl-project/sglang#14199 |
yzh119
left a comment
There was a problem hiding this comment.
Hi @fdz-1999 thanks for the great work and it make sense to me in general.
Adding a new mask type would significantly bloat binary size (in jit-cache, etc.), can we make it a standalone class instead of changing csrc/batch_prefill_customize_config.jinja etc?
| } // namespace flashinfer | ||
|
|
||
| #endif // FLASHINFER_DECODE_PARAMS_CUH_ | ||
| #endif // FLASHINFER_DECODE_PARAMS_CUH_ No newline at end of file |
There was a problem hiding this comment.
Please fix the lint issue (new line in the end).
| } // namespace flashinfer | ||
|
|
||
| #endif // FLASHINFER_UTILS_CUH_ | ||
| #endif // FLASHINFER_UTILS_CUH_ No newline at end of file |
Thanks for the review and the concern about binary size — it's a really important consideration for a JIT-heavy project like FlashInfer, and I want to make sure the impact is clear. I looked into this carefully, and I believe the current approach should not cause binary bloat for non-DLLM users. Here's the reasoning: 1) Mode 4 is not compiled by default The JIT default in 2) Jinja template changes are conditionally compiled away The {% if 'maybe_q_block_expanding_offset' in additional_params_decl %}
return maybe_q_block_expanding_offset[batch_idx];
{% else %}
return 0;
{% endif %}Non-DLLM callers never pass 3) The kernel-level changes follow the existing MaskMode pattern The additions in The only "global" changes are:
4) Regarding the standalone class alternative I did consider this direction. The challenge is that block expanding only differs from the existing kernel at a few specific points (mask predicate, iteration count, mask_iteration boundary — ~50 lines of A standalone implementation would need to duplicate the full prefill kernel stack ( Given that the current approach has zero impact on non-DLLM users (no extra compilation, no binary size increase, no runtime overhead), I think extending the existing MaskMode infrastructure is the more sustainable path. But I'm happy to discuss further if there are specific scenarios you're concerned about! |
There was a problem hiding this comment.
Actionable comments posted: 3
♻️ Duplicate comments (3)
flashinfer/dllm/batch_block_extend.py (2)
293-296:⚠️ Potential issue | 🟠 MajorDon't overwrite the wrapper's backend preference on first plan.
Both
_create_inner_wrapperimplementations still assignself._backend = effective_backend. When the wrapper was constructed withbackend="auto", the first plan locks the instance into whatever was picked for that(head_dim, dtype)pair, so a subsequent re-plan for a different shape no longer auto-falls-back and can fail even when the alternative backend is available.self._preferred_backendis already stored for this purpose — use the localeffective_backendfor URI/variant/inner construction and leaveself._backendunchanged.Also applies to: 430-433
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/dllm/batch_block_extend.py` around lines 293 - 296, The _create_inner_wrapper method currently assigns self._backend = effective_backend which locks the instance to the first chosen backend; instead, stop mutating self._backend and only use the local effective_backend variable when constructing the inner wrapper/URI/variant/inner objects (leave self._preferred_backend and self._backend unchanged so future re-plans can re-select backends). Apply the same change in the other _create_inner_wrapper implementation as well (remove any assignment to self._backend and rely on effective_backend locally).
179-186:⚠️ Potential issue | 🟠 MajorURI still omits
idtype— silent kernel-mismatch risk when indptr dtype changes.The recreation check at line 338 / 472 now correctly includes
self._idtype, but_get_batch_be_module_uri()does not encodeidtypeinto the URI string. When a user re-plans with a different indptr dtype (e.g.int32→int64),_create_inner_wrapperproduces the same URI, and the downstream JIT module cache may return the previously-specialized kernel, leading to silent miscompute or undefined behavior. Bakeidtypeinto the URI (and reject unsupported dtypes explicitly as done for the element dtype).🛠️ Proposed fix
-def _get_batch_be_module_uri(head_dim: int, dtype: torch.dtype) -> str: - _dtype_map = {torch.float16: "fp16", torch.bfloat16: "bf16"} +def _get_batch_be_module_uri(head_dim: int, dtype: torch.dtype, idtype: torch.dtype = torch.int32) -> str: + _dtype_map = {torch.float16: "fp16", torch.bfloat16: "bf16"} + _idtype_map = {torch.int32: "i32", torch.int64: "i64"} if dtype not in _dtype_map: raise ValueError( f"Unsupported dtype {dtype} for Block Extend Attention. " f"Supported: {list(_dtype_map.keys())}" ) - return f"batch_prefill_block_expanding_hd{head_dim}_{_dtype_map[dtype]}" + if idtype not in _idtype_map: + raise ValueError( + f"Unsupported idtype {idtype}. Supported: {list(_idtype_map.keys())}" + ) + return f"batch_prefill_block_expanding_hd{head_dim}_{_dtype_map[dtype]}_{_idtype_map[idtype]}"Thread
idtypeinto every call site (lines 86-88, 134-136, 298/302, 435/439).🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/dllm/batch_block_extend.py` around lines 179 - 186, _get_batch_be_module_uri currently omits the indptr idtype which can cause silent kernel-mismatch; change its signature to accept an idtype parameter (e.g. def _get_batch_be_module_uri(head_dim: int, dtype: torch.dtype, idtype: torch.dtype) -> str), add an _idtype_map (e.g. {torch.int32: "i32", torch.int64: "i64"}) and raise ValueError for unsupported idtypes mirroring the element-dtype check, and include the mapped idtype token in the returned URI string; then update all callers (notably places that call _get_batch_be_module_uri and _create_inner_wrapper and any other call sites that construct the module URI) to pass self._idtype so the URI uniquely encodes indptr dtype and avoids returning incorrectly specialized kernels.tests/attention/test_dllm_vs_flex_attention.py (1)
61-104:⚠️ Potential issue | 🟠 MajorPer-request offset plumbing is still untested.
compute_block_extend_referencehardcodeskv_offset=0,make_block_extend_mask_modignoresb, and the batch driver at line 321 constructsq_offsetsviatorch.full(...), so every request gets the same offset. The per-batchmaybe_q_block_expanding_offset/maybe_kv_block_expanding_offsetarrays added in this PR — plus the cascade current-chunkkv_block_expanding_offsetpath — never see heterogeneous inputs in these tests. Please add at least one case with distinct per-requestq_offsets and a non-zerokv_offsetto exercise the batch-indexed array reads.Also applies to: 321-321
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/attention/test_dllm_vs_flex_attention.py` around lines 61 - 104, The tests never exercise per-request offsets: update compute_block_extend_reference to accept and use a kv_offset parameter (and propagate it into q_pos/k_pos calculation), modify make_block_extend_mask_mod so the returned block_extend_mask reads per-request q_offset (use the b argument) and accepts/uses a kv_offset path, and change the test batch setup that currently uses torch.full(...) for q_offsets to create heterogeneous per-request q_offsets and a non-zero kv_offset (so maybe_q_block_expanding_offset, maybe_kv_block_expanding_offset and the kv_block_expanding_offset cascade actually read varied entries); run the updated test to ensure the batch-indexed reads are exercised.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/dllm/batch_block_extend.py`:
- Line 177: The module-scope variable _BATCH_BE_MODULE_CACHE is unused; either
remove it or wire it into the wrapper memoization: if removing, delete the
_BATCH_BE_MODULE_CACHE definition and any unused references; if implementing
memoization, update the _create_inner_wrapper methods on
BatchBlockExtendPagedOffsetWrapper and BatchBlockExtendRaggedOffsetWrapper to
check _BATCH_BE_MODULE_CACHE for an existing wrapper keyed by the unique
parameters (e.g., backend module name and config), return the cached instance
when present, and store newly created wrappers in _BATCH_BE_MODULE_CACHE to
avoid recreating them.
In `@flashinfer/dllm/block_extend.py`:
- Around line 200-204: The condition that checks FLASHINFER_DISABLE_JIT is
incorrectly treating the string "0" as truthy; update the check in
block_extend.py so it uses an explicit comparison (e.g.,
os.environ.get("FLASHINFER_DISABLE_JIT") == "1" or the same boolean parsing used
for FLASHINFER_FORCE_JIT) before raising the RuntimeError, and keep the same
error message referencing _get_aot_path(uri).
In `@tests/attention/test_dllm_vs_flex_attention.py`:
- Around line 1089-1093: The try/except around benchmark_with_cuda_graph
silently swallows errors; change the except block in the section calling
benchmark_with_cuda_graph(_run_fi, warmup_iters, bench_iters) so it captures the
exception as e, logs or prints the exception (including a clear message that
CUDA-graph capture failed) and sets a sentinel in entry (e.g., entry["fi_cg_ms"]
= None or an explicit failure string) so the result row shows the failure;
update the except clause that currently just does `pass` to record the error and
preserve diagnostic visibility.
---
Duplicate comments:
In `@flashinfer/dllm/batch_block_extend.py`:
- Around line 293-296: The _create_inner_wrapper method currently assigns
self._backend = effective_backend which locks the instance to the first chosen
backend; instead, stop mutating self._backend and only use the local
effective_backend variable when constructing the inner wrapper/URI/variant/inner
objects (leave self._preferred_backend and self._backend unchanged so future
re-plans can re-select backends). Apply the same change in the other
_create_inner_wrapper implementation as well (remove any assignment to
self._backend and rely on effective_backend locally).
- Around line 179-186: _get_batch_be_module_uri currently omits the indptr
idtype which can cause silent kernel-mismatch; change its signature to accept an
idtype parameter (e.g. def _get_batch_be_module_uri(head_dim: int, dtype:
torch.dtype, idtype: torch.dtype) -> str), add an _idtype_map (e.g.
{torch.int32: "i32", torch.int64: "i64"}) and raise ValueError for unsupported
idtypes mirroring the element-dtype check, and include the mapped idtype token
in the returned URI string; then update all callers (notably places that call
_get_batch_be_module_uri and _create_inner_wrapper and any other call sites that
construct the module URI) to pass self._idtype so the URI uniquely encodes
indptr dtype and avoids returning incorrectly specialized kernels.
In `@tests/attention/test_dllm_vs_flex_attention.py`:
- Around line 61-104: The tests never exercise per-request offsets: update
compute_block_extend_reference to accept and use a kv_offset parameter (and
propagate it into q_pos/k_pos calculation), modify make_block_extend_mask_mod so
the returned block_extend_mask reads per-request q_offset (use the b argument)
and accepts/uses a kv_offset path, and change the test batch setup that
currently uses torch.full(...) for q_offsets to create heterogeneous per-request
q_offsets and a non-zero kv_offset (so maybe_q_block_expanding_offset,
maybe_kv_block_expanding_offset and the kv_block_expanding_offset cascade
actually read varied entries); run the updated test to ensure the batch-indexed
reads are exercised.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 2afdf1fd-f0a9-41ad-84de-e36b89558f65
📒 Files selected for processing (6)
csrc/single_prefill_sm90_customize_config.jinjaflashinfer/dllm/batch_block_extend.pyflashinfer/dllm/block_extend.pyinclude/flashinfer/attention/default_prefill_params.cuhinclude/flashinfer/utils.cuhtests/attention/test_dllm_vs_flex_attention.py
🚧 Files skipped from review as they are similar to previous changes (2)
- csrc/single_prefill_sm90_customize_config.jinja
- include/flashinfer/attention/default_prefill_params.cuh
| try: | ||
| fi_cg_ms = benchmark_with_cuda_graph(_run_fi, warmup_iters, bench_iters) | ||
| entry["fi_cg_ms"] = fi_cg_ms | ||
| except Exception: | ||
| pass |
There was a problem hiding this comment.
Don't silently swallow CUDA-graph capture failures.
The bare try: … except Exception: pass drops any exception from benchmark_with_cuda_graph, so a real capture failure (e.g., dynamic shapes, stream issues) silently produces a result row with fi_cg_ms missing and no diagnostic. At minimum, print the exception so the benchmark output makes the failure visible.
🛠️ Proposed fix
try:
fi_cg_ms = benchmark_with_cuda_graph(_run_fi, warmup_iters, bench_iters)
entry["fi_cg_ms"] = fi_cg_ms
- except Exception:
- pass
+ except Exception as e:
+ print(f" [dllm_bs={dbs}, seq={seq_len}] FI CUDA Graph failed: {e}")🧰 Tools
🪛 Ruff (0.15.10)
[error] 1092-1093: try-except-pass detected, consider logging the exception
(S110)
[warning] 1092-1092: Do not catch blind exception: Exception
(BLE001)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/attention/test_dllm_vs_flex_attention.py` around lines 1089 - 1093, The
try/except around benchmark_with_cuda_graph silently swallows errors; change the
except block in the section calling benchmark_with_cuda_graph(_run_fi,
warmup_iters, bench_iters) so it captures the exception as e, logs or prints the
exception (including a clear message that CUDA-graph capture failed) and sets a
sentinel in entry (e.g., entry["fi_cg_ms"] = None or an explicit failure string)
so the result row shows the failure; update the except clause that currently
just does `pass` to record the error and preserve diagnostic visibility.
5ec571a to
93965f5
Compare
There was a problem hiding this comment.
Actionable comments posted: 4
♻️ Duplicate comments (5)
flashinfer/dllm/batch_block_extend.py (1)
179-186:⚠️ Potential issue | 🟠 MajorURI still ignores
idtype(and coerces FP8 dtypes into fp16).
_get_batch_be_module_urionly encodeshead_dimand the element dtype. Differentidtypevariants (torch.int32vstorch.int64) will collide under the sameflashinfer::{uri}_*name, and the wrapper-recreation check at Lines 338/472 tracks idtype but the URI does not. Please bakeidtype(and any supported dtype explicitly) into the URI.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/dllm/batch_block_extend.py` around lines 179 - 186, The URI builder _get_batch_be_module_uri currently only encodes head_dim and a limited dtype map, causing collisions across different idtype values and misrepresenting FP8 types; update _get_batch_be_module_uri to include the idtype (e.g., torch.int32 vs torch.int64) in the returned string and expand the dtype mapping to explicitly handle all supported dtypes (including FP8 variants) so each unique combination of head_dim, element dtype, and idtype produces a distinct URI; reference the function name _get_batch_be_module_uri and ensure the returned string format embeds both idtype and the normalized dtype token.flashinfer/dllm/block_extend.py (1)
200-204:⚠️ Potential issue | 🟠 Major
FLASHINFER_DISABLE_JIT=0still disables JIT.
os.environ.get("FLASHINFER_DISABLE_JIT")returns the literal string, and"0"is truthy in Python — so users setting this to the natural "off" value get JIT disabled. Line 85 in the same file correctly uses== "1"forFLASHINFER_FORCE_JIT; please mirror that here.🛠️ Proposed fix
- if os.environ.get("FLASHINFER_DISABLE_JIT"): + if os.environ.get("FLASHINFER_DISABLE_JIT", "0") == "1":🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/dllm/block_extend.py` around lines 200 - 204, The environment check currently uses os.environ.get("FLASHINFER_DISABLE_JIT") which treats "0" as truthy and incorrectly disables JIT; update the condition to explicitly compare the env value to "1" (i.e., FLASHINFER_DISABLE_JIT == "1") so only an explicit "1" disables JIT, keeping the existing RuntimeError message that references _get_aot_path(uri) unchanged.tests/attention/test_dllm_vs_flex_attention.py (3)
99-104:⚠️ Potential issue | 🟡 MinorBatch-indexed
kv_block_expanding_offsetpath still has no heterogeneous-batch test.
make_block_extend_mask_modignoresb, the reference uses a scalarq_offset, and the batch setup usestorch.full((num_requests,), q_offset, ...)— so every request in the batch gets the same offset andkv_offsetsis never exercised end-to-end against a per-request reference. The PR introduces per-batchq_block_expanding_offset/kv_block_expanding_offsetaccessors; please add at least one correctness case with heterogeneousq_offsetsand a non-zerokv_offsetsto cover the new plumbing.Also applies to: 321-321
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/attention/test_dllm_vs_flex_attention.py` around lines 99 - 104, The batch-indexed path isn't tested with heterogeneous per-request offsets: update the test in tests/attention/test_dllm_vs_flex_attention.py to exercise per-batch q_offsets and non-zero kv_offsets by fixing make_block_extend_mask_mod to actually use the batch index parameter b (i.e., compute q_global from q_offset[b] rather than a scalar), create a heterogeneous torch tensor for q_block_expanding_offset (not torch.full) and set a non-zero kv_block_expanding_offset/kv_offsets, then add an assertion comparing block_extend_mask results (or end-to-end attention outputs) between the dllm implementation and the flex/reference implementation so the new q_block_expanding_offset and kv_block_expanding_offset plumbing is covered.
20-56:⚠️ Potential issue | 🟡 MinorModule-level import-time prints on non-CUDA/unsupported hosts.
Since this file is named
test_dllm_vs_flex_attention.py, pytest will still import it during collection. On machines without CUDA, the top-of-file prints fire but nothing gets skipped cleanly. The driver entry points have been renamed tobench_*(good), but please gate the module with an explicitpytest.importorskip/arch check (e.g.,is_sm90a_supported/is_sm80a_supportedfromflashinfer.utils) so collection is quiet and deterministic on unsupported runners. As per coding guidelines: "Skip test execution on unsupported GPU architectures usingflashinfer.utilscheck functions".🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/attention/test_dllm_vs_flex_attention.py` around lines 20 - 56, Top-level imports in test_dllm_vs_flex_attention.py cause prints during pytest collection on unsupported/non-CUDA hosts; wrap/gate module import with pytest.importorskip or an explicit arch check using flashinfer.utils functions to prevent collection printing. Concretely, at the top of the module call pytest.importorskip("flashinfer") or call flashinfer.utils.is_sm90a_supported()/is_sm80a_supported() (and skip via pytest.skip if neither is true) before importing single_prefill_with_kv_cache, BatchBlockExtendPagedOffsetWrapper, BatchBlockExtendRaggedOffsetWrapper, flex_attention, create_block_mask so the file is skipped quietly on unsupported runners and no module-level prints occur.
1089-1093:⚠️ Potential issue | 🟡 MinorBare
except: passstill silently drops CUDA-graph capture failures.
entry["fi_cg_ms"]is omitted with no diagnostic, making real capture failures invisible. At minimum print the exception; the same pattern also appears at Line 1029 in the workspace probe and Line 1130 in the Flex path (the Flex path already prints, which is the right shape).🛠️ Proposed fix
- except Exception: - pass + except Exception as e: + print(f" [dllm_bs={dbs}, seq={seq_len}] FI CUDA Graph failed: {e}")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/attention/test_dllm_vs_flex_attention.py` around lines 1089 - 1093, Replace the bare "except: pass" around the CUDA-graph capture so failures are not silently dropped: wrap the benchmark_with_cuda_graph(_run_fi, ...) call in "except Exception as e:" and print or log the exception (including context) and set entry["fi_cg_ms"] to a sentinel (e.g., None or an "error" value) so the failure is recorded; apply the same change to the other identical patterns (the workspace probe capture and the Flex path) to ensure all capture exceptions are surfaced rather than ignored.
🧹 Nitpick comments (3)
flashinfer/dllm/__init__.py (1)
16-19: Unused private imports in package__init__.
_check_batch_be_aot_available,_get_batch_be_aot_path, and_get_batch_be_module_uriare imported but not added to__all__and not otherwise referenced here, so they're dead imports at the package level. Either drop them from the import list or promote them to the public surface consciously.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/dllm/__init__.py` around lines 16 - 19, The three private symbols _check_batch_be_aot_available, _get_batch_be_aot_path, and _get_batch_be_module_uri are imported into flashinfer.dllm.__init__ but never exported or used; either remove them from the import list to eliminate dead imports or intentionally expose them by adding their names to __all__ (or re-exporting under public names) so they become part of the package surface; update the import statement and/or the __all__ list in __init__.py accordingly to keep imports and public API consistent.tests/attention/test_dllm_vs_flex_attention.py (1)
1073-1074: Loop-variable closure in_run_fi(Ruff B023).
_run_ficaptureswrapper/q/k/vby reference from the enclosing loop. It happens to work today because the closure is invoked inside the same iteration, but a future refactor (e.g., deferring the callable past thedel ... wrapperat Line 1097) would silently pick up a later iteration's bindings. Bind explicitly:def _run_fi(_w=wrapper, _q=q, _k=k, _v=v): _w.run(_q, _k, _v)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/attention/test_dllm_vs_flex_attention.py` around lines 1073 - 1074, The nested function _run_fi closes over loop variables (wrapper, q, k, v) which can change across iterations; update the definition of _run_fi to bind those values as default parameters (e.g., def _run_fi(_w=wrapper, _q=q, _k=k, _v=v): ...) and call _w.run(_q, _k, _v) so the callable captures the current iteration's bindings instead of referencing them by closure.flashinfer/dllm/block_extend.py (1)
292-296:backend="auto"here does not check actual kernel availability.Unlike
batch_block_extend.py'sselect_best_backend[_paged], this helper picks purely onis_sm90a_supported(q.device)and then unconditionally callsget_block_extend_module_with_offset(..., backend=backend). On a Hopper box where only the FA2 variant is AOT-compiled and JIT is disabled, auto mode will hard-fail instead of falling back. Consider reusing/consolidating the availability-aware selector used in the batch module.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/dllm/block_extend.py` around lines 292 - 296, The current auto backend selection uses is_sm90a_supported(q.device) and then directly calls get_block_extend_module_with_offset, which can hard-fail if the chosen kernel isn't actually available; change the logic to consult the availability-aware selector used in batch_block_extend.py (e.g., call select_best_backend or its _paged variant) instead of the simple is_sm90a_supported check, or wrap get_block_extend_module_with_offset in a try/fallback that falls back to the other backend on failure; update the backend local variable from that selector and then call get_block_extend_module_with_offset(head_dim=head_dim, dtype=dtype, backend=backend, device=q.device).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/dllm/batch_block_extend.py`:
- Around line 313-317: The jit_kwargs currently lists mask_modes=[0,1,2,3,4]
which forces compilation of all mask-mode kernels even though the DLLM wrappers
call the inner prefill with mask_mode=MaskMode.BLOCK_EXPANDING.value; change the
mask_modes entry in jit_kwargs to only include [MaskMode.BLOCK_EXPANDING.value]
(or the literal 4) so only mode 4 is JIT/AOT-compiled; update both occurrences
around the jit_kwargs definitions (the one at lines ~313 and the similar block
at ~450-454) to avoid inflating build/AOT size and cache footprint.
- Around line 34-64: check_jit_environment currently shells out to a
PATH-resolved "nvcc" and uses a broad except that swallows real errors; update
it to locate nvcc via shutil.which and fallback to CUDA_HOME (e.g., check
os.environ["CUDA_HOME"] + "/bin/nvcc") before declaring nvcc missing and include
a clear issue message when nvcc is found/not found, and replace the broad
"except Exception as e" around the tvm_ffi probe with a narrower handler (or at
minimum log/append the full exception details to results["issues"] instead of
swallowing) so callers can see the real failure; refer to the function name
check_jit_environment, the results dict keys ("nvcc_ok", "issues"), and the
tvm_ffi probe block to make the changes.
- Around line 251-267: Add the required decorators to the public DLLM APIs:
import and apply `@flashinfer_api` to BatchBlockExtendPagedOffsetWrapper,
BatchBlockExtendRaggedOffsetWrapper, batch_block_extend_cascade, and
sglang_style_cascade_attention so these high-level Python APIs have crash-safe
logging; for the FA3/architecture-gated code paths (the functions/constructors
that require FA3), also import and apply `@backend_requirement` with the
appropriate backend constraint (e.g., the FA3 identifier used elsewhere in the
repo) to those symbols; ensure imports for flashinfer_api and
backend_requirement are added at the top if missing and keep decorator placement
immediately above the class or def declarations for the listed symbols.
In `@flashinfer/dllm/block_extend.py`:
- Around line 235-245: Add the missing decorators to make these public APIs
crash-safe and declare architecture requirements: annotate
block_extend_attention_with_offset and block_extend_cascade with `@flashinfer_api`
to enable crash-safe logging, and add `@backend_requirement` with the appropriate
backend tag(s) used for FA3 dispatch (e.g., `@backend_requirement`("fa3") or the
specific backend string your dispatch expects) to the same functions; ensure the
decorators are imported at the top of flashinfer.dllm (or block_extend.py) if
not already present and keep any existing signature and defaults intact so only
decorators are added.
---
Duplicate comments:
In `@flashinfer/dllm/batch_block_extend.py`:
- Around line 179-186: The URI builder _get_batch_be_module_uri currently only
encodes head_dim and a limited dtype map, causing collisions across different
idtype values and misrepresenting FP8 types; update _get_batch_be_module_uri to
include the idtype (e.g., torch.int32 vs torch.int64) in the returned string and
expand the dtype mapping to explicitly handle all supported dtypes (including
FP8 variants) so each unique combination of head_dim, element dtype, and idtype
produces a distinct URI; reference the function name _get_batch_be_module_uri
and ensure the returned string format embeds both idtype and the normalized
dtype token.
In `@flashinfer/dllm/block_extend.py`:
- Around line 200-204: The environment check currently uses
os.environ.get("FLASHINFER_DISABLE_JIT") which treats "0" as truthy and
incorrectly disables JIT; update the condition to explicitly compare the env
value to "1" (i.e., FLASHINFER_DISABLE_JIT == "1") so only an explicit "1"
disables JIT, keeping the existing RuntimeError message that references
_get_aot_path(uri) unchanged.
In `@tests/attention/test_dllm_vs_flex_attention.py`:
- Around line 99-104: The batch-indexed path isn't tested with heterogeneous
per-request offsets: update the test in
tests/attention/test_dllm_vs_flex_attention.py to exercise per-batch q_offsets
and non-zero kv_offsets by fixing make_block_extend_mask_mod to actually use the
batch index parameter b (i.e., compute q_global from q_offset[b] rather than a
scalar), create a heterogeneous torch tensor for q_block_expanding_offset (not
torch.full) and set a non-zero kv_block_expanding_offset/kv_offsets, then add an
assertion comparing block_extend_mask results (or end-to-end attention outputs)
between the dllm implementation and the flex/reference implementation so the new
q_block_expanding_offset and kv_block_expanding_offset plumbing is covered.
- Around line 20-56: Top-level imports in test_dllm_vs_flex_attention.py cause
prints during pytest collection on unsupported/non-CUDA hosts; wrap/gate module
import with pytest.importorskip or an explicit arch check using flashinfer.utils
functions to prevent collection printing. Concretely, at the top of the module
call pytest.importorskip("flashinfer") or call
flashinfer.utils.is_sm90a_supported()/is_sm80a_supported() (and skip via
pytest.skip if neither is true) before importing single_prefill_with_kv_cache,
BatchBlockExtendPagedOffsetWrapper, BatchBlockExtendRaggedOffsetWrapper,
flex_attention, create_block_mask so the file is skipped quietly on unsupported
runners and no module-level prints occur.
- Around line 1089-1093: Replace the bare "except: pass" around the CUDA-graph
capture so failures are not silently dropped: wrap the
benchmark_with_cuda_graph(_run_fi, ...) call in "except Exception as e:" and
print or log the exception (including context) and set entry["fi_cg_ms"] to a
sentinel (e.g., None or an "error" value) so the failure is recorded; apply the
same change to the other identical patterns (the workspace probe capture and the
Flex path) to ensure all capture exceptions are surfaced rather than ignored.
---
Nitpick comments:
In `@flashinfer/dllm/__init__.py`:
- Around line 16-19: The three private symbols _check_batch_be_aot_available,
_get_batch_be_aot_path, and _get_batch_be_module_uri are imported into
flashinfer.dllm.__init__ but never exported or used; either remove them from the
import list to eliminate dead imports or intentionally expose them by adding
their names to __all__ (or re-exporting under public names) so they become part
of the package surface; update the import statement and/or the __all__ list in
__init__.py accordingly to keep imports and public API consistent.
In `@flashinfer/dllm/block_extend.py`:
- Around line 292-296: The current auto backend selection uses
is_sm90a_supported(q.device) and then directly calls
get_block_extend_module_with_offset, which can hard-fail if the chosen kernel
isn't actually available; change the logic to consult the availability-aware
selector used in batch_block_extend.py (e.g., call select_best_backend or its
_paged variant) instead of the simple is_sm90a_supported check, or wrap
get_block_extend_module_with_offset in a try/fallback that falls back to the
other backend on failure; update the backend local variable from that selector
and then call get_block_extend_module_with_offset(head_dim=head_dim,
dtype=dtype, backend=backend, device=q.device).
In `@tests/attention/test_dllm_vs_flex_attention.py`:
- Around line 1073-1074: The nested function _run_fi closes over loop variables
(wrapper, q, k, v) which can change across iterations; update the definition of
_run_fi to bind those values as default parameters (e.g., def
_run_fi(_w=wrapper, _q=q, _k=k, _v=v): ...) and call _w.run(_q, _k, _v) so the
callable captures the current iteration's bindings instead of referencing them
by closure.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: dd7c2332-0df7-4061-8e4f-2809936d2756
📒 Files selected for processing (7)
csrc/single_prefill_sm90_customize_config.jinjaflashinfer/dllm/__init__.pyflashinfer/dllm/batch_block_extend.pyflashinfer/dllm/block_extend.pyinclude/flashinfer/attention/default_prefill_params.cuhinclude/flashinfer/utils.cuhtests/attention/test_dllm_vs_flex_attention.py
🚧 Files skipped from review as they are similar to previous changes (3)
- csrc/single_prefill_sm90_customize_config.jinja
- include/flashinfer/utils.cuh
- include/flashinfer/attention/default_prefill_params.cuh
| def check_jit_environment() -> dict: | ||
| """Check if JIT compilation environment is working properly""" | ||
| results = { | ||
| "tvm_ffi_ok": False, | ||
| "device_guard_ok": False, | ||
| "nvcc_ok": False, | ||
| "issues": [], | ||
| } | ||
|
|
||
| try: | ||
| import tvm_ffi | ||
| results["tvm_ffi_ok"] = True | ||
| include_path = tvm_ffi.libinfo.find_include_path() | ||
| device_guard_path = Path(include_path) / "tvm" / "ffi" / "extra" / "cuda" / "device_guard.h" | ||
| results["device_guard_ok"] = device_guard_path.exists() | ||
| if not results["device_guard_ok"]: | ||
| results["issues"].append(f"Missing TVM header: {device_guard_path}") | ||
| except ImportError: | ||
| results["issues"].append("tvm_ffi package not installed") | ||
| except Exception as e: | ||
| results["issues"].append(f"Error checking tvm_ffi: {e}") | ||
|
|
||
| import subprocess | ||
| try: | ||
| result = subprocess.run(["nvcc", "--version"], capture_output=True, text=True) | ||
| results["nvcc_ok"] = result.returncode == 0 | ||
| except FileNotFoundError: | ||
| results["nvcc_ok"] = False | ||
| results["issues"].append("nvcc not found in PATH") | ||
|
|
||
| return results |
There was a problem hiding this comment.
check_jit_environment relies on PATH-resolved nvcc and a blind except Exception.
Calling subprocess.run(["nvcc", "--version"]) is a partial-executable-path invocation (S607); on systems where CUDA is present but nvcc is not in PATH (common in production containers that ship only runtime libs) this will return jit_available=False and drop the wrapper into a confusing "no kernel available" error even when AOT would have worked. Also, except Exception around the tvm_ffi probe (line 53) swallows real errors. Consider: (1) resolving nvcc via CUDA_HOME/shutil.which with a helpful message, and (2) narrowing or at least logging the exception.
🧰 Tools
🪛 Ruff (0.15.10)
[warning] 53-53: Do not catch blind exception: Exception
(BLE001)
[error] 58-58: Starting a process with a partial executable path
(S607)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/dllm/batch_block_extend.py` around lines 34 - 64,
check_jit_environment currently shells out to a PATH-resolved "nvcc" and uses a
broad except that swallows real errors; update it to locate nvcc via
shutil.which and fallback to CUDA_HOME (e.g., check os.environ["CUDA_HOME"] +
"/bin/nvcc") before declaring nvcc missing and include a clear issue message
when nvcc is found/not found, and replace the broad "except Exception as e"
around the tvm_ffi probe with a narrower handler (or at minimum log/append the
full exception details to results["issues"] instead of swallowing) so callers
can see the real failure; refer to the function name check_jit_environment, the
results dict keys ("nvcc_ok", "issues"), and the tvm_ffi probe block to make the
changes.
| class BatchBlockExtendPagedOffsetWrapper: | ||
| """Batch Block Extend Paged Attention with Offset Support""" | ||
|
|
||
| def __init__( | ||
| self, | ||
| float_workspace_buffer: torch.Tensor, | ||
| kv_layout: str = "NHD", | ||
| dllm_block_size: int = 256, | ||
| use_cuda_graph: bool = False, | ||
| qo_indptr_buf: Optional[torch.Tensor] = None, | ||
| paged_kv_indptr_buf: Optional[torch.Tensor] = None, | ||
| paged_kv_indices_buf: Optional[torch.Tensor] = None, | ||
| paged_kv_last_page_len_buf: Optional[torch.Tensor] = None, | ||
| q_offsets_buf: Optional[torch.Tensor] = None, | ||
| kv_offsets_buf: Optional[torch.Tensor] = None, | ||
| backend: str = "auto", | ||
| ) -> None: |
There was a problem hiding this comment.
Missing @flashinfer_api / @backend_requirement decorators on public DLLM APIs.
BatchBlockExtendPagedOffsetWrapper, BatchBlockExtendRaggedOffsetWrapper, batch_block_extend_cascade, and sglang_style_cascade_attention are high-level Python APIs exposed from flashinfer.dllm but none carry the @flashinfer_api decorator for crash-safe logging, and the FA3 path is architecture-gated without @backend_requirement. As per coding guidelines: "Use @flashinfer_api decorator on high-level Python APIs" and "Use @backend_requirement decorator on APIs with architecture-specific requirements".
Also applies to: 392-406, 521-539, 608-623
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/dllm/batch_block_extend.py` around lines 251 - 267, Add the
required decorators to the public DLLM APIs: import and apply `@flashinfer_api` to
BatchBlockExtendPagedOffsetWrapper, BatchBlockExtendRaggedOffsetWrapper,
batch_block_extend_cascade, and sglang_style_cascade_attention so these
high-level Python APIs have crash-safe logging; for the FA3/architecture-gated
code paths (the functions/constructors that require FA3), also import and apply
`@backend_requirement` with the appropriate backend constraint (e.g., the FA3
identifier used elsewhere in the repo) to those symbols; ensure imports for
flashinfer_api and backend_requirement are added at the top if missing and keep
decorator placement immediately above the class or def declarations for the
listed symbols.
| jit_kwargs = { | ||
| "pos_encoding_mode": 0, "use_sliding_window": False, | ||
| "use_logits_soft_cap": False, "use_fp16_qk_reduction": False, | ||
| "mask_modes": [0, 1, 2, 3, 4], | ||
| } |
There was a problem hiding this comment.
mask_modes=[0, 1, 2, 3, 4] compiles every mask variant for a wrapper that only uses mode 4.
Both DLLM wrappers always call the inner prefill with mask_mode=MaskMode.BLOCK_EXPANDING.value (Lines 369, 502), but the JIT request instantiates all five mask-mode kernel specializations. This defeats the PR's stated design (mode 4 should be compiled only when DLLM wrappers request it) and inflates build/AOT size and cache footprint for each (head_dim, dtype, backend) combo. Restrict to mask_modes=[4] unless you deliberately want the other modes here.
🛠️ Proposed fix
- "mask_modes": [0, 1, 2, 3, 4],
+ "mask_modes": [MaskMode.BLOCK_EXPANDING.value],Also applies to: 450-454
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/dllm/batch_block_extend.py` around lines 313 - 317, The jit_kwargs
currently lists mask_modes=[0,1,2,3,4] which forces compilation of all mask-mode
kernels even though the DLLM wrappers call the inner prefill with
mask_mode=MaskMode.BLOCK_EXPANDING.value; change the mask_modes entry in
jit_kwargs to only include [MaskMode.BLOCK_EXPANDING.value] (or the literal 4)
so only mode 4 is JIT/AOT-compiled; update both occurrences around the
jit_kwargs definitions (the one at lines ~313 and the similar block at ~450-454)
to avoid inflating build/AOT size and cache footprint.
| def block_extend_attention_with_offset( | ||
| q: torch.Tensor, | ||
| k: torch.Tensor, | ||
| v: torch.Tensor, | ||
| dllm_block_size: int, | ||
| q_offset: int = 0, | ||
| kv_offset: int = 0, | ||
| sm_scale: Optional[float] = None, | ||
| return_lse: bool = False, | ||
| backend: str = "auto", | ||
| ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: |
There was a problem hiding this comment.
Public APIs lack @flashinfer_api decorator.
block_extend_attention_with_offset and block_extend_cascade are exported from flashinfer.dllm as top-level APIs but don't carry @flashinfer_api for crash-safe logging, and FA3 dispatch is architecture-gated without @backend_requirement. As per coding guidelines: "Use @flashinfer_api decorator on high-level Python APIs for crash-safe logging" and "Use @backend_requirement decorator on APIs with architecture-specific requirements".
Also applies to: 314-324
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/dllm/block_extend.py` around lines 235 - 245, Add the missing
decorators to make these public APIs crash-safe and declare architecture
requirements: annotate block_extend_attention_with_offset and
block_extend_cascade with `@flashinfer_api` to enable crash-safe logging, and add
`@backend_requirement` with the appropriate backend tag(s) used for FA3 dispatch
(e.g., `@backend_requirement`("fa3") or the specific backend string your dispatch
expects) to the same functions; ensure the decorators are imported at the top of
flashinfer.dllm (or block_extend.py) if not already present and keep any
existing signature and defaults intact so only decorators are added.
There was a problem hiding this comment.
Actionable comments posted: 4
♻️ Duplicate comments (6)
flashinfer/dllm/batch_block_extend.py (4)
311-315:⚠️ Potential issue | 🟠 MajorCompile only the block-expanding mask specialization here.
Both wrappers always run with
MaskMode.BLOCK_EXPANDING.value, but the JIT request still instantiates all five mask modes. This directly expands JIT/AOT cache size for DLLM wrappers.🛠️ Proposed fix
jit_kwargs = { "pos_encoding_mode": 0, "use_sliding_window": False, "use_logits_soft_cap": False, "use_fp16_qk_reduction": False, - "mask_modes": [0, 1, 2, 3, 4], + "mask_modes": [MaskMode.BLOCK_EXPANDING.value], }Apply this to both paged and ragged wrapper JIT kwargs.
Also applies to: 448-452
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/dllm/batch_block_extend.py` around lines 311 - 315, The jit_kwargs currently lists all mask modes which forces JIT/AOT to compile five specializations; change the jit_kwargs in both wrapper configurations (the jit_kwargs dict used for the paged and ragged wrappers — the occurrences around the current jit_kwargs and the other block at lines ~448-452) to only include MaskMode.BLOCK_EXPANDING.value for the "mask_modes" key so the JIT compiles only the block-expanding specialization; ensure you reference the same jit_kwargs variable names and import/usage of MaskMode present in this module.
249-265:⚠️ Potential issue | 🟡 MinorAdd the required decorators to exported DLLM APIs.
These wrapper classes and helper functions are public high-level APIs exported from
flashinfer.dllm, but they do not carry@flashinfer_api; the backend-dispatching APIs should also declare backend requirements for architecture tracking.As per coding guidelines,
flashinfer/**/*.py: Use@flashinfer_apidecorator on high-level Python APIs and@backend_requirementdecorator on APIs with architecture-specific requirements.Also applies to: 390-404, 519-537, 606-621
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/dllm/batch_block_extend.py` around lines 249 - 265, The exported wrapper class BatchBlockExtendPagedOffsetWrapper (and the other public wrapper classes and helper functions at the noted locations) are missing the required decorators; add the `@flashinfer_api` decorator to each high-level API (e.g., class BatchBlockExtendPagedOffsetWrapper) and add `@backend_requirement` where the API enforces architecture-specific backends (APIs that accept a backend: str parameter or perform backend dispatching) so the backend-dispatching is tracked; ensure you import these decorators and apply `@backend_requirement` to the constructors or functions that accept the backend argument (e.g., any __init__ or factory functions with backend: str) while leaving non-backend-specific helpers only with `@flashinfer_api`.
561-571:⚠️ Potential issue | 🟠 MajorDo not default prefix requests to zero offsets.
When a prefix exists,
q_offsets=Nonemakes the current chunk start at global position 0, so the block-expanding mask is wrong for any nonzero prefix length. Derive per-request prefix lengths from paged metadata or require explicit offsets.🛠️ Proposed fix
if q_offsets is None: if has_prefix: - import warnings - warnings.warn( - "q_offsets is None but prefix exists. Block extend mask may be incorrect " - "if prefix length is nonzero. Consider passing explicit q_offsets.", - stacklevel=2, - ) - q_offsets = torch.zeros(batch_size, dtype=torch.int32, device=device) + q_offsets = ( + (paged_kv_indptr[1:] - paged_kv_indptr[:-1] - 1) * page_size + + paged_kv_last_page_len + ).to(device=device, dtype=qo_indptr.dtype) + else: + q_offsets = torch.zeros(batch_size, dtype=qo_indptr.dtype, device=device)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/dllm/batch_block_extend.py` around lines 561 - 571, The current code defaults q_offsets to zeros when q_offsets is None (and aliases kv_offsets to q_offsets), which breaks the block-extend mask whenever has_prefix is True; instead, when has_prefix is True derive per-request prefix lengths from the paged metadata and populate q_offsets/kv_offsets accordingly (or raise/require explicit offsets) rather than assigning zeros; update the logic around q_offsets, kv_offsets, and has_prefix in batch_block_extend.py so q_offsets is computed from the request/page metadata for each of the batch_size entries on the given device and dtype, and only fall back to a true zero-offset alias when you have verified there is no prefix for all requests.
177-184:⚠️ Potential issue | 🟠 MajorInclude
idtypein the module URI and reject unsupported index dtypes.The generated ABI depends on
idtype, but the URI only includeshead_dimand data dtype. Replanning the same(head_dim, dtype)withint32vsint64can collide on the same registered module name;dtype_map_for_idtype()also aliases unknown dtypes toint32_t.🛠️ Proposed fix
-def _get_batch_be_module_uri(head_dim: int, dtype: torch.dtype) -> str: +def _get_batch_be_module_uri( + head_dim: int, + dtype: torch.dtype, + idtype: torch.dtype = torch.int32, +) -> str: _dtype_map = {torch.float16: "fp16", torch.bfloat16: "bf16"} + _idtype_map = {torch.int32: "i32", torch.int64: "i64"} if dtype not in _dtype_map: raise ValueError( f"Unsupported dtype {dtype} for Block Extend Attention. " f"Supported: {list(_dtype_map.keys())}" ) - return f"batch_prefill_block_expanding_hd{head_dim}_{_dtype_map[dtype]}" + if idtype not in _idtype_map: + raise ValueError( + f"Unsupported idtype {idtype} for Block Extend Attention. " + f"Supported: {list(_idtype_map.keys())}" + ) + return ( + f"batch_prefill_block_expanding_hd{head_dim}_" + f"{_dtype_map[dtype]}_{_idtype_map[idtype]}" + ) @@ -def dtype_map_for_idtype(idtype: torch.dtype) -> str: - return {torch.int32: "int32_t", torch.int64: "int64_t"}.get(idtype, "int32_t") +def dtype_map_for_idtype(idtype: torch.dtype) -> str: + _idtype_map = {torch.int32: "int32_t", torch.int64: "int64_t"} + if idtype not in _idtype_map: + raise ValueError(f"Unsupported idtype {idtype}") + return _idtype_map[idtype]Then pass
idtypeat each_get_batch_be_module_uri(...)call site.Also applies to: 304-308, 386-387, 441-445
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/dllm/batch_block_extend.py` around lines 177 - 184, The module URI builder _get_batch_be_module_uri currently only includes head_dim and tensor dtype, which can collide across different index dtypes; update _get_batch_be_module_uri to accept an idtype (torch.dtype for indices), extend the internal mapping (use dtype_map_for_idtype or same mapping logic) to map supported index dtypes (e.g., torch.int32->"i32", torch.int64->"i64") and raise ValueError for unsupported index dtypes, then include the idtype token in the returned URI string (e.g., ..._hd{head_dim}_{_dtype_map[input_dtype]}_{idtype_token}). Finally, update all call sites of _get_batch_be_module_uri (and any other variants at the other locations mentioned) to pass the idtype argument so the generated ABI name is unique per (head_dim, dtype, idtype).flashinfer/dllm/block_extend.py (1)
235-245:⚠️ Potential issue | 🟡 MinorAdd the required public API decorators.
block_extend_attention_with_offsetandblock_extend_cascadeare exported high-level APIs but still lack crash-safe API logging, and their FA2/FA3 backend dispatch should declare the architecture-gated backend requirement.As per coding guidelines,
flashinfer/**/*.py: Use@flashinfer_apidecorator on high-level Python APIs and@backend_requirementdecorator on APIs with architecture-specific requirements.Also applies to: 314-324
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/dllm/block_extend.py` around lines 235 - 245, The high-level APIs block_extend_attention_with_offset and block_extend_cascade need the public API and architecture-gated backend decorators: add `@flashinfer_api` above each function and add `@backend_requirement`("fa2","fa3") (or the project-specific backend_requirement form used elsewhere) to declare FA2/FA3 architecture requirements for the backend-dispatching paths; ensure the decorators are imported from the flashinfer decorator module if not already, and place them immediately above the def for block_extend_attention_with_offset and the block_extend_cascade function (the other exported API at the 314-324 region) so crash-safe API logging and backend gating are applied.tests/attention/test_dllm_vs_flex_attention.py (1)
61-104:⚠️ Potential issue | 🟠 MajorExercise heterogeneous
q_offsetsand nonzerokv_offsetin the reference path.The reference and Flex mask still model only a scalar
q_offsetand implicitkv_offset=0, while the batch benchmark uses identical offsets for every request. That leaves the new per-request offset and current-chunkkv_block_expanding_offsetplumbing largely unvalidated.🛠️ Proposed direction
def compute_block_extend_reference( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, dllm_block_size: int, q_offset: int = 0, + kv_offset: int = 0, sm_scale: float = None, ) -> torch.Tensor: @@ q_pos = torch.arange(qo_len, device=device) + q_offset - k_pos = torch.arange(kv_len, device=device) + k_pos = torch.arange(kv_len, device=device) + kv_offset @@ -def make_block_extend_mask_mod(dllm_block_size: int, q_offset: int = 0): +def make_block_extend_mask_mod(dllm_block_size: int, q_offsets, kv_offsets=None): @@ def block_extend_mask(b, h, q_idx, kv_idx): - q_global = q_idx + q_offset + q_global = q_idx + q_offsets[b] + kv_global = kv_idx + (0 if kv_offsets is None else kv_offsets[b]) q_blk = q_global // dllm_block_size - kv_blk = kv_idx // dllm_block_size + kv_blk = kv_global // dllm_block_size return q_blk >= kv_blkAlso update the batch setup to use non-uniform offsets and add at least one nonzero
kv_offsetcase.Also applies to: 321-321, 381-384, 518-523
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/attention/test_dllm_vs_flex_attention.py` around lines 61 - 104, The reference functions only handle a scalar q_offset and assume kv_offset=0, so update compute_block_extend_reference to accept per-request heterogeneous q_offsets (e.g., a 1D tensor/array matching batch or sequence of queries) and an explicit kv_offset parameter, compute q_pos using per-request q_offsets and k_pos using kv_offset, and build mask_2d accordingly so it mirrors batched/variable offsets; likewise update make_block_extend_mask_mod (and its returned block_extend_mask) to accept/close over a kv_offset and support per-request q_offset (or an indexable q_offset source) so the Flex mask and the reference use identical offset semantics, and adjust the batch setup to include non-uniform q_offsets and at least one nonzero kv_offset test case.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/dllm/__init__.py`:
- Around line 21-37: Ruff flags the __all__ list ordering (RUF022); sort the
entries in the __all__ list alphabetically according to the project's convention
so the exported names (e.g., "BatchBlockExtendPagedOffsetWrapper",
"BatchBlockExtendRaggedOffsetWrapper",
"BLOCK_EXTEND_V2_WITH_OFFSET_VARIANT_DECL",
"BLOCK_EXTEND_V3_WITH_OFFSET_VARIANT_DECL",
"BLOCK_EXTEND_V3_WITH_OFFSET_VARIANT_DECL",
"BLOCK_EXTEND_V2_WITH_OFFSET_VARIANT_DECL",
"BLOCK_EXTEND_V3_WITH_OFFSET_VARIANT_DECL",
"BLOCK_EXTEND_V2_WITH_OFFSET_VARIANT_DECL",
"BatchBlockExtendPagedOffsetWrapper", "block_extend_attention_with_offset",
"block_extend_cascade", "batch_block_extend_cascade",
"get_block_extend_module_with_offset", "sglang_style_cascade_attention",
"_BATCH_BE_OFFSET_VARIANT_DECL", "_BATCH_BE_OFFSET_VARIANT_DECL_FA3") are in the
required sorted order; update the __all__ declaration in
flashinfer/dllm/__init__.py to the sorted list so the linter RUF022 is
satisfied.
In `@flashinfer/dllm/batch_block_extend.py`:
- Around line 82-127: select_best_backend currently considers fa3_available even
on non-Hopper GPUs; update it to only allow returning "fa3" when the device
supports SM90A (use is_sm90a_supported(device))—i.e., set fa3_effective =
fa3_available and then if not is_hopper set fa3_effective = False, use
fa3_effective for all selection logic (including the auto path and when
preferred_backend == "fa3" raise a RuntimeError if device is non-Hopper). Apply
the identical change to select_best_backend_paged so FA3 is never chosen or
accepted on non-SM90 devices; ensure device is defaulted to torch.device("cuda")
when None and reuse the same is_hopper check.
In `@flashinfer/dllm/block_extend.py`:
- Around line 368-370: The code currently treats a partially provided prefix
(only k_prefix or only v_prefix) as absent; update the logic in block_extend.py
around the has_prefix computation to validate that either both k_prefix and
v_prefix are provided or neither are—if one is None and the other is not, raise
a clear ValueError (or custom exception) indicating mismatched prefix arguments;
keep the existing prefix_len computation (prefix_len = k_prefix.size(0)) when
both are present. Ensure you reference and modify the has_prefix, k_prefix,
v_prefix, and prefix_len handling so callers fail fast instead of silently
dropping a single-side prefix.
In `@tests/attention/test_dllm_vs_flex_attention.py`:
- Line 241: Multiple plain string print statements in
tests/attention/test_dllm_vs_flex_attention.py were written as f-strings without
placeholders (e.g., print(f"FlashInfer Block Extend vs PyTorch Flex
Attention")); remove the unnecessary 'f' prefix on those 16 instances so they
become regular string literals. Search for all print/assignment/logging lines in
that file that start with f" or f' but contain no braces/format placeholders and
change them to plain "..." or '...' (example symbol to locate: the print call
containing "FlashInfer Block Extend vs PyTorch Flex Attention"). Ensure no
formatting behavior is altered and run linters to confirm Ruff F541 is resolved.
---
Duplicate comments:
In `@flashinfer/dllm/batch_block_extend.py`:
- Around line 311-315: The jit_kwargs currently lists all mask modes which
forces JIT/AOT to compile five specializations; change the jit_kwargs in both
wrapper configurations (the jit_kwargs dict used for the paged and ragged
wrappers — the occurrences around the current jit_kwargs and the other block at
lines ~448-452) to only include MaskMode.BLOCK_EXPANDING.value for the
"mask_modes" key so the JIT compiles only the block-expanding specialization;
ensure you reference the same jit_kwargs variable names and import/usage of
MaskMode present in this module.
- Around line 249-265: The exported wrapper class
BatchBlockExtendPagedOffsetWrapper (and the other public wrapper classes and
helper functions at the noted locations) are missing the required decorators;
add the `@flashinfer_api` decorator to each high-level API (e.g., class
BatchBlockExtendPagedOffsetWrapper) and add `@backend_requirement` where the API
enforces architecture-specific backends (APIs that accept a backend: str
parameter or perform backend dispatching) so the backend-dispatching is tracked;
ensure you import these decorators and apply `@backend_requirement` to the
constructors or functions that accept the backend argument (e.g., any __init__
or factory functions with backend: str) while leaving non-backend-specific
helpers only with `@flashinfer_api`.
- Around line 561-571: The current code defaults q_offsets to zeros when
q_offsets is None (and aliases kv_offsets to q_offsets), which breaks the
block-extend mask whenever has_prefix is True; instead, when has_prefix is True
derive per-request prefix lengths from the paged metadata and populate
q_offsets/kv_offsets accordingly (or raise/require explicit offsets) rather than
assigning zeros; update the logic around q_offsets, kv_offsets, and has_prefix
in batch_block_extend.py so q_offsets is computed from the request/page metadata
for each of the batch_size entries on the given device and dtype, and only fall
back to a true zero-offset alias when you have verified there is no prefix for
all requests.
- Around line 177-184: The module URI builder _get_batch_be_module_uri currently
only includes head_dim and tensor dtype, which can collide across different
index dtypes; update _get_batch_be_module_uri to accept an idtype (torch.dtype
for indices), extend the internal mapping (use dtype_map_for_idtype or same
mapping logic) to map supported index dtypes (e.g., torch.int32->"i32",
torch.int64->"i64") and raise ValueError for unsupported index dtypes, then
include the idtype token in the returned URI string (e.g.,
..._hd{head_dim}_{_dtype_map[input_dtype]}_{idtype_token}). Finally, update all
call sites of _get_batch_be_module_uri (and any other variants at the other
locations mentioned) to pass the idtype argument so the generated ABI name is
unique per (head_dim, dtype, idtype).
In `@flashinfer/dllm/block_extend.py`:
- Around line 235-245: The high-level APIs block_extend_attention_with_offset
and block_extend_cascade need the public API and architecture-gated backend
decorators: add `@flashinfer_api` above each function and add
`@backend_requirement`("fa2","fa3") (or the project-specific backend_requirement
form used elsewhere) to declare FA2/FA3 architecture requirements for the
backend-dispatching paths; ensure the decorators are imported from the
flashinfer decorator module if not already, and place them immediately above the
def for block_extend_attention_with_offset and the block_extend_cascade function
(the other exported API at the 314-324 region) so crash-safe API logging and
backend gating are applied.
In `@tests/attention/test_dllm_vs_flex_attention.py`:
- Around line 61-104: The reference functions only handle a scalar q_offset and
assume kv_offset=0, so update compute_block_extend_reference to accept
per-request heterogeneous q_offsets (e.g., a 1D tensor/array matching batch or
sequence of queries) and an explicit kv_offset parameter, compute q_pos using
per-request q_offsets and k_pos using kv_offset, and build mask_2d accordingly
so it mirrors batched/variable offsets; likewise update
make_block_extend_mask_mod (and its returned block_extend_mask) to accept/close
over a kv_offset and support per-request q_offset (or an indexable q_offset
source) so the Flex mask and the reference use identical offset semantics, and
adjust the batch setup to include non-uniform q_offsets and at least one nonzero
kv_offset test case.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: c8dc3936-3642-48fa-b28c-0da34a217dee
📒 Files selected for processing (7)
csrc/single_prefill_sm90_customize_config.jinjaflashinfer/dllm/__init__.pyflashinfer/dllm/batch_block_extend.pyflashinfer/dllm/block_extend.pyinclude/flashinfer/attention/default_prefill_params.cuhinclude/flashinfer/utils.cuhtests/attention/test_dllm_vs_flex_attention.py
🚧 Files skipped from review as they are similar to previous changes (3)
- csrc/single_prefill_sm90_customize_config.jinja
- include/flashinfer/utils.cuh
- include/flashinfer/attention/default_prefill_params.cuh
| __all__ = [ | ||
| # Single Prefill with offset (FA2/FA3 auto-select) | ||
| "block_extend_attention_with_offset", | ||
| "get_block_extend_module_with_offset", | ||
| "BLOCK_EXTEND_V2_WITH_OFFSET_VARIANT_DECL", | ||
| "BLOCK_EXTEND_V3_WITH_OFFSET_VARIANT_DECL", | ||
| # Cascade + block extend (SGLang style: causal + merge_state) | ||
| "block_extend_cascade", | ||
| "batch_block_extend_cascade", | ||
| "sglang_style_cascade_attention", | ||
| # Batch Prefill with offset versions | ||
| "BatchBlockExtendPagedOffsetWrapper", | ||
| "BatchBlockExtendRaggedOffsetWrapper", | ||
| # Batch Offset variant declarations | ||
| "_BATCH_BE_OFFSET_VARIANT_DECL", | ||
| "_BATCH_BE_OFFSET_VARIANT_DECL_FA3", | ||
| ] |
There was a problem hiding this comment.
Sort __all__ to satisfy Ruff.
Ruff flags this list with RUF022; please apply the project’s __all__ sorting convention or the lint step may fail.
🛠️ Proposed fix
__all__ = [
- # Single Prefill with offset (FA2/FA3 auto-select)
- "block_extend_attention_with_offset",
- "get_block_extend_module_with_offset",
"BLOCK_EXTEND_V2_WITH_OFFSET_VARIANT_DECL",
"BLOCK_EXTEND_V3_WITH_OFFSET_VARIANT_DECL",
- # Cascade + block extend (SGLang style: causal + merge_state)
- "block_extend_cascade",
- "batch_block_extend_cascade",
- "sglang_style_cascade_attention",
- # Batch Prefill with offset versions
"BatchBlockExtendPagedOffsetWrapper",
"BatchBlockExtendRaggedOffsetWrapper",
- # Batch Offset variant declarations
"_BATCH_BE_OFFSET_VARIANT_DECL",
"_BATCH_BE_OFFSET_VARIANT_DECL_FA3",
+ "batch_block_extend_cascade",
+ "block_extend_attention_with_offset",
+ "block_extend_cascade",
+ "get_block_extend_module_with_offset",
+ "sglang_style_cascade_attention",
]🧰 Tools
🪛 Ruff (0.15.10)
[warning] 21-37: __all__ is not sorted
Apply an isort-style sorting to __all__
(RUF022)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/dllm/__init__.py` around lines 21 - 37, Ruff flags the __all__
list ordering (RUF022); sort the entries in the __all__ list alphabetically
according to the project's convention so the exported names (e.g.,
"BatchBlockExtendPagedOffsetWrapper", "BatchBlockExtendRaggedOffsetWrapper",
"BLOCK_EXTEND_V2_WITH_OFFSET_VARIANT_DECL",
"BLOCK_EXTEND_V3_WITH_OFFSET_VARIANT_DECL",
"BLOCK_EXTEND_V3_WITH_OFFSET_VARIANT_DECL",
"BLOCK_EXTEND_V2_WITH_OFFSET_VARIANT_DECL",
"BLOCK_EXTEND_V3_WITH_OFFSET_VARIANT_DECL",
"BLOCK_EXTEND_V2_WITH_OFFSET_VARIANT_DECL",
"BatchBlockExtendPagedOffsetWrapper", "block_extend_attention_with_offset",
"block_extend_cascade", "batch_block_extend_cascade",
"get_block_extend_module_with_offset", "sglang_style_cascade_attention",
"_BATCH_BE_OFFSET_VARIANT_DECL", "_BATCH_BE_OFFSET_VARIANT_DECL_FA3") are in the
required sorted order; update the __all__ declaration in
flashinfer/dllm/__init__.py to the sorted list so the linter RUF022 is
satisfied.
| def select_best_backend(head_dim: int, dtype: torch.dtype, preferred_backend: str = "auto", device: torch.device = None) -> str: | ||
| """Select backend based on kernel availability and compute capability""" | ||
| from ..utils import is_sm90a_supported | ||
|
|
||
| base_uri = _get_batch_be_module_uri(head_dim, dtype) | ||
| fa2_uri = base_uri + "_ragged_offset" | ||
| fa3_uri = base_uri + "_ragged_offset_fa3" | ||
|
|
||
| fa2_aot, fa2_jit, _ = check_kernel_availability(fa2_uri) | ||
| fa3_aot, fa3_jit, _ = check_kernel_availability(fa3_uri) | ||
|
|
||
| fa2_available = fa2_aot or fa2_jit | ||
| fa3_available = fa3_aot or fa3_jit | ||
|
|
||
| if preferred_backend == "auto": | ||
| if device is None: | ||
| device = torch.device("cuda") | ||
| is_hopper = is_sm90a_supported(device) | ||
|
|
||
| if is_hopper: | ||
| if fa3_available: | ||
| return "fa3" | ||
| if fa2_available: | ||
| return "fa2" | ||
| else: | ||
| if fa2_available: | ||
| return "fa2" | ||
| if fa3_available: | ||
| return "fa3" | ||
|
|
||
| raise RuntimeError( | ||
| f"No Block Extend kernel available for head_dim={head_dim}, dtype={dtype}. " | ||
| f"FA2: AOT={fa2_aot}, JIT={fa2_jit}; FA3: AOT={fa3_aot}, JIT={fa3_jit}" | ||
| ) | ||
|
|
||
| if preferred_backend == "fa2": | ||
| if fa2_available: | ||
| return "fa2" | ||
| raise RuntimeError(f"FA2 kernel '{fa2_uri}' not available") | ||
|
|
||
| if preferred_backend == "fa3": | ||
| if fa3_available: | ||
| return "fa3" | ||
| raise RuntimeError(f"FA3 kernel '{fa3_uri}' not available") | ||
|
|
||
| raise ValueError(f"Unknown backend: {preferred_backend}") |
There was a problem hiding this comment.
Do not select FA3 on non-Hopper devices.
fa3_available only means AOT/JIT exists; it does not mean the current GPU can run FA3. On non-SM90 devices, these selectors can still return "fa3" when FA2 is unavailable or explicitly requested, causing a later Hopper-kernel failure.
🛠️ Proposed fix
if preferred_backend == "auto":
if device is None:
device = torch.device("cuda")
is_hopper = is_sm90a_supported(device)
if is_hopper:
if fa3_available:
return "fa3"
if fa2_available:
return "fa2"
else:
if fa2_available:
return "fa2"
- if fa3_available:
- return "fa3"
raise RuntimeError(
@@
if preferred_backend == "fa3":
+ if device is None:
+ device = torch.device("cuda")
+ if not is_sm90a_supported(device):
+ raise RuntimeError("FA3 backend requires SM90/Hopper architecture")
if fa3_available:
return "fa3"
raise RuntimeError(f"FA3 kernel '{fa3_uri}' not available")Apply the same guard in select_best_backend_paged.
Also applies to: 130-174
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/dllm/batch_block_extend.py` around lines 82 - 127,
select_best_backend currently considers fa3_available even on non-Hopper GPUs;
update it to only allow returning "fa3" when the device supports SM90A (use
is_sm90a_supported(device))—i.e., set fa3_effective = fa3_available and then if
not is_hopper set fa3_effective = False, use fa3_effective for all selection
logic (including the auto path and when preferred_backend == "fa3" raise a
RuntimeError if device is non-Hopper). Apply the identical change to
select_best_backend_paged so FA3 is never chosen or accepted on non-SM90
devices; ensure device is defaulted to torch.device("cuda") when None and reuse
the same is_hopper check.
| has_prefix = k_prefix is not None and v_prefix is not None | ||
| prefix_len = k_prefix.size(0) if has_prefix else 0 | ||
|
|
There was a problem hiding this comment.
Reject partially specified prefix KV pairs.
If callers pass only k_prefix or only v_prefix, has_prefix becomes false and the provided prefix is silently ignored, returning current-chunk-only attention. Fail fast instead.
🛠️ Proposed fix
- has_prefix = k_prefix is not None and v_prefix is not None
+ if (k_prefix is None) != (v_prefix is None):
+ raise ValueError("k_prefix and v_prefix must be provided together")
+ has_prefix = k_prefix is not None
prefix_len = k_prefix.size(0) if has_prefix else 0🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/dllm/block_extend.py` around lines 368 - 370, The code currently
treats a partially provided prefix (only k_prefix or only v_prefix) as absent;
update the logic in block_extend.py around the has_prefix computation to
validate that either both k_prefix and v_prefix are provided or neither are—if
one is None and the other is not, raise a clear ValueError (or custom exception)
indicating mismatched prefix arguments; keep the existing prefix_len computation
(prefix_len = k_prefix.size(0)) when both are present. Ensure you reference and
modify the has_prefix, k_prefix, v_prefix, and prefix_len handling so callers
fail fast instead of silently dropping a single-side prefix.
| q_offset = total_kv_len - qo_len | ||
|
|
||
| print(f"\n{'='*80}") | ||
| print(f"FlashInfer Block Extend vs PyTorch Flex Attention") |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Verify remaining placeholder-free f-strings in this file.
ruff check tests/attention/test_dllm_vs_flex_attention.py --select F541Repository: flashinfer-ai/flashinfer
Length of output: 10824
Convert placeholder-free f-strings to plain string literals.
These are Ruff F541 lint failures. Remove the f prefix from 16 f-string instances that contain no placeholders across lines 241, 328, 557, 647, 656, 745–747, 827, 860–862, 995–1002, 1145, and 1190.
🧰 Tools
🪛 Ruff (0.15.10)
[error] 241-241: f-string without any placeholders
Remove extraneous f prefix
(F541)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/attention/test_dllm_vs_flex_attention.py` at line 241, Multiple plain
string print statements in tests/attention/test_dllm_vs_flex_attention.py were
written as f-strings without placeholders (e.g., print(f"FlashInfer Block Extend
vs PyTorch Flex Attention")); remove the unnecessary 'f' prefix on those 16
instances so they become regular string literals. Search for all
print/assignment/logging lines in that file that start with f" or f' but contain
no braces/format placeholders and change them to plain "..." or '...' (example
symbol to locate: the print call containing "FlashInfer Block Extend vs PyTorch
Flex Attention"). Ensure no formatting behavior is altered and run linters to
confirm Ruff F541 is resolved.
- Validate dllm_block_size > 0 to reject zero and negative values - Raise ValueError on unsupported dtype instead of silent fallback to fp16 - Preserve user's preferred backend across wrapper re-creation - Track idtype to correctly invalidate plan when index dtype changes - Defer backend auto-selection to wrappers instead of pre-resolving in cascade - Warn when q_offsets is None but prefix exists in cascade attention - Pass device to FA3 SM90 check and include device in module cache key - Remove unused logits_soft_cap parameter from sglang_style_cascade_attention - Fix causal=True comment to causal=False in sglang_style_cascade_attention - Fix docstring function names (block_expanding_* -> block_extend_*) - Add assert statements to test correctness checks - Rename benchmark functions from test_* to bench_* to avoid pytest collection - Fix missing trailing newline in .cuh and .jinja files
93965f5 to
4284113
Compare
Motivation
What is Diffusion LLM (DLLM)?
Diffusion LLM (DLLM) is an emerging text generation paradigm. Unlike traditional Auto-Regressive LLMs that generate one token at a time, DLLM generates multiple tokens in parallel at the block level. In each iteration, all tokens within the current block are produced simultaneously through multi-step denoising. As a result, tokens within the same block require bidirectional visibility, while tokens in subsequent blocks must be completely invisible — this is the semantic origin of the Block Extend Mask.
Similarity Between Block Diffusion and Chunked Prefill
The execution flow of Block Diffusion closely resembles SGLang's existing Chunked Prefill — both split the full sequence into chunks and process them step by step:
The only difference lies in the intra-block query: Chunked Prefill uses a causal mask (earlier tokens cannot attend to later ones), whereas Block Diffusion requires bidirectional full attention since tokens within the same block are generated in parallel.
Why Native Kernel Support is Needed
Chunked Prefill is SGLang's default and well-established execution path. Given the strong similarity between Block Diffusion and Chunked Prefill, the initial approach was to reuse the Chunked Prefill path — splitting by DLLM block size and using
causal=Falsefor the current chunk to approximate the Block Extend mask.However, this indirect Cascade Attention-based approach has fundamental limitations:
causal=Falseis only correct when the chunk exactly equals one complete DLLM block; larger chunk sizes that would reduce the number of iteration steps cannot be usedAn alternative path is to use Custom Mask (
MaskMode::kCustom), but it requires O(qo_len × kv_len) of GPU memory to store the 2D mask tensor, which is completely infeasible for long sequences (a single request at seq_len=32K needs 1GB of mask memory).Therefore, we implemented the native
MaskMode::kBlockExpandingmask mode in FlashInfer, embedding the Block Extend mask semantics directly into the CUDA kernel:📌 Description
Add DLLM Block Extend Attention with native
MaskMode::kBlockExpandingtile-level skip optimization for Diffusion LLM inference.Block Extend Mask Rule
mask[q, k] = floor(q_global / B) >= floor(k_global / B)
Same DLLM block tokens are bidirectionally visible; can see previous blocks but not subsequent blocks. This is the core mask mode for Diffusion LLM (DLLM).
Core API
block_extend_attention_with_offset()BatchBlockExtendRaggedOffsetWrapperBatchBlockExtendPagedOffsetWrapperblock_extend_cascade()/batch_block_extend_cascade()Why faster than Custom Mask
Why faster than Cascade Attention (SGLang-style)
Why faster than PyTorch Flex Attention
Compilation
🔍 Related Issues
None
🚀 Pull Request Checklist
- [ ] Pre-commit Checks
pre-commitby runningpip install pre-commitpre-commit installpre-commit run --all-filesand fixed any reported issues🧪 Tests
📊 Benchmark Results
1. Block Extend vs Custom Mask
Baseline:
single_prefill_with_kv_cache(custom_mask=...)(FA2 only, FA3 not supported)Case 1: Single-Request Incremental Prefill (tokens=8192, CUDA Graph)
Case 2: Multi-Request BatchPrefill (256 reqs × 512 tokens, CUDA Graph)
2. Block Extend vs Cascade Attention
Baseline: SGLang-style 3-stage Cascade (FA3)
Case 1: Single-Request (tokens=8192, CUDA Graph)
Case 2: Multi-Request BatchPrefill (256 reqs × 512 tokens, CUDA Graph)
3. Block Extend vs PyTorch Flex Attention
Key Takeaways:
4. LLaDA2.0-flash-CAP End-to-End Benchmark Results
TTFT Optimization Highlights
The most significant improvement is in TTFT (Time To First Token), which directly reflects the prefill-stage kernel efficiency:
With cache-aware scheduling enabled in version 260209, TTFT is reduced by 4.3x–8.2x across all scenarios. The benefit scales with both prompt length and concurrency — under the most demanding setting (long prompt, concurrency 4), TTFT drops from 22.3s to 2.7s.
refs : sgl-project/sglang#12766
Summary by CodeRabbit
New Features
Tests