diff --git a/csrc/batch_prefill_customize_config.jinja b/csrc/batch_prefill_customize_config.jinja index 77490d71b2..e96f499cd8 100644 --- a/csrc/batch_prefill_customize_config.jinja +++ b/csrc/batch_prefill_customize_config.jinja @@ -76,6 +76,22 @@ struct RaggedParams { __host__ __device__ __forceinline__ uint32_t get_kv_len(uint32_t batch_idx) const { return kv_indptr[batch_idx + 1] - kv_indptr[batch_idx]; } + + __host__ __device__ __forceinline__ uint32_t get_q_block_expanding_offset(uint32_t batch_idx) const { +{% if 'maybe_q_block_expanding_offset' in additional_params_decl %} + return (maybe_q_block_expanding_offset != nullptr) ? maybe_q_block_expanding_offset[batch_idx] : 0; +{% else %} + return 0; +{% endif %} + } + + __host__ __device__ __forceinline__ uint32_t get_kv_block_expanding_offset(uint32_t batch_idx) const { +{% if 'maybe_kv_block_expanding_offset' in additional_params_decl %} + return (maybe_kv_block_expanding_offset != nullptr) ? maybe_kv_block_expanding_offset[batch_idx] : 0; +{% else %} + return 0; +{% endif %} + } }; struct PagedParams { @@ -116,6 +132,22 @@ struct PagedParams { __host__ __device__ __forceinline__ uint32_t get_kv_len(uint32_t batch_idx) const { return paged_kv.get_length(batch_idx); } + + __host__ __device__ __forceinline__ uint32_t get_q_block_expanding_offset(uint32_t batch_idx) const { +{% if 'maybe_q_block_expanding_offset' in additional_params_decl %} + return (maybe_q_block_expanding_offset != nullptr) ? maybe_q_block_expanding_offset[batch_idx] : 0; +{% else %} + return 0; +{% endif %} + } + + __host__ __device__ __forceinline__ uint32_t get_kv_block_expanding_offset(uint32_t batch_idx) const { +{% if 'maybe_kv_block_expanding_offset' in additional_params_decl %} + return (maybe_kv_block_expanding_offset != nullptr) ? maybe_kv_block_expanding_offset[batch_idx] : 0; +{% else %} + return 0; +{% endif %} + } }; {{ variant_decl }} diff --git a/csrc/batch_prefill_sm90_customize_config.jinja b/csrc/batch_prefill_sm90_customize_config.jinja index 640637c7df..373e469343 100644 --- a/csrc/batch_prefill_sm90_customize_config.jinja +++ b/csrc/batch_prefill_sm90_customize_config.jinja @@ -66,6 +66,31 @@ struct RaggedParams { int window_left; bool causal; + + // Block Expanding support + __host__ __device__ __forceinline__ uint32_t get_qo_len(uint32_t batch_idx) const { + return qo_lens[batch_idx]; + } + + __host__ __device__ __forceinline__ uint32_t get_kv_len(uint32_t batch_idx) const { + return kv_lens[batch_idx]; + } + + __host__ __device__ __forceinline__ uint32_t get_q_block_expanding_offset(uint32_t batch_idx) const { +{% if 'maybe_q_block_expanding_offset' in additional_params_decl %} + return (additional_params.maybe_q_block_expanding_offset != nullptr) ? additional_params.maybe_q_block_expanding_offset[batch_idx] : 0; +{% else %} + return 0; +{% endif %} + } + + __host__ __device__ __forceinline__ uint32_t get_kv_block_expanding_offset(uint32_t batch_idx) const { +{% if 'maybe_kv_block_expanding_offset' in additional_params_decl %} + return (additional_params.maybe_kv_block_expanding_offset != nullptr) ? additional_params.maybe_kv_block_expanding_offset[batch_idx] : 0; +{% else %} + return 0; +{% endif %} + } }; struct PagedParams { @@ -117,6 +142,31 @@ struct PagedParams { int window_left; bool causal; + + // Block Expanding support + __host__ __device__ __forceinline__ uint32_t get_qo_len(uint32_t batch_idx) const { + return qo_lens[batch_idx]; + } + + __host__ __device__ __forceinline__ uint32_t get_kv_len(uint32_t batch_idx) const { + return kv_lens[batch_idx]; + } + + __host__ __device__ __forceinline__ uint32_t get_q_block_expanding_offset(uint32_t batch_idx) const { +{% if 'maybe_q_block_expanding_offset' in additional_params_decl %} + return (additional_params.maybe_q_block_expanding_offset != nullptr) ? additional_params.maybe_q_block_expanding_offset[batch_idx] : 0; +{% else %} + return 0; +{% endif %} + } + + __host__ __device__ __forceinline__ uint32_t get_kv_block_expanding_offset(uint32_t batch_idx) const { +{% if 'maybe_kv_block_expanding_offset' in additional_params_decl %} + return (additional_params.maybe_kv_block_expanding_offset != nullptr) ? additional_params.maybe_kv_block_expanding_offset[batch_idx] : 0; +{% else %} + return 0; +{% endif %} + } }; {{ variant_decl }} diff --git a/csrc/single_prefill_customize_config.jinja b/csrc/single_prefill_customize_config.jinja index fa31e08b7b..0592531a5c 100644 --- a/csrc/single_prefill_customize_config.jinja +++ b/csrc/single_prefill_customize_config.jinja @@ -67,6 +67,26 @@ struct Params { __host__ __device__ __forceinline__ uint32_t get_kv_len(uint32_t batch_idx) const { return kv_len; } + + // SinglePrefill: q_block_expanding_offset support + // If q_block_expanding_offset parameter is provided, use it; otherwise return 0 + __host__ __device__ __forceinline__ uint32_t get_q_block_expanding_offset(uint32_t batch_idx) const { +{% if has_q_block_expanding_offset %} + return static_cast(q_block_expanding_offset); +{% else %} + return 0; +{% endif %} + } + + // SinglePrefill: kv_block_expanding_offset support (for Cascade Current Chunk) + // If kv_block_expanding_offset parameter is provided, use it; otherwise return 0 + __host__ __device__ __forceinline__ uint32_t get_kv_block_expanding_offset(uint32_t batch_idx) const { +{% if has_kv_block_expanding_offset %} + return static_cast(kv_block_expanding_offset); +{% else %} + return 0; +{% endif %} + } }; {{ variant_decl }} diff --git a/csrc/single_prefill_sm90_customize_config.jinja b/csrc/single_prefill_sm90_customize_config.jinja index 7922ca2ba9..4446dde203 100644 --- a/csrc/single_prefill_sm90_customize_config.jinja +++ b/csrc/single_prefill_sm90_customize_config.jinja @@ -62,6 +62,33 @@ struct Params { int window_left; bool causal; + + // Block Expanding support + __host__ __device__ __forceinline__ uint32_t get_qo_len(uint32_t batch_idx) const { + return qo_len; + } + + __host__ __device__ __forceinline__ uint32_t get_kv_len(uint32_t batch_idx) const { + return kv_len; + } + + // SinglePrefill: q_block_expanding_offset support + __host__ __device__ __forceinline__ uint32_t get_q_block_expanding_offset(uint32_t batch_idx) const { +{% if has_q_block_expanding_offset %} + return static_cast(additional_params.q_block_expanding_offset); +{% else %} + return 0; +{% endif %} + } + + // SinglePrefill: kv_block_expanding_offset support (for Cascade Current Chunk) + __host__ __device__ __forceinline__ uint32_t get_kv_block_expanding_offset(uint32_t batch_idx) const { +{% if has_kv_block_expanding_offset %} + return static_cast(additional_params.kv_block_expanding_offset); +{% else %} + return 0; +{% endif %} + } }; {{ variant_decl }} diff --git a/flashinfer/__init__.py b/flashinfer/__init__.py index 26b5f97894..72f91069ba 100644 --- a/flashinfer/__init__.py +++ b/flashinfer/__init__.py @@ -20,6 +20,7 @@ from .version import __version__ as __version__ from .version import __git_version__ as __git_version__ +from . import dllm as dllm from . import jit as jit from .activation import gelu_and_mul as gelu_and_mul diff --git a/flashinfer/dllm/__init__.py b/flashinfer/dllm/__init__.py new file mode 100644 index 0000000000..9925b272b9 --- /dev/null +++ b/flashinfer/dllm/__init__.py @@ -0,0 +1,37 @@ +from .block_extend import ( + block_extend_attention_with_offset, + block_extend_cascade, + get_block_extend_module_with_offset, + BLOCK_EXTEND_V2_WITH_OFFSET_VARIANT_DECL, + BLOCK_EXTEND_V3_WITH_OFFSET_VARIANT_DECL, +) + +from .batch_block_extend import ( + BatchBlockExtendPagedOffsetWrapper, + BatchBlockExtendRaggedOffsetWrapper, + batch_block_extend_cascade, + sglang_style_cascade_attention, + _BATCH_BE_OFFSET_VARIANT_DECL, + _BATCH_BE_OFFSET_VARIANT_DECL_FA3, + _check_batch_be_aot_available, + _get_batch_be_aot_path, + _get_batch_be_module_uri, +) + +__all__ = [ + # Single Prefill with offset (FA2/FA3 auto-select) + "block_extend_attention_with_offset", + "get_block_extend_module_with_offset", + "BLOCK_EXTEND_V2_WITH_OFFSET_VARIANT_DECL", + "BLOCK_EXTEND_V3_WITH_OFFSET_VARIANT_DECL", + # Cascade + block extend (SGLang style: causal + merge_state) + "block_extend_cascade", + "batch_block_extend_cascade", + "sglang_style_cascade_attention", + # Batch Prefill with offset versions + "BatchBlockExtendPagedOffsetWrapper", + "BatchBlockExtendRaggedOffsetWrapper", + # Batch Offset variant declarations + "_BATCH_BE_OFFSET_VARIANT_DECL", + "_BATCH_BE_OFFSET_VARIANT_DECL_FA3", +] diff --git a/flashinfer/dllm/batch_block_extend.py b/flashinfer/dllm/batch_block_extend.py new file mode 100644 index 0000000000..fe09e1c769 --- /dev/null +++ b/flashinfer/dllm/batch_block_extend.py @@ -0,0 +1,678 @@ +""" +Batch Block Extend Attention for DLLM (Diffusion LLM) + +Block Extend Mask Rules: + q_global = q_offset + q_idx + kv_global = kv_offset + kv_idx + mask[q, k] = (q_global / dllm_block_size) >= (kv_global / dllm_block_size) + Bidirectional visibility within the same block, can see previous blocks, cannot see subsequent blocks + +Usage: + from flashinfer.dllm import BatchBlockExtendRaggedOffsetWrapper + wrapper = BatchBlockExtendRaggedOffsetWrapper(workspace, kv_layout="NHD", dllm_block_size=32) + wrapper.plan(qo_indptr, kv_indptr, num_heads, num_kv_heads, head_dim) + output = wrapper.run(q, k, v) +""" + +from __future__ import annotations + +import math +import os +import torch +from pathlib import Path +from typing import Optional, Tuple, Union, List, Dict, Any + +from ..prefill import ( + BatchPrefillWithRaggedKVCacheWrapper, + BatchPrefillWithPagedKVCacheWrapper, +) +from ..jit import gen_customize_batch_prefill_module +from ..jit import env as jit_env +from ..utils import MaskMode +from ..api_logging import flashinfer_api + + +def check_jit_environment() -> dict: + """Check if JIT compilation environment is working properly""" + results = { + "tvm_ffi_ok": False, + "device_guard_ok": False, + "nvcc_ok": False, + "issues": [], + } + + try: + import tvm_ffi + results["tvm_ffi_ok"] = True + include_path = tvm_ffi.libinfo.find_include_path() + device_guard_path = Path(include_path) / "tvm" / "ffi" / "extra" / "cuda" / "device_guard.h" + results["device_guard_ok"] = device_guard_path.exists() + if not results["device_guard_ok"]: + results["issues"].append(f"Missing TVM header: {device_guard_path}") + except ImportError: + results["issues"].append("tvm_ffi package not installed") + except Exception as e: + results["issues"].append(f"Error checking tvm_ffi: {e}") + + import subprocess + try: + result = subprocess.run(["nvcc", "--version"], capture_output=True, text=True) + results["nvcc_ok"] = result.returncode == 0 + except FileNotFoundError: + results["nvcc_ok"] = False + results["issues"].append("nvcc not found in PATH") + + return results + + +def check_kernel_availability(uri: str) -> tuple: + """Check availability of specified kernel""" + aot_path = jit_env.FLASHINFER_AOT_DIR / uri / f"{uri}.so" + aot_available = aot_path.exists() + + jit_env_check = check_jit_environment() + jit_available = ( + jit_env_check["tvm_ffi_ok"] and + jit_env_check["device_guard_ok"] and + jit_env_check["nvcc_ok"] + ) + + return aot_available, jit_available, aot_path + + +def select_best_backend(head_dim: int, dtype: torch.dtype, preferred_backend: str = "auto", device: torch.device = None) -> str: + """Select backend based on kernel availability and compute capability""" + from ..utils import is_sm90a_supported + + base_uri = _get_batch_be_module_uri(head_dim, dtype) + fa2_uri = base_uri + "_ragged_offset" + fa3_uri = base_uri + "_ragged_offset_fa3" + + fa2_aot, fa2_jit, _ = check_kernel_availability(fa2_uri) + fa3_aot, fa3_jit, _ = check_kernel_availability(fa3_uri) + + fa2_available = fa2_aot or fa2_jit + fa3_available = fa3_aot or fa3_jit + + if preferred_backend == "auto": + if device is None: + device = torch.device("cuda") + is_hopper = is_sm90a_supported(device) + + if is_hopper: + if fa3_available: + return "fa3" + if fa2_available: + return "fa2" + else: + if fa2_available: + return "fa2" + if fa3_available: + return "fa3" + + raise RuntimeError( + f"No Block Extend kernel available for head_dim={head_dim}, dtype={dtype}. " + f"FA2: AOT={fa2_aot}, JIT={fa2_jit}; FA3: AOT={fa3_aot}, JIT={fa3_jit}" + ) + + if preferred_backend == "fa2": + if fa2_available: + return "fa2" + raise RuntimeError(f"FA2 kernel '{fa2_uri}' not available") + + if preferred_backend == "fa3": + if fa3_available: + return "fa3" + raise RuntimeError(f"FA3 kernel '{fa3_uri}' not available") + + raise ValueError(f"Unknown backend: {preferred_backend}") + + +def select_best_backend_paged(head_dim: int, dtype: torch.dtype, preferred_backend: str = "auto", device: torch.device = None) -> str: + """Select backend based on Paged kernel availability and compute capability""" + from ..utils import is_sm90a_supported + + base_uri = _get_batch_be_module_uri(head_dim, dtype) + fa2_uri = base_uri + "_paged_offset" + fa3_uri = base_uri + "_paged_offset_fa3" + + fa2_aot, fa2_jit, _ = check_kernel_availability(fa2_uri) + fa3_aot, fa3_jit, _ = check_kernel_availability(fa3_uri) + + fa2_available = fa2_aot or fa2_jit + fa3_available = fa3_aot or fa3_jit + + if preferred_backend == "auto": + if device is None: + device = torch.device("cuda") + is_hopper = is_sm90a_supported(device) + + if is_hopper: + if fa3_available: + return "fa3" + if fa2_available: + return "fa2" + else: + if fa2_available: + return "fa2" + if fa3_available: + return "fa3" + + raise RuntimeError( + f"No Paged Block Extend kernel available for head_dim={head_dim}, dtype={dtype}" + ) + + if preferred_backend == "fa2": + if fa2_available: + return "fa2" + raise RuntimeError(f"FA2 paged kernel '{fa2_uri}' not available") + + if preferred_backend == "fa3": + if fa3_available: + return "fa3" + raise RuntimeError(f"FA3 paged kernel '{fa3_uri}' not available") + + raise ValueError(f"Unknown backend: {preferred_backend}") + + +def _get_batch_be_module_uri(head_dim: int, dtype: torch.dtype) -> str: + _dtype_map = {torch.float16: "fp16", torch.bfloat16: "bf16"} + if dtype not in _dtype_map: + raise ValueError( + f"Unsupported dtype {dtype} for Block Extend Attention. " + f"Supported: {list(_dtype_map.keys())}" + ) + return f"batch_prefill_block_expanding_hd{head_dim}_{_dtype_map[dtype]}" + + +def _get_batch_be_aot_path(uri: str) -> Path: + return jit_env.FLASHINFER_AOT_DIR / uri / f"{uri}.so" + + +def _check_batch_be_aot_available(uri: str) -> bool: + if os.environ.get("FLASHINFER_FORCE_JIT", "0") == "1": + return False + return _get_batch_be_aot_path(uri).exists() + + +# FA2 Offset Variant +_BATCH_BE_OFFSET_VARIANT_DECL = r""" +struct BatchBlockExtendOffsetAttention : AttentionVariantBase { + static constexpr bool use_softmax = true; + + uint32_t qo_len; + uint32_t kv_len; + uint32_t window_left; + float sm_scale_log2; + + template + __device__ __host__ BatchBlockExtendOffsetAttention(const Params& params, uint32_t batch_idx, + uint8_t* smem_ptr) { + qo_len = params.get_qo_len(batch_idx); + kv_len = params.get_kv_len(batch_idx); + sm_scale_log2 = params.sm_scale * math::log2e; + window_left = kv_len; + } + + REGISTER_LOGITS_MASK(params, batch_idx, qo_idx, kv_idx, qo_head_idx, kv_head_idx, { + return true; + }); + + REGISTER_LOGITS_TRANSFORM(params, logits, batch_idx, qo_idx, kv_idx, qo_head_idx, kv_head_idx, { + return logits; + }); +}; +""" + +# FA3 Offset Variant +_BATCH_BE_OFFSET_VARIANT_DECL_FA3 = r""" +struct BatchBlockExtendOffsetAttentionFA3 : AttentionVariantBase { + float sm_scale_log2; + + template + __device__ __host__ BatchBlockExtendOffsetAttentionFA3( + const MainloopParams& params, const BlockCoord& block_coord) { + sm_scale_log2 = params.additional_params.sm_scale * math::log2e; + } + + template + __device__ auto GetAttentionUpdater() { + return OnlineSoftmax(sm_scale_log2); + } + + REGISTER_LOGITS_TRANSFORM(params, logits, batch_idx, qo_idx, kv_idx, qo_head_idx, kv_head_idx, { + return logits; + }); +}; +""" + + +class BatchBlockExtendPagedOffsetWrapper: + """Batch Block Extend Paged Attention with Offset Support""" + + @flashinfer_api + def __init__( + self, + float_workspace_buffer: torch.Tensor, + kv_layout: str = "NHD", + dllm_block_size: int = 256, + use_cuda_graph: bool = False, + qo_indptr_buf: Optional[torch.Tensor] = None, + paged_kv_indptr_buf: Optional[torch.Tensor] = None, + paged_kv_indices_buf: Optional[torch.Tensor] = None, + paged_kv_last_page_len_buf: Optional[torch.Tensor] = None, + q_offsets_buf: Optional[torch.Tensor] = None, + kv_offsets_buf: Optional[torch.Tensor] = None, + backend: str = "auto", + ) -> None: + assert dllm_block_size > 0 and (dllm_block_size & (dllm_block_size - 1)) == 0, \ + f"dllm_block_size must be a positive power of 2, got {dllm_block_size}" + + self._dllm_block_size = dllm_block_size + self._kv_layout = kv_layout + self._backend = backend + self._preferred_backend = backend + self._device = float_workspace_buffer.device + self._dtype: Optional[torch.dtype] = None + self._head_dim: Optional[int] = None + self._idtype: Optional[torch.dtype] = None + + self._float_workspace_buffer = float_workspace_buffer + self._use_cuda_graph = use_cuda_graph + self._qo_indptr_buf = qo_indptr_buf + self._paged_kv_indptr_buf = paged_kv_indptr_buf + self._paged_kv_indices_buf = paged_kv_indices_buf + self._paged_kv_last_page_len_buf = paged_kv_last_page_len_buf + self._q_offsets_buf = q_offsets_buf + self._kv_offsets_buf = kv_offsets_buf + + self._inner_wrapper: Optional[BatchPrefillWithPagedKVCacheWrapper] = None + self._q_offsets: Optional[torch.Tensor] = None + self._kv_offsets: Optional[torch.Tensor] = None + + def _create_inner_wrapper(self, dtype: torch.dtype, head_dim: int, idtype: torch.dtype = torch.int32) -> None: + effective_backend = select_best_backend_paged(head_dim, dtype, self._preferred_backend, self._device) + self._backend = effective_backend + + if self._backend == "fa3": + uri = _get_batch_be_module_uri(head_dim, dtype) + "_paged_offset_fa3" + variant_name = "BatchBlockExtendOffsetAttentionFA3" + variant_decl = _BATCH_BE_OFFSET_VARIANT_DECL_FA3 + else: + uri = _get_batch_be_module_uri(head_dim, dtype) + "_paged_offset" + variant_name = "BatchBlockExtendOffsetAttention" + variant_decl = _BATCH_BE_OFFSET_VARIANT_DECL + + jit_args = [ + uri, dtype, dtype, dtype, idtype, head_dim, head_dim, + ["maybe_q_block_expanding_offset", "maybe_kv_block_expanding_offset"], + [dtype_map_for_idtype(idtype), dtype_map_for_idtype(idtype)], + ["sm_scale", "dllm_block_size"], ["double", "int64_t"], + variant_name, variant_decl, + ] + jit_kwargs = { + "pos_encoding_mode": 0, "use_sliding_window": False, + "use_logits_soft_cap": False, "use_fp16_qk_reduction": False, + "mask_modes": [0, 1, 2, 3, 4], + } + + self._inner_wrapper = BatchPrefillWithPagedKVCacheWrapper( + self._float_workspace_buffer, kv_layout=self._kv_layout, + use_cuda_graph=self._use_cuda_graph, qo_indptr_buf=self._qo_indptr_buf, + paged_kv_indptr_buf=self._paged_kv_indptr_buf, + paged_kv_indices_buf=self._paged_kv_indices_buf, + paged_kv_last_page_len_buf=self._paged_kv_last_page_len_buf, + backend=self._backend, jit_args=jit_args, jit_kwargs=jit_kwargs, + ) + self._dtype = dtype + self._head_dim = head_dim + self._idtype = idtype + + def plan( + self, qo_indptr: torch.Tensor, paged_kv_indptr: torch.Tensor, + paged_kv_indices: torch.Tensor, paged_kv_last_page_len: torch.Tensor, + num_qo_heads: int, num_kv_heads: int, head_dim: int, page_size: int, + q_data_type: torch.dtype = torch.float16, sm_scale: Optional[float] = None, + q_offsets: Optional[torch.Tensor] = None, kv_offsets: Optional[torch.Tensor] = None, + ) -> None: + if self._inner_wrapper is None or self._head_dim != head_dim or self._dtype != q_data_type or self._idtype != qo_indptr.dtype: + self._create_inner_wrapper(q_data_type, head_dim, qo_indptr.dtype) + + self._sm_scale = sm_scale if sm_scale is not None else 1.0 / math.sqrt(head_dim) + + if self._use_cuda_graph: + if q_offsets is not None: + if self._q_offsets_buf is None: + raise ValueError("q_offsets_buf must be provided in CUDA Graph mode") + self._q_offsets_buf[:len(q_offsets)].copy_(q_offsets, non_blocking=True) + self._q_offsets = self._q_offsets_buf[:len(q_offsets)] + else: + self._q_offsets = None + + if kv_offsets is not None: + if self._kv_offsets_buf is None: + raise ValueError("kv_offsets_buf must be provided in CUDA Graph mode") + self._kv_offsets_buf[:len(kv_offsets)].copy_(kv_offsets, non_blocking=True) + self._kv_offsets = self._kv_offsets_buf[:len(kv_offsets)] + else: + self._kv_offsets = None + else: + self._q_offsets = q_offsets + self._kv_offsets = kv_offsets + + self._inner_wrapper.plan( + qo_indptr=qo_indptr, paged_kv_indptr=paged_kv_indptr, + paged_kv_indices=paged_kv_indices, paged_kv_last_page_len=paged_kv_last_page_len, + num_qo_heads=num_qo_heads, num_kv_heads=num_kv_heads, + head_dim_qk=head_dim, head_dim_vo=head_dim, page_size=page_size, + causal=False, pos_encoding_mode="NONE", + q_data_type=q_data_type, mask_mode=MaskMode.BLOCK_EXPANDING.value, + ) + + def run( + self, q: torch.Tensor, paged_kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], + sm_scale: Optional[float] = None, return_lse: bool = False, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + assert self._inner_wrapper is not None, "Must call plan() before run()" + effective_sm_scale = sm_scale if sm_scale is not None else self._sm_scale + return self._inner_wrapper.run( + q, paged_kv_cache, self._q_offsets, self._kv_offsets, + effective_sm_scale, self._dllm_block_size, return_lse=return_lse, + ) + + @property + def dllm_block_size(self) -> int: + return self._dllm_block_size + + +def dtype_map_for_idtype(idtype: torch.dtype) -> str: + return {torch.int32: "int32_t", torch.int64: "int64_t"}.get(idtype, "int32_t") + + +class BatchBlockExtendRaggedOffsetWrapper: + """Batch Block Extend Ragged Attention with Offset Support""" + + @flashinfer_api + def __init__( + self, + float_workspace_buffer: torch.Tensor, + kv_layout: str = "NHD", + dllm_block_size: int = 256, + use_cuda_graph: bool = False, + qo_indptr_buf: Optional[torch.Tensor] = None, + kv_indptr_buf: Optional[torch.Tensor] = None, + q_offsets_buf: Optional[torch.Tensor] = None, + kv_offsets_buf: Optional[torch.Tensor] = None, + backend: str = "auto", + ) -> None: + assert dllm_block_size > 0 and (dllm_block_size & (dllm_block_size - 1)) == 0, \ + f"dllm_block_size must be a positive power of 2, got {dllm_block_size}" + + self._dllm_block_size = dllm_block_size + self._kv_layout = kv_layout + self._backend = backend + self._preferred_backend = backend + self._device = float_workspace_buffer.device + self._dtype: Optional[torch.dtype] = None + self._head_dim: Optional[int] = None + self._idtype: Optional[torch.dtype] = None + + self._float_workspace_buffer = float_workspace_buffer + self._use_cuda_graph = use_cuda_graph + self._qo_indptr_buf = qo_indptr_buf + self._kv_indptr_buf = kv_indptr_buf + self._q_offsets_buf = q_offsets_buf + self._kv_offsets_buf = kv_offsets_buf + + self._inner_wrapper: Optional[BatchPrefillWithRaggedKVCacheWrapper] = None + self._q_offsets: Optional[torch.Tensor] = None + self._kv_offsets: Optional[torch.Tensor] = None + + def _create_inner_wrapper(self, dtype: torch.dtype, head_dim: int, idtype: torch.dtype = torch.int32) -> None: + effective_backend = select_best_backend(head_dim, dtype, self._preferred_backend, self._device) + self._backend = effective_backend + + if self._backend == "fa3": + uri = _get_batch_be_module_uri(head_dim, dtype) + "_ragged_offset_fa3" + variant_name = "BatchBlockExtendOffsetAttentionFA3" + variant_decl = _BATCH_BE_OFFSET_VARIANT_DECL_FA3 + else: + uri = _get_batch_be_module_uri(head_dim, dtype) + "_ragged_offset" + variant_name = "BatchBlockExtendOffsetAttention" + variant_decl = _BATCH_BE_OFFSET_VARIANT_DECL + + jit_args = [ + uri, dtype, dtype, dtype, idtype, head_dim, head_dim, + ["maybe_q_block_expanding_offset", "maybe_kv_block_expanding_offset"], + [dtype_map_for_idtype(idtype), dtype_map_for_idtype(idtype)], + ["sm_scale", "dllm_block_size"], ["double", "int64_t"], + variant_name, variant_decl, + ] + jit_kwargs = { + "pos_encoding_mode": 0, "use_sliding_window": False, + "use_logits_soft_cap": False, "use_fp16_qk_reduction": False, + "mask_modes": [0, 1, 2, 3, 4], + } + + self._inner_wrapper = BatchPrefillWithRaggedKVCacheWrapper( + self._float_workspace_buffer, kv_layout=self._kv_layout, + use_cuda_graph=self._use_cuda_graph, qo_indptr_buf=self._qo_indptr_buf, + kv_indptr_buf=self._kv_indptr_buf, backend=self._backend, + jit_args=jit_args, jit_kwargs=jit_kwargs, + ) + self._dtype = dtype + self._head_dim = head_dim + self._idtype = idtype + + def plan( + self, qo_indptr: torch.Tensor, kv_indptr: torch.Tensor, + num_qo_heads: int, num_kv_heads: int, head_dim: int, + q_data_type: torch.dtype = torch.float16, sm_scale: Optional[float] = None, + q_offsets: Optional[torch.Tensor] = None, kv_offsets: Optional[torch.Tensor] = None, + ) -> None: + if self._inner_wrapper is None or self._head_dim != head_dim or self._dtype != q_data_type or self._idtype != qo_indptr.dtype: + self._create_inner_wrapper(q_data_type, head_dim, qo_indptr.dtype) + + self._sm_scale = sm_scale if sm_scale is not None else 1.0 / math.sqrt(head_dim) + + if self._use_cuda_graph: + if q_offsets is not None: + if self._q_offsets_buf is None: + raise ValueError("q_offsets_buf must be provided in CUDA Graph mode") + self._q_offsets_buf[:len(q_offsets)].copy_(q_offsets, non_blocking=True) + self._q_offsets = self._q_offsets_buf[:len(q_offsets)] + else: + self._q_offsets = None + + if kv_offsets is not None: + if self._kv_offsets_buf is None: + raise ValueError("kv_offsets_buf must be provided in CUDA Graph mode") + self._kv_offsets_buf[:len(kv_offsets)].copy_(kv_offsets, non_blocking=True) + self._kv_offsets = self._kv_offsets_buf[:len(kv_offsets)] + else: + self._kv_offsets = None + else: + self._q_offsets = q_offsets + self._kv_offsets = kv_offsets + + self._inner_wrapper.plan( + qo_indptr=qo_indptr, kv_indptr=kv_indptr, + num_qo_heads=num_qo_heads, num_kv_heads=num_kv_heads, + head_dim_qk=head_dim, head_dim_vo=head_dim, + causal=False, pos_encoding_mode="NONE", + q_data_type=q_data_type, mask_mode=MaskMode.BLOCK_EXPANDING.value, + ) + + def run( + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + sm_scale: Optional[float] = None, return_lse: bool = False, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + assert self._inner_wrapper is not None, "Must call plan() before run()" + effective_sm_scale = sm_scale if sm_scale is not None else self._sm_scale + return self._inner_wrapper.run( + q, k, v, self._q_offsets, self._kv_offsets, + effective_sm_scale, self._dllm_block_size, return_lse=return_lse, + ) + + @property + def dllm_block_size(self) -> int: + return self._dllm_block_size + + +@flashinfer_api +def batch_block_extend_cascade( + q: torch.Tensor, + k_current: torch.Tensor, + v_current: torch.Tensor, + qo_indptr: torch.Tensor, + kv_curr_indptr: torch.Tensor, + paged_kv_cache: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, + paged_kv_indptr: Optional[torch.Tensor] = None, + paged_kv_indices: Optional[torch.Tensor] = None, + paged_kv_last_page_len: Optional[torch.Tensor] = None, + page_size: int = 16, + dllm_block_size: int = 256, + q_offsets: Optional[torch.Tensor] = None, + kv_offsets: Optional[torch.Tensor] = None, + workspace_buffer: Optional[torch.Tensor] = None, + sm_scale: Optional[float] = None, + return_lse: bool = False, + backend: str = "auto", +) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """Batch Block Extend Cascade Attention (Current Chunk + Prefix + Merge State)""" + from ..cascade import merge_state + + assert q.dim() == 3 and k_current.dim() == 3 and v_current.dim() == 3 + + device = q.device + head_dim = q.size(-1) + num_qo_heads = q.size(1) + num_kv_heads = k_current.size(1) + batch_size = qo_indptr.size(0) - 1 + + if sm_scale is None: + sm_scale = 1.0 / math.sqrt(head_dim) + + if workspace_buffer is None: + workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device) + + has_prefix = ( + paged_kv_cache is not None and paged_kv_indptr is not None and + paged_kv_indices is not None and paged_kv_last_page_len is not None and + paged_kv_indices.size(0) > 0 + ) + + if q_offsets is None: + if has_prefix: + import warnings + warnings.warn( + "q_offsets is None but prefix exists. Block extend mask may be incorrect " + "if prefix length is nonzero. Consider passing explicit q_offsets.", + stacklevel=2, + ) + q_offsets = torch.zeros(batch_size, dtype=torch.int32, device=device) + if kv_offsets is None: + kv_offsets = q_offsets # safe alias: neither is mutated downstream + + # Stage 1: Current Chunk (Ragged) + current_wrapper = BatchBlockExtendRaggedOffsetWrapper( + workspace_buffer, kv_layout="NHD", dllm_block_size=dllm_block_size, backend=backend, + ) + current_wrapper.plan( + qo_indptr=qo_indptr, kv_indptr=kv_curr_indptr, + num_qo_heads=num_qo_heads, num_kv_heads=num_kv_heads, head_dim=head_dim, + q_data_type=q.dtype, sm_scale=sm_scale, q_offsets=q_offsets, kv_offsets=kv_offsets, + ) + + if has_prefix: + o1, s1 = current_wrapper.run(q, k_current, v_current, return_lse=True) + else: + return current_wrapper.run(q, k_current, v_current, return_lse=return_lse) + + # Stage 2: Prefix (Paged) + prefix_wrapper = BatchBlockExtendPagedOffsetWrapper( + workspace_buffer, kv_layout="NHD", dllm_block_size=dllm_block_size, backend=backend, + ) + prefix_wrapper.plan( + qo_indptr=qo_indptr, paged_kv_indptr=paged_kv_indptr, + paged_kv_indices=paged_kv_indices, paged_kv_last_page_len=paged_kv_last_page_len, + num_qo_heads=num_qo_heads, num_kv_heads=num_kv_heads, head_dim=head_dim, + page_size=page_size, q_data_type=q.dtype, sm_scale=sm_scale, + q_offsets=q_offsets, kv_offsets=None, + ) + o2, s2 = prefix_wrapper.run(q, paged_kv_cache, return_lse=True) + + # Stage 3: Merge State + o, s = merge_state(o1, s1, o2, s2) + return (o, s) if return_lse else o + + +@flashinfer_api +def sglang_style_cascade_attention( + q: torch.Tensor, + k_current: torch.Tensor, + v_current: torch.Tensor, + qo_indptr: torch.Tensor, + kv_curr_indptr: torch.Tensor, + paged_kv_cache: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, + paged_kv_indptr: Optional[torch.Tensor] = None, + paged_kv_indices: Optional[torch.Tensor] = None, + paged_kv_last_page_len: Optional[torch.Tensor] = None, + page_size: int = 16, + workspace_buffer: Optional[torch.Tensor] = None, + sm_scale: Optional[float] = None, + return_lse: bool = False, + backend: str = "fa2", +) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """SGLang style Cascade Attention (non-causal current chunk + non-causal prefix + merge)""" + from ..cascade import merge_state + from ..prefill import BatchPrefillWithRaggedKVCacheWrapper, BatchPrefillWithPagedKVCacheWrapper + + assert q.dim() == 3 and k_current.dim() == 3 and v_current.dim() == 3 + + device = q.device + head_dim = q.size(-1) + num_qo_heads = q.size(1) + num_kv_heads = k_current.size(1) + + if sm_scale is None: + sm_scale = 1.0 / math.sqrt(head_dim) + + if workspace_buffer is None: + workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device) + + has_prefix = ( + paged_kv_cache is not None and paged_kv_indptr is not None and + paged_kv_indices is not None and paged_kv_last_page_len is not None and + paged_kv_indices.size(0) > 0 + ) + + # Stage 1: Current Chunk (Ragged, causal=False) + # Non-causal because the current chunk uses full bidirectional attention; + # the causal relationship with the prefix is handled by the merge stage. + ragged_wrapper = BatchPrefillWithRaggedKVCacheWrapper(workspace_buffer, kv_layout="NHD", backend=backend) + ragged_wrapper.plan( + qo_indptr=qo_indptr, kv_indptr=kv_curr_indptr, + num_qo_heads=num_qo_heads, num_kv_heads=num_kv_heads, + head_dim_qk=head_dim, head_dim_vo=head_dim, q_data_type=q.dtype, causal=False, + ) + + if has_prefix: + o1, s1 = ragged_wrapper.run(q, k_current, v_current, return_lse=True) + else: + return ragged_wrapper.run(q, k_current, v_current, return_lse=return_lse) + + # Stage 2: Prefix (Paged, causal=False) + paged_wrapper = BatchPrefillWithPagedKVCacheWrapper(workspace_buffer, kv_layout="NHD", backend=backend) + paged_wrapper.plan( + qo_indptr=qo_indptr, paged_kv_indptr=paged_kv_indptr, + paged_kv_indices=paged_kv_indices, paged_kv_last_page_len=paged_kv_last_page_len, + num_qo_heads=num_qo_heads, num_kv_heads=num_kv_heads, + head_dim_qk=head_dim, head_dim_vo=head_dim, page_size=page_size, + q_data_type=q.dtype, causal=False, + ) + o2, s2 = paged_wrapper.run(q, paged_kv_cache, return_lse=True) + + # Stage 3: Merge State + o, s = merge_state(o1, s1, o2, s2) + return (o, s) if return_lse else o diff --git a/flashinfer/dllm/block_extend.py b/flashinfer/dllm/block_extend.py new file mode 100644 index 0000000000..dd8e39c420 --- /dev/null +++ b/flashinfer/dllm/block_extend.py @@ -0,0 +1,422 @@ +""" +Blockwise Extend Attention with Tile-Level Skip Optimization + +Optimization Principle: + +Use native MaskMode::kBlockExpanding to trigger kernel's built-in tile-level skip optimization: + +1. num_iterations calculation: Precisely calculate KV tiles to iterate based on Block Expanding boundaries + kv_valid_end = ((q_tile_end - 1) / dllm_block_size + 1) * dllm_block_size + Completely invisible KV tiles are skipped directly, not loaded or computed + +2. mask_iteration calculation: Determine the first iteration that needs mask checking + kv_fully_visible_end = (q_tile_start / dllm_block_size + 1) * dllm_block_size + Tiles before this are fully visible, no per-element check needed + +3. Native mask calculation: Use (q_block >= k_block) rule on boundary tiles + +Block Extend Mask Rules: + mask[q, k] = (q / dllm_block_size) >= (k / dllm_block_size) + Bidirectional visibility within the same block, can see previous blocks, cannot see subsequent blocks + +Usage: + from flashinfer.dllm import block_extend_attention_with_offset + o = block_extend_attention_with_offset(q, k, v, dllm_block_size=32) +""" + +import math +import torch +from pathlib import Path +from typing import Optional, Union, Tuple + +from ..jit import env as jit_env +from ..jit.attention import gen_customize_single_prefill_module +from ..prefill import single_prefill_with_kv_cache_with_jit_module +from ..utils import MaskMode, is_sm90a_supported +from ..api_logging import flashinfer_api + +BLOCK_EXTEND_V2_WITH_OFFSET_VARIANT_DECL = r""" +// For incremental Chunk Prefill scenarios: +// - Each chunk's Q has global offset q_offset +// - Kernel retrieves offset via params.get_q_block_expanding_offset() +// - position_mask internally calculates: (q_global_block >= k_block) + +struct BlockExtendAttentionV2WithOffset : AttentionVariantBase { + static constexpr bool use_softmax = true; + + uint32_t qo_len; + uint32_t kv_len; + uint32_t window_left; + float sm_scale_log2; + + template + __device__ __host__ BlockExtendAttentionV2WithOffset(const Params& params, uint32_t batch_idx, + uint8_t* smem_ptr) { + qo_len = params.get_qo_len(batch_idx); + kv_len = params.get_kv_len(batch_idx); + sm_scale_log2 = params.sm_scale * math::log2e; + window_left = kv_len; // No sliding window + } + + // CUDA kernel natively supports MaskMode::kBlockExpanding: + // - q_offset retrieved via params.get_q_block_expanding_offset(batch_idx) + // - position_mask internally handles: (q_global_block >= k_block) + // + // Therefore LogitsMask only needs to return true + + REGISTER_LOGITS_MASK(params, batch_idx, qo_idx, kv_idx, qo_head_idx, kv_head_idx, { + return true; // kernel's position_mask already handles Block Expanding + q_offset logic + }); + + // No additional logits transformation needed + REGISTER_LOGITS_TRANSFORM(params, logits, batch_idx, qo_idx, kv_idx, qo_head_idx, kv_head_idx, { + return logits; + }); +}; +""" + +def _get_aot_path(uri: str) -> Path: + """Get AOT precompiled path (unified interface)""" + return jit_env.FLASHINFER_AOT_DIR / uri / f"{uri}.so" + + +def _check_aot_available(uri: str) -> bool: + """Check if AOT kernel is available (unified interface)""" + import os + if os.environ.get("FLASHINFER_FORCE_JIT", "0") == "1": + return False + return _get_aot_path(uri).exists() + + +def _get_dtype_str(dtype: torch.dtype) -> str: + """Get dtype string representation (unified interface)""" + _dtype_map = { + torch.float16: "fp16", + torch.bfloat16: "bf16", + } + if dtype not in _dtype_map: + raise ValueError( + f"Unsupported dtype {dtype} for Block Extend Attention. " + f"Supported: {list(_dtype_map.keys())}" + ) + return _dtype_map[dtype] + + +def _get_module_uri_with_offset(head_dim: int, dtype: torch.dtype, backend: str) -> str: + """Generate unique identifier for with offset module + + v2: 4 scalar params (sm_scale, dllm_block_size, q_block_expanding_offset, + kv_block_expanding_offset) + Old version (without _v2 suffix) only has 3 scalars, will automatically match new URI when recompilation is needed. + """ + return f"block_expanding_{backend}_with_offset_v2_hdim{head_dim}_{_get_dtype_str(dtype)}" + + +_MODULE_CACHE_WITH_OFFSET = {} + +# V3 FA3 Variant Definition: Block Expanding Attention for Hopper (SM90) architecture + +# FA3 uses different variant interface: +# - Constructor receives MainloopParams and BlockCoord +# - Requires GetAttentionUpdater() template function +# - Access custom parameters via params.additional_params.xxx +# +# FA3 kernel natively supports kBlockExpanding, so LogitsTransform only needs to return logits + +BLOCK_EXTEND_V3_WITH_OFFSET_VARIANT_DECL = r""" +// FA3 kernel natively supports MaskMode::kBlockExpanding: +// - get_num_kv_tiles(): Precisely calculates KV valid range based on Block Expanding boundaries +// - mma_f16(): BLOCK_EXPANDING template parameter controls n_masking_steps and col_limit +// - position_mask: (q_global_block >= k_block) && (kv_idx < kv_len) +// +// Therefore LogitsTransform only needs to return logits, letting kernel's native mask logic take effect + +struct BlockExtendAttentionV3WithOffset : AttentionVariantBase { + float sm_scale_log2; + template + __device__ __host__ BlockExtendAttentionV3WithOffset( + const MainloopParams& params, const BlockCoord& block_coord) { + sm_scale_log2 = params.additional_params.sm_scale * math::log2e; + } + template + __device__ auto GetAttentionUpdater() { + return OnlineSoftmax(sm_scale_log2); + } + + REGISTER_LOGITS_TRANSFORM(params, logits, batch_idx, qo_idx, kv_idx, qo_head_idx, kv_head_idx, { + return logits; // kernel's native BLOCK_EXPANDING mask already handles this + }); +}; +""" + + +def get_block_extend_module_with_offset( + head_dim: int = 128, + dtype: torch.dtype = torch.float16, + backend: str = "fa2", + device: Optional[torch.device] = None, +): + """ + Get Block Extend Attention module with q_offset/kv_offset support + + Args: + head_dim: Head dimension + dtype: Data type + backend: "fa2" or "fa3" + device: Target CUDA device (default: current CUDA device) + + Returns: + Compiled module + + Raises: + RuntimeError: If backend="fa3" but GPU doesn't support SM90 + """ + import os + import tvm_ffi + + if device is None: + device = torch.device("cuda") + + # FA3 requires SM90 support + if backend == "fa3" and not is_sm90a_supported(device): + raise RuntimeError( + "FA3 backend requires SM90 (Hopper) architecture. " + "Use backend='fa2' for older architectures." + ) + + cache_key = (head_dim, dtype, backend, device) + if cache_key in _MODULE_CACHE_WITH_OFFSET: + return _MODULE_CACHE_WITH_OFFSET[cache_key] + + uri = _get_module_uri_with_offset(head_dim, dtype, backend) + + # AOT mode + if _check_aot_available(uri): + aot_path = _get_aot_path(uri) + module = tvm_ffi.load_module(str(aot_path)) + _MODULE_CACHE_WITH_OFFSET[cache_key] = module + return module + + # AOT not available, check if JIT is disabled + if os.environ.get("FLASHINFER_DISABLE_JIT", "0") == "1": + raise RuntimeError( + f"JIT compilation is disabled via FLASHINFER_DISABLE_JIT environment variable, " + f"but the required AOT module is not found at: {_get_aot_path(uri)}." + ) + + # JIT mode + if backend == "fa3": + variant_name = "BlockExtendAttentionV3WithOffset" + variant_decl = BLOCK_EXTEND_V3_WITH_OFFSET_VARIANT_DECL + else: + variant_name = "BlockExtendAttentionV2WithOffset" + variant_decl = BLOCK_EXTEND_V2_WITH_OFFSET_VARIANT_DECL + + spec = gen_customize_single_prefill_module( + backend=backend, + uri=uri, + dtype_q=dtype, + dtype_kv=dtype, + dtype_o=dtype, + head_dim_qk=head_dim, + head_dim_vo=head_dim, + additional_tensor_names=[], + additional_tensor_dtypes=[], + additional_scalar_names=["sm_scale", "dllm_block_size", "q_block_expanding_offset", "kv_block_expanding_offset"], + additional_scalar_dtypes=["double", "int64_t", "int64_t", "int64_t"], + variant_name=variant_name, + variant_decl=variant_decl, + mask_modes=[4], # kBlockExpanding = 4 + ) + module = spec.build_and_load() + + _MODULE_CACHE_WITH_OFFSET[cache_key] = module + return module + +@flashinfer_api +def block_extend_attention_with_offset( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + dllm_block_size: int, + q_offset: int = 0, + kv_offset: int = 0, + sm_scale: Optional[float] = None, + return_lse: bool = False, + backend: str = "auto", +) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + Block Extend Attention with Q and KV Offset Support + + Supports incremental Chunk Prefill and Cascade Current Chunk scenarios. + + Args: + q: Query tensor [qo_len, num_heads, head_dim] + k: Key tensor [kv_len, num_heads, head_dim] + v: Value tensor [kv_len, num_heads, head_dim] + dllm_block_size: DLLM block size (must be power of 2) + q_offset: Q's global starting position (default 0) + kv_offset: KV's global starting position (default 0) + sm_scale: Softmax scale (default 1/sqrt(head_dim)) + return_lse: Whether to return log-sum-exp + backend: "auto" (auto-select), "fa2" or "fa3" + + Returns: + Output tensor [qo_len, num_heads, head_dim] + + Example: + >>> # Incremental chunk prefill + >>> o = block_extend_attention_with_offset( + ... q, k_cumul, v_cumul, + ... dllm_block_size=32, + ... q_offset=i * chunk_len, + ... ) + >>> + >>> # Cascade Current Chunk + >>> o = block_extend_attention_with_offset( + ... q, k_current, v_current, + ... dllm_block_size=256, + ... q_offset=prefix_len, + ... kv_offset=prefix_len, + ... ) + """ + assert q.dim() == 3 and k.dim() == 3 and v.dim() == 3, \ + "q, k, v must be 3D tensors [seq_len, num_heads, head_dim]" + assert dllm_block_size > 0 and (dllm_block_size & (dllm_block_size - 1)) == 0, \ + f"dllm_block_size must be a positive power of 2, got {dllm_block_size}" + + head_dim = q.size(-1) + dtype = q.dtype + + if sm_scale is None: + sm_scale = 1.0 / math.sqrt(head_dim) + + # backend selection + if backend == "auto": + backend = "fa3" if is_sm90a_supported(q.device) else "fa2" + + module = get_block_extend_module_with_offset(head_dim=head_dim, dtype=dtype, backend=backend, device=q.device) + + return single_prefill_with_kv_cache_with_jit_module( + module, + q, k, v, + sm_scale, + dllm_block_size, + q_offset, + kv_offset, + mask_mode=MaskMode.BLOCK_EXPANDING.value, + return_lse=return_lse, + ) + +# FA2 Cascade version: Current Chunk (causal=True) + Prefix (causal=False) + merge_state +# Similar to SGLang's Cascade Attention implementation: +# - When chunk_size = dllm_block_size, causal mask ≡ block extend mask +# - Uses standard FlashInfer API, no custom_mask needed +# - Prefix fully visible (Q's block >= all prefix's blocks) +@flashinfer_api +def block_extend_cascade( + q: torch.Tensor, + k_current: torch.Tensor, + v_current: torch.Tensor, + k_prefix: Optional[torch.Tensor] = None, + v_prefix: Optional[torch.Tensor] = None, + dllm_block_size: int = 64, + sm_scale: Optional[float] = None, + return_lse: bool = False, + backend: str = "auto", +) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + Cascade Block Extend Attention (Single Request) + + Implements Cascade Attention using native Block Extend mask. + + Three-stage Cascade: + 1. Current Chunk: Q attend to K_current/V_current (Block Expanding mask) + 2. Prefix: Q attend to K_prefix/V_prefix (causal=False, fully visible) + 3. Merge State: Merge softmax states from both stages + + Args: + q: Query tensor [qo_len, num_heads, head_dim] - Q of current chunk + k_current: Key tensor [curr_kv_len, num_kv_heads, head_dim] - K of current chunk + v_current: Value tensor [curr_kv_len, num_kv_heads, head_dim] - V of current chunk + k_prefix: Key tensor [prefix_len, num_kv_heads, head_dim] - Prefix K (optional) + v_prefix: Value tensor [prefix_len, num_kv_heads, head_dim] - Prefix V (optional) + dllm_block_size: DLLM block size (must be power of 2) + sm_scale: Softmax scale (default 1/sqrt(head_dim)) + return_lse: Whether to return log-sum-exp + backend: Backend selection ("auto", "fa2", "fa3") + + Returns: + If return_lse=False: Output tensor [qo_len, num_heads, head_dim] + If return_lse=True: (output, lse) tuple + + Example: + >>> # Incremental chunk prefill + >>> # Step 0: Q=[0:64), K=[0:64), V=[0:64) + >>> o0 = block_extend_cascade(q0, k0, v0, dllm_block_size=64) + >>> + >>> # Step 1: Q=[64:128), K_curr=[64:128), K_prefix=[0:64) + >>> o1 = block_extend_cascade(q1, k1, v1, k_prefix=k0, v_prefix=v0, dllm_block_size=64) + """ + from ..cascade import merge_state_in_place + from ..prefill import single_prefill_with_kv_cache + + assert q.dim() == 3 and k_current.dim() == 3 and v_current.dim() == 3, \ + "q, k_current, v_current must be 3D tensors [seq_len, num_heads, head_dim]" + + head_dim = q.size(-1) + if sm_scale is None: + sm_scale = 1.0 / math.sqrt(head_dim) + + has_prefix = k_prefix is not None and v_prefix is not None + prefix_len = k_prefix.size(0) if has_prefix else 0 + + # Stage 1: Current Chunk (Block Expanding mask) + # + # Use native Block Expanding mask with offset support: + # - q_offset = prefix_len: Q's global position starts from prefix_len + # - kv_offset = prefix_len: K_current's global position also starts from prefix_len + # - mask[q, k] = ((q_offset + q_idx) / B) >= ((kv_offset + kv_idx) / B) + # + # Return directly when no prefix, force return_lse=True when prefix exists for merge + if not has_prefix: + return block_extend_attention_with_offset( + q, k_current, v_current, + dllm_block_size=dllm_block_size, + q_offset=prefix_len, # = 0 + kv_offset=prefix_len, # = 0 + sm_scale=sm_scale, + return_lse=return_lse, + backend=backend, + ) + + o1, s1 = block_extend_attention_with_offset( + q, k_current, v_current, + dllm_block_size=dllm_block_size, + q_offset=prefix_len, + kv_offset=prefix_len, + sm_scale=sm_scale, + return_lse=True, # merge requires lse + backend=backend, + ) + + # Stage 2: Prefix (causal=False, fully visible) + # + # Q's global position >= prefix's end position, so: + # - Q_block >= all prefix's K_block (since q_offset = prefix_len) + # - Block Expanding mask is all 1 for prefix, equivalent to causal=False + # + o2, s2 = single_prefill_with_kv_cache( + q, k_prefix, v_prefix, + causal=False, # prefix fully visible + sm_scale=sm_scale, + return_lse=True, + ) + + # Stage 3: Merge State (in-place merge) + merge_state_in_place(o1, s1, o2, s2) + + if return_lse: + return o1, s1 + else: + return o1 diff --git a/flashinfer/jit/attention/modules.py b/flashinfer/jit/attention/modules.py index bb6962b791..0e6a2480a6 100755 --- a/flashinfer/jit/attention/modules.py +++ b/flashinfer/jit/attention/modules.py @@ -15,7 +15,7 @@ """ import os -from typing import List +from typing import List, Optional import jinja2 import torch @@ -1280,6 +1280,7 @@ def gen_customize_single_prefill_module( use_logits_soft_cap: bool = False, use_fp16_qk_reduction: bool = False, fp8_enabled: bool = False, + mask_modes: Optional[List[int]] = None, ) -> JitSpec: kwargs = { "variant_decl": variant_decl, @@ -1321,6 +1322,7 @@ def gen_customize_single_prefill_module( "additional_func_params": additional_func_params, "additional_params_decl": additional_params_decl, "additional_params_setter": additional_params_setter, + "has_q_block_expanding_offset": "q_block_expanding_offset" in additional_scalar_names, } generated_inc_str = config_templ.render( @@ -1329,7 +1331,8 @@ def gen_customize_single_prefill_module( os.makedirs(gen_directory, exist_ok=True) source_paths = [] - for mask_mode in [0, 1, 2, 3]: + _mask_modes = mask_modes if mask_modes is not None else [0, 1, 2, 3] + for mask_mode in _mask_modes: filename = f"single_prefill_kernel_mask_{mask_mode}.cu" dest_path = gen_directory / filename source_paths.append(dest_path) @@ -1393,7 +1396,8 @@ def gen_customize_single_prefill_module( os.makedirs(gen_directory, exist_ok=True) source_paths = [] - for mask_mode in [0, 1, 2, 3]: + _mask_modes = mask_modes if mask_modes is not None else [0, 1, 2, 3] + for mask_mode in _mask_modes: filename = f"single_prefill_sm90_kernel_mask_{mask_mode}.cu" dest_path = gen_directory / filename source_paths.append(dest_path) @@ -1525,6 +1529,7 @@ def gen_customize_batch_prefill_module( use_logits_soft_cap: bool = False, use_fp16_qk_reduction: bool = False, fp8_enabled: bool = False, + mask_modes: Optional[List[int]] = None, ) -> JitSpec: kwargs = { "variant_decl": variant_decl, @@ -1580,7 +1585,8 @@ def gen_customize_batch_prefill_module( os.makedirs(gen_directory, exist_ok=True) source_paths = [] - for mask_mode in [0, 1, 2, 3]: + _mask_modes = mask_modes if mask_modes is not None else [0, 1, 2, 3] + for mask_mode in _mask_modes: dest_path = ( gen_directory / f"batch_prefill_paged_kernel_mask_{mask_mode}.cu" ) @@ -1654,7 +1660,8 @@ def gen_customize_batch_prefill_module( generated_inc_str = config_templ.render(**kwargs) source_paths = [] - for mask_mode in [0, 1, 2, 3]: + _mask_modes = mask_modes if mask_modes is not None else [0, 1, 2, 3] + for mask_mode in _mask_modes: filename = f"batch_prefill_paged_sm90_kernel_mask_{mask_mode}.cu" dest_path = gen_directory / filename source_paths.append(dest_path) diff --git a/flashinfer/jit/utils.py b/flashinfer/jit/utils.py index 4e19212e14..20d5ea47ff 100644 --- a/flashinfer/jit/utils.py +++ b/flashinfer/jit/utils.py @@ -80,4 +80,5 @@ def write_if_different(path: pathlib.Path, content: str) -> None: 1: "MaskMode::kCausal", 2: "MaskMode::kCustom", 3: "MaskMode::kMultiItemScoring", + 4: "MaskMode::kBlockExpanding", } diff --git a/flashinfer/prefill.py b/flashinfer/prefill.py index 62d886042b..8c3c7b07ee 100755 --- a/flashinfer/prefill.py +++ b/flashinfer/prefill.py @@ -182,6 +182,7 @@ def get_customize_batch_prefill_module( use_logits_soft_cap: bool = False, use_fp16_qk_reduction: bool = False, fp8_enabled: bool = False, + mask_modes: Optional[List[int]] = None, ): return gen_customize_batch_prefill_module( backend, @@ -203,6 +204,7 @@ def get_customize_batch_prefill_module( use_logits_soft_cap, use_fp16_qk_reduction, fp8_enabled, + mask_modes, ).build_and_load() @@ -1579,6 +1581,7 @@ def __init__( self._seq_lens_kv = None self._seq_lens_q = None self._block_tables = None + self._mask_mode = None @property def is_cuda_graph_enabled(self) -> bool: @@ -1645,6 +1648,7 @@ def plan( max_sequence_kv: Optional[int] = None, fixed_split_size: Optional[int] = None, disable_split_kv: bool = False, + mask_mode: Optional[int] = None, ) -> None: r"""Plan batch prefill/append attention on Paged KV-Cache for given problem specification. @@ -2003,6 +2007,7 @@ def plan( self._rope_theta = rope_theta self._seq_lens_kv = seq_lens self._seq_lens_q = seq_lens_q if seq_lens_q is not None else seq_lens + self._mask_mode = mask_mode begin_forward = plan @@ -2197,6 +2202,8 @@ def run( if self._custom_mask_buf is not None: mask_mode = MaskMode.CUSTOM.value + elif self._mask_mode is not None: + mask_mode = self._mask_mode else: if self._causal: mask_mode = MaskMode.CAUSAL.value @@ -2563,6 +2570,7 @@ def __init__( self._max_total_num_rows: Optional[int] = None self._backend = backend self._cached_module = None + self._mask_mode = None @property def is_cuda_graph_enabled(self) -> bool: @@ -2621,6 +2629,7 @@ def plan( max_item_len_ptr: Optional[torch.Tensor] = None, fixed_split_size: Optional[int] = None, disable_split_kv: bool = False, + mask_mode:Optional[int] = None, seq_lens: Optional[torch.Tensor] = None, seq_lens_q: Optional[torch.Tensor] = None, max_token_per_sequence: Optional[int] = None, @@ -2926,6 +2935,7 @@ def plan( self._sm_scale = sm_scale self._rope_scale = rope_scale self._rope_theta = rope_theta + self._mask_mode = mask_mode begin_forward = plan @@ -3146,6 +3156,8 @@ def run( if self._custom_mask_buf is not None: mask_mode = MaskMode.CUSTOM.value + elif self._mask_mode is not None: + mask_mode = self._mask_mode else: if self._causal: mask_mode = MaskMode.CAUSAL.value diff --git a/flashinfer/utils.py b/flashinfer/utils.py index e6c2bd836d..febb95594a 100644 --- a/flashinfer/utils.py +++ b/flashinfer/utils.py @@ -40,6 +40,7 @@ class MaskMode(Enum): CAUSAL = 1 CUSTOM = 2 MULTIITEMSCORING = 3 + BLOCK_EXPANDING = 4 class TensorLayout(Enum): diff --git a/include/flashinfer/attention/block_expanding_prefill.cuh b/include/flashinfer/attention/block_expanding_prefill.cuh new file mode 100644 index 0000000000..69d7c0a60a --- /dev/null +++ b/include/flashinfer/attention/block_expanding_prefill.cuh @@ -0,0 +1,166 @@ +/* + * Copyright (c) 2023-2024 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef FLASHINFER_ATTENTION_BLOCK_EXPANDING_PREFILL_CUH_ +#define FLASHINFER_ATTENTION_BLOCK_EXPANDING_PREFILL_CUH_ + +#include +#include "../math.cuh" + +namespace flashinfer { + +// Block Expanding Mask: mask[q, k] = (q / B) >= (k / B) +// For Q tile [q_start, q_end), visible KV range: [0, ceil(q_end / B) * B) + +template +struct BlockExpandingTileSkipController { + uint32_t dllm_block_size; + uint32_t log2_block_size; + uint32_t kv_len; + + __device__ __host__ __forceinline__ + BlockExpandingTileSkipController(uint32_t block_size, uint32_t kv_length) + : dllm_block_size(block_size), kv_len(kv_length) { +#ifdef __CUDA_ARCH__ + log2_block_size = __popc(block_size - 1); +#else + log2_block_size = 0; + while ((1u << log2_block_size) < block_size) ++log2_block_size; +#endif + } + + __device__ __forceinline__ + uint32_t get_kv_valid_end(uint32_t q_tile_end) const { + uint32_t q_last_idx = q_tile_end - 1; + uint32_t q_block_id = q_last_idx >> log2_block_size; + uint32_t valid_end = (q_block_id + 1) << log2_block_size; + return min(valid_end, kv_len); + } + + __device__ __forceinline__ + uint32_t get_num_iterations(uint32_t q_tile_end, uint32_t chunk_start, + uint32_t chunk_size) const { + uint32_t kv_valid_end = get_kv_valid_end(q_tile_end); + uint32_t effective_chunk_size; + if (kv_valid_end <= chunk_start) { + effective_chunk_size = 0; + } else { + effective_chunk_size = min(chunk_size, kv_valid_end - chunk_start); + } + return (effective_chunk_size + CTA_TILE_KV - 1) / CTA_TILE_KV; + } + + __device__ __forceinline__ + uint32_t get_mask_iteration(uint32_t q_tile_start, uint32_t chunk_start, + uint32_t chunk_size) const { + uint32_t q_first_block_id = q_tile_start >> log2_block_size; + uint32_t kv_fully_visible_end = (q_first_block_id + 1) << log2_block_size; + uint32_t kv_fully_visible_in_chunk; + if (kv_fully_visible_end <= chunk_start) { + kv_fully_visible_in_chunk = 0; + } else { + kv_fully_visible_in_chunk = min(chunk_size, kv_fully_visible_end - chunk_start); + } + return kv_fully_visible_in_chunk / CTA_TILE_KV; + } + + __device__ __forceinline__ + bool needs_mask(uint32_t iter, uint32_t mask_iteration, + uint32_t num_iterations) const { + return iter >= mask_iteration && iter < num_iterations; + } + + __device__ __forceinline__ + bool is_tile_fully_visible(uint32_t q_tile_start, uint32_t kv_tile_start, + uint32_t kv_tile_end) const { + uint32_t q_first_block = q_tile_start >> log2_block_size; + uint32_t k_last_block = (kv_tile_end - 1) >> log2_block_size; + return q_first_block >= k_last_block; + } + + __device__ __forceinline__ + bool is_tile_fully_masked(uint32_t q_tile_end, uint32_t kv_tile_start) const { + uint32_t q_last_block = (q_tile_end - 1) >> log2_block_size; + uint32_t k_first_block = kv_tile_start >> log2_block_size; + return q_last_block < k_first_block; + } +}; + + +__device__ __forceinline__ uint32_t block_expanding_kv_valid_end( + uint32_t q_tile_end, uint32_t dllm_block_size, uint32_t q_offset = 0) { + uint32_t q_global_last_idx = q_offset + q_tile_end - 1; + uint32_t q_block_id = q_global_last_idx / dllm_block_size; + return (q_block_id + 1) * dllm_block_size; +} + +__device__ __forceinline__ uint32_t block_expanding_num_iterations( + uint32_t q_tile_end, uint32_t chunk_start, uint32_t chunk_size, + uint32_t dllm_block_size, uint32_t CTA_TILE_KV, uint32_t q_offset = 0) { + uint32_t kv_valid_end = block_expanding_kv_valid_end(q_tile_end, dllm_block_size, q_offset); + uint32_t effective_chunk_size; + if (kv_valid_end <= chunk_start) { + effective_chunk_size = 0; + } else { + effective_chunk_size = min(chunk_size, kv_valid_end - chunk_start); + } + return (effective_chunk_size + CTA_TILE_KV - 1) / CTA_TILE_KV; +} + +__device__ __forceinline__ uint32_t block_expanding_mask_iteration( + uint32_t q_tile_start, uint32_t chunk_start, uint32_t chunk_size, + uint32_t dllm_block_size, uint32_t CTA_TILE_KV, uint32_t q_offset = 0, uint32_t kv_offset = 0) { + uint32_t q_global_start = q_offset + q_tile_start; + uint32_t q_first_block_id = q_global_start / dllm_block_size; + // Consider kv_offset: kv_global < (q_block + 1) * B + // kv_offset + kv_local < (q_block + 1) * B + // kv_local < (q_block + 1) * B - kv_offset + int64_t max_kv_global = (int64_t)(q_first_block_id + 1) * dllm_block_size; + int64_t kv_fully_visible_end_local = max_kv_global - (int64_t)kv_offset; + uint32_t kv_fully_visible_end = (kv_fully_visible_end_local > 0) ? (uint32_t)kv_fully_visible_end_local : 0; + uint32_t kv_fully_visible_in_chunk; + if (kv_fully_visible_end <= chunk_start) { + kv_fully_visible_in_chunk = 0; + } else { + kv_fully_visible_in_chunk = min(chunk_size, kv_fully_visible_end - chunk_start); + } + return kv_fully_visible_in_chunk / CTA_TILE_KV; +} + +__device__ __forceinline__ bool block_expanding_needs_mask( + uint32_t iter, uint32_t mask_iteration, uint32_t num_iterations) { + return iter >= mask_iteration && iter < num_iterations; +} + +// MMA tile level check +__device__ __forceinline__ bool block_expanding_tile_fully_visible( + uint32_t q_start, uint32_t q_end, uint32_t kv_start, uint32_t kv_end, + uint32_t log2_block_size) { + uint32_t q_first_block = q_start >> log2_block_size; + uint32_t k_last_block = (kv_end - 1) >> log2_block_size; + return q_first_block >= k_last_block; +} + +__device__ __forceinline__ bool block_expanding_tile_fully_masked( + uint32_t q_end, uint32_t kv_start, uint32_t log2_block_size) { + uint32_t q_last_block = (q_end - 1) >> log2_block_size; + uint32_t k_first_block = kv_start >> log2_block_size; + return q_last_block < k_first_block; +} + +} // namespace flashinfer + +#endif // FLASHINFER_ATTENTION_BLOCK_EXPANDING_PREFILL_CUH_ diff --git a/include/flashinfer/attention/default_prefill_params.cuh b/include/flashinfer/attention/default_prefill_params.cuh index 2e857fcc72..e8b1dae661 100644 --- a/include/flashinfer/attention/default_prefill_params.cuh +++ b/include/flashinfer/attention/default_prefill_params.cuh @@ -57,6 +57,9 @@ struct SinglePrefillParams { float rope_rcp_theta; uint32_t partition_kv; + uint32_t dllm_block_size; // DLLM block size for BlockExpanding mask (0 = disabled) + uint32_t q_block_expanding_offset; // Q global offset (for incremental chunk prefill) + uint32_t kv_block_expanding_offset; // KV global offset (for Cascade Current Chunk) __host__ SinglePrefillParams() : q(nullptr), @@ -83,7 +86,10 @@ struct SinglePrefillParams { sm_scale(0.0f), rope_rcp_scale(0.0f), rope_rcp_theta(0.0f), - partition_kv(false) {} + partition_kv(false), + dllm_block_size(0), + q_block_expanding_offset(0), + kv_block_expanding_offset(0) {} __host__ SinglePrefillParams(DTypeQ* q, DTypeKV* k, DTypeKV* v, uint8_t* maybe_custom_mask, DTypeO* o, float* lse, float* maybe_alibi_slopes, @@ -116,7 +122,10 @@ struct SinglePrefillParams { sm_scale(sm_scale), rope_rcp_scale(1. / rope_scale), rope_rcp_theta(1. / rope_theta), - partition_kv(false) {} + partition_kv(false), + dllm_block_size(0), + q_block_expanding_offset(0), + kv_block_expanding_offset(0) {} __host__ __device__ __forceinline__ uint32_t get_qo_len(uint32_t batch_idx) const { return qo_len; @@ -125,6 +134,16 @@ struct SinglePrefillParams { __host__ __device__ __forceinline__ uint32_t get_kv_len(uint32_t batch_idx) const { return kv_len; } + + // Single prefill supports incremental chunk prefill + __host__ __device__ __forceinline__ uint32_t get_q_block_expanding_offset(uint32_t batch_idx) const { + return q_block_expanding_offset; + } + + // Single prefill supports Cascade Current Chunk + __host__ __device__ __forceinline__ uint32_t get_kv_block_expanding_offset(uint32_t batch_idx) const { + return kv_block_expanding_offset; + } }; template @@ -177,6 +196,11 @@ struct BatchPrefillRaggedParams { uint32_t token_pos_in_items_len; uint16_t* maybe_max_item_len_ptr; + // Block Expanding mask supports incremental chunk prefill + uint32_t dllm_block_size; // DLLM block size (0 = disabled) + IdType* maybe_q_block_expanding_offset; // Q global offset (one value per batch item) + IdType* maybe_kv_block_expanding_offset; // KV global offset (one value per batch item, for Cascade Current Chunk) + __host__ BatchPrefillRaggedParams() : q(nullptr), k(nullptr), @@ -218,7 +242,10 @@ struct BatchPrefillRaggedParams { maybe_prefix_len_ptr(nullptr), maybe_token_pos_in_items_ptr(nullptr), token_pos_in_items_len(0), - maybe_max_item_len_ptr(nullptr) {} + maybe_max_item_len_ptr(nullptr), + dllm_block_size(0), + maybe_q_block_expanding_offset(nullptr), + maybe_kv_block_expanding_offset(nullptr) {} __host__ BatchPrefillRaggedParams(DTypeQ* q, DTypeKV* k, DTypeKV* v, uint8_t* maybe_custom_mask, IdType* q_indptr, IdType* kv_indptr, IdType* maybe_mask_indptr, @@ -269,7 +296,10 @@ struct BatchPrefillRaggedParams { maybe_prefix_len_ptr(nullptr), maybe_token_pos_in_items_ptr(nullptr), token_pos_in_items_len(0), - maybe_max_item_len_ptr(nullptr) {} + maybe_max_item_len_ptr(nullptr), + dllm_block_size(0), + maybe_q_block_expanding_offset(nullptr), + maybe_kv_block_expanding_offset(nullptr) {} __host__ __device__ __forceinline__ uint32_t get_qo_len(uint32_t batch_idx) const { return q_indptr[batch_idx + 1] - q_indptr[batch_idx]; @@ -278,6 +308,16 @@ struct BatchPrefillRaggedParams { __host__ __device__ __forceinline__ uint32_t get_kv_len(uint32_t batch_idx) const { return kv_indptr[batch_idx + 1] - kv_indptr[batch_idx]; } + + // Get batch item's Q global offset (for incremental chunk prefill) + __host__ __device__ __forceinline__ uint32_t get_q_block_expanding_offset(uint32_t batch_idx) const { + return (maybe_q_block_expanding_offset != nullptr) ? maybe_q_block_expanding_offset[batch_idx] : 0; + } + + // Get batch item's KV global offset (for Cascade Current Chunk) + __host__ __device__ __forceinline__ uint32_t get_kv_block_expanding_offset(uint32_t batch_idx) const { + return (maybe_kv_block_expanding_offset != nullptr) ? maybe_kv_block_expanding_offset[batch_idx] : 0; + } }; template @@ -322,6 +362,11 @@ struct BatchPrefillPagedParams { uint32_t token_pos_in_items_len; uint16_t* maybe_max_item_len_ptr; + // Block Expanding mask supports incremental chunk prefill + uint32_t dllm_block_size; // DLLM block size (0 = disabled) + IdType* maybe_q_block_expanding_offset; // Q global offset (one value per batch item) + IdType* maybe_kv_block_expanding_offset; // KV global offset (one value per batch item, for Cascade Current Chunk) + __host__ BatchPrefillPagedParams() : q(nullptr), paged_kv(), @@ -355,7 +400,10 @@ struct BatchPrefillPagedParams { maybe_prefix_len_ptr(nullptr), maybe_token_pos_in_items_ptr(nullptr), token_pos_in_items_len(0), - maybe_max_item_len_ptr(nullptr) {} + maybe_max_item_len_ptr(nullptr), + dllm_block_size(0), + maybe_q_block_expanding_offset(nullptr), + maybe_kv_block_expanding_offset(nullptr) {} __host__ BatchPrefillPagedParams(DTypeQ* q, paged_kv_t paged_kv, uint8_t* maybe_custom_mask, IdType* q_indptr, @@ -396,7 +444,10 @@ struct BatchPrefillPagedParams { maybe_prefix_len_ptr(nullptr), maybe_token_pos_in_items_ptr(nullptr), token_pos_in_items_len(0), - maybe_max_item_len_ptr(nullptr) {} + maybe_max_item_len_ptr(nullptr), + dllm_block_size(0), + maybe_q_block_expanding_offset(nullptr), + maybe_kv_block_expanding_offset(nullptr) {} __host__ __device__ __forceinline__ uint32_t get_qo_len(uint32_t batch_idx) const { return q_indptr[batch_idx + 1] - q_indptr[batch_idx]; @@ -405,6 +456,16 @@ struct BatchPrefillPagedParams { __host__ __device__ __forceinline__ uint32_t get_kv_len(uint32_t batch_idx) const { return paged_kv.get_length(batch_idx); } + + // Get batch item's Q global offset (for incremental chunk prefill) + __host__ __device__ __forceinline__ uint32_t get_q_block_expanding_offset(uint32_t batch_idx) const { + return (maybe_q_block_expanding_offset != nullptr) ? maybe_q_block_expanding_offset[batch_idx] : 0; + } + + // Get batch item's KV global offset (for Cascade Current Chunk) + __host__ __device__ __forceinline__ uint32_t get_kv_block_expanding_offset(uint32_t batch_idx) const { + return (maybe_kv_block_expanding_offset != nullptr) ? maybe_kv_block_expanding_offset[batch_idx] : 0; + } }; } // namespace flashinfer diff --git a/include/flashinfer/attention/hopper/mainloop.cuh b/include/flashinfer/attention/hopper/mainloop.cuh index e5bf4ffb9f..171c828cdf 100644 --- a/include/flashinfer/attention/hopper/mainloop.cuh +++ b/include/flashinfer/attention/hopper/mainloop.cuh @@ -24,7 +24,7 @@ namespace flashinfer { using namespace cute; -template +template struct CollectiveMainloop { using DTypeQ = typename Ktraits::DTypeQ; using DTypeKV = typename Ktraits::DTypeKV; @@ -137,9 +137,9 @@ struct CollectiveMainloop { cute::prefetch_tma_descriptor(mainloop_params.tma_load_V.get_tma_descriptor()); } - CUTLASS_DEVICE +CUTLASS_DEVICE int get_num_kv_tiles(Params const& mainloop_params, int q_tile_idx, const int qo_len, - const int kv_len) { + const int kv_len, const int batch_idx = 0) { static constexpr int CTA_Q = get<0>(TileShape_QKD{}); static constexpr int CTA_KV = get<1>(TileShape_QKD{}); int num_kv_tiles = cute::ceil_div(kv_len, CTA_KV); @@ -147,6 +147,43 @@ struct CollectiveMainloop { num_kv_tiles = std::min(num_kv_tiles, cute::ceil_div((q_tile_idx + 1) * CTA_Q + kv_len - qo_len, CTA_KV)); } + if constexpr (BLOCK_EXPANDING) { + // Block Expanding: Calculate valid KV range based on block boundaries + // q_tile_end = min((q_tile_idx + 1) * CTA_Q, qo_len) + // q_block = (q_offset + q_tile_end - 1) / B + // kv_valid_end = (q_block + 1) * B - kv_offset + int64_t dllm_block_size = 1; + int64_t q_offset = 0; + int64_t kv_offset = 0; // kv_offset support for Cascade Current Chunk + if constexpr (has_dllm_block_size_v) { + dllm_block_size = mainloop_params.additional_params.dllm_block_size; + } + // Prefer reading per-batch offset from maybe_q_block_expanding_offset array + if constexpr (has_maybe_q_block_expanding_offset_v) { + auto* offset_ptr = mainloop_params.additional_params.maybe_q_block_expanding_offset; + if (offset_ptr != nullptr) { + q_offset = offset_ptr[batch_idx]; + } + } else if constexpr (has_q_block_expanding_offset_v) { + q_offset = mainloop_params.additional_params.q_block_expanding_offset; + } + // Read kv_offset (for Cascade Current Chunk scenario) + if constexpr (has_maybe_kv_block_expanding_offset_v) { + auto* offset_ptr = mainloop_params.additional_params.maybe_kv_block_expanding_offset; + if (offset_ptr != nullptr) { + kv_offset = offset_ptr[batch_idx]; + } + } else if constexpr (has_kv_block_expanding_offset_v) { + kv_offset = mainloop_params.additional_params.kv_block_expanding_offset; + } + int q_tile_end = std::min((q_tile_idx + 1) * CTA_Q, qo_len); + int64_t q_global_end = q_offset + q_tile_end - 1; + int64_t q_block = q_global_end / dllm_block_size; + // Consider kv_offset: max valid kv_idx = (q_block + 1) * B - kv_offset + int64_t max_kv_global = (q_block + 1) * dllm_block_size; + int kv_valid_end = static_cast(std::max(max_kv_global - kv_offset, int64_t(0))); + num_kv_tiles = std::min(num_kv_tiles, cute::ceil_div(std::min(kv_len, kv_valid_end), CTA_KV)); + } return num_kv_tiles; } @@ -190,7 +227,7 @@ struct CollectiveMainloop { tma_partition(mainloop_params.tma_load_V, _0{}, Layout<_1>{}, group_modes<0, 2>(sV), group_modes<0, 2>(gV)); // (TMA, k), (TMA, PIPE) - int num_kv_tiles = get_num_kv_tiles(mainloop_params, q_tile_idx, qo_len, kv_len); + int num_kv_tiles = get_num_kv_tiles(mainloop_params, q_tile_idx, qo_len, kv_len, batch_idx); int kv_tile_idx = num_kv_tiles - 1; int swa_begin_kv_tile_idx = 0; if constexpr (LEFT_SLIDING_WINDOW) { diff --git a/include/flashinfer/attention/hopper/mainloop_mma.cuh b/include/flashinfer/attention/hopper/mainloop_mma.cuh index 27522f3187..5499ffda47 100644 --- a/include/flashinfer/attention/hopper/mainloop_mma.cuh +++ b/include/flashinfer/attention/hopper/mainloop_mma.cuh @@ -11,12 +11,13 @@ #include #include #include +#include "../../utils.cuh" #include "variants.cuh" namespace flashinfer { -template @@ -26,7 +27,7 @@ CUTLASS_DEVICE void mma_f16( FrgTensorO& tOrO, AttentionUpdater& attention_updater, int kv_tile_idx_count, int swa_begin_kv_tile_idx, int swa_end_kv_tile_idx, int thread_idx, int work_idx, int q_tile_idx, SharedStorage& shared_storage, const int32_t qo_len, const int32_t kv_len, - const int32_t qo_head_idx, const int32_t kv_head_idx, const uint32_t prefix_len, + const int32_t qo_head_idx, const int32_t kv_head_idx, const int32_t batch_idx, const uint32_t prefix_len, uint16_t* token_pos_in_items, const int num_kv_tiles_outside_items_window = 0, const int num_kv_tiles_prefix = 0) { using DTypeQ = typename Ktraits::DTypeQ; @@ -95,6 +96,54 @@ CUTLASS_DEVICE void mma_f16( auto col_limit_left = [&](int qo_idx) { return qo_idx + kv_len - qo_len - mainloop_params.window_left; }; + + // ════════════════════════════════════════════════════════════════════════════════════ + // Block Expanding Mask Helper + // ════════════════════════════════════════════════════════════════════════════════════ + // mask[q, k] = (q_global / B) >= (kv_global / B) + // q_global = q_offset + qo_idx + // kv_global = kv_offset + kv_idx + // ════════════════════════════════════════════════════════════════════════════════════ + int64_t dllm_block_size = 1; + int64_t q_block_expanding_offset = 0; + int64_t kv_block_expanding_offset = 0; // kv_offset support for Cascade Current Chunk + if constexpr (BLOCK_EXPANDING) { + if constexpr (has_dllm_block_size_v) { + dllm_block_size = mainloop_params.additional_params.dllm_block_size; + } + // Prefer reading per-batch offset from maybe_q_block_expanding_offset array + // Otherwise fallback to scalar q_block_expanding_offset + if constexpr (has_maybe_q_block_expanding_offset_v) { + auto* offset_ptr = mainloop_params.additional_params.maybe_q_block_expanding_offset; + if (offset_ptr != nullptr) { + q_block_expanding_offset = offset_ptr[batch_idx]; + } + } else if constexpr (has_q_block_expanding_offset_v) { + q_block_expanding_offset = mainloop_params.additional_params.q_block_expanding_offset; + } + // Read kv_offset (for Cascade Current Chunk scenario) + // Prefer reading per-batch offset from maybe_kv_block_expanding_offset array + // Otherwise fallback to scalar kv_block_expanding_offset + if constexpr (has_maybe_kv_block_expanding_offset_v) { + auto* offset_ptr = mainloop_params.additional_params.maybe_kv_block_expanding_offset; + if (offset_ptr != nullptr) { + kv_block_expanding_offset = offset_ptr[batch_idx]; + } + } else if constexpr (has_kv_block_expanding_offset_v) { + kv_block_expanding_offset = mainloop_params.additional_params.kv_block_expanding_offset; + } + } + auto block_expanding_col_limit = [&](int qo_idx) -> int { + // q_block = (q_offset + qo_idx) / B + // Consider kv_offset: kv_global = kv_offset + kv_idx < (q_block + 1) * B + // So kv_idx < (q_block + 1) * B - kv_offset + // Fix: Ensure result is non-negative (return 0 when max_kv_global <= kv_offset) + int64_t q_global = q_block_expanding_offset + qo_idx; + int64_t q_block = q_global / dllm_block_size; + int64_t max_kv_global = (q_block + 1) * dllm_block_size; + return static_cast(std::max(max_kv_global - kv_block_expanding_offset, int64_t(0))); + }; + auto mask_multi_item_scoring = [&](decltype(tSrS)& tSrS, int i, int qo_idx, int kv_idx) { const uint32_t idx_in_original_seq = qo_idx + kv_len - qo_len; const bool out_of_boundary = @@ -150,10 +199,15 @@ CUTLASS_DEVICE void mma_f16( for (int i = 0; i < size(tSrS); ++i) { int qo_idx = get<0>(tScS(i)) + q_tile_idx * CTA_Q; int kv_idx = get<1>(tScS(i)) + kv_tile_idx * CTA_KV; - tSrS(i) = variant.LogitsTransform(mainloop_params, tSrS(i), /*batch_idx=*/0, qo_idx, kv_idx, + tSrS(i) = variant.LogitsTransform(mainloop_params, tSrS(i), batch_idx, qo_idx, kv_idx, qo_head_idx, kv_head_idx); if constexpr (MULTIITEMSCORING) { mask_multi_item_scoring(tSrS, i, qo_idx, kv_idx); + } else if constexpr (BLOCK_EXPANDING) { + // Block Expanding Mask: (q_block >= k_block) && (kv_idx < kv_len) + if (kv_idx >= std::min(kv_len, block_expanding_col_limit(qo_idx))) { + tSrS(i) = AttentionUpdater::fill_value; + } } else if constexpr (!CAUSAL) { // Just masking based on col if (kv_idx >= kv_len) { tSrS(i) = AttentionUpdater::fill_value; @@ -176,7 +230,7 @@ CUTLASS_DEVICE void mma_f16( convert_layout_acc_Aregs(tSrS.layout())); constexpr int n_masking_steps = MULTIITEMSCORING ? (cute::ceil_div(CTA_Q, CTA_KV) + 1) - : (CAUSAL ? cute::ceil_div(CTA_Q, CTA_KV) : 0); + : ((CAUSAL || BLOCK_EXPANDING) ? cute::ceil_div(CTA_Q, CTA_KV) : 0); // masking loops // ziangl@nvidia.com: for multi item scoring, we use this loop only to mask along the diagonal #pragma unroll @@ -202,10 +256,15 @@ CUTLASS_DEVICE void mma_f16( for (int i = 0; i < size(tSrS); ++i) { int qo_idx = get<0>(tScS(i)) + q_tile_idx * CTA_Q; int kv_idx = get<1>(tScS(i)) + kv_tile_idx_decrement(kv_tile_idx) * CTA_KV; - tSrS(i) = variant.LogitsTransform(mainloop_params, tSrS(i), /*batch_idx=*/0, qo_idx, kv_idx, + tSrS(i) = variant.LogitsTransform(mainloop_params, tSrS(i), batch_idx, qo_idx, kv_idx, qo_head_idx, kv_head_idx); if (MULTIITEMSCORING) { mask_multi_item_scoring(tSrS, i, qo_idx, kv_idx); + } else if constexpr (BLOCK_EXPANDING) { + // Fix: Add kv_len boundary check to be consistent with initial mask logic + if (kv_idx >= std::min(kv_len, block_expanding_col_limit(qo_idx))) { + tSrS(i) = AttentionUpdater::fill_value; + } } else { if (kv_idx >= col_limit_right(qo_idx)) { tSrS(i) = AttentionUpdater::fill_value; @@ -248,7 +307,7 @@ CUTLASS_DEVICE void mma_f16( for (int i = 0; i < size(tSrS); ++i) { int qo_idx = get<0>(tScS(i)) + q_tile_idx * CTA_Q; int kv_idx = get<1>(tScS(i)) + kv_tile_idx_decrement(kv_tile_idx) * CTA_KV; - tSrS(i) = variant.LogitsTransform(mainloop_params, tSrS(i), /*batch_idx=*/0, qo_idx, kv_idx, + tSrS(i) = variant.LogitsTransform(mainloop_params, tSrS(i), batch_idx, qo_idx, kv_idx, qo_head_idx, kv_head_idx); } if constexpr (MULTIITEMSCORING) { @@ -294,7 +353,7 @@ CUTLASS_DEVICE void mma_f16( for (int i = 0; i < size(tSrS); ++i) { int qo_idx = get<0>(tScS(i)) + q_tile_idx * CTA_Q; int kv_idx = get<1>(tScS(i)) + (kv_tile_idx - 1) * CTA_KV; - tSrS(i) = variant.LogitsTransform(mainloop_params, tSrS(i), /*batch_idx=*/0, qo_idx, kv_idx, + tSrS(i) = variant.LogitsTransform(mainloop_params, tSrS(i), batch_idx, qo_idx, kv_idx, qo_head_idx, kv_head_idx); if (kv_idx < col_limit_left(qo_idx)) { tSrS(i) = AttentionUpdater::fill_value; diff --git a/include/flashinfer/attention/hopper/prefill_sm90.cuh b/include/flashinfer/attention/hopper/prefill_sm90.cuh index f1e441a53b..bc3313809e 100644 --- a/include/flashinfer/attention/hopper/prefill_sm90.cuh +++ b/include/flashinfer/attention/hopper/prefill_sm90.cuh @@ -41,7 +41,7 @@ DEFINE_HAS_MEMBER(token_pos_in_items_len) DEFINE_HAS_MEMBER(maybe_max_item_len_ptr) template __global__ void __launch_bounds__(Ktraits::NUM_WARPS* cutlass::NumThreadsPerWarp, 1) PrefillWithKVCacheKernel(CUTE_GRID_CONSTANT @@ -172,7 +172,7 @@ __global__ void __launch_bounds__(Ktraits::NUM_WARPS* cutlass::NumThreadsPerWarp continue; } int num_kv_tiles = - collective_mainloop.get_num_kv_tiles(mainloop_params, q_tile_idx, qo_len, kv_len); + collective_mainloop.get_num_kv_tiles(mainloop_params, q_tile_idx, qo_len, kv_len, batch_idx); if (num_kv_tiles <= 0) { scheduler.prefetch_next_work(scheduler_params, work_tile_info); scheduler.broadcast_next_work(work_tile_info); @@ -241,7 +241,7 @@ __global__ void __launch_bounds__(Ktraits::NUM_WARPS* cutlass::NumThreadsPerWarp continue; } int num_kv_tiles = - collective_mainloop.get_num_kv_tiles(mainloop_params, q_tile_idx, qo_len, kv_len); + collective_mainloop.get_num_kv_tiles(mainloop_params, q_tile_idx, qo_len, kv_len, batch_idx); if (num_kv_tiles <= 0) { // We exit early and write 0 to gO and -inf to gLSE. collective_epilogue.store_zero(epilogue_params, shared_storage, threadIdx.x - NUM_COPY_THREADS, block_coord); @@ -273,12 +273,12 @@ __global__ void __launch_bounds__(Ktraits::NUM_WARPS* cutlass::NumThreadsPerWarp num_kv_tiles_outside_items_window = valid_items_window_len / CTA_KV; num_kv_tiles_prefix = cute::ceil_div(prefix_len, CTA_KV); } - mma_f16( mainloop_params, variant, pipeline_k, pipeline_v, smem_pipe_read_k, smem_pipe_read_v, tOrO, attention_updater, num_kv_tiles, swa_begin_kv_tile_idx, swa_end_kv_tile_idx, threadIdx.x - NUM_COPY_THREADS, work_idx, q_tile_idx, shared_storage, qo_len, kv_len, - qo_head_idx, kv_head_idx, prefix_len, token_pos_in_items, + qo_head_idx, kv_head_idx, batch_idx, prefix_len, token_pos_in_items, num_kv_tiles_outside_items_window, num_kv_tiles_prefix); collective_epilogue.store(epilogue_params, tOrO, attention_updater.get_lse(), shared_storage, tiled_mma_pv, threadIdx.x - NUM_COPY_THREADS, block_coord); @@ -289,14 +289,14 @@ __global__ void __launch_bounds__(Ktraits::NUM_WARPS* cutlass::NumThreadsPerWarp } } -template +template cudaError_t SinglePrefillWithKVCacheKernelTraitsDispatched(Params& params, cudaStream_t stream) { using DTypeQ = typename KernelTraits::DTypeQ; using DTypeKV = typename KernelTraits::DTypeKV; using DTypeO = typename KernelTraits::DTypeO; using CollectiveMainloop = - CollectiveMainloop; + CollectiveMainloop; using CollectiveEpilogue = CollectiveEpilogue; using Scheduler = SingleTileScheduler; typename CollectiveMainloop::Params mainloop_params = CollectiveMainloop::to_underlying_arguments( @@ -332,7 +332,7 @@ cudaError_t SinglePrefillWithKVCacheKernelTraitsDispatched(Params& params, cudaS auto kernel = (void*)PrefillWithKVCacheKernel; + LEFT_SLIDING_WINDOW, CAUSAL, BLOCK_EXPANDING, Scheduler>; int smem_size = sizeof(typename KernelTraits::SharedStorage); FLASHINFER_CUDA_CALL( cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); @@ -351,7 +351,7 @@ cudaError_t SinglePrefillWithKVCacheKernelTraitsDispatched(Params& params, cudaS return cudaSuccess; } -template cudaError_t BatchPrefillWithPagedKVCacheKernelTraitsDispatched(Params& params, cudaStream_t stream) { @@ -361,7 +361,7 @@ cudaError_t BatchPrefillWithPagedKVCacheKernelTraitsDispatched(Params& params, using IdType = typename KernelTraits::IdType; using CollectiveMainloop = SparseCollectiveMainloop; + KernelTraits, CAUSAL, BLOCK_EXPANDING, MULTIITEMSCORING>; using CollectiveEpilogue = CollectiveEpilogue; using Scheduler = std::conditional_t, @@ -409,7 +409,8 @@ cudaError_t BatchPrefillWithPagedKVCacheKernelTraitsDispatched(Params& params, // Get the ptr to kernel function. auto kernel = (void*)PrefillWithKVCacheKernel; + LEFT_SLIDING_WINDOW, CAUSAL, BLOCK_EXPANDING, Scheduler, + MULTIITEMSCORING>; int smem_size = sizeof(typename KernelTraits::SharedStorage); FLASHINFER_CUDA_CALL( cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); @@ -428,7 +429,7 @@ cudaError_t BatchPrefillWithPagedKVCacheKernelTraitsDispatched(Params& params, return cudaSuccess; } -template cudaError_t BatchPrefillWithRaggedKVCacheKernelTraitsDispatched(Params& params, cudaStream_t stream) { @@ -438,7 +439,7 @@ cudaError_t BatchPrefillWithRaggedKVCacheKernelTraitsDispatched(Params& params, using IdType = typename KernelTraits::IdType; using CollectiveMainloop = - CollectiveMainloop; + CollectiveMainloop; using CollectiveEpilogue = CollectiveEpilogue; using Scheduler = std::conditional_t, @@ -484,7 +485,7 @@ cudaError_t BatchPrefillWithRaggedKVCacheKernelTraitsDispatched(Params& params, // Get the ptr to kernel function. auto kernel = (void*)PrefillWithKVCacheKernel; + LEFT_SLIDING_WINDOW, CAUSAL, BLOCK_EXPANDING, Scheduler>; int smem_size = sizeof(typename KernelTraits::SharedStorage); FLASHINFER_CUDA_CALL( cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); @@ -503,13 +504,13 @@ cudaError_t BatchPrefillWithRaggedKVCacheKernelTraitsDispatched(Params& params, return cudaSuccess; } -template +template constexpr auto getCTATileSize() { if constexpr (HEAD_DIM_QK == HEAD_DIM_VO) { if constexpr (HEAD_DIM_QK == 64) { return std::make_tuple(192, 128); } else if constexpr (HEAD_DIM_QK == 128) { - if constexpr (CAUSAL) { + if constexpr (CAUSAL_OR_BLOCK_EXPANDING) { return std::make_tuple(128, 128); } else { return std::make_tuple(128, 192); @@ -532,14 +533,15 @@ cudaError_t SinglePrefillWithKVCacheDispatched(Params& params, cudaStream_t stre return cudaErrorNotSupported; // Not supported yet. } constexpr bool CAUSAL = MASK_MODE == MaskMode::kCausal; - constexpr auto CTA_TILE_SIZE = getCTATileSize(); + constexpr bool BLOCK_EXPANDING = MASK_MODE == MaskMode::kBlockExpanding; + constexpr auto CTA_TILE_SIZE = getCTATileSize(); SinglePrefillWithKVCacheKernelTraitsDispatched< AttentionKernelTraits(CTA_TILE_SIZE), /*CTA_KV_=*/get<1>(CTA_TILE_SIZE), /*NUM_STAGES_=*/2, typename Params::DTypeQ, typename Params::DTypeKV, typename Params::DTypeO, typename Params::IdType, AttentionVariant>, - LEFT_SLIDING_WINDOW, CAUSAL>(params, stream); + LEFT_SLIDING_WINDOW, CAUSAL, BLOCK_EXPANDING>(params, stream); cudaError_t status = cudaGetLastError(); return status; } @@ -553,14 +555,15 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched(Params& params, bool enable_ return cudaErrorNotSupported; // Not supported yet. } constexpr bool CAUSAL = MASK_MODE == MaskMode::kCausal; - constexpr auto CTA_TILE_SIZE = getCTATileSize(); + constexpr bool BLOCK_EXPANDING = MASK_MODE == MaskMode::kBlockExpanding; + constexpr auto CTA_TILE_SIZE = getCTATileSize(); BatchPrefillWithRaggedKVCacheKernelTraitsDispatched< AttentionKernelTraits(CTA_TILE_SIZE), /*CTA_KV_=*/get<1>(CTA_TILE_SIZE), /*NUM_STAGES_=*/2, typename Params::DTypeQ, typename Params::DTypeKV, typename Params::DTypeO, typename Params::IdType, AttentionVariant>, - LEFT_SLIDING_WINDOW, CAUSAL, SAME_SCHEDULE_FOR_ALL_HEADS>(params, stream); + LEFT_SLIDING_WINDOW, CAUSAL, BLOCK_EXPANDING, SAME_SCHEDULE_FOR_ALL_HEADS>(params, stream); cudaError_t status = cudaGetLastError(); return status; } @@ -574,6 +577,7 @@ cudaError_t BatchPrefillWithPagedKVCacheDispatched(Params& params, bool enable_p return cudaErrorNotSupported; // Not supported yet. } constexpr bool CAUSAL = MASK_MODE == MaskMode::kCausal; + constexpr bool BLOCK_EXPANDING = MASK_MODE == MaskMode::kBlockExpanding; constexpr bool MULTIITEMSCORING = MASK_MODE == MaskMode::kMultiItemScoring; if constexpr (HEAD_DIM_QK == HEAD_DIM_VO) { if constexpr (HEAD_DIM_VO == 64) { @@ -585,7 +589,7 @@ cudaError_t BatchPrefillWithPagedKVCacheDispatched(Params& params, bool enable_p /*NUM_STAGES_=*/2, typename Params::DTypeQ, typename Params::DTypeKV, typename Params::DTypeO, typename Params::IdType, AttentionVariant>, - LEFT_SLIDING_WINDOW, CAUSAL, SAME_SCHEDULE_FOR_ALL_HEADS, Params, MULTIITEMSCORING>( + LEFT_SLIDING_WINDOW, CAUSAL, BLOCK_EXPANDING, SAME_SCHEDULE_FOR_ALL_HEADS, Params, MULTIITEMSCORING>( params, stream); } else if constexpr (HEAD_DIM_VO == 128) { BatchPrefillWithPagedKVCacheKernelTraitsDispatched< @@ -595,7 +599,7 @@ cudaError_t BatchPrefillWithPagedKVCacheDispatched(Params& params, bool enable_p /*NUM_STAGES_=*/2, typename Params::DTypeQ, typename Params::DTypeKV, typename Params::DTypeO, typename Params::IdType, AttentionVariant>, - LEFT_SLIDING_WINDOW, CAUSAL, SAME_SCHEDULE_FOR_ALL_HEADS, Params, MULTIITEMSCORING>( + LEFT_SLIDING_WINDOW, CAUSAL, BLOCK_EXPANDING, SAME_SCHEDULE_FOR_ALL_HEADS, Params, MULTIITEMSCORING>( params, stream); } else { // HEAD_DIM == 256; @@ -607,7 +611,7 @@ cudaError_t BatchPrefillWithPagedKVCacheDispatched(Params& params, bool enable_p /*NUM_STAGES_=*/2, typename Params::DTypeQ, typename Params::DTypeKV, typename Params::DTypeO, typename Params::IdType, AttentionVariant>, - LEFT_SLIDING_WINDOW, CAUSAL, SAME_SCHEDULE_FOR_ALL_HEADS, Params, MULTIITEMSCORING>( + LEFT_SLIDING_WINDOW, CAUSAL, BLOCK_EXPANDING, SAME_SCHEDULE_FOR_ALL_HEADS, Params, MULTIITEMSCORING>( params, stream); } } else { diff --git a/include/flashinfer/attention/hopper/sparse_mainloop.cuh b/include/flashinfer/attention/hopper/sparse_mainloop.cuh index 1d9697af4e..1cecd9d581 100644 --- a/include/flashinfer/attention/hopper/sparse_mainloop.cuh +++ b/include/flashinfer/attention/hopper/sparse_mainloop.cuh @@ -33,7 +33,7 @@ namespace flashinfer { using namespace cute; -template +template struct SparseCollectiveMainloop { using DTypeQ = typename Ktraits::DTypeQ; using DTypeKV = typename Ktraits::DTypeKV; @@ -157,7 +157,7 @@ struct SparseCollectiveMainloop { CUTLASS_DEVICE int get_num_kv_tiles(Params const& mainloop_params, int q_tile_idx, const int qo_len, - const int kv_len) { + const int kv_len, const int batch_idx = 0) { static constexpr int CTA_Q = get<0>(TileShape_QKD{}); static constexpr int CTA_KV = get<1>(TileShape_QKD{}); int num_kv_tiles = cute::ceil_div(kv_len, CTA_KV); @@ -165,6 +165,41 @@ struct SparseCollectiveMainloop { num_kv_tiles = std::min(num_kv_tiles, cute::ceil_div((q_tile_idx + 1) * CTA_Q + kv_len - qo_len, CTA_KV)); } + if constexpr (BLOCK_EXPANDING) { + // Block Expanding: Calculate valid KV range based on block boundaries + int64_t dllm_block_size = 1; + int64_t q_offset = 0; + int64_t kv_offset = 0; + if constexpr (has_dllm_block_size_v) { + dllm_block_size = mainloop_params.additional_params.dllm_block_size; + } + // Prefer reading per-batch offset from maybe_q_block_expanding_offset array + if constexpr (has_maybe_q_block_expanding_offset_v) { + auto* offset_ptr = mainloop_params.additional_params.maybe_q_block_expanding_offset; + if (offset_ptr != nullptr) { + q_offset = offset_ptr[batch_idx]; + } + } else if constexpr (has_q_block_expanding_offset_v) { + q_offset = mainloop_params.additional_params.q_block_expanding_offset; + } + // Read kv_offset (for Cascade Current Chunk scenario) + if constexpr (has_maybe_kv_block_expanding_offset_v) { + auto* offset_ptr = mainloop_params.additional_params.maybe_kv_block_expanding_offset; + if (offset_ptr != nullptr) { + kv_offset = offset_ptr[batch_idx]; + } + } else if constexpr (has_kv_block_expanding_offset_v) { + kv_offset = mainloop_params.additional_params.kv_block_expanding_offset; + } + int q_tile_end = std::min((q_tile_idx + 1) * CTA_Q, qo_len); + int64_t q_global_end = q_offset + q_tile_end - 1; + int64_t q_block = q_global_end / dllm_block_size; + // Consider kv_offset: max valid kv_idx = (q_block + 1) * B - kv_offset + // Ensure kv_valid_end is non-negative (return 0 when max_kv_global <= kv_offset) + int64_t max_kv_global = (q_block + 1) * dllm_block_size; + int kv_valid_end = static_cast(std::max(max_kv_global - kv_offset, int64_t(0))); + num_kv_tiles = std::min(num_kv_tiles, cute::ceil_div(std::min(kv_len, kv_valid_end), CTA_KV)); + } if constexpr (MULTIITEMSCORING) { num_kv_tiles = std::min(num_kv_tiles, cute::ceil_div((q_tile_idx + 1) * CTA_Q + kv_len - qo_len, CTA_KV)); @@ -204,7 +239,7 @@ struct SparseCollectiveMainloop { tma_partition(mainloop_params.tma_load_Q, _0{}, Layout<_1>{}, group_modes<0, 2>(sQ_x), group_modes<0, 2>(gQ_x)); // (TMA), (TMA) - int num_kv_tiles = get_num_kv_tiles(mainloop_params, q_tile_idx, qo_len, kv_len); + int num_kv_tiles = get_num_kv_tiles(mainloop_params, q_tile_idx, qo_len, kv_len, batch_idx); int kv_tile_idx = num_kv_tiles - 1; int swa_begin_kv_tile_idx = 0; if constexpr (LEFT_SLIDING_WINDOW) { @@ -445,4 +480,4 @@ struct SparseCollectiveMainloop { } // namespace flashinfer -#endif // FLASHINFER_ATTENTION_HOPPER_SPARSE_MAINLOOP_CUH_ +#endif // FLASHINFER_ATTENTION_HOPPER_SPARSE_MAINLOOP_CUH_ \ No newline at end of file diff --git a/include/flashinfer/attention/mask.cuh b/include/flashinfer/attention/mask.cuh index 6692b0cf3f..6af1d829d6 100644 --- a/include/flashinfer/attention/mask.cuh +++ b/include/flashinfer/attention/mask.cuh @@ -23,6 +23,7 @@ enum class MaskMode { kCausal = 1U, // Causal mask kCustom = 2U, // Custom mask kMultiItemScoring = 3U, + kBlockExpanding = 4U, }; } // namespace flashinfer diff --git a/include/flashinfer/attention/prefill.cuh b/include/flashinfer/attention/prefill.cuh index 5db013bb03..071df839c5 100644 --- a/include/flashinfer/attention/prefill.cuh +++ b/include/flashinfer/attention/prefill.cuh @@ -37,6 +37,7 @@ #include "cascade.cuh" #include "mask.cuh" #include "variants.cuh" +#include "block_expanding_prefill.cuh" namespace flashinfer { DEFINE_HAS_MEMBER(maybe_q_rope_offset) @@ -779,10 +780,22 @@ __device__ __forceinline__ void logits_mask( 2 * (lane_idx % 4) + 8 * (reg_id / 4) + reg_id % 2; const uint32_t qo_head_idx = kv_head_idx * group_size + r[mma_q][(reg_id % 4) / 2]; - const bool mask = - (!(MASK_MODE == MaskMode::kCausal || MASK_MODE == MaskMode::kMultiItemScoring - ? (kv_idx + qo_len > kv_len + q_idx || (kv_idx >= chunk_end)) - : kv_idx >= chunk_end)) && + bool position_mask; + if constexpr (MASK_MODE == MaskMode::kBlockExpanding) { + const uint32_t dllm_block_size = params.dllm_block_size; + const uint32_t q_offset = params.get_q_block_expanding_offset(batch_idx); + const uint32_t kv_offset = params.get_kv_block_expanding_offset(batch_idx); + const uint32_t q_global = q_offset + q_idx; + const uint32_t q_block = q_global / dllm_block_size; + const uint32_t kv_global = kv_offset + kv_idx; + const uint32_t k_block = kv_global / dllm_block_size; + position_mask = (q_block >= k_block) && (kv_idx < chunk_end); + } else if constexpr (MASK_MODE == MaskMode::kCausal || MASK_MODE == MaskMode::kMultiItemScoring) { + position_mask = (kv_idx + qo_len <= kv_len + q_idx) && (kv_idx < chunk_end); + } else { + position_mask = (kv_idx < chunk_end); + } + const bool mask = position_mask && variant.LogitsMask(params, batch_idx, q_idx, kv_idx, qo_head_idx, kv_head_idx); s_frag[mma_q][mma_kv][reg_id] = (mask) ? s_frag[mma_q][mma_kv][reg_id] : (KTraits::MaskFillValue); @@ -1434,26 +1447,46 @@ __device__ __forceinline__ void SinglePrefillWithKVCacheDevice( smem_t k_smem(smem_storage.k_smem), v_smem(smem_storage.v_smem); - const uint32_t num_iterations = ceil_div( - MASK_MODE == MaskMode::kCausal - ? min(chunk_size, - sub_if_greater_or_zero( - kv_len - qo_len + ceil_div(((bx + 1) * CTA_TILE_Q), group_size), chunk_start)) - : chunk_size, - CTA_TILE_KV); + uint32_t num_iterations; + if constexpr (MASK_MODE == MaskMode::kBlockExpanding) { + // Block Expanding: Calculate kv_valid_end based on block boundaries + const uint32_t dllm_block_size = params.dllm_block_size; + const uint32_t q_tile_end = min(qo_len, ceil_div(((bx + 1) * CTA_TILE_Q), group_size)); + const uint32_t q_offset = params.get_q_block_expanding_offset(0); // Single prefill: batch_idx=0 + num_iterations = block_expanding_num_iterations( + q_tile_end, chunk_start, chunk_size, dllm_block_size, CTA_TILE_KV, q_offset); + } else if constexpr (MASK_MODE == MaskMode::kCausal) { + num_iterations = ceil_div( + min(chunk_size, + sub_if_greater_or_zero( + kv_len - qo_len + ceil_div(((bx + 1) * CTA_TILE_Q), group_size), chunk_start)), + CTA_TILE_KV); + } else { + num_iterations = ceil_div(chunk_size, CTA_TILE_KV); + } const uint32_t window_iteration = ceil_div(sub_if_greater_or_zero(kv_len + ceil_div((bx + 1) * CTA_TILE_Q, group_size), qo_len + window_left + chunk_start), CTA_TILE_KV); - const uint32_t mask_iteration = - (MASK_MODE == MaskMode::kCausal - ? min(chunk_size, - sub_if_greater_or_zero(kv_len + ceil_div((bx * CTA_TILE_Q), group_size) - qo_len, - chunk_start)) - : chunk_size) / - CTA_TILE_KV; + uint32_t mask_iteration; + if constexpr (MASK_MODE == MaskMode::kBlockExpanding) { + const uint32_t dllm_block_size = params.dllm_block_size; + const uint32_t q_tile_start = ceil_div((bx * CTA_TILE_Q), group_size); + const uint32_t q_offset = params.get_q_block_expanding_offset(0); // Single prefill: batch_idx=0 + const uint32_t kv_offset = params.get_kv_block_expanding_offset(0); // Single prefill: batch_idx=0 + mask_iteration = block_expanding_mask_iteration( + q_tile_start, chunk_start, chunk_size, dllm_block_size, CTA_TILE_KV, q_offset, kv_offset); + } else if constexpr (MASK_MODE == MaskMode::kCausal) { + mask_iteration = + min(chunk_size, + sub_if_greater_or_zero(kv_len + ceil_div((bx * CTA_TILE_Q), group_size) - qo_len, + chunk_start)) / + CTA_TILE_KV; + } else { + mask_iteration = chunk_size / CTA_TILE_KV; + } DTypeKV* k_ptr = k + @@ -1502,7 +1535,14 @@ __device__ __forceinline__ void SinglePrefillWithKVCacheDevice( qo_len, kv_len, group_size, s_frag, tid, kv_head_idx); // apply mask - if (MASK_MODE == MaskMode::kCustom || (iter >= mask_iteration || iter < window_iteration)) { + // if (MASK_MODE == MaskMode::kCustom || (iter >= mask_iteration || iter < window_iteration)) { + // logits_mask(params, variant, /*batch_idx=*/0, qo_packed_idx_base, kv_idx_base, + // qo_len, kv_len, chunk_end, group_size, s_frag, tid, kv_head_idx); + // } + bool needs_mask = (MASK_MODE == MaskMode::kCustom) || + (MASK_MODE == MaskMode::kBlockExpanding && iter >= mask_iteration) || + (iter >= mask_iteration || iter < window_iteration); + if (needs_mask) { logits_mask(params, variant, /*batch_idx=*/0, qo_packed_idx_base, kv_idx_base, qo_len, kv_len, chunk_end, group_size, s_frag, tid, kv_head_idx); } @@ -1858,28 +1898,47 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchPrefillWithRaggedKV block.sync(); } - const uint32_t num_iterations = ceil_div( - (MASK_MODE == MaskMode::kCausal - ? min(chunk_size, - sub_if_greater_or_zero( - kv_len - qo_len + ceil_div(((qo_tile_idx + 1) * CTA_TILE_Q), group_size), - chunk_start)) - : chunk_size), - CTA_TILE_KV); + uint32_t num_iterations; + if constexpr (MASK_MODE == MaskMode::kBlockExpanding) { + // Block Expanding: Calculate kv_valid_end based on block boundaries + const uint32_t dllm_block_size = params.dllm_block_size; + const uint32_t q_offset = params.get_q_block_expanding_offset(request_idx); + const uint32_t q_tile_end = min(qo_len, ceil_div(((qo_tile_idx + 1) * CTA_TILE_Q), group_size)); + num_iterations = block_expanding_num_iterations( + q_tile_end, chunk_start, chunk_size, dllm_block_size, CTA_TILE_KV, q_offset); + } else if constexpr (MASK_MODE == MaskMode::kCausal) { + num_iterations = ceil_div( + min(chunk_size, + sub_if_greater_or_zero( + kv_len - qo_len + ceil_div(((qo_tile_idx + 1) * CTA_TILE_Q), group_size), + chunk_start)), + CTA_TILE_KV); + } else { + num_iterations = ceil_div(chunk_size, CTA_TILE_KV); + } const uint32_t window_iteration = ceil_div( sub_if_greater_or_zero(kv_len + ceil_div((qo_tile_idx + 1) * CTA_TILE_Q, group_size), qo_len + window_left + chunk_start), CTA_TILE_KV); - const uint32_t mask_iteration = - (MASK_MODE == MaskMode::kCausal - ? min(chunk_size, - sub_if_greater_or_zero( - kv_len + ceil_div((qo_tile_idx * CTA_TILE_Q), group_size) - qo_len, - chunk_start)) - : chunk_size) / - CTA_TILE_KV; + uint32_t mask_iteration; + if constexpr (MASK_MODE == MaskMode::kBlockExpanding) { + const uint32_t dllm_block_size = params.dllm_block_size; + const uint32_t q_offset = params.get_q_block_expanding_offset(request_idx); + const uint32_t kv_offset = params.get_kv_block_expanding_offset(request_idx); + const uint32_t q_tile_start = ceil_div((qo_tile_idx * CTA_TILE_Q), group_size); + mask_iteration = block_expanding_mask_iteration( + q_tile_start, chunk_start, chunk_size, dllm_block_size, CTA_TILE_KV, q_offset, kv_offset); + } else if constexpr (MASK_MODE == MaskMode::kCausal) { + mask_iteration = + min(chunk_size, + sub_if_greater_or_zero(kv_len + ceil_div((qo_tile_idx * CTA_TILE_Q), group_size) - qo_len, + chunk_start)) / + CTA_TILE_KV; + } else { + mask_iteration = chunk_size / CTA_TILE_KV; + } smem_t k_smem(smem_storage.k_smem), v_smem(smem_storage.v_smem); @@ -2213,14 +2272,22 @@ __device__ __forceinline__ void BatchPrefillWithPagedKVCacheDevice( uint32_t num_iterations = 0; if constexpr (MASK_MODE != MaskMode::kMultiItemScoring) { - num_iterations = ceil_div( - (MASK_MODE == MaskMode::kCausal - ? min(chunk_size, - sub_if_greater_or_zero( - kv_len - qo_len + ceil_div(((qo_tile_idx + 1) * CTA_TILE_Q), group_size), - chunk_start)) - : chunk_size), - CTA_TILE_KV); + if constexpr (MASK_MODE == MaskMode::kBlockExpanding) { + const uint32_t dllm_block_size = params.dllm_block_size; + const uint32_t q_offset = params.get_q_block_expanding_offset(request_idx); + const uint32_t q_tile_end = min(qo_len, ceil_div(((qo_tile_idx + 1) * CTA_TILE_Q), group_size)); + num_iterations = block_expanding_num_iterations( + q_tile_end, chunk_start, chunk_size, dllm_block_size, CTA_TILE_KV, q_offset); + } else if constexpr (MASK_MODE == MaskMode::kCausal) { + num_iterations = ceil_div( + min(chunk_size, + sub_if_greater_or_zero( + kv_len - qo_len + ceil_div(((qo_tile_idx + 1) * CTA_TILE_Q), group_size), + chunk_start)), + CTA_TILE_KV); + } else { + num_iterations = ceil_div(chunk_size, CTA_TILE_KV); + } } else if constexpr (MASK_MODE == MaskMode::kMultiItemScoring) { num_iterations_prefix = ceil_div( min(min(chunk_size, @@ -2253,14 +2320,24 @@ __device__ __forceinline__ void BatchPrefillWithPagedKVCacheDevice( qo_len + window_left + chunk_start), CTA_TILE_KV); - const uint32_t mask_iteration = - (MASK_MODE == MaskMode::kCausal || MASK_MODE == MaskMode::kMultiItemScoring - ? min(chunk_size, - sub_if_greater_or_zero( - kv_len + ceil_div((qo_tile_idx * CTA_TILE_Q), group_size) - qo_len, - chunk_start)) - : chunk_size) / - CTA_TILE_KV; + uint32_t mask_iteration; + if constexpr (MASK_MODE == MaskMode::kBlockExpanding) { + const uint32_t dllm_block_size = params.dllm_block_size; + const uint32_t q_offset = params.get_q_block_expanding_offset(request_idx); + const uint32_t kv_offset = params.get_kv_block_expanding_offset(request_idx); + const uint32_t q_tile_start = ceil_div((qo_tile_idx * CTA_TILE_Q), group_size); + mask_iteration = block_expanding_mask_iteration( + q_tile_start, chunk_start, chunk_size, dllm_block_size, CTA_TILE_KV, q_offset, kv_offset); + } else if constexpr (MASK_MODE == MaskMode::kCausal || MASK_MODE == MaskMode::kMultiItemScoring) { + mask_iteration = + min(chunk_size, + sub_if_greater_or_zero( + kv_len + ceil_div((qo_tile_idx * CTA_TILE_Q), group_size) - qo_len, + chunk_start)) / + CTA_TILE_KV; + } else { + mask_iteration = chunk_size / CTA_TILE_KV; + } #pragma unroll 1 for (uint32_t iter = 0; iter < num_iterations; diff --git a/include/flashinfer/utils.cuh b/include/flashinfer/utils.cuh index c7edf5ab57..8d60f126f0 100644 --- a/include/flashinfer/utils.cuh +++ b/include/flashinfer/utils.cuh @@ -178,6 +178,11 @@ __VA_ARGS__ \ break; \ } \ + case MaskMode::kBlockExpanding: { \ + constexpr MaskMode MASK_MODE = MaskMode::kBlockExpanding; \ + __VA_ARGS__ \ + break; \ + } \ default: { \ std::ostringstream err_msg; \ err_msg << "Unsupported mask_mode: " << int(mask_mode); \ @@ -518,6 +523,12 @@ __device__ __forceinline__ uint32_t dim4_offset(const uint32_t& dim_c, const uin template \ inline constexpr bool has_##member##_v = has_##member::value; +// DLLM Block Expanding related type traits (defined once here to avoid ODR violations) +DEFINE_HAS_MEMBER(dllm_block_size) +DEFINE_HAS_MEMBER(q_block_expanding_offset) +DEFINE_HAS_MEMBER(maybe_q_block_expanding_offset) // per-batch q_offset pointer for FA3 BatchPrefill +DEFINE_HAS_MEMBER(kv_block_expanding_offset) // kv_offset for Cascade Current Chunk +DEFINE_HAS_MEMBER(maybe_kv_block_expanding_offset) // per-batch kv_offset pointer for FA3 BatchPrefill } // namespace flashinfer #endif // FLASHINFER_UTILS_CUH_ diff --git a/tests/attention/test_dllm_cascade_vs_blockwise_extend_attention.py b/tests/attention/test_dllm_cascade_vs_blockwise_extend_attention.py new file mode 100644 index 0000000000..43d0ebd80d --- /dev/null +++ b/tests/attention/test_dllm_cascade_vs_blockwise_extend_attention.py @@ -0,0 +1,4781 @@ +"""DLLM Block-wise Mask Implementation Comparison + +Comparing three implementation approaches: +1. Cascade Attention (SGLang approach): Ragged + Paged + merge_state + - Ragged: Bidirectional attention within current block (causal=False) + - Paged: Access all previous cached blocks (causal=False) + - merge_state: Merge softmax states from both stages + +2. Batch Prefill + kBlockExtend + q_offset (FlashInfer optimized approach) + - Single kernel launch + - Tile-level skip for invalid computations + +3. V2 Serial approach (reference baseline) + - Each chunk called independently + - Tile-level skip + +Mask rule: mask[q, k] = (q_global // B) >= (k_global // B) +""" + +import torch +import time +import math +import flashinfer +from flashinfer import ( + BatchPrefillWithPagedKVCacheWrapper, + BatchPrefillWithRaggedKVCacheWrapper, + merge_state, + single_prefill_with_kv_cache, +) +from flashinfer.dllm import ( + BatchBlockExtendPagedOffsetWrapper, + BatchBlockExtendRaggedOffsetWrapper, + block_extend_attention_with_offset, +) + + +def compute_block_extend_reference( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + dllm_block_size: int, + q_offset: int = 0, + sm_scale: float = None, +) -> torch.Tensor: + """ + Compute Block Extend Attention reference result using custom_mask + + Mask rule: mask[q, k] = ((q_local + q_offset) // B) >= (k // B) + + Args: + q: [qo_len, num_heads, head_dim] + k: [kv_len, num_kv_heads, head_dim] + v: [kv_len, num_kv_heads, head_dim] + dllm_block_size: DLLM block size + q_offset: Q's global starting position + sm_scale: softmax scale + + Returns: + output: [qo_len, num_heads, head_dim] + """ + qo_len = q.shape[0] + kv_len = k.shape[0] + head_dim = q.shape[-1] + device = q.device + + if sm_scale is None: + sm_scale = 1.0 / math.sqrt(head_dim) + + # Construct custom_mask + # q_global = q_local + q_offset + # mask[q, k] = (q_global // B) >= (k // B) + q_pos = torch.arange(qo_len, device=device) + q_offset + k_pos = torch.arange(kv_len, device=device) + q_block = q_pos.unsqueeze(1) // dllm_block_size # [qo_len, 1] + k_block = k_pos.unsqueeze(0) // dllm_block_size # [1, kv_len] + mask_2d = (q_block >= k_block).to(torch.uint8) # [qo_len, kv_len] + + return single_prefill_with_kv_cache( + q, k, v, + custom_mask=mask_2d, + sm_scale=sm_scale, + ) + +def test_incremental_batchprefill_step_by_step_with_cuda_graph( + num_requests: int = 4, + tokens_per_request: int = 256, + dllm_block_size: int = 32, + chunk_sizes: list = None, + num_heads: int = 32, + num_kv_heads: int = 8, + head_dim: int = 128, + warmup_iters: int = 10, + bench_iters: int = 100, + verbose: bool = False, +): + """ + Realistically simulate DLLM incremental Prefill step-by-step execution flow + CUDA Graph + + Key points: + 1. Must execute step by step, each chunk step depends on previous step's KV cache + 2. SGLang Cascade forces chunk_size = dllm_block_size + 3. BatchBlockExtend can use larger chunk_size, reducing number of steps + 4. Enable CUDA Graph to reduce CPU overhead and kernel launch latency + + CUDA Graph notes: + - plan() contains CPU-GPU synchronization, cannot execute during capture + - Must create independent wrapper for each step, complete plan before capture + - Only capture run() operations + """ + if chunk_sizes is None: + chunk_sizes = [32, 64, 128, 256] + + device = torch.device("cuda:0") + dtype = torch.float16 + sm_scale = 1.0 / (head_dim ** 0.5) + + # Baseline chunk_size = dllm_block_size + baseline_chunk_size = dllm_block_size + baseline_num_chunks = tokens_per_request // baseline_chunk_size + + print(f"\n{'='*80}") + print(f"Step-by-step Execution + CUDA Graph: DLLM Incremental Prefill Performance Comparison") + print(f"{'='*80}") + print(f"Configuration:") + print(f" num_requests = {num_requests}") + print(f" tokens_per_request = {tokens_per_request}") + print(f" dllm_block_size = {dllm_block_size}") + print(f" chunk_sizes to test = {chunk_sizes}") + print(f" num_heads = {num_heads}") + print(f" num_kv_heads = {num_kv_heads}") + print(f" head_dim = {head_dim}") + print(f"\nKey features:") + print(f" - Step-by-step execution: Each step must wait for previous step to complete") + print(f" - SGLang forces chunk_size = dllm_block_size = {dllm_block_size}") + print(f" - BatchBlockExtend can use larger chunk_size") + print(f" - CUDA Graph: Reduce CPU overhead and kernel launch latency") + + # Data preparation, generate complete Q, K, V for each request + all_qs = [torch.randn(tokens_per_request, num_heads, head_dim, dtype=dtype, device=device) + for _ in range(num_requests)] + all_ks = [torch.randn(tokens_per_request, num_kv_heads, head_dim, dtype=dtype, device=device) + for _ in range(num_requests)] + all_vs = [torch.randn(tokens_per_request, num_kv_heads, head_dim, dtype=dtype, device=device) + for _ in range(num_requests)] + + # Split each request into chunks + def split_chunks(tensor, chunk_size): + return [tensor[i*chunk_size:(i+1)*chunk_size] for i in range(tensor.shape[0] // chunk_size)] + + qs_chunks = [split_chunks(q, baseline_chunk_size) for q in all_qs] + ks_chunks = [split_chunks(k, baseline_chunk_size) for k in all_ks] + vs_chunks = [split_chunks(v, baseline_chunk_size) for v in all_vs] + + results = {} + + # Correctness verification (sample verification of first request) + print(f"\n{'='*80}") + print(f"Correctness verification (sample verification of request_0)") + print(f"{'='*80}") + print(f" Reference implementation: single_prefill_with_kv_cache + custom_mask") + print(f" Mask rule: mask[q,k] = ((q + offset) // B) >= (k // B)") + + # Sample verification of first request + req_idx = 0 + k_req = all_ks[req_idx] + v_req = all_vs[req_idx] + + # Cumulative KV buffer + k_cumul_verify = [k_req[:(i+1)*baseline_chunk_size] for i in range(baseline_num_chunks)] + v_cumul_verify = [v_req[:(i+1)*baseline_chunk_size] for i in range(baseline_num_chunks)] + + # Compute reference results + ref_outputs = [] + for step_idx in range(baseline_num_chunks): + q_offset = step_idx * baseline_chunk_size + ref_out = compute_block_extend_reference( + qs_chunks[req_idx][step_idx], + k_cumul_verify[step_idx], + v_cumul_verify[step_idx], + dllm_block_size=dllm_block_size, + q_offset=q_offset, + sm_scale=sm_scale, + ) + ref_outputs.append(ref_out) + + bbe_outputs = [] + workspace_verify = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device) + qo_indptr_verify = torch.tensor([0, baseline_chunk_size], dtype=torch.int32, device=device) + for step_idx in range(baseline_num_chunks): + kv_len = (step_idx + 1) * baseline_chunk_size + kv_indptr = torch.tensor([0, kv_len], dtype=torch.int32, device=device) + q_offset_tensor = torch.tensor([step_idx * baseline_chunk_size], dtype=torch.int32, device=device) + + wrapper = BatchBlockExtendRaggedOffsetWrapper( + workspace_verify, kv_layout="NHD", dllm_block_size=dllm_block_size + ) + wrapper.plan( + qo_indptr=qo_indptr_verify, + kv_indptr=kv_indptr, + num_qo_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + q_data_type=dtype, + sm_scale=sm_scale, + q_offsets=q_offset_tensor, + ) + bbe_out = wrapper.run(qs_chunks[req_idx][step_idx], k_cumul_verify[step_idx], v_cumul_verify[step_idx]) + bbe_outputs.append(bbe_out) + + # Verification + tol = 1e-2 + bbe_max_diff = max((bbe_outputs[i] - ref_outputs[i]).abs().max().item() for i in range(baseline_num_chunks)) + bbe_pass = bbe_max_diff < tol + print(f"\n [BBE] BatchBlockExtendRaggedOffsetWrapper:") + print(f" max_diff = {bbe_max_diff:.6f}, tolerance = {tol}") + print(f" {' PASS' if bbe_pass else ' FAIL'}") + + if not bbe_pass: + print(f"\n Correctness verification failed, but continue with performance test") + else: + print(f"\n BBE correctness verification passed") + + del workspace_verify, bbe_outputs, ref_outputs, k_cumul_verify, v_cumul_verify + torch.cuda.empty_cache() + + # Baseline 2: Custom Mask (BatchPrefillWithRaggedKVCacheWrapper + custom_mask) + CUDA Graph + print(f"\n{'='*80}") + print(f"[Baseline 2] Custom Mask (BatchPrefill + custom_mask)") + print(f"{'='*80}") + print(f" num_steps: {baseline_num_chunks}") + print(f" num_requests: {num_requests}") + print(f" Each step: BatchPrefillWithRaggedKVCacheWrapper + custom_mask") + print(f" Kernels per step: 1 (batch processes all requests)") + + # Pre-allocate Q buffers (concat all requests) + cm_q_buffers = [] + for step_idx in range(baseline_num_chunks): + q_list = [qs_chunks[req_idx][step_idx] for req_idx in range(num_requests)] + cm_q_buffers.append(torch.cat(q_list, dim=0)) + + # Pre-allocate cumulative KV buffers (concat all requests) + cm_k_buffers = [] + cm_v_buffers = [] + for step_idx in range(baseline_num_chunks): + kv_len = (step_idx + 1) * baseline_chunk_size + k_cumul_list = [all_ks[req_idx][:kv_len] for req_idx in range(num_requests)] + v_cumul_list = [all_vs[req_idx][:kv_len] for req_idx in range(num_requests)] + cm_k_buffers.append(torch.cat(k_cumul_list, dim=0)) + cm_v_buffers.append(torch.cat(v_cumul_list, dim=0)) + + # Construct flattened custom_mask (batch version) + # custom_mask shape: (sum(q_len[i] * k_len[i] for i in range(batch_size))) + # Each request's mask is the same, concat num_requests times + custom_mask_buffers = [] + for step_idx in range(baseline_num_chunks): + kv_len = (step_idx + 1) * baseline_chunk_size + q_offset = step_idx * baseline_chunk_size + # Construct single request's 2D mask: [q_len, kv_len] + q_pos = torch.arange(baseline_chunk_size, device=device) + q_offset + k_pos = torch.arange(kv_len, device=device) + q_block = q_pos.unsqueeze(1) // dllm_block_size + k_block = k_pos.unsqueeze(0) // dllm_block_size + mask_2d = (q_block >= k_block) # [q_len, kv_len], bool + # Flatten and repeat for batch + mask_flat = mask_2d.flatten() # [q_len * kv_len] + # All requests have the same mask, concat + batch_mask = mask_flat.repeat(num_requests) # [num_requests * q_len * kv_len] + custom_mask_buffers.append(batch_mask) + + # qo_indptr and kv_indptr + cm_qo_indptr = torch.tensor( + [i * baseline_chunk_size for i in range(num_requests + 1)], + dtype=torch.int32, device=device + ) + cm_kv_indptr_list = [] + for step_idx in range(baseline_num_chunks): + kv_len = (step_idx + 1) * baseline_chunk_size + cm_kv_indptr_list.append(torch.tensor( + [i * kv_len for i in range(num_requests + 1)], + dtype=torch.int32, device=device + )) + + # Create independent wrapper for each step and complete plan + # Note: custom_mask is only supported in FA2 backend, not FA3 + print(f" Creating wrappers and completing plan...") + cm_wrappers = [] + for step_idx in range(baseline_num_chunks): + wrapper = BatchPrefillWithRaggedKVCacheWrapper( + torch.empty(256 * 1024 * 1024, dtype=torch.uint8, device=device), + kv_layout="NHD", + backend="fa2", # custom_mask only supported in FA2 + ) + wrapper.plan( + qo_indptr=cm_qo_indptr, + kv_indptr=cm_kv_indptr_list[step_idx], + num_qo_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim_qk=head_dim, + custom_mask=custom_mask_buffers[step_idx], + causal=False, + sm_scale=sm_scale, + ) + cm_wrappers.append(wrapper) + + cm_output = torch.empty(num_requests * baseline_chunk_size, num_heads, head_dim, dtype=dtype, device=device) + + def run_custom_mask_pipeline(): + for step_idx in range(baseline_num_chunks): + cm_output.copy_(cm_wrappers[step_idx].run( + cm_q_buffers[step_idx], cm_k_buffers[step_idx], cm_v_buffers[step_idx] + )) + + if verbose: + print(f" Step flow preview:") + for step_id in range(baseline_num_chunks): + kv_len = (step_id + 1) * baseline_chunk_size + print(f" Step {step_id}: Q[{step_id*baseline_chunk_size}:{(step_id+1)*baseline_chunk_size}] attend to KV[0:{kv_len}]") + + # Warmup + print(f" Warmup...") + for _ in range(warmup_iters): + run_custom_mask_pipeline() + torch.cuda.synchronize() + + # CUDA Graph capture + print(f" Capturing CUDA Graph...") + cm_stream = torch.cuda.Stream() + with torch.cuda.stream(cm_stream): + run_custom_mask_pipeline() + cm_stream.synchronize() + + cm_graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(cm_graph, stream=cm_stream): + run_custom_mask_pipeline() + + # Warmup with cuda_graph + for _ in range(warmup_iters): + cm_graph.replay() + torch.cuda.synchronize() + + # Benchmark + print(f" Benchmark (with cuda_graph)...") + start = time.perf_counter() + for _ in range(bench_iters): + cm_graph.replay() + torch.cuda.synchronize() + cm_cuda_graph_time = (time.perf_counter() - start) / bench_iters * 1000 + + results["custom_mask_baseline"] = { + "time_cuda_graph_ms": cm_cuda_graph_time, + "chunk_size": baseline_chunk_size, + "num_steps": baseline_num_chunks, + "method": "Custom Mask (Baseline 2)", + } + print(f" => cuda_graph: {cm_cuda_graph_time:.3f} ms ({cm_cuda_graph_time/baseline_num_chunks:.3f} ms/step × {baseline_num_chunks} steps)") + + del cm_graph, custom_mask_buffers, cm_wrappers + torch.cuda.empty_cache() + + # Baseline 1: SGLang Cascade (chunk_size = dllm_block_size) + CUDA Graph + print(f"\n{'='*80}") + print(f"[Baseline 1] SGLang Cascade (chunk_size = dllm_block_size = {baseline_chunk_size})") + print(f"{'='*80}") + print(f" num_steps: {baseline_num_chunks}") + print(f" Each step: BatchRagged(current chunk) + BatchRagged(prefix) + merge_state") + print(f" Kernels per step: 2-3") + + # Pre-allocate all buffers (for CUDA Graph) + q_current_buffers = [] + k_current_buffers = [] + v_current_buffers = [] + for step_idx in range(baseline_num_chunks): + q_list = [qs_chunks[req_idx][step_idx] for req_idx in range(num_requests)] + k_list = [ks_chunks[req_idx][step_idx] for req_idx in range(num_requests)] + v_list = [vs_chunks[req_idx][step_idx] for req_idx in range(num_requests)] + q_current_buffers.append(torch.cat(q_list, dim=0)) + k_current_buffers.append(torch.cat(k_list, dim=0)) + v_current_buffers.append(torch.cat(v_list, dim=0)) + + # Pre-allocate prefix KV buffer + k_prefix_buffers = [None] + v_prefix_buffers = [None] + for step_idx in range(1, baseline_num_chunks): + prefix_len = step_idx * baseline_chunk_size + k_prefix_list = [all_ks[req_idx][:prefix_len] for req_idx in range(num_requests)] + v_prefix_list = [all_vs[req_idx][:prefix_len] for req_idx in range(num_requests)] + k_prefix_buffers.append(torch.cat(k_prefix_list, dim=0)) + v_prefix_buffers.append(torch.cat(v_prefix_list, dim=0)) + + cascade_output = torch.empty(num_requests * baseline_chunk_size, num_heads, head_dim, dtype=dtype, device=device) + + # indptr + qo_indptr_current = torch.tensor( + [i * baseline_chunk_size for i in range(num_requests + 1)], + dtype=torch.int32, device=device + ) + kv_indptr_prefix_list = [None] + for step_idx in range(1, baseline_num_chunks): + prefix_len = step_idx * baseline_chunk_size + kv_indptr_prefix_list.append(torch.tensor( + [i * prefix_len for i in range(num_requests + 1)], + dtype=torch.int32, device=device + )) + + # Create independent wrapper for each step and complete plan (critical!) + print(f" Creating wrappers and completing plan...") + cascade_wrappers_current = [] + cascade_wrappers_prefix = [] + + for step_idx in range(baseline_num_chunks): + # Wrapper for current chunk + wrapper_current = BatchPrefillWithRaggedKVCacheWrapper( + torch.empty(256 * 1024 * 1024, dtype=torch.uint8, device=device), kv_layout="NHD",backend = "fa3", + ) + wrapper_current.plan( + qo_indptr=qo_indptr_current, + kv_indptr=qo_indptr_current, + num_qo_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim_qk=head_dim, + causal=False, + sm_scale=sm_scale, + ) + cascade_wrappers_current.append(wrapper_current) + + # Wrapper for prefix (step 0 has no prefix) + if step_idx == 0: + cascade_wrappers_prefix.append(None) + else: + wrapper_prefix = BatchPrefillWithRaggedKVCacheWrapper( + torch.empty(256 * 1024 * 1024, dtype=torch.uint8, device=device), kv_layout="NHD",backend = "fa3", + ) + wrapper_prefix.plan( + qo_indptr=qo_indptr_current, + kv_indptr=kv_indptr_prefix_list[step_idx], + num_qo_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim_qk=head_dim, + causal=False, + sm_scale=sm_scale, + ) + cascade_wrappers_prefix.append(wrapper_prefix) + + o1_buffer = torch.empty(num_requests * baseline_chunk_size, num_heads, head_dim, dtype=dtype, device=device) + s1_buffer = torch.empty(num_requests * baseline_chunk_size, num_heads, dtype=torch.float32, device=device) + o2_buffer = torch.empty(num_requests * baseline_chunk_size, num_heads, head_dim, dtype=dtype, device=device) + s2_buffer = torch.empty(num_requests * baseline_chunk_size, num_heads, dtype=torch.float32, device=device) + + def run_cascade_pipeline(): + for step_idx in range(baseline_num_chunks): + q_batch = q_current_buffers[step_idx] + k_current = k_current_buffers[step_idx] + v_current = v_current_buffers[step_idx] + + if step_idx == 0: + cascade_output.copy_(cascade_wrappers_current[step_idx].run(q_batch, k_current, v_current)) + else: + o1, s1 = cascade_wrappers_current[step_idx].run_return_lse(q_batch, k_current, v_current) + o1_buffer.copy_(o1) + s1_buffer.copy_(s1) + + o2, s2 = cascade_wrappers_prefix[step_idx].run_return_lse( + q_batch, k_prefix_buffers[step_idx], v_prefix_buffers[step_idx] + ) + o2_buffer.copy_(o2) + s2_buffer.copy_(s2) + + o, _ = merge_state(o1_buffer, s1_buffer, o2_buffer, s2_buffer) + cascade_output.copy_(o) + + # Display step flow + if verbose: + print(f" Step flow preview:") + for step_id in range(baseline_num_chunks): + prefix_len = step_id * baseline_chunk_size + if step_id == 0: + print(f" Step {step_id}: current_chunk[{baseline_chunk_size}] only (no prefix)") + else: + print(f" Step {step_id}: current_chunk[{baseline_chunk_size}] + prefix[{prefix_len}] + merge") + + # Warmup + print(f" Warmup...") + for _ in range(warmup_iters): + run_cascade_pipeline() + torch.cuda.synchronize() + + # CUDA Graph capture + print(f" Capturing CUDA Graph...") + cascade_stream = torch.cuda.Stream() + with torch.cuda.stream(cascade_stream): + run_cascade_pipeline() + cascade_stream.synchronize() + + cascade_graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(cascade_graph, stream=cascade_stream): + run_cascade_pipeline() + + # Warmup with cuda_graph + for _ in range(warmup_iters): + cascade_graph.replay() + torch.cuda.synchronize() + + # Benchmark + print(f" Benchmark (with cuda_graph)...") + start = time.perf_counter() + for _ in range(bench_iters): + cascade_graph.replay() + torch.cuda.synchronize() + cascade_cuda_graph_time = (time.perf_counter() - start) / bench_iters * 1000 + + results["cascade_baseline"] = { + "time_cuda_graph_ms": cascade_cuda_graph_time, + "chunk_size": baseline_chunk_size, + "num_steps": baseline_num_chunks, + "method": "SGLang Cascade (Baseline 1)", + } + print(f" => cuda_graph: {cascade_cuda_graph_time:.3f} ms ({cascade_cuda_graph_time/baseline_num_chunks:.3f} ms/step × {baseline_num_chunks} steps)") + + del cascade_wrappers_current, cascade_wrappers_prefix, cascade_graph + torch.cuda.empty_cache() + + # Comparison: BatchBlockExtend Ragged (different chunk_size) + CUDA Graph + + for test_chunk_size in chunk_sizes: + if tokens_per_request % test_chunk_size != 0: + print(f"\n[Skip] chunk_size={test_chunk_size} cannot divide tokens_per_request={tokens_per_request}") + continue + + num_chunks_bbe = tokens_per_request // test_chunk_size + + print(f"\n{'-'*60}") + print(f"[BatchBlockExtend Ragged] chunk_size = {test_chunk_size}") + print(f"{'-'*60}") + print(f" num_steps: {num_chunks_bbe} ({baseline_num_chunks - num_chunks_bbe} steps fewer than Baseline)") + print(f" Kernels per step: 1") + + # Split each request into chunks + qs_chunks_bbe = [split_chunks(q, test_chunk_size) for q in all_qs] + + # Pre-allocate all buffers + bbe_q_buffers = [] + for step_idx in range(num_chunks_bbe): + q_list = [qs_chunks_bbe[req_idx][step_idx] for req_idx in range(num_requests)] + bbe_q_buffers.append(torch.cat(q_list, dim=0)) + + bbe_k_buffers = [] + bbe_v_buffers = [] + for step_idx in range(num_chunks_bbe): + kv_len = (step_idx + 1) * test_chunk_size + k_cumul_list = [all_ks[req_idx][:kv_len] for req_idx in range(num_requests)] + v_cumul_list = [all_vs[req_idx][:kv_len] for req_idx in range(num_requests)] + bbe_k_buffers.append(torch.cat(k_cumul_list, dim=0)) + bbe_v_buffers.append(torch.cat(v_cumul_list, dim=0)) + + bbe_qo_indptr = torch.tensor( + [i * test_chunk_size for i in range(num_requests + 1)], + dtype=torch.int32, device=device + ) + bbe_kv_indptr_list = [] + bbe_q_offsets_list = [] + for step_idx in range(num_chunks_bbe): + kv_len = (step_idx + 1) * test_chunk_size + bbe_kv_indptr_list.append(torch.tensor( + [i * kv_len for i in range(num_requests + 1)], + dtype=torch.int32, device=device + )) + q_offset = step_idx * test_chunk_size + """ + bbe_q_offsets_list = [ + # step_idx=0: q_offset = 0 * 64 = 0 + tensor([0, 0, 0, 0], dtype=int32), # shape: (4,) + + # step_idx=1: q_offset = 1 * 64 = 64 + tensor([64, 64, 64, 64], dtype=int32), # shape: (4,) + + # step_idx=2: q_offset = 2 * 64 = 128 + tensor([128, 128, 128, 128], dtype=int32), # shape: (4,) + ] + """ + bbe_q_offsets_list.append(torch.full((num_requests,), q_offset, dtype=torch.int32, device=device)) + + bbe_output = torch.empty(num_requests * test_chunk_size, num_heads, head_dim, dtype=dtype, device=device) + + print(f" Creating wrappers and completing plan...") + bbe_wrappers = [] + for step_idx in range(num_chunks_bbe): + wrapper = BatchBlockExtendRaggedOffsetWrapper( + torch.empty(256 * 1024 * 1024, dtype=torch.uint8, device=device), + kv_layout="NHD", dllm_block_size=dllm_block_size + ) + wrapper.plan( + qo_indptr=bbe_qo_indptr, + kv_indptr=bbe_kv_indptr_list[step_idx], + num_qo_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + q_data_type=dtype, + sm_scale=sm_scale, + q_offsets=bbe_q_offsets_list[step_idx], + ) + bbe_wrappers.append(wrapper) + + def run_bbe_pipeline(): + for step_idx in range(num_chunks_bbe): + bbe_output.copy_(bbe_wrappers[step_idx].run( + bbe_q_buffers[step_idx], bbe_k_buffers[step_idx], bbe_v_buffers[step_idx] + )) + + # Display step flow + if verbose: + print(f" Step flow preview:") + for step_id in range(num_chunks_bbe): + kv_len = (step_id + 1) * test_chunk_size + q_offset = step_id * test_chunk_size + print(f" Step {step_id}: Q[{q_offset}:{q_offset+test_chunk_size}] attend to KV[0:{kv_len}] (q_offset={q_offset})") + + # Warmup + print(f" Warmup...") + for _ in range(warmup_iters): + run_bbe_pipeline() + torch.cuda.synchronize() + + + # CUDA Graph capture + print(f" Capturing CUDA Graph...") + bbe_stream = torch.cuda.Stream() + with torch.cuda.stream(bbe_stream): + run_bbe_pipeline() + bbe_stream.synchronize() + + bbe_graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(bbe_graph, stream=bbe_stream): + run_bbe_pipeline() + + # Warmup with cuda_graph + for _ in range(warmup_iters): + bbe_graph.replay() + torch.cuda.synchronize() + + # Benchmark + print(f" Benchmark (with cuda_graph)...") + start = time.perf_counter() + for _ in range(bench_iters): + bbe_graph.replay() + torch.cuda.synchronize() + bbe_cuda_graph_time = (time.perf_counter() - start) / bench_iters * 1000 + + results[f"bbe_chunk{test_chunk_size}"] = { + "time_cuda_graph_ms": bbe_cuda_graph_time, + "chunk_size": test_chunk_size, + "num_steps": num_chunks_bbe, + "method": "BatchBlockExtend Ragged", + } + print(f" => cuda_graph: {bbe_cuda_graph_time:.3f} ms ({bbe_cuda_graph_time/num_chunks_bbe:.3f} ms/step × {num_chunks_bbe} steps)") + + del bbe_wrappers, bbe_graph + torch.cuda.empty_cache() + + # Custom Mask different chunk_size test (chunk_size = dllm_block_size not required) + for test_chunk_size in chunk_sizes: + if tokens_per_request % test_chunk_size != 0: + print(f"\n[Skip] chunk_size={test_chunk_size} cannot divide tokens_per_request={tokens_per_request}") + continue + + # Skip already tested baseline chunk_size + if test_chunk_size == baseline_chunk_size: + continue + + num_chunks_cm = tokens_per_request // test_chunk_size + + print(f"\n{'-'*60}") + print(f"[Custom Mask] chunk_size = {test_chunk_size}") + print(f"{'-'*60}") + print(f" num_steps: {num_chunks_cm} ({baseline_num_chunks - num_chunks_cm} steps fewer than Baseline)") + print(f" Kernels per step: 1 (batch processes all requests)") + + # Split each request into chunks + qs_chunks_cm = [split_chunks(q, test_chunk_size) for q in all_qs] + + # Pre-allocate Q buffers (concat all requests) + cm_q_buffers_var = [] + for step_idx in range(num_chunks_cm): + q_list = [qs_chunks_cm[req_idx][step_idx] for req_idx in range(num_requests)] + cm_q_buffers_var.append(torch.cat(q_list, dim=0)) + + # Pre-allocate cumulative KV buffers (concat all requests) + cm_k_buffers_var = [] + cm_v_buffers_var = [] + for step_idx in range(num_chunks_cm): + kv_len = (step_idx + 1) * test_chunk_size + k_cumul_list = [all_ks[req_idx][:kv_len] for req_idx in range(num_requests)] + v_cumul_list = [all_vs[req_idx][:kv_len] for req_idx in range(num_requests)] + cm_k_buffers_var.append(torch.cat(k_cumul_list, dim=0)) + cm_v_buffers_var.append(torch.cat(v_cumul_list, dim=0)) + + # Construct flattened custom_mask (batch version) + # DLLM blockwise mask: mask[q, k] = ((q + q_offset) // B) >= (k // B) + cm_mask_buffers_var = [] + for step_idx in range(num_chunks_cm): + kv_len = (step_idx + 1) * test_chunk_size + q_offset = step_idx * test_chunk_size + # Construct single request's 2D mask: [q_len, kv_len] + q_pos = torch.arange(test_chunk_size, device=device) + q_offset + k_pos = torch.arange(kv_len, device=device) + q_block = q_pos.unsqueeze(1) // dllm_block_size + k_block = k_pos.unsqueeze(0) // dllm_block_size + mask_2d = (q_block >= k_block) # [q_len, kv_len], bool + # Flatten and repeat for batch + mask_flat = mask_2d.flatten() + batch_mask = mask_flat.repeat(num_requests) + cm_mask_buffers_var.append(batch_mask) + + # qo_indptr and kv_indptr + cm_qo_indptr_var = torch.tensor( + [i * test_chunk_size for i in range(num_requests + 1)], + dtype=torch.int32, device=device + ) + cm_kv_indptr_list_var = [] + for step_idx in range(num_chunks_cm): + kv_len = (step_idx + 1) * test_chunk_size + cm_kv_indptr_list_var.append(torch.tensor( + [i * kv_len for i in range(num_requests + 1)], + dtype=torch.int32, device=device + )) + + # Create independent wrapper for each step and complete plan + print(f" Creating wrappers and completing plan...") + cm_wrappers_var = [] + for step_idx in range(num_chunks_cm): + wrapper = BatchPrefillWithRaggedKVCacheWrapper( + torch.empty(256 * 1024 * 1024, dtype=torch.uint8, device=device), + kv_layout="NHD", + backend="fa2", # custom_mask only supported in FA2 + ) + wrapper.plan( + qo_indptr=cm_qo_indptr_var, + kv_indptr=cm_kv_indptr_list_var[step_idx], + num_qo_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim_qk=head_dim, + custom_mask=cm_mask_buffers_var[step_idx], + causal=False, + sm_scale=sm_scale, + ) + cm_wrappers_var.append(wrapper) + + cm_output_var = torch.empty(num_requests * test_chunk_size, num_heads, head_dim, dtype=dtype, device=device) + + def run_cm_pipeline_var(): + for step_idx in range(num_chunks_cm): + cm_output_var.copy_(cm_wrappers_var[step_idx].run( + cm_q_buffers_var[step_idx], cm_k_buffers_var[step_idx], cm_v_buffers_var[step_idx] + )) + + if verbose: + print(f" Step flow preview:") + for step_id in range(num_chunks_cm): + kv_len = (step_id + 1) * test_chunk_size + q_offset = step_id * test_chunk_size + print(f" Step {step_id}: Q[{q_offset}:{q_offset+test_chunk_size}] attend to KV[0:{kv_len}]") + + # Warmup + print(f" Warmup...") + for _ in range(warmup_iters): + run_cm_pipeline_var() + torch.cuda.synchronize() + + # CUDA Graph capture + print(f" Capturing CUDA Graph...") + cm_stream_var = torch.cuda.Stream() + with torch.cuda.stream(cm_stream_var): + run_cm_pipeline_var() + cm_stream_var.synchronize() + + cm_graph_var = torch.cuda.CUDAGraph() + with torch.cuda.graph(cm_graph_var, stream=cm_stream_var): + run_cm_pipeline_var() + + # Warmup with cuda_graph + for _ in range(warmup_iters): + cm_graph_var.replay() + torch.cuda.synchronize() + + # Benchmark + print(f" Benchmark (with cuda_graph)...") + start = time.perf_counter() + for _ in range(bench_iters): + cm_graph_var.replay() + torch.cuda.synchronize() + cm_var_time = (time.perf_counter() - start) / bench_iters * 1000 + + results[f"cm_chunk{test_chunk_size}"] = { + "time_cuda_graph_ms": cm_var_time, + "chunk_size": test_chunk_size, + "num_steps": num_chunks_cm, + "method": "Custom Mask", + } + print(f" => cuda_graph: {cm_var_time:.3f} ms ({cm_var_time/num_chunks_cm:.3f} ms/step × {num_chunks_cm} steps)") + + del cm_graph_var, cm_mask_buffers_var, cm_wrappers_var + torch.cuda.empty_cache() + + print(f"\n{'='*80}") + print(f"Results Summary (Step-by-step Execution + CUDA Graph)") + print(f"{'='*80}") + + cm_baseline_time = results["custom_mask_baseline"]["time_cuda_graph_ms"] + cascade_baseline_time = results["cascade_baseline"]["time_cuda_graph_ms"] + + + print(f"\nNotes:") + print(f" - Baseline 1: SGLang Cascade (BatchPrefill + merge_state)") + print(f" - Baseline 2: Custom Mask (BatchPrefill + custom_mask)") + print(f" - vs Base1: Speedup relative to SGLang Cascade") + print(f" - vs Base2: Speedup relative to Custom Mask") + + print(f"\n{'Method':<40} | {'chunk':>6} | {'steps':>6} | {'cuda_graph(ms)':>10} | {'ms/step':>10} | {'vs Base1':>10} | {'vs Base2':>10}") + print(f"{'-'*40}-+-{'-'*6}-+-{'-'*6}-+-{'-'*10}-+-{'-'*10}-+-{'-'*10}-+-{'-'*10}") + + # Baseline 1: SGLang Cascade + r = results["cascade_baseline"] + print(f"{'[Baseline 1] SGLang Cascade':<40} | {r['chunk_size']:>6} | {r['num_steps']:>6} | {r['time_cuda_graph_ms']:>10.3f} | {r['time_cuda_graph_ms']/r['num_steps']:>10.3f} | {'1.00x':>10} | {cm_baseline_time/cascade_baseline_time:>9.2f}x") + + # Baseline 2: Custom Mask + r = results["custom_mask_baseline"] + print(f"{'[Baseline 2] Custom Mask':<40} | {r['chunk_size']:>6} | {r['num_steps']:>6} | {r['time_cuda_graph_ms']:>10.3f} | {r['time_cuda_graph_ms']/r['num_steps']:>10.3f} | {cascade_baseline_time/cm_baseline_time:>9.2f}x | {'1.00x':>10}") + + + cm_keys = [k for k in results.keys() if k.startswith("cm_chunk")] + cm_keys_sorted = sorted(cm_keys, key=lambda k: results[k]["chunk_size"]) + for key in cm_keys_sorted: + r = results[key] + speedup_vs_cascade = cascade_baseline_time / r["time_cuda_graph_ms"] + speedup_vs_cm = cm_baseline_time / r["time_cuda_graph_ms"] + print(f"{'Custom Mask':<40} | {r['chunk_size']:>6} | {r['num_steps']:>6} | {r['time_cuda_graph_ms']:>10.3f} | {r['time_cuda_graph_ms']/r['num_steps']:>10.3f} | {speedup_vs_cascade:>9.2f}x | {speedup_vs_cm:>9.2f}x") + + + bbe_keys = [k for k in results.keys() if k.startswith("bbe_chunk")] + bbe_keys_sorted = sorted(bbe_keys, key=lambda k: results[k]["chunk_size"]) + for key in bbe_keys_sorted: + r = results[key] + speedup_vs_cascade = cascade_baseline_time / r["time_cuda_graph_ms"] + speedup_vs_cm = cm_baseline_time / r["time_cuda_graph_ms"] + print(f"{'BatchBlockExtend Ragged':<40} | {r['chunk_size']:>6} | {r['num_steps']:>6} | {r['time_cuda_graph_ms']:>10.3f} | {r['time_cuda_graph_ms']/r['num_steps']:>10.3f} | {speedup_vs_cascade:>9.2f}x | {speedup_vs_cm:>9.2f}x") + + + return results + + +def test_incremental_singlereq_prefill_step_by_step_with_cuda_graph( + tokens_per_request: int = 512, + dllm_block_size: int = 32, + chunk_sizes: list = None, + num_heads: int = 32, + num_kv_heads: int = 8, + head_dim: int = 128, + warmup_iters: int = 10, + bench_iters: int = 100, + verbose: bool = False, +): + """ + Single request incremental Prefill step-by-step execution test + CUDA Graph + + Scenario description: + - Incremental prefill of single request + - Must execute step by step: chunk1 can only be computed after chunk0 completes (pipeline dependency) + - SGLang uses BatchPrefill interface even with batch_size=1 + + Baseline (SGLang 3-stage DLLM Cascade): + 1. BatchPrefillWithRaggedKVCacheWrapper: current chunk (causal=False) + 2. BatchPrefillWithRaggedKVCacheWrapper: prefix KV (causal=False) + 3. merge_state: merge results from both parts + Each step: 2-3 kernel launches + + Comparison: + 1. block_extend_attention_with_offset + CUDA Graph + 2. BatchBlockExtendRaggedOffsetWrapper + CUDA Graph + Each step: 1 kernel launch + """ + if chunk_sizes is None: + chunk_sizes = [32, 64, 128, 256] + + device = torch.device("cuda:0") + dtype = torch.float16 + sm_scale = 1.0 / (head_dim ** 0.5) + + # Baseline: chunk_size = dllm_block_size (SGLang constraint) + baseline_chunk_size = dllm_block_size + num_chunks = tokens_per_request // baseline_chunk_size + + print(f"\n{'='*80}") + print(f"Single Request Incremental Prefill Step-by-step + CUDA Graph") + print(f"{'='*80}") + print(f"Configuration:") + print(f" tokens_per_request = {tokens_per_request}") + print(f" dllm_block_size = {dllm_block_size}") + print(f" chunk_sizes to test = {chunk_sizes}") + print(f" num_heads = {num_heads}") + print(f" num_kv_heads = {num_kv_heads}") + print(f" head_dim = {head_dim}") + print(f"\nScenario description:") + print(f" - Single request (batch_size=1), but uses BatchPrefill interface (SGLang approach)") + print(f" - Step-by-step execution: chunk1 can only be computed after chunk0 completes (pipeline dependency)") + print(f" - num_steps = {num_chunks}") + + # Data preparation + # Single request's complete Q, K, V + q_full = torch.randn(tokens_per_request, num_heads, head_dim, dtype=dtype, device=device) + k_full = torch.randn(tokens_per_request, num_kv_heads, head_dim, dtype=dtype, device=device) + v_full = torch.randn(tokens_per_request, num_kv_heads, head_dim, dtype=dtype, device=device) + + # Split into chunks + def split_chunks(tensor, chunk_size): + return [tensor[i*chunk_size:(i+1)*chunk_size] for i in range(tensor.shape[0] // chunk_size)] + + qs_chunks = split_chunks(q_full, baseline_chunk_size) + ks_chunks = split_chunks(k_full, baseline_chunk_size) + vs_chunks = split_chunks(v_full, baseline_chunk_size) + + # Cumulative KV buffer + k_cumul_list = [torch.cat([k_full[:(i+1)*baseline_chunk_size]], dim=0) for i in range(num_chunks)] + v_cumul_list = [torch.cat([v_full[:(i+1)*baseline_chunk_size]], dim=0) for i in range(num_chunks)] + + results = {} + + # Correctness verification + print(f"\n{'='*80}") + print(f"Correctness Verification") + print(f"{'='*80}") + print(f" Reference implementation: single_prefill_with_kv_cache + custom_mask") + print(f" Mask rule: mask[q,k] = ((q + offset) // B) >= (k // B)") + + # Compute reference results (each chunk computed independently) + ref_outputs = [] + for step_idx in range(num_chunks): + q_offset = step_idx * baseline_chunk_size + ref_out = compute_block_extend_reference( + qs_chunks[step_idx], + k_cumul_list[step_idx], + v_cumul_list[step_idx], + dllm_block_size=dllm_block_size, + q_offset=q_offset, + sm_scale=sm_scale, + ) + ref_outputs.append(ref_out) + + # Compute V2 results and verify + v2_outputs = [] + for step_idx in range(num_chunks): + q_offset = step_idx * baseline_chunk_size + v2_out = block_extend_attention_with_offset( + qs_chunks[step_idx], + k_cumul_list[step_idx], + v_cumul_list[step_idx], + dllm_block_size=dllm_block_size, + q_offset=q_offset, + sm_scale=sm_scale, + backend="fa2", + ) + v2_outputs.append(v2_out) + + # Verify V2 + v2_max_diff = max((v2_outputs[i] - ref_outputs[i]).abs().max().item() for i in range(num_chunks)) + tol = 1e-3 + v2_pass = v2_max_diff < tol + print(f"\n [V2] block_extend_attention_with_offset:") + print(f" max_diff = {v2_max_diff:.6f}, tolerance = {tol}") + print(f" {' PASS' if v2_pass else ' FAIL'}") + + # Compute BBE results and verify + bbe_outputs = [] + workspace_verify = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device) + qo_indptr_verify = torch.tensor([0, baseline_chunk_size], dtype=torch.int32, device=device) + for step_idx in range(num_chunks): + kv_len = (step_idx + 1) * baseline_chunk_size + kv_indptr = torch.tensor([0, kv_len], dtype=torch.int32, device=device) + q_offset_tensor = torch.tensor([step_idx * baseline_chunk_size], dtype=torch.int32, device=device) + + wrapper = BatchBlockExtendRaggedOffsetWrapper( + workspace_verify, kv_layout="NHD", dllm_block_size=dllm_block_size + ) + wrapper.plan( + qo_indptr=qo_indptr_verify, + kv_indptr=kv_indptr, + num_qo_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + q_data_type=dtype, + sm_scale=sm_scale, + q_offsets=q_offset_tensor, + ) + bbe_out = wrapper.run(qs_chunks[step_idx], k_cumul_list[step_idx], v_cumul_list[step_idx]) + bbe_outputs.append(bbe_out) + + # Verify BBE + bbe_max_diff = max((bbe_outputs[i] - ref_outputs[i]).abs().max().item() for i in range(num_chunks)) + bbe_pass = bbe_max_diff < tol + print(f"\n [BBE] BatchBlockExtendRaggedOffsetWrapper:") + print(f" max_diff = {bbe_max_diff:.6f}, tolerance = {tol}") + print(f" {' PASS' if bbe_pass else ' FAIL'}") + + if not (v2_pass and bbe_pass): + print(f"\n Correctness verification failed, but continue with performance test") + else: + print(f"\n All methods passed correctness verification") + + del workspace_verify, bbe_outputs, v2_outputs, ref_outputs + torch.cuda.empty_cache() + + + # Baseline 2: Custom Mask (native single_prefill + custom_mask) + CUDA Graph + print(f"\n{'='*80}") + print(f"[Baseline 2] Custom Mask (single_prefill + custom_mask)") + print(f"{'='*80}") + print(f" num_steps: {num_chunks}") + print(f" Each step: single_prefill_with_kv_cache + custom_mask") + print(f" Kernels per step: 1 (but needs to construct mask tensor)") + + + custom_mask_buffers = [] + for step_idx in range(num_chunks): + kv_len = (step_idx + 1) * baseline_chunk_size + q_offset = step_idx * baseline_chunk_size + # mask[q, k] = ((q + offset) // B) >= (k // B) + q_pos = torch.arange(baseline_chunk_size, device=device) + q_offset + k_pos = torch.arange(kv_len, device=device) + q_block = q_pos.unsqueeze(1) // dllm_block_size + k_block = k_pos.unsqueeze(0) // dllm_block_size + mask_2d = (q_block >= k_block).to(torch.uint8) + custom_mask_buffers.append(mask_2d) + + cm_output = torch.empty(baseline_chunk_size, num_heads, head_dim, dtype=dtype, device=device) + + # Note: custom_mask is only supported in FA2 backend, not FA3 + def run_custom_mask_pipeline(): + for step_idx in range(num_chunks): + cm_output.copy_(single_prefill_with_kv_cache( + qs_chunks[step_idx], + k_cumul_list[step_idx], + v_cumul_list[step_idx], + custom_mask=custom_mask_buffers[step_idx], + sm_scale=sm_scale, + backend="fa2", # custom_mask only supported in FA2 + )) + + + if verbose: + print(f" Step flow preview:") + for step_id in range(num_chunks): + kv_len = (step_id + 1) * baseline_chunk_size + print(f" Step {step_id}: Q[{step_id*baseline_chunk_size}:{(step_id+1)*baseline_chunk_size}] attend to KV[0:{kv_len}]") + + # Warmup + print(f" Warmup...") + for _ in range(warmup_iters): + run_custom_mask_pipeline() + torch.cuda.synchronize() + + # CUDA Graph capture + print(f" Capturing CUDA Graph...") + cm_stream = torch.cuda.Stream() + with torch.cuda.stream(cm_stream): + run_custom_mask_pipeline() + cm_stream.synchronize() + + cm_graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(cm_graph, stream=cm_stream): + run_custom_mask_pipeline() + + # Warmup with cuda_graph + for _ in range(warmup_iters): + cm_graph.replay() + torch.cuda.synchronize() + + # Benchmark + print(f" Benchmark (with cuda_graph)...") + start = time.perf_counter() + for _ in range(bench_iters): + cm_graph.replay() + torch.cuda.synchronize() + cm_cuda_graph_time = (time.perf_counter() - start) / bench_iters * 1000 + + results["custom_mask_baseline"] = { + "time_cuda_graph_ms": cm_cuda_graph_time, + "chunk_size": baseline_chunk_size, + "num_steps": num_chunks, + "method": "Custom Mask (Baseline 2)", + } + print(f" => cuda_graph: {cm_cuda_graph_time:.3f} ms ({cm_cuda_graph_time/num_chunks:.3f} ms/step × {num_chunks} steps)") + + del cm_graph, custom_mask_buffers + torch.cuda.empty_cache() + + + # Custom Mask different chunk_size test (chunk_size = dllm_block_size not required) + for test_chunk_size in chunk_sizes: + if tokens_per_request % test_chunk_size != 0: + print(f"\n[Skip] chunk_size={test_chunk_size} cannot divide tokens_per_request={tokens_per_request}") + continue + + # Skip already tested baseline chunk_size + if test_chunk_size == baseline_chunk_size: + continue + + num_steps_cm = tokens_per_request // test_chunk_size + + print(f"\n{'-'*60}") + print(f"[Custom Mask] chunk_size = {test_chunk_size}") + print(f"{'-'*60}") + print(f" num_steps: {num_steps_cm} ({num_chunks - num_steps_cm} steps fewer than Baseline)") + print(f" Kernels per step: 1") + + # Split into chunks + qs_cm = split_chunks(q_full, test_chunk_size) + + # Cumulative KV buffer + k_cumul_cm = [k_full[:(i+1)*test_chunk_size].clone() for i in range(num_steps_cm)] + v_cumul_cm = [v_full[:(i+1)*test_chunk_size].clone() for i in range(num_steps_cm)] + + # Pre-allocate custom_mask buffers (different mask for each step) + cm_mask_buffers = [] + for step_idx in range(num_steps_cm): + kv_len = (step_idx + 1) * test_chunk_size + q_offset = step_idx * test_chunk_size + # mask[q, k] = ((q + offset) // B) >= (k // B) + q_pos = torch.arange(test_chunk_size, device=device) + q_offset + k_pos = torch.arange(kv_len, device=device) + q_block = q_pos.unsqueeze(1) // dllm_block_size + k_block = k_pos.unsqueeze(0) // dllm_block_size + mask_2d = (q_block >= k_block).to(torch.uint8) + cm_mask_buffers.append(mask_2d) + + # Pre-allocate output buffer + cm_output_var = torch.empty(test_chunk_size, num_heads, head_dim, dtype=dtype, device=device) + + def run_cm_pipeline(): + for step_idx in range(num_steps_cm): + cm_output_var.copy_(single_prefill_with_kv_cache( + qs_cm[step_idx], k_cumul_cm[step_idx], v_cumul_cm[step_idx], + custom_mask=cm_mask_buffers[step_idx], + sm_scale=sm_scale, + backend="fa2", # custom_mask only supported in FA2 + )) + + # Display step flow + if verbose: + print(f" Step flow preview:") + for step_id in range(num_steps_cm): + kv_len = (step_id + 1) * test_chunk_size + q_offset = step_id * test_chunk_size + print(f" Step {step_id}: Q[{q_offset}:{q_offset+test_chunk_size}] attend to KV[0:{kv_len}]") + + # Warmup + print(f" Warmup...") + for _ in range(warmup_iters): + run_cm_pipeline() + torch.cuda.synchronize() + + # CUDA Graph capture + print(f" Capturing CUDA Graph...") + cm_stream_var = torch.cuda.Stream() + with torch.cuda.stream(cm_stream_var): + run_cm_pipeline() + cm_stream_var.synchronize() + + cm_graph_var = torch.cuda.CUDAGraph() + with torch.cuda.graph(cm_graph_var, stream=cm_stream_var): + run_cm_pipeline() + + # Warmup with cuda_graph + for _ in range(warmup_iters): + cm_graph_var.replay() + torch.cuda.synchronize() + + # Benchmark + print(f" Benchmark (with cuda_graph)...") + start = time.perf_counter() + for _ in range(bench_iters): + cm_graph_var.replay() + torch.cuda.synchronize() + cm_var_time = (time.perf_counter() - start) / bench_iters * 1000 + + results[f"cm_chunk{test_chunk_size}"] = { + "time_cuda_graph_ms": cm_var_time, + "chunk_size": test_chunk_size, + "num_steps": num_steps_cm, + "method": "Custom Mask", + } + print(f" => cuda_graph: {cm_var_time:.3f} ms ({cm_var_time/num_steps_cm:.3f} ms/step × {num_steps_cm} steps)") + + del cm_graph_var, cm_mask_buffers + torch.cuda.empty_cache() + + # Baseline 1: SGLang 3-stage DLLM Cascade (BatchPrefill interface, batch_size=1) + print(f"\n{'='*80}") + print(f"[Baseline 1] SGLang 3-stage Cascade (BatchPrefill, batch_size=1)") + print(f"{'='*80}") + print(f" num_steps: {num_chunks}") + print(f" Each step: BatchPrefill(current chunk) + BatchPrefill(prefix) + merge_state") + print(f" Kernels per step: 2-3") + + # indptr for batch_size=1 + qo_indptr_chunk = torch.tensor([0, baseline_chunk_size], dtype=torch.int32, device=device) + + # workspace size: + workspace_size = 16 * 1024 * 1024 + + # Create independent wrapper for each step and complete plan + print(f" Creating wrappers and completing plan...") + cascade_wrappers_current = [] + cascade_wrappers_prefix = [] + kv_indptr_prefix_list = [None] # step 0 has no prefix + + for step_idx in range(num_chunks): + # Wrapper for current chunk + wrapper_current = BatchPrefillWithRaggedKVCacheWrapper( + torch.empty(workspace_size, dtype=torch.uint8, device=device), kv_layout="NHD", backend="fa3", + ) + wrapper_current.plan( + qo_indptr=qo_indptr_chunk, + kv_indptr=qo_indptr_chunk, # current chunk KV length = Q length + num_qo_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim_qk=head_dim, + causal=False, + sm_scale=sm_scale, + ) + cascade_wrappers_current.append(wrapper_current) + + # Wrapper for prefix (step 0 has no prefix) + if step_idx == 0: + cascade_wrappers_prefix.append(None) + else: + prefix_len = step_idx * baseline_chunk_size + kv_indptr_prefix = torch.tensor([0, prefix_len], dtype=torch.int32, device=device) + kv_indptr_prefix_list.append(kv_indptr_prefix) + + wrapper_prefix = BatchPrefillWithRaggedKVCacheWrapper( + torch.empty(workspace_size, dtype=torch.uint8, device=device), kv_layout="NHD", backend="fa3", + ) + wrapper_prefix.plan( + qo_indptr=qo_indptr_chunk, + kv_indptr=kv_indptr_prefix, + num_qo_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim_qk=head_dim, + causal=False, + sm_scale=sm_scale, + ) + cascade_wrappers_prefix.append(wrapper_prefix) + + # Pre-allocate buffer + cascade_output = torch.empty(baseline_chunk_size, num_heads, head_dim, dtype=dtype, device=device) + o1_buffer = torch.empty(baseline_chunk_size, num_heads, head_dim, dtype=dtype, device=device) + s1_buffer = torch.empty(baseline_chunk_size, num_heads, dtype=torch.float32, device=device) + o2_buffer = torch.empty(baseline_chunk_size, num_heads, head_dim, dtype=dtype, device=device) + s2_buffer = torch.empty(baseline_chunk_size, num_heads, dtype=torch.float32, device=device) + + # Pre-allocate prefix KV buffer + k_prefix_buffers = [None] + v_prefix_buffers = [None] + for step_idx in range(1, num_chunks): + prefix_len = step_idx * baseline_chunk_size + k_prefix_buffers.append(k_full[:prefix_len].clone()) + v_prefix_buffers.append(v_full[:prefix_len].clone()) + + def run_cascade_pipeline(): + for step_idx in range(num_chunks): + q_chunk = qs_chunks[step_idx] + k_chunk = ks_chunks[step_idx] + v_chunk = vs_chunks[step_idx] + + if step_idx == 0: + cascade_output.copy_(cascade_wrappers_current[step_idx].run(q_chunk, k_chunk, v_chunk)) + else: + o1, s1 = cascade_wrappers_current[step_idx].run_return_lse(q_chunk, k_chunk, v_chunk) + o1_buffer.copy_(o1) + s1_buffer.copy_(s1) + + o2, s2 = cascade_wrappers_prefix[step_idx].run_return_lse( + q_chunk, k_prefix_buffers[step_idx], v_prefix_buffers[step_idx] + ) + o2_buffer.copy_(o2) + s2_buffer.copy_(s2) + + o, _ = merge_state(o1_buffer, s1_buffer, o2_buffer, s2_buffer) + cascade_output.copy_(o) + + + if verbose: + print(f" Step flow preview:") + for step_id in range(num_chunks): + prefix_len = step_id * baseline_chunk_size + if step_id == 0: + print(f" Step {step_id}: current_chunk[{baseline_chunk_size}] only (no prefix)") + else: + print(f" Step {step_id}: current_chunk[{baseline_chunk_size}] + prefix[{prefix_len}] + merge") + + # Warmup + print(f" Warmup...") + for _ in range(warmup_iters): + run_cascade_pipeline() + torch.cuda.synchronize() + + # CUDA Graph capture + print(f" Capturing CUDA Graph...") + cascade_stream = torch.cuda.Stream() + with torch.cuda.stream(cascade_stream): + run_cascade_pipeline() + cascade_stream.synchronize() + + cascade_graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(cascade_graph, stream=cascade_stream): + run_cascade_pipeline() + + # Warmup with cuda_graph + for _ in range(warmup_iters): + cascade_graph.replay() + torch.cuda.synchronize() + + # Benchmark + print(f" Benchmark (with cuda_graph)...") + start = time.perf_counter() + for _ in range(bench_iters): + cascade_graph.replay() + torch.cuda.synchronize() + cascade_cuda_graph_time = (time.perf_counter() - start) / bench_iters * 1000 + + results["cascade_baseline"] = { + "time_cuda_graph_ms": cascade_cuda_graph_time, + "chunk_size": baseline_chunk_size, + "num_steps": num_chunks, + "method": "SGLang Cascade (Baseline 1)", + } + print(f" => cuda_graph: {cascade_cuda_graph_time:.3f} ms ({cascade_cuda_graph_time/num_chunks:.3f} ms/step × {num_chunks} steps)") + + del cascade_wrappers_current, cascade_wrappers_prefix, cascade_graph + torch.cuda.empty_cache() + + # Method 1: block_extend_attention_with_offset + CUDA Graph + for test_chunk_size in chunk_sizes: + if tokens_per_request % test_chunk_size != 0: + print(f"\n[Skip] chunk_size={test_chunk_size} cannot divide tokens_per_request={tokens_per_request}") + continue + + num_steps = tokens_per_request // test_chunk_size + + print(f"\n{'-'*60}") + print(f"[V2] block_extend_attention_with_offset (chunk_size={test_chunk_size})") + print(f"{'-'*60}") + print(f" num_steps: {num_steps} ({num_chunks - num_steps} steps fewer than Baseline)") + print(f" Kernels per step: 1") + + # Split into chunks + qs_v2 = split_chunks(q_full, test_chunk_size) + + # Pre-allocate output buffer + v2_output = torch.empty(test_chunk_size, num_heads, head_dim, dtype=dtype, device=device) + + # Cumulative KV buffer + k_cumul_v2 = [k_full[:(i+1)*test_chunk_size].clone() for i in range(num_steps)] + v_cumul_v2 = [v_full[:(i+1)*test_chunk_size].clone() for i in range(num_steps)] + + def run_v2_pipeline(): + for step_idx in range(num_steps): + v2_output.copy_(block_extend_attention_with_offset( + qs_v2[step_idx], k_cumul_v2[step_idx], v_cumul_v2[step_idx], + dllm_block_size=dllm_block_size, + q_offset=step_idx * test_chunk_size, + sm_scale=sm_scale, + backend="fa2", + )) + + # Display step flow + if verbose: + print(f" Step flow preview:") + for step_id in range(num_steps): + kv_len = (step_id + 1) * test_chunk_size + q_offset = step_id * test_chunk_size + print(f" Step {step_id}: Q[{q_offset}:{q_offset+test_chunk_size}] attend to KV[0:{kv_len}] (q_offset={q_offset})") + + # Warmup + print(f" Warmup...") + for _ in range(warmup_iters): + run_v2_pipeline() + torch.cuda.synchronize() + + # CUDA Graph capture + print(f" Capturing CUDA Graph...") + v2_stream = torch.cuda.Stream() + with torch.cuda.stream(v2_stream): + run_v2_pipeline() + v2_stream.synchronize() + + v2_graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(v2_graph, stream=v2_stream): + run_v2_pipeline() + + # Warmup with cuda_graph + for _ in range(warmup_iters): + v2_graph.replay() + torch.cuda.synchronize() + + # Benchmark + print(f" Benchmark (with cuda_graph)...") + start = time.perf_counter() + for _ in range(bench_iters): + v2_graph.replay() + torch.cuda.synchronize() + v2_cuda_graph_time = (time.perf_counter() - start) / bench_iters * 1000 + + results[f"v2_chunk{test_chunk_size}"] = { + "time_cuda_graph_ms": v2_cuda_graph_time, + "chunk_size": test_chunk_size, + "num_steps": num_steps, + "method": "V2 + CUDA Graph", + } + print(f" => cuda_graph: {v2_cuda_graph_time:.3f} ms ({v2_cuda_graph_time/num_steps:.3f} ms/step × {num_steps} steps)") + + del v2_graph + torch.cuda.empty_cache() + + + # Method 2: BatchBlockExtendRaggedOffsetWrapper + CUDA Graph + for test_chunk_size in chunk_sizes: + if tokens_per_request % test_chunk_size != 0: + continue + + num_steps = tokens_per_request // test_chunk_size + + print(f"\n{'-'*60}") + print(f"[BBE] BatchBlockExtendRaggedOffsetWrapper (chunk_size={test_chunk_size})") + print(f"{'-'*60}") + print(f" num_steps: {num_steps} ({num_chunks - num_steps} steps fewer than Baseline)") + print(f" Kernels per step: 1") + + # Split into chunks + qs_bbe = split_chunks(q_full, test_chunk_size) + + # indptr for batch_size=1 + qo_indptr_bbe = torch.tensor([0, test_chunk_size], dtype=torch.int32, device=device) + + # Cumulative KV buffer + k_cumul_bbe = [k_full[:(i+1)*test_chunk_size].clone() for i in range(num_steps)] + v_cumul_bbe = [v_full[:(i+1)*test_chunk_size].clone() for i in range(num_steps)] + + # Create independent wrapper for each step and complete plan + print(f" Creating wrappers and completing plan...") + + workspace_size = 128 * 1024 * 1024 # 128MB - BBE uses JIT which requires larger workspace + bbe_wrappers = [] + for step_idx in range(num_steps): + kv_len = (step_idx + 1) * test_chunk_size + kv_indptr = torch.tensor([0, kv_len], dtype=torch.int32, device=device) + q_offset = torch.tensor([step_idx * test_chunk_size], dtype=torch.int32, device=device) + + wrapper = BatchBlockExtendRaggedOffsetWrapper( + torch.empty(workspace_size, dtype=torch.uint8, device=device), + kv_layout="NHD", dllm_block_size=dllm_block_size + ) + wrapper.plan( + qo_indptr=qo_indptr_bbe, + kv_indptr=kv_indptr, + num_qo_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + q_data_type=dtype, + sm_scale=sm_scale, + q_offsets=q_offset, + ) + bbe_wrappers.append(wrapper) + + # Pre-allocate output buffer + bbe_output = torch.empty(test_chunk_size, num_heads, head_dim, dtype=dtype, device=device) + + def run_bbe_pipeline(): + for step_idx in range(num_steps): + bbe_output.copy_(bbe_wrappers[step_idx].run( + qs_bbe[step_idx], k_cumul_bbe[step_idx], v_cumul_bbe[step_idx] + )) + + # Display step flow + if verbose: + print(f" Step flow preview:") + for step_id in range(num_steps): + kv_len = (step_id + 1) * test_chunk_size + q_offset = step_id * test_chunk_size + print(f" Step {step_id}: Q[{q_offset}:{q_offset+test_chunk_size}] attend to KV[0:{kv_len}] (q_offset={q_offset})") + + # Warmup + print(f" Warmup...") + for _ in range(warmup_iters): + run_bbe_pipeline() + torch.cuda.synchronize() + + # CUDA Graph capture + print(f" Capturing CUDA Graph...") + bbe_stream = torch.cuda.Stream() + with torch.cuda.stream(bbe_stream): + run_bbe_pipeline() + bbe_stream.synchronize() + + bbe_graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(bbe_graph, stream=bbe_stream): + run_bbe_pipeline() + + # Warmup with cuda_graph + for _ in range(warmup_iters): + bbe_graph.replay() + torch.cuda.synchronize() + + # Benchmark + print(f" Benchmark (with cuda_graph)...") + start = time.perf_counter() + for _ in range(bench_iters): + bbe_graph.replay() + torch.cuda.synchronize() + bbe_cuda_graph_time = (time.perf_counter() - start) / bench_iters * 1000 + + results[f"bbe_chunk{test_chunk_size}"] = { + "time_cuda_graph_ms": bbe_cuda_graph_time, + "chunk_size": test_chunk_size, + "num_steps": num_steps, + "method": "BBE Ragged + CUDA Graph", + } + print(f" => cuda_graph: {bbe_cuda_graph_time:.3f} ms ({bbe_cuda_graph_time/num_steps:.3f} ms/step × {num_steps} steps)") + + del bbe_wrappers, bbe_graph + torch.cuda.empty_cache() + + print(f"\n{'='*80}") + print(f"Results Summary (Single Request Step-by-step + CUDA Graph)") + print(f"{'='*80}") + + cm_baseline_r = results.get("custom_mask_baseline", {}) + cascade_baseline_r = results.get("cascade_baseline", {}) + cascade_skipped = cascade_baseline_r.get("skipped", False) + + cm_baseline_time = cm_baseline_r.get("time_cuda_graph_ms", float('nan')) + cascade_baseline_time = cascade_baseline_r.get("time_cuda_graph_ms", float('nan')) + + # Header explanation + print(f"\nNotes:") + print(f" - Baseline 1: SGLang Cascade (BatchPrefill + merge_state)") + print(f" - Baseline 2: Custom Mask (single_prefill + custom_mask)") + print(f" - vs Base1: Speedup relative to SGLang Cascade") + print(f" - vs Base2: Speedup relative to Custom Mask") + + print(f"\n{'Method':<45} | {'chunk':>6} | {'steps':>6} | {'cuda_graph(ms)':>10} | {'ms/step':>10} | {'vs Base1':>10} | {'vs Base2':>10}") + print(f"{'-'*45}-+-{'-'*6}-+-{'-'*6}-+-{'-'*10}-+-{'-'*10}-+-{'-'*10}-+-{'-'*10}") + + # Baseline 1: SGLang Cascade + if cascade_skipped: + print(f"{'[Baseline 1] SGLang Cascade':<45} | {cascade_baseline_r['chunk_size']:>6} | {cascade_baseline_r['num_steps']:>6} | {'SKIPPED':>10} | {'-':>10} | {'-':>10} | {'-':>10}") + else: + r = cascade_baseline_r + print(f"{'[Baseline 1] SGLang Cascade':<45} | {r['chunk_size']:>6} | {r['num_steps']:>6} | {r['time_cuda_graph_ms']:>10.3f} | {r['time_cuda_graph_ms']/r['num_steps']:>10.3f} | {'1.00x':>10} | {cm_baseline_time/cascade_baseline_time:>9.2f}x") + + # Baseline 2: Custom Mask + r = cm_baseline_r + if cascade_skipped: + vs_base1_str = "-" + else: + vs_base1_str = f"{cascade_baseline_time/cm_baseline_time:>9.2f}x" + print(f"{'[Baseline 2] Custom Mask':<45} | {r['chunk_size']:>6} | {r['num_steps']:>6} | {r['time_cuda_graph_ms']:>10.3f} | {r['time_cuda_graph_ms']/r['num_steps']:>10.3f} | {vs_base1_str:>10} | {'1.00x':>10}") + + # Custom Mask different chunk_size results (sorted by chunk_size ascending) + cm_keys = sorted([k for k in results.keys() if k.startswith("cm_chunk")], + key=lambda k: results[k]["chunk_size"]) + for key in cm_keys: + r = results[key] + speedup_vs_cm = cm_baseline_time / r["time_cuda_graph_ms"] + if cascade_skipped: + speedup_vs_cascade_str = "-" + else: + speedup_vs_cascade = cascade_baseline_time / r["time_cuda_graph_ms"] + speedup_vs_cascade_str = f"{speedup_vs_cascade:>9.2f}x" + print(f"{'Custom Mask':<45} | {r['chunk_size']:>6} | {r['num_steps']:>6} | {r['time_cuda_graph_ms']:>10.3f} | {r['time_cuda_graph_ms']/r['num_steps']:>10.3f} | {speedup_vs_cascade_str:>10} | {speedup_vs_cm:>9.2f}x") + + # V2 results (sorted by chunk_size ascending) + v2_keys = sorted([k for k in results.keys() if k.startswith("v2_chunk")], + key=lambda k: results[k]["chunk_size"]) + for key in v2_keys: + r = results[key] + speedup_vs_cm = cm_baseline_time / r["time_cuda_graph_ms"] + if cascade_skipped: + speedup_vs_cascade_str = "-" + else: + speedup_vs_cascade = cascade_baseline_time / r["time_cuda_graph_ms"] + speedup_vs_cascade_str = f"{speedup_vs_cascade:>9.2f}x" + print(f"{'V2 block_extend_attention':<45} | {r['chunk_size']:>6} | {r['num_steps']:>6} | {r['time_cuda_graph_ms']:>10.3f} | {r['time_cuda_graph_ms']/r['num_steps']:>10.3f} | {speedup_vs_cascade_str:>10} | {speedup_vs_cm:>9.2f}x") + + # BBE results (sorted by chunk_size ascending) + bbe_keys = sorted([k for k in results.keys() if k.startswith("bbe_chunk")], + key=lambda k: results[k]["chunk_size"]) + for key in bbe_keys: + r = results[key] + speedup_vs_cm = cm_baseline_time / r["time_cuda_graph_ms"] + if cascade_skipped: + speedup_vs_cascade_str = "-" + else: + speedup_vs_cascade = cascade_baseline_time / r["time_cuda_graph_ms"] + speedup_vs_cascade_str = f"{speedup_vs_cascade:>9.2f}x" + print(f"{'BBE BatchBlockExtendRagged':<45} | {r['chunk_size']:>6} | {r['num_steps']:>6} | {r['time_cuda_graph_ms']:>10.3f} | {r['time_cuda_graph_ms']/r['num_steps']:>10.3f} | {speedup_vs_cascade_str:>10} | {speedup_vs_cm:>9.2f}x") + + + return results + + +def test_fa2_fa3_block_extending_vs_causal( + num_requests: int = 4, + chunk_sizes: list = None, + tokens_per_request: int = 2048, + dllm_block_size: int = 32, + num_heads: int = 32, + num_kv_heads: int = 8, + head_dim: int = 128, + warmup_iters: int = 10, + bench_iters: int = 100, + verbose: bool = False, +): + """ + FA2 vs FA3 BlockExtend Mask vs FA3 Causal Mask Performance Comparison + + Test scenario (Incremental Prefill): + - Each request's total tokens is fixed (tokens_per_request) + - Execute incremental prefill in multiple steps by chunk_size + - Each step accumulates KV, implementing Block Extend Mask + - Use CUDA Graph to reduce kernel launch overhead + + Comparison methods: + 1. FA3 Causal Mask (baseline) - BatchPrefillWithRaggedKVCacheWrapper + 2. FA2 BlockExtend Mask - BatchBlockExtendRaggedOffsetWrapper (backend="fa2") + 3. FA3 BlockExtend Mask - BatchBlockExtendRaggedOffsetWrapper (backend="fa3") + + Mask rules: + - Causal: mask[q, k] = (q + q_offset) >= k + - BlockExtend: mask[q, k] = ((q + q_offset) // B) >= (k // B) + """ + if chunk_sizes is None: + chunk_sizes = [32, 64, 128, 256] + + device = torch.device("cuda:0") + dtype = torch.float16 + sm_scale = 1.0 / (head_dim ** 0.5) + + print(f"\n{'='*90}") + print(f"FA2 vs FA3 BlockExtend vs Causal - Incremental Prefill Performance Comparison") + print(f"{'='*90}") + print(f"Configuration:") + print(f" num_requests = {num_requests}") + print(f" tokens_per_request = {tokens_per_request}") + print(f" dllm_block_size = {dllm_block_size}") + print(f" chunk_sizes = {chunk_sizes}") + print(f" num_heads = {num_heads}") + print(f" num_kv_heads = {num_kv_heads}") + print(f" head_dim = {head_dim}") + print(f"\nScenario description:") + print(f" - Fixed total tokens = {tokens_per_request}") + print(f" - Execute incremental prefill in multiple steps by chunk_size") + print(f" - Use CUDA Graph to reduce overhead") + + # Data preparation, generate complete Q, K, V for each request + all_qs = [torch.randn(tokens_per_request, num_heads, head_dim, dtype=dtype, device=device) + for _ in range(num_requests)] + all_ks = [torch.randn(tokens_per_request, num_kv_heads, head_dim, dtype=dtype, device=device) + for _ in range(num_requests)] + all_vs = [torch.randn(tokens_per_request, num_kv_heads, head_dim, dtype=dtype, device=device) + for _ in range(num_requests)] + + def split_chunks(tensor, chunk_size): + return [tensor[i*chunk_size:(i+1)*chunk_size] for i in range(tensor.shape[0] // chunk_size)] + + results = {} + + for chunk_size in chunk_sizes: + if tokens_per_request % chunk_size != 0: + print(f"\n[Skip] chunk_size={chunk_size} cannot divide tokens_per_request={tokens_per_request}") + continue + + num_steps = tokens_per_request // chunk_size + + print(f"\n{'-'*90}") + print(f"chunk_size = {chunk_size}, num_steps = {num_steps}") + print(f"{'-'*90}") + + # Split each request into chunks + qs_chunks = [split_chunks(q, chunk_size) for q in all_qs] + + # Pre-allocate all step buffers + # Q buffers: Q concat for each step + q_buffers = [] + for step_idx in range(num_steps): + q_list = [qs_chunks[req_idx][step_idx] for req_idx in range(num_requests)] + q_buffers.append(torch.cat(q_list, dim=0)) + + # KV buffers: cumulative K, V + k_buffers = [] + v_buffers = [] + for step_idx in range(num_steps): + kv_len = (step_idx + 1) * chunk_size + k_cumul_list = [all_ks[req_idx][:kv_len] for req_idx in range(num_requests)] + v_cumul_list = [all_vs[req_idx][:kv_len] for req_idx in range(num_requests)] + k_buffers.append(torch.cat(k_cumul_list, dim=0)) + v_buffers.append(torch.cat(v_cumul_list, dim=0)) + + # indptrs + qo_indptr = torch.tensor( + [i * chunk_size for i in range(num_requests + 1)], + dtype=torch.int32, device=device + ) + kv_indptr_list = [] + q_offsets_list = [] + for step_idx in range(num_steps): + kv_len = (step_idx + 1) * chunk_size + kv_indptr_list.append(torch.tensor( + [i * kv_len for i in range(num_requests + 1)], + dtype=torch.int32, device=device + )) + q_offset = step_idx * chunk_size + q_offsets_list.append(torch.full((num_requests,), q_offset, dtype=torch.int32, device=device)) + + workspace_size = 256 * 1024 * 1024 + output_buffer = torch.empty(num_requests * chunk_size, num_heads, head_dim, dtype=dtype, device=device) + + # [0] Precision verification: FA2 vs FA3 BlockExtend + print(f" [Precision verification] Comparing FA2 vs FA3 BlockExtend output...") + + # Create temporary wrappers for precision verification + fa2_verify_wrapper = BatchBlockExtendRaggedOffsetWrapper( + torch.empty(workspace_size, dtype=torch.uint8, device=device), + kv_layout="NHD", + dllm_block_size=dllm_block_size, + backend="fa2", + ) + fa3_verify_wrapper = BatchBlockExtendRaggedOffsetWrapper( + torch.empty(workspace_size, dtype=torch.uint8, device=device), + kv_layout="NHD", + dllm_block_size=dllm_block_size, + backend="fa3", + ) + + # Verify precision for each step + max_diff_all_steps = 0.0 + for step_idx in range(num_steps): + fa2_verify_wrapper.plan( + qo_indptr=qo_indptr, + kv_indptr=kv_indptr_list[step_idx], + num_qo_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + q_data_type=dtype, + sm_scale=sm_scale, + q_offsets=q_offsets_list[step_idx], + ) + fa3_verify_wrapper.plan( + qo_indptr=qo_indptr, + kv_indptr=kv_indptr_list[step_idx], + num_qo_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + q_data_type=dtype, + sm_scale=sm_scale, + q_offsets=q_offsets_list[step_idx], + ) + + fa2_out = fa2_verify_wrapper.run(q_buffers[step_idx], k_buffers[step_idx], v_buffers[step_idx]) + fa3_out = fa3_verify_wrapper.run(q_buffers[step_idx], k_buffers[step_idx], v_buffers[step_idx]) + + step_max_diff = (fa2_out - fa3_out).abs().max().item() + max_diff_all_steps = max(max_diff_all_steps, step_max_diff) + + if verbose: + print(f" step {step_idx}: max_diff = {step_max_diff:.6f}") + + # fp16 precision ~0.001 difference is normal + precision_ok = max_diff_all_steps < 0.01 + status = "PASS" if precision_ok else " FAIL" + print(f" [Precision verification] FA2 vs FA3 max_diff = {max_diff_all_steps:.6f} {status}") + + if not precision_ok: + print(f" [Warning] FA2 and FA3 BlockExtend output difference too large, performance data may not be reliable!") + + del fa2_verify_wrapper, fa3_verify_wrapper + torch.cuda.empty_cache() + + # [1] FA3 Causal Mask (baseline) + print(f" [FA3 Causal] Creating wrappers...") + causal_wrappers = [] + for step_idx in range(num_steps): + wrapper = BatchPrefillWithRaggedKVCacheWrapper( + torch.empty(workspace_size, dtype=torch.uint8, device=device), + kv_layout="NHD", + backend="fa3", + ) + wrapper.plan( + qo_indptr=qo_indptr, + kv_indptr=kv_indptr_list[step_idx], + num_qo_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim_qk=head_dim, + causal=True, + sm_scale=sm_scale, + ) + causal_wrappers.append(wrapper) + + def run_causal_pipeline(): + for step_idx in range(num_steps): + output_buffer.copy_(causal_wrappers[step_idx].run( + q_buffers[step_idx], k_buffers[step_idx], v_buffers[step_idx] + )) + + # Warmup + for _ in range(warmup_iters): + run_causal_pipeline() + torch.cuda.synchronize() + + # CUDA Graph capture + causal_stream = torch.cuda.Stream() + with torch.cuda.stream(causal_stream): + run_causal_pipeline() + causal_stream.synchronize() + + causal_graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(causal_graph, stream=causal_stream): + run_causal_pipeline() + + # Warmup with cuda_graph + for _ in range(warmup_iters): + causal_graph.replay() + torch.cuda.synchronize() + + # Benchmark + start = time.perf_counter() + for _ in range(bench_iters): + causal_graph.replay() + torch.cuda.synchronize() + fa3_causal_time = (time.perf_counter() - start) / bench_iters * 1000 + + print(f" [FA3 Causal] {fa3_causal_time:.3f} ms ({fa3_causal_time/num_steps:.3f} ms/step × {num_steps} steps)") + + del causal_wrappers, causal_graph + torch.cuda.empty_cache() + + # [2] FA2 BlockExtend Mask + print(f" [FA2 BlockExp] Creating wrappers...") + fa2_be_wrappers = [] + for step_idx in range(num_steps): + wrapper = BatchBlockExtendRaggedOffsetWrapper( + torch.empty(workspace_size, dtype=torch.uint8, device=device), + kv_layout="NHD", + dllm_block_size=dllm_block_size, + backend="fa2", + ) + wrapper.plan( + qo_indptr=qo_indptr, + kv_indptr=kv_indptr_list[step_idx], + num_qo_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + q_data_type=dtype, + sm_scale=sm_scale, + q_offsets=q_offsets_list[step_idx], + ) + fa2_be_wrappers.append(wrapper) + + def run_fa2_be_pipeline(): + for step_idx in range(num_steps): + output_buffer.copy_(fa2_be_wrappers[step_idx].run( + q_buffers[step_idx], k_buffers[step_idx], v_buffers[step_idx] + )) + + # Warmup + for _ in range(warmup_iters): + run_fa2_be_pipeline() + torch.cuda.synchronize() + + # CUDA Graph capture + fa2_stream = torch.cuda.Stream() + with torch.cuda.stream(fa2_stream): + run_fa2_be_pipeline() + fa2_stream.synchronize() + + fa2_graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(fa2_graph, stream=fa2_stream): + run_fa2_be_pipeline() + + # Warmup with cuda_graph + for _ in range(warmup_iters): + fa2_graph.replay() + torch.cuda.synchronize() + + # Benchmark + start = time.perf_counter() + for _ in range(bench_iters): + fa2_graph.replay() + torch.cuda.synchronize() + fa2_be_time = (time.perf_counter() - start) / bench_iters * 1000 + + speedup_fa2_vs_causal = fa3_causal_time / fa2_be_time + print(f" [FA2 BlockExp] {fa2_be_time:.3f} ms ({fa2_be_time/num_steps:.3f} ms/step, {speedup_fa2_vs_causal:.2f}x vs Causal)") + + del fa2_be_wrappers, fa2_graph + torch.cuda.empty_cache() + + # ================================================================ + # [3] FA3 BlockExtend Mask + # ================================================================ + print(f" [FA3 BlockExp] Creating wrappers...") + fa3_be_wrappers = [] + for step_idx in range(num_steps): + wrapper = BatchBlockExtendRaggedOffsetWrapper( + torch.empty(workspace_size, dtype=torch.uint8, device=device), + kv_layout="NHD", + dllm_block_size=dllm_block_size, + backend="fa3", + ) + wrapper.plan( + qo_indptr=qo_indptr, + kv_indptr=kv_indptr_list[step_idx], + num_qo_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + q_data_type=dtype, + sm_scale=sm_scale, + q_offsets=q_offsets_list[step_idx], + ) + fa3_be_wrappers.append(wrapper) + + def run_fa3_be_pipeline(): + for step_idx in range(num_steps): + output_buffer.copy_(fa3_be_wrappers[step_idx].run( + q_buffers[step_idx], k_buffers[step_idx], v_buffers[step_idx] + )) + + # Warmup + for _ in range(warmup_iters): + run_fa3_be_pipeline() + torch.cuda.synchronize() + + # CUDA Graph capture + fa3_stream = torch.cuda.Stream() + with torch.cuda.stream(fa3_stream): + run_fa3_be_pipeline() + fa3_stream.synchronize() + + fa3_graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(fa3_graph, stream=fa3_stream): + run_fa3_be_pipeline() + + # Warmup with cuda_graph + for _ in range(warmup_iters): + fa3_graph.replay() + torch.cuda.synchronize() + + # Benchmark + start = time.perf_counter() + for _ in range(bench_iters): + fa3_graph.replay() + torch.cuda.synchronize() + fa3_be_time = (time.perf_counter() - start) / bench_iters * 1000 + + speedup_fa3_vs_causal = fa3_causal_time / fa3_be_time + speedup_fa3_vs_fa2 = fa2_be_time / fa3_be_time + print(f" [FA3 BlockExp] {fa3_be_time:.3f} ms ({fa3_be_time/num_steps:.3f} ms/step, {speedup_fa3_vs_causal:.2f}x vs Causal, {speedup_fa3_vs_fa2:.2f}x vs FA2)") + + results[f"chunk{chunk_size}"] = { + "chunk_size": chunk_size, + "num_steps": num_steps, + "fa2_fa3_max_diff": max_diff_all_steps, + "precision_ok": precision_ok, + "fa3_causal_ms": fa3_causal_time, + "fa2_be_ms": fa2_be_time, + "fa3_be_ms": fa3_be_time, + "speedup_fa2_vs_causal": speedup_fa2_vs_causal, + "speedup_fa3_vs_causal": speedup_fa3_vs_causal, + "speedup_fa3_vs_fa2": speedup_fa3_vs_fa2, + } + + del fa3_be_wrappers, fa3_graph + torch.cuda.empty_cache() + + + print(f"\n{'='*90}") + print(f"Results Summary (num_requests={num_requests}, tokens_per_request={tokens_per_request}, dllm_block_size={dllm_block_size})") + print(f"{'='*90}") + + print(f"\n{'chunk':>8} | {'steps':>5} | {'FA3 Causal':>12} | {'FA2 BlockExp':>12} | {'FA3 BlockExp':>12} | {'FA2/Causal':>10} | {'FA3/Causal':>10} | {'FA3/FA2':>10}") + print(f"{'-'*8}-+-{'-'*5}-+-{'-'*12}-+-{'-'*12}-+-{'-'*12}-+-{'-'*10}-+-{'-'*10}-+-{'-'*10}") + + for key in sorted(results.keys(), key=lambda k: results[k]["chunk_size"]): + r = results[key] + print(f"{r['chunk_size']:>8} | {r['num_steps']:>5} | {r['fa3_causal_ms']:>10.3f}ms | {r['fa2_be_ms']:>10.3f}ms | {r['fa3_be_ms']:>10.3f}ms | {r['speedup_fa2_vs_causal']:>9.2f}x | {r['speedup_fa3_vs_causal']:>9.2f}x | {r['speedup_fa3_vs_fa2']:>9.2f}x") + + print(f"\nNotes:") + print(f" - Scenario: Incremental Prefill, fixed total tokens = {tokens_per_request}, execute in multiple steps by chunk_size") + print(f" - FA2/Causal: Speedup of FA2 BlockExtend relative to FA3 Causal (>1 means FA2 BE is faster)") + print(f" - FA3/Causal: Speedup of FA3 BlockExtend relative to FA3 Causal (>1 means FA3 BE is faster)") + print(f" - FA3/FA2: Speedup of FA3 BlockExtend relative to FA2 BlockExtend (>1 means FA3 is faster)") + print(f" - BlockExtend mask has less computation than Causal mask (tile-level skip), should theoretically be faster") + + return results + + +def test_dllm_precision_vs_custom_mask_fa2( + verbose: bool = True, + test_dtypes: list = None, +): + """ + DLLM Component Precision Test: Comparison with native Custom Mask FA2 implementation + + Reference implementation (Ground Truth): + - Single request: single_prefill_with_kv_cache + custom_mask (FA2) + - Multi-request: BatchPrefillWithRaggedKVCacheWrapper + custom_mask (FA2) + + Tested components (three DLLM components imported at lines 31-33): + 1. BatchBlockExtendRaggedOffsetWrapper + 2. BatchBlockExtendPagedOffsetWrapper + 3. block_extend_attention_with_offset + + Test coverage: + - Data types: fp16, bf16 + - Different dllm_block_size: [16, 32, 64, 128] + - Different qo_len: [32, 64, 128, 256] + - Different kv_len: [64, 128, 256, 512, 1024] + - Different q_offset: [0, 32, 64, 128] + - Different num_heads / num_kv_heads combinations + - Different head_dim: [64, 128] + + Mask rule: mask[q, k] = ((q_local + q_offset) // B) >= (k // B) + """ + device = torch.device("cuda:0") + backends = ["fa2", "fa3"] + + + if test_dtypes is None: + test_dtypes = [torch.float16, torch.bfloat16] + + # Precision tolerance for different data types + dtype_tolerances = { + torch.float16: 1e-2, + torch.bfloat16: 2e-2, # FA3 bf16 tile accumulation order differs from FA2; + # max_diff up to ~2 ULP (0.015625) is expected + } + + dtype_names = { + torch.float16: "fp16", + torch.bfloat16: "bf16", + } + + # Test parameter combinations + test_configs = [ + # Basic tests: different dllm_block_size + {"dllm_block_size": 16, "qo_len": 64, "kv_len": 128, "q_offset": 0, "num_heads": 32, "num_kv_heads": 8, "head_dim": 128}, + {"dllm_block_size": 32, "qo_len": 64, "kv_len": 128, "q_offset": 0, "num_heads": 32, "num_kv_heads": 8, "head_dim": 128}, + {"dllm_block_size": 64, "qo_len": 64, "kv_len": 128, "q_offset": 0, "num_heads": 32, "num_kv_heads": 8, "head_dim": 128}, + {"dllm_block_size": 128, "qo_len": 128, "kv_len": 256, "q_offset": 0, "num_heads": 32, "num_kv_heads": 8, "head_dim": 128}, + + # Different qo_len + {"dllm_block_size": 32, "qo_len": 32, "kv_len": 128, "q_offset": 0, "num_heads": 32, "num_kv_heads": 8, "head_dim": 128}, + {"dllm_block_size": 32, "qo_len": 128, "kv_len": 256, "q_offset": 0, "num_heads": 32, "num_kv_heads": 8, "head_dim": 128}, + {"dllm_block_size": 32, "qo_len": 256, "kv_len": 512, "q_offset": 0, "num_heads": 32, "num_kv_heads": 8, "head_dim": 128}, + + # Different kv_len + {"dllm_block_size": 32, "qo_len": 64, "kv_len": 64, "q_offset": 0, "num_heads": 32, "num_kv_heads": 8, "head_dim": 128}, + {"dllm_block_size": 32, "qo_len": 64, "kv_len": 256, "q_offset": 0, "num_heads": 32, "num_kv_heads": 8, "head_dim": 128}, + {"dllm_block_size": 32, "qo_len": 64, "kv_len": 512, "q_offset": 0, "num_heads": 32, "num_kv_heads": 8, "head_dim": 128}, + {"dllm_block_size": 32, "qo_len": 64, "kv_len": 1024, "q_offset": 0, "num_heads": 32, "num_kv_heads": 8, "head_dim": 128}, + + # Different q_offset (simulating different steps in incremental prefill) + {"dllm_block_size": 32, "qo_len": 64, "kv_len": 128, "q_offset": 32, "num_heads": 32, "num_kv_heads": 8, "head_dim": 128}, + {"dllm_block_size": 32, "qo_len": 64, "kv_len": 192, "q_offset": 64, "num_heads": 32, "num_kv_heads": 8, "head_dim": 128}, + {"dllm_block_size": 32, "qo_len": 64, "kv_len": 256, "q_offset": 128, "num_heads": 32, "num_kv_heads": 8, "head_dim": 128}, + {"dllm_block_size": 32, "qo_len": 64, "kv_len": 320, "q_offset": 192, "num_heads": 32, "num_kv_heads": 8, "head_dim": 128}, + + # Different head configurations (MHA vs GQA vs MQA) + {"dllm_block_size": 32, "qo_len": 64, "kv_len": 128, "q_offset": 0, "num_heads": 32, "num_kv_heads": 32, "head_dim": 128}, # MHA + {"dllm_block_size": 32, "qo_len": 64, "kv_len": 128, "q_offset": 0, "num_heads": 32, "num_kv_heads": 4, "head_dim": 128}, # GQA-8 + {"dllm_block_size": 32, "qo_len": 64, "kv_len": 128, "q_offset": 0, "num_heads": 32, "num_kv_heads": 1, "head_dim": 128}, # MQA + + # Different head_dim + {"dllm_block_size": 32, "qo_len": 64, "kv_len": 128, "q_offset": 0, "num_heads": 32, "num_kv_heads": 8, "head_dim": 64}, + {"dllm_block_size": 32, "qo_len": 64, "kv_len": 128, "q_offset": 0, "num_heads": 16, "num_kv_heads": 4, "head_dim": 256}, + + # Boundary condition tests + {"dllm_block_size": 32, "qo_len": 1, "kv_len": 32, "q_offset": 0, "num_heads": 32, "num_kv_heads": 8, "head_dim": 128}, # single query + {"dllm_block_size": 32, "qo_len": 32, "kv_len": 32, "q_offset": 0, "num_heads": 32, "num_kv_heads": 8, "head_dim": 128}, # qo_len == kv_len == block_size + {"dllm_block_size": 64, "qo_len": 33, "kv_len": 97, "q_offset": 17, "num_heads": 32, "num_kv_heads": 8, "head_dim": 128}, # non-aligned boundary + + # Long sequence tests + {"dllm_block_size": 32, "qo_len": 128, "kv_len": 2048, "q_offset": 0, "num_heads": 32, "num_kv_heads": 8, "head_dim": 128}, + {"dllm_block_size": 64, "qo_len": 256, "kv_len": 4096, "q_offset": 0, "num_heads": 32, "num_kv_heads": 8, "head_dim": 128}, + ] + + # Multi-request test configurations + multi_req_configs = [ + # Basic multi-request tests + {"num_requests": 2, "dllm_block_size": 32, "qo_len": 64, "kv_len": 128, "q_offset": 0, "num_heads": 32, "num_kv_heads": 8, "head_dim": 128}, + {"num_requests": 4, "dllm_block_size": 32, "qo_len": 64, "kv_len": 128, "q_offset": 0, "num_heads": 32, "num_kv_heads": 8, "head_dim": 128}, + {"num_requests": 8, "dllm_block_size": 32, "qo_len": 64, "kv_len": 128, "q_offset": 0, "num_heads": 32, "num_kv_heads": 8, "head_dim": 128}, + + # Different dllm_block_size + {"num_requests": 4, "dllm_block_size": 16, "qo_len": 64, "kv_len": 128, "q_offset": 0, "num_heads": 32, "num_kv_heads": 8, "head_dim": 128}, + {"num_requests": 4, "dllm_block_size": 64, "qo_len": 64, "kv_len": 128, "q_offset": 0, "num_heads": 32, "num_kv_heads": 8, "head_dim": 128}, + + # Different q_offset (simulating incremental prefill step) + {"num_requests": 4, "dllm_block_size": 32, "qo_len": 64, "kv_len": 128, "q_offset": 32, "num_heads": 32, "num_kv_heads": 8, "head_dim": 128}, + {"num_requests": 4, "dllm_block_size": 32, "qo_len": 64, "kv_len": 192, "q_offset": 64, "num_heads": 32, "num_kv_heads": 8, "head_dim": 128}, + {"num_requests": 4, "dllm_block_size": 32, "qo_len": 64, "kv_len": 256, "q_offset": 128, "num_heads": 32, "num_kv_heads": 8, "head_dim": 128}, + + # Different head configurations + {"num_requests": 4, "dllm_block_size": 32, "qo_len": 64, "kv_len": 128, "q_offset": 0, "num_heads": 32, "num_kv_heads": 32, "head_dim": 128}, # MHA + {"num_requests": 4, "dllm_block_size": 32, "qo_len": 64, "kv_len": 128, "q_offset": 0, "num_heads": 32, "num_kv_heads": 4, "head_dim": 128}, # GQA-8 + {"num_requests": 4, "dllm_block_size": 32, "qo_len": 64, "kv_len": 128, "q_offset": 0, "num_heads": 32, "num_kv_heads": 1, "head_dim": 128}, # MQA + + # Long sequence multi-request tests + {"num_requests": 4, "dllm_block_size": 32, "qo_len": 128, "kv_len": 1024, "q_offset": 0, "num_heads": 32, "num_kv_heads": 8, "head_dim": 128}, + {"num_requests": 2, "dllm_block_size": 64, "qo_len": 256, "kv_len": 2048, "q_offset": 0, "num_heads": 32, "num_kv_heads": 8, "head_dim": 128}, + ] + + for dtype in test_dtypes: + dtype_name = dtype_names[dtype] + tol = dtype_tolerances[dtype] + + print(f"\n{'='*120}") + print(f"DLLM Component Precision Test: Comparing with Native Custom Mask FA2 [{dtype_name.upper()}]") + print(f"{'='*120}") + print(f"Reference implementation: single_prefill_with_kv_cache / BatchPrefillWithRaggedKVCacheWrapper + custom_mask (FA2)") + print(f"Backends under test: FA2, FA3 (DLLM BlockWise supports both backends)") + print(f"Data type: {dtype_name}") + print(f"Mask rule: mask[q, k] = ((q_local + q_offset) // B) >= (k // B)") + print(f"Precision tolerance: {tol}") + + print(f"\n{'='*100}") + print(f"[Part 1] Single-request Precision Test (FA2 & FA3 backends) [{dtype_name}]") + print(f"{'='*100}") + print(f"Reference implementation: single_prefill_with_kv_cache + custom_mask (FA2)") + print(f"Objects under test:") + print(f" 1. BatchBlockExtendRaggedOffsetWrapper (batch_size=1) - FA2 backend") + print(f" 2. BatchBlockExtendRaggedOffsetWrapper (batch_size=1) - FA3 backend") + print(f" 3. block_extend_attention_with_offset") + + single_req_results = [] + + for cfg_idx, cfg in enumerate(test_configs): + dllm_block_size = cfg["dllm_block_size"] + qo_len = cfg["qo_len"] + kv_len = cfg["kv_len"] + q_offset = cfg["q_offset"] + num_heads = cfg["num_heads"] + num_kv_heads = cfg["num_kv_heads"] + head_dim = cfg["head_dim"] + sm_scale = 1.0 / math.sqrt(head_dim) + + # Generate test data + q = torch.randn(qo_len, num_heads, head_dim, dtype=dtype, device=device) + k = torch.randn(kv_len, num_kv_heads, head_dim, dtype=dtype, device=device) + v = torch.randn(kv_len, num_kv_heads, head_dim, dtype=dtype, device=device) + + # Reference implementation: single_prefill_with_kv_cache + custom_mask (FA2) + # Build custom_mask: mask[q, k] = ((q_local + q_offset) // B) >= (k // B) + q_pos = torch.arange(qo_len, device=device) + q_offset + k_pos = torch.arange(kv_len, device=device) + q_block = q_pos.unsqueeze(1) // dllm_block_size # [qo_len, 1] + k_block = k_pos.unsqueeze(0) // dllm_block_size # [1, kv_len] + mask_2d = (q_block >= k_block).to(torch.uint8) # [qo_len, kv_len] + + ref_output = single_prefill_with_kv_cache( + q, k, v, + custom_mask=mask_2d, + sm_scale=sm_scale, + backend="fa2", + ) + + result = { + "config_idx": cfg_idx, + "dllm_block_size": dllm_block_size, + "qo_len": qo_len, + "kv_len": kv_len, + "q_offset": q_offset, + "num_heads": num_heads, + "num_kv_heads": num_kv_heads, + "head_dim": head_dim, + } + + # Object under test 1 & 2: BatchBlockExtendRaggedOffsetWrapper (FA2 and FA3 backends) + qo_indptr = torch.tensor([0, qo_len], dtype=torch.int32, device=device) + kv_indptr = torch.tensor([0, kv_len], dtype=torch.int32, device=device) + q_offset_tensor = torch.tensor([q_offset], dtype=torch.int32, device=device) + + for backend in backends: + workspace = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device) + wrapper = BatchBlockExtendRaggedOffsetWrapper( + workspace, kv_layout="NHD", dllm_block_size=dllm_block_size, backend=backend + ) + wrapper.plan( + qo_indptr=qo_indptr, + kv_indptr=kv_indptr, + num_qo_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + q_data_type=dtype, + sm_scale=sm_scale, + q_offsets=q_offset_tensor, + ) + bbe_output = wrapper.run(q, k, v) + + # Calculate precision differences + bbe_diff = (bbe_output - ref_output).abs().max().item() + bbe_mean_diff = (bbe_output - ref_output).abs().mean().item() + bbe_pass = bbe_diff < tol + + result[f"bbe_{backend}_max_diff"] = bbe_diff + result[f"bbe_{backend}_mean_diff"] = bbe_mean_diff + result[f"bbe_{backend}_pass"] = bbe_pass + + del workspace, wrapper + + # Object under test 3: block_extend_attention_with_offset (backend="fa2") + v2_output = block_extend_attention_with_offset( + q, k, v, + dllm_block_size=dllm_block_size, + q_offset=q_offset, + sm_scale=sm_scale, + backend="fa2", + ) + + # Calculate V2 precision differences + v2_diff = (v2_output - ref_output).abs().max().item() + v2_mean_diff = (v2_output - ref_output).abs().mean().item() + v2_pass = v2_diff < tol + + result["v2_max_diff"] = v2_diff + result["v2_mean_diff"] = v2_mean_diff + result["v2_pass"] = v2_pass + + single_req_results.append(result) + + if verbose: + fa2_status = "PASS" if result["bbe_fa2_pass"] else "FAIL" + fa3_status = "PASS" if result["bbe_fa3_pass"] else "FAIL" + v2_status = "PASS" if v2_pass else "FAIL" + print(f"\n [Test {cfg_idx:02d}] B={dllm_block_size:3d}, qo={qo_len:4d}, kv={kv_len:4d}, " + f"q_off={q_offset:3d}, heads={num_heads}/{num_kv_heads}, dim={head_dim}") + print(f" BBE-FA2: max_diff={result['bbe_fa2_max_diff']:.6f}, mean_diff={result['bbe_fa2_mean_diff']:.6f} [{fa2_status}]") + print(f" BBE-FA3: max_diff={result['bbe_fa3_max_diff']:.6f}, mean_diff={result['bbe_fa3_mean_diff']:.6f} [{fa3_status}]") + print(f" V2: max_diff={v2_diff:.6f}, mean_diff={v2_mean_diff:.6f} [{v2_status}]") + + torch.cuda.empty_cache() + + # Single-request test summary + print(f"\n{'-'*100}") + print(f"[Single-request Precision Test Summary] [{dtype_name}]") + print(f"{'-'*100}") + + total_tests = len(single_req_results) + + for backend in backends: + pass_count = sum(1 for r in single_req_results if r[f"bbe_{backend}_pass"]) + max_diff_all = max(r[f"bbe_{backend}_max_diff"] for r in single_req_results) + mean_diff_all = sum(r[f"bbe_{backend}_mean_diff"] for r in single_req_results) / total_tests + print(f" BatchBlockExtendRaggedOffsetWrapper ({backend.upper()}): {pass_count}/{total_tests} PASS") + print(f" max_diff (all tests): {max_diff_all:.6f}") + print(f" mean_diff (avg): {mean_diff_all:.6f}") + + v2_pass_count = sum(1 for r in single_req_results if r["v2_pass"]) + v2_max_diff_all = max(r["v2_max_diff"] for r in single_req_results) + v2_mean_diff_all = sum(r["v2_mean_diff"] for r in single_req_results) / total_tests + print(f" block_extend_attention_with_offset: {v2_pass_count}/{total_tests} PASS") + print(f" max_diff (all tests): {v2_max_diff_all:.6f}") + print(f" mean_diff (avg): {v2_mean_diff_all:.6f}") + + # Failed test details + for backend in backends: + failed = [r for r in single_req_results if not r[f"bbe_{backend}_pass"]] + if failed: + print(f"\n [BBE-{backend.upper()} Failed Test Details]") + for r in failed: + print(f" Test {r['config_idx']:02d}: B={r['dllm_block_size']}, qo={r['qo_len']}, kv={r['kv_len']}, " + f"q_off={r['q_offset']}, max_diff={r[f'bbe_{backend}_max_diff']:.6f}") + + failed_v2 = [r for r in single_req_results if not r["v2_pass"]] + if failed_v2: + print(f"\n [V2 Failed Test Details]") + for r in failed_v2: + print(f" Test {r['config_idx']:02d}: B={r['dllm_block_size']}, qo={r['qo_len']}, kv={r['kv_len']}, " + f"q_off={r['q_offset']}, max_diff={r['v2_max_diff']:.6f}") + + # ═══════════════════════════════════════════════════════════════════════════════ + # Part 2: Multi-request Precision Test (FA2 & FA3 backends) + # ═══════════════════════════════════════════════════════════════════════════════ + print(f"\n{'='*100}") + print(f"[Part 2] Multi-request Precision Test (FA2 & FA3 backends) [{dtype_name}]") + print(f"{'='*100}") + print(f"Reference implementation: BatchPrefillWithRaggedKVCacheWrapper + custom_mask (FA2)") + print(f"Objects under test:") + print(f" 1. BatchBlockExtendRaggedOffsetWrapper - FA2 backend") + print(f" 2. BatchBlockExtendRaggedOffsetWrapper - FA3 backend") + + multi_req_results = [] + + for cfg_idx, cfg in enumerate(multi_req_configs): + num_requests = cfg["num_requests"] + dllm_block_size = cfg["dllm_block_size"] + qo_len = cfg["qo_len"] + kv_len = cfg["kv_len"] + q_offset = cfg["q_offset"] + num_heads = cfg["num_heads"] + num_kv_heads = cfg["num_kv_heads"] + head_dim = cfg["head_dim"] + sm_scale = 1.0 / math.sqrt(head_dim) + + q_list = [torch.randn(qo_len, num_heads, head_dim, dtype=dtype, device=device) for _ in range(num_requests)] + k_list = [torch.randn(kv_len, num_kv_heads, head_dim, dtype=dtype, device=device) for _ in range(num_requests)] + v_list = [torch.randn(kv_len, num_kv_heads, head_dim, dtype=dtype, device=device) for _ in range(num_requests)] + + q_batch = torch.cat(q_list, dim=0) + k_batch = torch.cat(k_list, dim=0) + v_batch = torch.cat(v_list, dim=0) + + # Build mask + q_pos = torch.arange(qo_len, device=device) + q_offset + k_pos = torch.arange(kv_len, device=device) + q_block = q_pos.unsqueeze(1) // dllm_block_size + k_block = k_pos.unsqueeze(0) // dllm_block_size + mask_2d = (q_block >= k_block) + mask_flat = mask_2d.flatten() + batch_mask = mask_flat.repeat(num_requests) + + qo_indptr = torch.tensor([i * qo_len for i in range(num_requests + 1)], dtype=torch.int32, device=device) + kv_indptr = torch.tensor([i * kv_len for i in range(num_requests + 1)], dtype=torch.int32, device=device) + + # Reference implementation + ref_wrapper = BatchPrefillWithRaggedKVCacheWrapper( + torch.empty(256 * 1024 * 1024, dtype=torch.uint8, device=device), + kv_layout="NHD", backend="fa2", + ) + ref_wrapper.plan( + qo_indptr=qo_indptr, kv_indptr=kv_indptr, + num_qo_heads=num_heads, num_kv_heads=num_kv_heads, + head_dim_qk=head_dim, q_data_type=dtype, + custom_mask=batch_mask, causal=False, sm_scale=sm_scale, + ) + ref_output = ref_wrapper.run(q_batch, k_batch, v_batch) + + q_offsets = torch.full((num_requests,), q_offset, dtype=torch.int32, device=device) + result = { + "config_idx": cfg_idx, "num_requests": num_requests, + "dllm_block_size": dllm_block_size, "qo_len": qo_len, "kv_len": kv_len, + "q_offset": q_offset, "num_heads": num_heads, + "num_kv_heads": num_kv_heads, "head_dim": head_dim, + } + + for backend in backends: + bbe_wrapper = BatchBlockExtendRaggedOffsetWrapper( + torch.empty(256 * 1024 * 1024, dtype=torch.uint8, device=device), + kv_layout="NHD", dllm_block_size=dllm_block_size, backend=backend + ) + bbe_wrapper.plan( + qo_indptr=qo_indptr, kv_indptr=kv_indptr, + num_qo_heads=num_heads, num_kv_heads=num_kv_heads, + head_dim=head_dim, q_data_type=dtype, + sm_scale=sm_scale, q_offsets=q_offsets, + ) + bbe_output = bbe_wrapper.run(q_batch, k_batch, v_batch) + + bbe_diff = (bbe_output - ref_output).abs().max().item() + bbe_mean_diff = (bbe_output - ref_output).abs().mean().item() + bbe_pass = bbe_diff < tol + + result[f"bbe_{backend}_max_diff"] = bbe_diff + result[f"bbe_{backend}_mean_diff"] = bbe_mean_diff + result[f"bbe_{backend}_pass"] = bbe_pass + del bbe_wrapper + + multi_req_results.append(result) + + if verbose: + fa2_status = "PASS" if result["bbe_fa2_pass"] else "FAIL" + fa3_status = "PASS" if result["bbe_fa3_pass"] else "FAIL" + print(f"\n [Test {cfg_idx:02d}] reqs={num_requests}, B={dllm_block_size:3d}, qo={qo_len:4d}, kv={kv_len:4d}, " + f"q_off={q_offset:3d}, heads={num_heads}/{num_kv_heads}, dim={head_dim}") + print(f" BBE-FA2: max_diff={result['bbe_fa2_max_diff']:.6f} [{fa2_status}]") + print(f" BBE-FA3: max_diff={result['bbe_fa3_max_diff']:.6f} [{fa3_status}]") + + del ref_wrapper + torch.cuda.empty_cache() + + # Multi-request test summary + print(f"\n[Multi-request Precision Test Summary] [{dtype_name}]") + total_tests = len(multi_req_results) + for backend in backends: + pass_count = sum(1 for r in multi_req_results if r[f"bbe_{backend}_pass"]) + max_diff_all = max(r[f"bbe_{backend}_max_diff"] for r in multi_req_results) + print(f" BBE ({backend.upper()}): {pass_count}/{total_tests} PASS, max_diff={max_diff_all:.6f}") + + for backend in backends: + failed = [r for r in multi_req_results if not r[f"bbe_{backend}_pass"]] + if failed: + print(f"\n [BBE-{backend.upper()} Failed Details]") + for r in failed: + print(f" Test {r['config_idx']:02d}: reqs={r['num_requests']}, B={r['dllm_block_size']}, max_diff={r[f'bbe_{backend}_max_diff']:.6f}") + + # Part 3: Paged KV Cache Test + print(f"\n[Part 3] Paged KV Cache Precision Test [{dtype_name}]") + + page_size = 16 + paged_configs = [ + {"dllm_block_size": 32, "qo_len": 64, "kv_len": 128, "q_offset": 0, "num_heads": 32, "num_kv_heads": 8, "head_dim": 128}, + {"dllm_block_size": 32, "qo_len": 64, "kv_len": 128, "q_offset": 32, "num_heads": 32, "num_kv_heads": 8, "head_dim": 128}, + {"dllm_block_size": 64, "qo_len": 128, "kv_len": 512, "q_offset": 0, "num_heads": 32, "num_kv_heads": 8, "head_dim": 128}, + {"dllm_block_size": 32, "qo_len": 64, "kv_len": 128, "q_offset": 0, "num_heads": 32, "num_kv_heads": 4, "head_dim": 128}, + {"dllm_block_size": 32, "qo_len": 128, "kv_len": 1024, "q_offset": 0, "num_heads": 32, "num_kv_heads": 8, "head_dim": 128}, + ] + + paged_results = [] + + for cfg_idx, cfg in enumerate(paged_configs): + dllm_block_size = cfg["dllm_block_size"] + qo_len = cfg["qo_len"] + kv_len = cfg["kv_len"] + q_offset = cfg["q_offset"] + num_heads = cfg["num_heads"] + num_kv_heads = cfg["num_kv_heads"] + head_dim = cfg["head_dim"] + sm_scale = 1.0 / math.sqrt(head_dim) + + q = torch.randn(qo_len, num_heads, head_dim, dtype=dtype, device=device) + num_pages = (kv_len + page_size - 1) // page_size + kv_data = torch.randn(num_pages, 2, page_size, num_kv_heads, head_dim, dtype=dtype, device=device) + + paged_kv_indices = torch.arange(num_pages, dtype=torch.int32, device=device) + paged_kv_indptr = torch.tensor([0, num_pages], dtype=torch.int32, device=device) + last_page_len = kv_len - (num_pages - 1) * page_size if kv_len % page_size != 0 else page_size + paged_kv_last_page_len = torch.tensor([last_page_len], dtype=torch.int32, device=device) + + k_continuous = kv_data[:, 0, :, :, :].reshape(-1, num_kv_heads, head_dim)[:kv_len] + v_continuous = kv_data[:, 1, :, :, :].reshape(-1, num_kv_heads, head_dim)[:kv_len] + + # Reference implementation + q_pos = torch.arange(qo_len, device=device) + q_offset + k_pos = torch.arange(kv_len, device=device) + q_block = q_pos.unsqueeze(1) // dllm_block_size + k_block = k_pos.unsqueeze(0) // dllm_block_size + mask_2d = (q_block >= k_block).to(torch.uint8) + + ref_output = single_prefill_with_kv_cache( + q, k_continuous, v_continuous, + custom_mask=mask_2d, sm_scale=sm_scale, backend="fa2", + ) + + result = { + "config_idx": cfg_idx, "dllm_block_size": dllm_block_size, + "qo_len": qo_len, "kv_len": kv_len, "q_offset": q_offset, + "num_heads": num_heads, "num_kv_heads": num_kv_heads, + "head_dim": head_dim, "page_size": page_size, + } + + qo_indptr = torch.tensor([0, qo_len], dtype=torch.int32, device=device) + q_offsets = torch.tensor([q_offset], dtype=torch.int32, device=device) + + for backend in backends: + paged_wrapper = BatchBlockExtendPagedOffsetWrapper( + torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device), + kv_layout="NHD", dllm_block_size=dllm_block_size, backend=backend + ) + paged_wrapper.plan( + qo_indptr=qo_indptr, paged_kv_indptr=paged_kv_indptr, + paged_kv_indices=paged_kv_indices, paged_kv_last_page_len=paged_kv_last_page_len, + num_qo_heads=num_heads, num_kv_heads=num_kv_heads, + head_dim=head_dim, page_size=page_size, + q_data_type=dtype, sm_scale=sm_scale, q_offsets=q_offsets, + ) + paged_output = paged_wrapper.run(q, kv_data) + + paged_diff = (paged_output - ref_output).abs().max().item() + paged_pass = paged_diff < tol + result[f"paged_{backend}_max_diff"] = paged_diff + result[f"paged_{backend}_pass"] = paged_pass + del paged_wrapper + + paged_results.append(result) + + if verbose: + fa2_status = "PASS" if result["paged_fa2_pass"] else "FAIL" + fa3_status = "PASS" if result["paged_fa3_pass"] else "FAIL" + print(f" [Test {cfg_idx:02d}] B={dllm_block_size:3d}, qo={qo_len:4d}, kv={kv_len:4d}, " + f"q_off={q_offset:3d} - FA2:{fa2_status}, FA3:{fa3_status}") + + torch.cuda.empty_cache() + + # Paged test summary + print(f"\n[Paged KV Cache Precision Test Summary] [{dtype_name}]") + for backend in backends: + pass_count = sum(1 for r in paged_results if r[f"paged_{backend}_pass"]) + max_diff_all = max(r[f"paged_{backend}_max_diff"] for r in paged_results) + print(f" Paged ({backend.upper()}): {pass_count}/{len(paged_results)} PASS, max_diff={max_diff_all:.6f}") + + for backend in backends: + failed = [r for r in paged_results if not r[f"paged_{backend}_pass"]] + if failed: + print(f"\n [Paged-{backend.upper()} Failed Details]") + for r in failed: + print(f" Test {r['config_idx']:02d}: B={r['dllm_block_size']}, max_diff={r[f'paged_{backend}_max_diff']:.6f}") + + # Summary + print(f"\n{'='*100}") + print(f"Precision Test Summary [{dtype_name}]") + print(f"{'='*100}") + + print(f"\n Single-request tests:") + for backend in backends: + pass_count = sum(1 for r in single_req_results if r[f"bbe_{backend}_pass"]) + print(f" BBE ({backend.upper()}): {pass_count}/{len(single_req_results)} PASS") + v2_pass_count = sum(1 for r in single_req_results if r["v2_pass"]) + print(f" V2: {v2_pass_count}/{len(single_req_results)} PASS") + + print(f"\n Multi-request tests:") + for backend in backends: + pass_count = sum(1 for r in multi_req_results if r[f"bbe_{backend}_pass"]) + print(f" BBE ({backend.upper()}): {pass_count}/{len(multi_req_results)} PASS") + + print(f"\n Paged tests:") + for backend in backends: + pass_count = sum(1 for r in paged_results if r[f"paged_{backend}_pass"]) + print(f" Paged ({backend.upper()}): {pass_count}/{len(paged_results)} PASS") + + # Overall results + all_single_pass = all( + r["bbe_fa2_pass"] and r["bbe_fa3_pass"] and r["v2_pass"] + for r in single_req_results + ) + all_multi_pass = all( + r["bbe_fa2_pass"] and r["bbe_fa3_pass"] + for r in multi_req_results + ) + all_paged_pass = all( + r["paged_fa2_pass"] and r["paged_fa3_pass"] + for r in paged_results + ) + + overall_pass = all_single_pass and all_multi_pass and all_paged_pass + overall_status = "ALL TESTS PASSED" if overall_pass else "SOME TESTS FAILED" + print(f"\n Overall results: {overall_status}") + + # FA2 vs FA3 comparison + fa2_single_max = max(r["bbe_fa2_max_diff"] for r in single_req_results) + fa3_single_max = max(r["bbe_fa3_max_diff"] for r in single_req_results) + fa2_multi_max = max(r["bbe_fa2_max_diff"] for r in multi_req_results) + fa3_multi_max = max(r["bbe_fa3_max_diff"] for r in multi_req_results) + fa2_paged_max = max(r["paged_fa2_max_diff"] for r in paged_results) + fa3_paged_max = max(r["paged_fa3_max_diff"] for r in paged_results) + + print(f"\n FA2 vs FA3 max_diff:") + print(f" Single-request: FA2={fa2_single_max:.6f}, FA3={fa3_single_max:.6f}") + print(f" Multi-request: FA2={fa2_multi_max:.6f}, FA3={fa3_multi_max:.6f}") + print(f" Paged: FA2={fa2_paged_max:.6f}, FA3={fa3_paged_max:.6f}") + + return { + "overall_pass": overall_pass, + } + + +def test_cascade_interfaces_perf( + num_requests: int = 4, + tokens_per_request: int = 512, + dllm_block_size: int = 32, + chunk_sizes: list = None, + num_heads: int = 32, + num_kv_heads: int = 8, + head_dim: int = 128, + page_size: int = 16, + warmup_iters: int = 10, + bench_iters: int = 100, + verbose: bool = False, + backend: str = "fa2", +): + """ + Compare performance of two Cascade interfaces (Step by Step incremental Prefill scenario) + + Interfaces under test: + 1. batch_block_extend_cascade: Uses Block Extend mask + - Supports chunk_size != dllm_block_size + 2. sglang_style_cascade_attention: Uses Causal mask (SGLang native style) + - Requires chunk_size == dllm_block_size + + Test scenario: + - Real incremental Prefill: each step depends on previous step's KV Cache + - Each step: Q attends to (current_chunk KV + prefix KV) + - Uses Paged KV Cache to store prefix + + Key points: + - When chunk_size == dllm_block_size, both should produce identical results + - batch_block_extend_cascade can use larger chunk_size + """ + from flashinfer.dllm import ( + batch_block_extend_cascade, + sglang_style_cascade_attention, + ) + + if chunk_sizes is None: + chunk_sizes = [dllm_block_size] # Default: only test chunk_size == dllm_block_size + + device = torch.device("cuda:0") + dtype = torch.float16 + sm_scale = 1.0 / (head_dim ** 0.5) + + print(f"\n{'='*90}") + print(f"Cascade Interface Performance Comparison: Step by Step Incremental Prefill") + print(f"{'='*90}") + print(f"Configuration:") + print(f" num_requests = {num_requests}") + print(f" tokens_per_request = {tokens_per_request}") + print(f" dllm_block_size = {dllm_block_size}") + print(f" chunk_sizes = {chunk_sizes}") + print(f" num_heads = {num_heads}") + print(f" num_kv_heads = {num_kv_heads}") + print(f" head_dim = {head_dim}") + print(f" page_size = {page_size}") + print(f" backend = {backend}") + print(f"\nScenario description:") + print(f" - Step by Step incremental Prefill: each step depends on previous step's KV Cache") + print(f" - Current Chunk: Ragged KV (contiguous memory)") + print(f" - Prefix: Paged KV Cache") + print(f" - Uses CUDA Graph to reduce overhead") + + # Generate complete Q, K, V for each request + all_qs = [torch.randn(tokens_per_request, num_heads, head_dim, dtype=dtype, device=device) + for _ in range(num_requests)] + all_ks = [torch.randn(tokens_per_request, num_kv_heads, head_dim, dtype=dtype, device=device) + for _ in range(num_requests)] + all_vs = [torch.randn(tokens_per_request, num_kv_heads, head_dim, dtype=dtype, device=device) + for _ in range(num_requests)] + + def split_chunks(tensor, chunk_size): + return [tensor[i*chunk_size:(i+1)*chunk_size] for i in range(tensor.shape[0] // chunk_size)] + + results = {} + + for chunk_size in chunk_sizes: + if tokens_per_request % chunk_size != 0: + print(f"\n[Skip] chunk_size={chunk_size} doesn't evenly divide tokens_per_request={tokens_per_request}") + continue + + num_steps = tokens_per_request // chunk_size + + print(f"\n{'-'*90}") + print(f"chunk_size = {chunk_size}, num_steps = {num_steps}") + print(f"{'-'*90}") + + # Split chunks for each request + qs_chunks = [split_chunks(q, chunk_size) for q in all_qs] + ks_chunks = [split_chunks(k, chunk_size) for k in all_ks] + vs_chunks = [split_chunks(v, chunk_size) for v in all_vs] + + # Pre-allocate buffers for all steps + # Current chunk Q, K, V (concatenate all requests) + q_current_buffers = [] + k_current_buffers = [] + v_current_buffers = [] + for step_idx in range(num_steps): + q_list = [qs_chunks[req_idx][step_idx] for req_idx in range(num_requests)] + k_list = [ks_chunks[req_idx][step_idx] for req_idx in range(num_requests)] + v_list = [vs_chunks[req_idx][step_idx] for req_idx in range(num_requests)] + q_current_buffers.append(torch.cat(q_list, dim=0)) + k_current_buffers.append(torch.cat(k_list, dim=0)) + v_current_buffers.append(torch.cat(v_list, dim=0)) + + # Paged KV Cache setup (prefix storage) + # Calculate maximum number of pages needed + max_prefix_len = (num_steps - 1) * chunk_size + max_pages_per_request = (max_prefix_len + page_size - 1) // page_size if max_prefix_len > 0 else 0 + total_max_pages = num_requests * max_pages_per_request + + # Allocate Paged KV Cache + if total_max_pages > 0: + paged_kv_cache = torch.randn( + total_max_pages, 2, page_size, num_kv_heads, head_dim, + dtype=dtype, device=device + ) + else: + paged_kv_cache = None + + # Prepare paged kv parameters for each step + paged_kv_params_list = [] + for step_idx in range(num_steps): + prefix_len = step_idx * chunk_size + if prefix_len == 0: + # First step has no prefix + paged_kv_params_list.append(None) + else: + pages_per_request = (prefix_len + page_size - 1) // page_size + total_pages = num_requests * pages_per_request + + paged_kv_indptr = torch.tensor( + [i * pages_per_request for i in range(num_requests + 1)], + dtype=torch.int32, device=device + ) + paged_kv_indices = torch.arange(total_pages, dtype=torch.int32, device=device) + last_page_len = prefix_len % page_size if prefix_len % page_size != 0 else page_size + paged_kv_last_page_len = torch.full( + (num_requests,), last_page_len, dtype=torch.int32, device=device + ) + + paged_kv_params_list.append({ + "paged_kv_cache": paged_kv_cache[:total_pages] if paged_kv_cache is not None else None, + "paged_kv_indptr": paged_kv_indptr, + "paged_kv_indices": paged_kv_indices, + "paged_kv_last_page_len": paged_kv_last_page_len, + }) + + # indptrs + qo_indptr = torch.tensor( + [i * chunk_size for i in range(num_requests + 1)], + dtype=torch.int32, device=device + ) + kv_curr_indptr = torch.tensor( + [i * chunk_size for i in range(num_requests + 1)], + dtype=torch.int32, device=device + ) + + # q_offsets and kv_offsets (required by block_extend_cascade) + q_offsets_list = [] + for step_idx in range(num_steps): + prefix_len = step_idx * chunk_size + q_offsets_list.append(torch.full((num_requests,), prefix_len, dtype=torch.int32, device=device)) + + # Workspace buffer (shared) + workspace_buffer = torch.empty(256 * 1024 * 1024, dtype=torch.uint8, device=device) + + # Output buffer + output_buffer = torch.empty(num_requests * chunk_size, num_heads, head_dim, dtype=dtype, device=device) + + # Precision verification (when chunk_size == dllm_block_size) + if chunk_size == dllm_block_size: + print(f" [Precision Verification] chunk_size == dllm_block_size, comparing outputs of both interfaces...") + + max_diff_all_steps = 0.0 + for step_idx in range(num_steps): + q_batch = q_current_buffers[step_idx] + k_current = k_current_buffers[step_idx] + v_current = v_current_buffers[step_idx] + paged_params = paged_kv_params_list[step_idx] + + # batch_block_extend_cascade (function internally determines has_prefix) + be_out = batch_block_extend_cascade( + q=q_batch, + k_current=k_current, + v_current=v_current, + qo_indptr=qo_indptr, + kv_curr_indptr=kv_curr_indptr, + paged_kv_cache=paged_params["paged_kv_cache"] if paged_params else None, + paged_kv_indptr=paged_params["paged_kv_indptr"] if paged_params else None, + paged_kv_indices=paged_params["paged_kv_indices"] if paged_params else None, + paged_kv_last_page_len=paged_params["paged_kv_last_page_len"] if paged_params else None, + page_size=page_size, + dllm_block_size=dllm_block_size, + q_offsets=q_offsets_list[step_idx], + kv_offsets=q_offsets_list[step_idx], # Cascade scenario: kv_offset == q_offset == prefix_len + workspace_buffer=workspace_buffer, + sm_scale=sm_scale, + backend=backend, + ) + + # sglang_style_cascade_attention (function internally determines has_prefix) + sg_out = sglang_style_cascade_attention( + q=q_batch, + k_current=k_current, + v_current=v_current, + qo_indptr=qo_indptr, + kv_curr_indptr=kv_curr_indptr, + paged_kv_cache=paged_params["paged_kv_cache"] if paged_params else None, + paged_kv_indptr=paged_params["paged_kv_indptr"] if paged_params else None, + paged_kv_indices=paged_params["paged_kv_indices"] if paged_params else None, + paged_kv_last_page_len=paged_params["paged_kv_last_page_len"] if paged_params else None, + page_size=page_size, + workspace_buffer=workspace_buffer, + sm_scale=sm_scale, + backend=backend, + ) + + step_diff = (be_out - sg_out).abs().max().item() + max_diff_all_steps = max(max_diff_all_steps, step_diff) + + if verbose: + print(f" step {step_idx}: max_diff = {step_diff:.6f}") + + precision_ok = max_diff_all_steps < 0.01 + status = " PASS" if precision_ok else " FAIL" + print(f" [Precision Verification] max_diff = {max_diff_all_steps:.6f} {status}") + + + # [1] batch_block_extend_cascade performance test + print(f" [batch_block_extend_cascade] Performance test (without CUDA Graph)...") + + # ========== Segmented timing: measure Python overhead vs Kernel overhead ========== + if verbose: + import time as time_module + + # 1) Measure single complete call + torch.cuda.synchronize() + t0 = time_module.perf_counter() + _ = batch_block_extend_cascade( + q=q_current_buffers[0], + k_current=k_current_buffers[0], + v_current=v_current_buffers[0], + qo_indptr=qo_indptr, + kv_curr_indptr=kv_curr_indptr, + dllm_block_size=dllm_block_size, + q_offsets=q_offsets_list[0], + kv_offsets=q_offsets_list[0], + workspace_buffer=workspace_buffer, + sm_scale=sm_scale, + backend=backend, + ) + torch.cuda.synchronize() + be_single_call = (time_module.perf_counter() - t0) * 1000 + print(f" [Segmented Timing] BE single call: {be_single_call:.3f} ms (includes Wrapper creation + plan + run)") + + # 2) Measure SG single complete call + torch.cuda.synchronize() + t0 = time_module.perf_counter() + _ = sglang_style_cascade_attention( + q=q_current_buffers[0], + k_current=k_current_buffers[0], + v_current=v_current_buffers[0], + qo_indptr=qo_indptr, + kv_curr_indptr=kv_curr_indptr, + workspace_buffer=workspace_buffer, + sm_scale=sm_scale, + backend=backend, + ) + torch.cuda.synchronize() + sg_single_call = (time_module.perf_counter() - t0) * 1000 + print(f" [Segmented Timing] SG single call: {sg_single_call:.3f} ms (includes Wrapper creation + plan + run)") + print(f" [Segmented Timing] Single call diff: BE={be_single_call:.3f}ms, SG={sg_single_call:.3f}ms, diff={sg_single_call-be_single_call:.3f}ms") + + # 3) Measure pure kernel time with reused Wrapper (excluding Wrapper creation and plan overhead) + print(f" [Segmented Timing] Measuring pure run() time with reused Wrapper...") + + # BE: Create and plan once + from flashinfer.dllm.batch_block_extend import ( + BatchBlockExtendRaggedOffsetWrapper, + ) + be_wrapper = BatchBlockExtendRaggedOffsetWrapper( + workspace_buffer.clone(), + kv_layout="NHD", + dllm_block_size=dllm_block_size, + backend=backend, + ) + be_wrapper.plan( + qo_indptr=qo_indptr, + kv_indptr=kv_curr_indptr, + num_qo_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + q_data_type=q_current_buffers[0].dtype, + sm_scale=sm_scale, + q_offsets=q_offsets_list[0], + kv_offsets=q_offsets_list[0], + ) + + # SG: Create and plan once + from flashinfer.prefill import BatchPrefillWithRaggedKVCacheWrapper + sg_wrapper = BatchPrefillWithRaggedKVCacheWrapper( + workspace_buffer.clone(), + kv_layout="NHD", + backend=backend, + ) + sg_wrapper.plan( + qo_indptr=qo_indptr, + kv_indptr=kv_curr_indptr, + num_qo_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim_qk=head_dim, + head_dim_vo=head_dim, + q_data_type=q_current_buffers[0].dtype, + causal=False, # Same as BE: use non-causal (fully visible) + ) + + # BE: Only measure run() time + for _ in range(10): # warmup + _ = be_wrapper.run(q_current_buffers[0], k_current_buffers[0], v_current_buffers[0]) + torch.cuda.synchronize() + t0 = time_module.perf_counter() + for _ in range(100): + _ = be_wrapper.run(q_current_buffers[0], k_current_buffers[0], v_current_buffers[0]) + torch.cuda.synchronize() + be_run_only = (time_module.perf_counter() - t0) / 100 * 1000 + + # SG: Only measure run() time + for _ in range(10): # warmup + _ = sg_wrapper.run(q_current_buffers[0], k_current_buffers[0], v_current_buffers[0]) + torch.cuda.synchronize() + t0 = time_module.perf_counter() + for _ in range(100): + _ = sg_wrapper.run(q_current_buffers[0], k_current_buffers[0], v_current_buffers[0]) + torch.cuda.synchronize() + sg_run_only = (time_module.perf_counter() - t0) / 100 * 1000 + + print(f" [Segmented Timing] BE run() only: {be_run_only:.3f} ms") + print(f" [Segmented Timing] SG run() only: {sg_run_only:.3f} ms") + print(f" [Segmented Timing] run() diff: BE={be_run_only:.3f}ms, SG={sg_run_only:.3f}ms, diff={sg_run_only-be_run_only:.3f}ms") + if abs(sg_run_only - be_run_only) < 0.01: + print(f" [Conclusion] Kernel-level performance is comparable, diff comes from Python overhead") + else: + print(f" [Conclusion] Kernel-level diff exists, possibly due to different mask_mode implementations") + + def run_be_cascade_pipeline(): + for step_idx in range(num_steps): + q_batch = q_current_buffers[step_idx] + k_current = k_current_buffers[step_idx] + v_current = v_current_buffers[step_idx] + paged_params = paged_kv_params_list[step_idx] + + # Function internally determines has_prefix + output_buffer.copy_(batch_block_extend_cascade( + q=q_batch, + k_current=k_current, + v_current=v_current, + qo_indptr=qo_indptr, + kv_curr_indptr=kv_curr_indptr, + paged_kv_cache=paged_params["paged_kv_cache"] if paged_params else None, + paged_kv_indptr=paged_params["paged_kv_indptr"] if paged_params else None, + paged_kv_indices=paged_params["paged_kv_indices"] if paged_params else None, + paged_kv_last_page_len=paged_params["paged_kv_last_page_len"] if paged_params else None, + page_size=page_size, + dllm_block_size=dllm_block_size, + q_offsets=q_offsets_list[step_idx], + kv_offsets=q_offsets_list[step_idx], # Cascade scenario: kv_offset == q_offset == prefix_len + workspace_buffer=workspace_buffer, + sm_scale=sm_scale, + backend=backend, + )) + + # Warmup + for _ in range(warmup_iters): + run_be_cascade_pipeline() + torch.cuda.synchronize() + + # Benchmark (without CUDA Graph) + start = time.perf_counter() + for _ in range(bench_iters): + run_be_cascade_pipeline() + torch.cuda.synchronize() + be_time = (time.perf_counter() - start) / bench_iters * 1000 + + print(f" => {be_time:.3f} ms ({be_time/num_steps:.3f} ms/step × {num_steps} steps)") + + torch.cuda.empty_cache() + + # [2] sglang_style_cascade_attention performance test + print(f" [sglang_style_cascade_attention] Performance test (without CUDA Graph)...") + + def run_sg_cascade_pipeline(): + for step_idx in range(num_steps): + q_batch = q_current_buffers[step_idx] + k_current = k_current_buffers[step_idx] + v_current = v_current_buffers[step_idx] + paged_params = paged_kv_params_list[step_idx] + + # Function internally determines has_prefix + output_buffer.copy_(sglang_style_cascade_attention( + q=q_batch, + k_current=k_current, + v_current=v_current, + qo_indptr=qo_indptr, + kv_curr_indptr=kv_curr_indptr, + paged_kv_cache=paged_params["paged_kv_cache"] if paged_params else None, + paged_kv_indptr=paged_params["paged_kv_indptr"] if paged_params else None, + paged_kv_indices=paged_params["paged_kv_indices"] if paged_params else None, + paged_kv_last_page_len=paged_params["paged_kv_last_page_len"] if paged_params else None, + page_size=page_size, + workspace_buffer=workspace_buffer, + sm_scale=sm_scale, + backend=backend, + )) + + # Warmup + for _ in range(warmup_iters): + run_sg_cascade_pipeline() + torch.cuda.synchronize() + + # Benchmark (without CUDA Graph) + start = time.perf_counter() + for _ in range(bench_iters): + run_sg_cascade_pipeline() + torch.cuda.synchronize() + sg_time = (time.perf_counter() - start) / bench_iters * 1000 + + print(f" => {sg_time:.3f} ms ({sg_time/num_steps:.3f} ms/step × {num_steps} steps)") + + + speedup = sg_time / be_time if be_time > 0 else 0 + if speedup > 1: + print(f" => batch_block_extend_cascade is faster by {speedup:.2f}x") + else: + print(f" => sglang_style_cascade_attention is faster by {1/speedup:.2f}x") + + results[f"chunk{chunk_size}"] = { + "chunk_size": chunk_size, + "num_steps": num_steps, + "be_cascade_ms": be_time, + "sg_cascade_ms": sg_time, + "speedup_be_over_sg": speedup, + } + + torch.cuda.empty_cache() + + + # Results summary + print(f"\n{'='*90}") + print(f"Results Summary (num_requests={num_requests}, tokens_per_request={tokens_per_request}, dllm_block_size={dllm_block_size})") + print(f"{'='*90}") + + print(f"\n{'chunk':>8} | {'steps':>6} | {'BE Cascade':>14} | {'SG Cascade':>14} | {'BE/SG':>10}") + print(f"{'-'*8}-+-{'-'*6}-+-{'-'*14}-+-{'-'*14}-+-{'-'*10}") + + for key in sorted(results.keys(), key=lambda k: results[k]["chunk_size"]): + r = results[key] + print(f"{r['chunk_size']:>8} | {r['num_steps']:>6} | {r['be_cascade_ms']:>12.3f}ms | {r['sg_cascade_ms']:>12.3f}ms | {r['speedup_be_over_sg']:>9.2f}x") + + print(f"\nNotes:") + print(f" - BE Cascade: batch_block_extend_cascade (Block Extend mask)") + print(f" - SG Cascade: sglang_style_cascade_attention (Causal mask)") + print(f" - BE/SG: Speed ratio of batch_block_extend_cascade vs sglang_style") + print(f" (>1 means BE is faster, <1 means SG is faster)") + print(f" - When chunk_size == dllm_block_size, causal mask = block_extend mask") + print(f" - batch_block_extend_cascade supports chunk_size != dllm_block_size") + + return results + + +def test_cascade_interfaces_perf_with_cuda_graph( + num_requests: int = 4, + tokens_per_request: int = 512, + dllm_block_size: int = 32, + chunk_sizes: list = None, + num_heads: int = 32, + num_kv_heads: int = 8, + head_dim: int = 128, + page_size: int = 16, + warmup_iters: int = 10, + bench_iters: int = 100, + verbose: bool = False, + backend: str = "fa2", +): + """ + Compare performance of two Cascade interfaces (with CUDA Graph optimization) + + Differences from test_cascade_interfaces_perf: + ═══════════════════════════════════════════════════════════════════════════════ + - This function: Uses CUDA Graph to capture run() operations, reducing Python/launch overhead + - Original function: Each call includes Wrapper creation + plan() + run() + + CUDA Graph implementation notes: + ═══════════════════════════════════════════════════════════════════════════════ + 1. plan() contains CPU-GPU synchronization, cannot be executed during Graph capture + 2. Pre-create independent Wrappers for each step and complete plan() + 3. CUDA Graph only captures run() operations + 4. Each step has different paged_kv configuration, requires independent Wrapper + + Interfaces under test: + 1. batch_block_extend_cascade (via Wrapper.run()) + 2. sglang_style_cascade_attention (via Wrapper.run()) + """ + from flashinfer.dllm import ( + BatchBlockExtendRaggedOffsetWrapper, + BatchBlockExtendPagedOffsetWrapper + ) + from flashinfer.prefill import ( + BatchPrefillWithRaggedKVCacheWrapper, + BatchPrefillWithPagedKVCacheWrapper, + ) + from flashinfer.cascade import merge_state + + if chunk_sizes is None: + chunk_sizes = [dllm_block_size] + + device = torch.device("cuda:0") + dtype = torch.float16 + sm_scale = 1.0 / (head_dim ** 0.5) + + print(f"\n{'='*90}") + print(f"Cascade Interface Performance Comparison: Step by Step Incremental Prefill (CUDA Graph Version)") + print(f"{'='*90}") + print(f"Configuration:") + print(f" num_requests = {num_requests}") + print(f" tokens_per_request = {tokens_per_request}") + print(f" dllm_block_size = {dllm_block_size}") + print(f" chunk_sizes = {chunk_sizes}") + print(f" num_heads = {num_heads}") + print(f" num_kv_heads = {num_kv_heads}") + print(f" head_dim = {head_dim}") + print(f" page_size = {page_size}") + print(f" backend = {backend}") + print(f"\nCUDA Graph Optimization:") + print(f" - Pre-create Wrappers for each step and complete plan()") + print(f" - CUDA Graph only captures run() operations") + print(f" - Reduces Python/launch overhead") + + # Generate complete Q, K, V for each request + all_qs = [torch.randn(tokens_per_request, num_heads, head_dim, dtype=dtype, device=device) + for _ in range(num_requests)] + all_ks = [torch.randn(tokens_per_request, num_kv_heads, head_dim, dtype=dtype, device=device) + for _ in range(num_requests)] + all_vs = [torch.randn(tokens_per_request, num_kv_heads, head_dim, dtype=dtype, device=device) + for _ in range(num_requests)] + + def split_chunks(tensor, chunk_size): + return [tensor[i*chunk_size:(i+1)*chunk_size] for i in range(tensor.shape[0] // chunk_size)] + + results = {} + + for chunk_size in chunk_sizes: + if tokens_per_request % chunk_size != 0: + print(f"\n[Skip] chunk_size={chunk_size} doesn't evenly divide tokens_per_request={tokens_per_request}") + continue + + num_steps = tokens_per_request // chunk_size + + print(f"\n{'-'*90}") + print(f"chunk_size = {chunk_size}, num_steps = {num_steps}") + print(f"{'-'*90}") + + # Split chunks for each request + qs_chunks = [split_chunks(q, chunk_size) for q in all_qs] + ks_chunks = [split_chunks(k, chunk_size) for k in all_ks] + vs_chunks = [split_chunks(v, chunk_size) for v in all_vs] + + # Pre-allocate buffers for all steps + q_current_buffers = [] + k_current_buffers = [] + v_current_buffers = [] + for step_idx in range(num_steps): + q_list = [qs_chunks[req_idx][step_idx] for req_idx in range(num_requests)] + k_list = [ks_chunks[req_idx][step_idx] for req_idx in range(num_requests)] + v_list = [vs_chunks[req_idx][step_idx] for req_idx in range(num_requests)] + q_current_buffers.append(torch.cat(q_list, dim=0)) + k_current_buffers.append(torch.cat(k_list, dim=0)) + v_current_buffers.append(torch.cat(v_list, dim=0)) + + # Paged KV Cache setup + max_prefix_len = (num_steps - 1) * chunk_size + max_pages_per_request = (max_prefix_len + page_size - 1) // page_size if max_prefix_len > 0 else 0 + total_max_pages = num_requests * max_pages_per_request + + if total_max_pages > 0: + paged_kv_cache = torch.randn( + total_max_pages, 2, page_size, num_kv_heads, head_dim, + dtype=dtype, device=device + ) + else: + paged_kv_cache = None + + # Prepare paged kv parameters for each step + paged_kv_params_list = [] + for step_idx in range(num_steps): + prefix_len = step_idx * chunk_size + if prefix_len == 0: + paged_kv_params_list.append(None) + else: + pages_per_request = (prefix_len + page_size - 1) // page_size + total_pages = num_requests * pages_per_request + + paged_kv_indptr = torch.tensor( + [i * pages_per_request for i in range(num_requests + 1)], + dtype=torch.int32, device=device + ) + paged_kv_indices = torch.arange(total_pages, dtype=torch.int32, device=device) + last_page_len = prefix_len % page_size if prefix_len % page_size != 0 else page_size + paged_kv_last_page_len = torch.full( + (num_requests,), last_page_len, dtype=torch.int32, device=device + ) + + paged_kv_params_list.append({ + "paged_kv_cache": paged_kv_cache[:total_pages] if paged_kv_cache is not None else None, + "paged_kv_indptr": paged_kv_indptr, + "paged_kv_indices": paged_kv_indices, + "paged_kv_last_page_len": paged_kv_last_page_len, + }) + + # indptrs + qo_indptr = torch.tensor( + [i * chunk_size for i in range(num_requests + 1)], + dtype=torch.int32, device=device + ) + kv_curr_indptr = torch.tensor( + [i * chunk_size for i in range(num_requests + 1)], + dtype=torch.int32, device=device + ) + + # q_offsets + q_offsets_list = [] + for step_idx in range(num_steps): + prefix_len = step_idx * chunk_size + q_offsets_list.append(torch.full((num_requests,), prefix_len, dtype=torch.int32, device=device)) + + # Output buffers + output_buffer = torch.empty(num_requests * chunk_size, num_heads, head_dim, dtype=dtype, device=device) + be_output_buffers = [torch.empty_like(output_buffer) for _ in range(num_steps)] + sg_output_buffers = [torch.empty_like(output_buffer) for _ in range(num_steps)] + + # LSE buffers for merge (only needed for step > 0) + be_lse_ragged = [torch.empty(num_requests * chunk_size, num_heads, dtype=torch.float32, device=device) for _ in range(num_steps)] + be_lse_paged = [torch.empty(num_requests * chunk_size, num_heads, dtype=torch.float32, device=device) for _ in range(num_steps)] + sg_lse_ragged = [torch.empty(num_requests * chunk_size, num_heads, dtype=torch.float32, device=device) for _ in range(num_steps)] + sg_lse_paged = [torch.empty(num_requests * chunk_size, num_heads, dtype=torch.float32, device=device) for _ in range(num_steps)] + + # Pre-create BE Wrappers (independent for each step) + print(f" [Preparation] Pre-creating {num_steps} step Wrappers for BE Cascade...") + be_ragged_wrappers = [] + be_paged_wrappers = [] + + for step_idx in range(num_steps): + prefix_len = step_idx * chunk_size + + # Ragged Wrapper (Current Chunk) + ws_ragged = torch.empty(256 * 1024 * 1024, dtype=torch.uint8, device=device) + ragged_wrapper = BatchBlockExtendRaggedOffsetWrapper( + ws_ragged, + kv_layout="NHD", + dllm_block_size=dllm_block_size, + backend=backend, + ) + ragged_wrapper.plan( + qo_indptr=qo_indptr, + kv_indptr=kv_curr_indptr, + num_qo_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + q_data_type=dtype, + sm_scale=sm_scale, + q_offsets=q_offsets_list[step_idx], + kv_offsets=q_offsets_list[step_idx], + ) + be_ragged_wrappers.append(ragged_wrapper) + + # Paged Wrapper (Prefix) - only needed for step > 0 + # Cascade scenario: Q's block >= all prefix blocks, so mask is all 1s, use full attention + if prefix_len > 0: + paged_params = paged_kv_params_list[step_idx] + ws_paged = torch.empty(256 * 1024 * 1024, dtype=torch.uint8, device=device) + # Use native BatchPrefillWithPagedKVCacheWrapper (causal=False) instead of BlockExtend + # Because Prefix mask is all 1s, no additional mask computation needed + paged_wrapper = BatchPrefillWithPagedKVCacheWrapper( + ws_paged, + kv_layout="NHD", + backend=backend, + ) + paged_wrapper.plan( + qo_indptr=qo_indptr, + paged_kv_indptr=paged_params["paged_kv_indptr"], + paged_kv_indices=paged_params["paged_kv_indices"], + paged_kv_last_page_len=paged_params["paged_kv_last_page_len"], + num_qo_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim_qk=head_dim, + head_dim_vo=head_dim, + page_size=page_size, + q_data_type=dtype, + causal=False, # Prefix is fully visible + ) + be_paged_wrappers.append(paged_wrapper) + else: + be_paged_wrappers.append(None) + + # Pre-create SG Wrappers (independent for each step) + print(f" [Preparation] Pre-creating {num_steps} step Wrappers for SG Cascade...") + sg_ragged_wrappers = [] + sg_paged_wrappers = [] + + for step_idx in range(num_steps): + prefix_len = step_idx * chunk_size + + # Ragged Wrapper (Current Chunk, causal=True for SGLang style) + ws_ragged = torch.empty(256 * 1024 * 1024, dtype=torch.uint8, device=device) + ragged_wrapper = BatchPrefillWithRaggedKVCacheWrapper( + ws_ragged, + kv_layout="NHD", + backend=backend, + ) + ragged_wrapper.plan( + qo_indptr=qo_indptr, + kv_indptr=kv_curr_indptr, + num_qo_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim_qk=head_dim, + head_dim_vo=head_dim, + q_data_type=dtype, + causal=True, # SGLang: Current Chunk uses causal=True + ) + sg_ragged_wrappers.append(ragged_wrapper) + + # Paged Wrapper (Prefix, causal=False) - only needed for step > 0 + if prefix_len > 0: + paged_params = paged_kv_params_list[step_idx] + ws_paged = torch.empty(256 * 1024 * 1024, dtype=torch.uint8, device=device) + paged_wrapper = BatchPrefillWithPagedKVCacheWrapper( + ws_paged, + kv_layout="NHD", + backend=backend, + ) + paged_wrapper.plan( + qo_indptr=qo_indptr, + paged_kv_indptr=paged_params["paged_kv_indptr"], + paged_kv_indices=paged_params["paged_kv_indices"], + paged_kv_last_page_len=paged_params["paged_kv_last_page_len"], + num_qo_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim_qk=head_dim, + head_dim_vo=head_dim, + page_size=page_size, + q_data_type=dtype, + causal=False, + ) + sg_paged_wrappers.append(paged_wrapper) + else: + sg_paged_wrappers.append(None) + + torch.cuda.synchronize() + + # BE Cascade performance test (CUDA Graph) + print(f" [BE Cascade] Performance test (CUDA Graph)...") + + def run_be_cascade_with_wrappers(): + for step_idx in range(num_steps): + q = q_current_buffers[step_idx] + k = k_current_buffers[step_idx] + v = v_current_buffers[step_idx] + + if step_idx == 0: + # No prefix, only ragged needed + be_output_buffers[step_idx].copy_( + be_ragged_wrappers[step_idx].run(q, k, v) + ) + else: + # Has prefix: ragged + paged + merge + o1, s1 = be_ragged_wrappers[step_idx].run(q, k, v, return_lse=True) + be_lse_ragged[step_idx].copy_(s1) + + paged_params = paged_kv_params_list[step_idx] + o2, s2 = be_paged_wrappers[step_idx].run(q, paged_params["paged_kv_cache"], return_lse=True) + be_lse_paged[step_idx].copy_(s2) + + o, _ = merge_state(o1, s1, o2, s2) + be_output_buffers[step_idx].copy_(o) + + # Warmup + for _ in range(warmup_iters): + run_be_cascade_with_wrappers() + torch.cuda.synchronize() + + # Capture CUDA Graph + stream = torch.cuda.Stream() + with torch.cuda.stream(stream): + be_graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(be_graph, stream=stream): + run_be_cascade_with_wrappers() + torch.cuda.synchronize() + + # Benchmark with CUDA Graph + start = time.perf_counter() + for _ in range(bench_iters): + be_graph.replay() + torch.cuda.synchronize() + be_cuda_graph_time = (time.perf_counter() - start) / bench_iters * 1000 + + # Benchmark without CUDA Graph (for comparison) + start = time.perf_counter() + for _ in range(bench_iters): + run_be_cascade_with_wrappers() + torch.cuda.synchronize() + be_no_cuda_graph_time = (time.perf_counter() - start) / bench_iters * 1000 + + print(f" => CUDA Graph: {be_cuda_graph_time:.3f} ms ({be_cuda_graph_time/num_steps:.3f} ms/step × {num_steps} steps)") + print(f" => No Graph: {be_no_cuda_graph_time:.3f} ms ({be_no_cuda_graph_time/num_steps:.3f} ms/step × {num_steps} steps)") + if be_no_cuda_graph_time > 0: + cuda_graph_speedup = be_no_cuda_graph_time / be_cuda_graph_time + if cuda_graph_speedup > 1: + print(f" => CUDA Graph speedup: {cuda_graph_speedup:.2f}x") + else: + print(f" => CUDA Graph no speedup (probably dominated by kernel time)") + + # SG Cascade performance test (CUDA Graph) + print(f" [SG Cascade] Performance test (CUDA Graph)...") + + def run_sg_cascade_with_wrappers(): + for step_idx in range(num_steps): + q = q_current_buffers[step_idx] + k = k_current_buffers[step_idx] + v = v_current_buffers[step_idx] + + if step_idx == 0: + # No prefix, only ragged needed + sg_output_buffers[step_idx].copy_( + sg_ragged_wrappers[step_idx].run(q, k, v) + ) + else: + # Has prefix: ragged + paged + merge + o1, s1 = sg_ragged_wrappers[step_idx].run(q, k, v, return_lse=True) + sg_lse_ragged[step_idx].copy_(s1) + + paged_params = paged_kv_params_list[step_idx] + o2, s2 = sg_paged_wrappers[step_idx].run(q, paged_params["paged_kv_cache"], return_lse=True) + sg_lse_paged[step_idx].copy_(s2) + + o, _ = merge_state(o1, s1, o2, s2) + sg_output_buffers[step_idx].copy_(o) + + # Warmup + for _ in range(warmup_iters): + run_sg_cascade_with_wrappers() + torch.cuda.synchronize() + + # Capture CUDA Graph + with torch.cuda.stream(stream): + sg_graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(sg_graph, stream=stream): + run_sg_cascade_with_wrappers() + torch.cuda.synchronize() + + # Benchmark with CUDA Graph + start = time.perf_counter() + for _ in range(bench_iters): + sg_graph.replay() + torch.cuda.synchronize() + sg_cuda_graph_time = (time.perf_counter() - start) / bench_iters * 1000 + + # Benchmark without CUDA Graph + start = time.perf_counter() + for _ in range(bench_iters): + run_sg_cascade_with_wrappers() + torch.cuda.synchronize() + sg_no_cuda_graph_time = (time.perf_counter() - start) / bench_iters * 1000 + + print(f" => CUDA Graph: {sg_cuda_graph_time:.3f} ms ({sg_cuda_graph_time/num_steps:.3f} ms/step × {num_steps} steps)") + print(f" => No Graph: {sg_no_cuda_graph_time:.3f} ms ({sg_no_cuda_graph_time/num_steps:.3f} ms/step × {num_steps} steps)") + if sg_no_cuda_graph_time > 0: + cuda_graph_speedup = sg_no_cuda_graph_time / sg_cuda_graph_time + if cuda_graph_speedup > 1: + print(f" => CUDA Graph speedup: {cuda_graph_speedup:.2f}x") + else: + print(f" => CUDA Graph no speedup") + + # Compare BE vs SG (CUDA Graph) + if be_cuda_graph_time > 0 and sg_cuda_graph_time > 0: + speedup = sg_cuda_graph_time / be_cuda_graph_time + if speedup > 1: + print(f" [Comparison] BE Cascade is faster by {speedup:.2f}x (CUDA Graph)") + else: + print(f" [Comparison] SG Cascade is faster by {1/speedup:.2f}x (CUDA Graph)") + + results[f"chunk{chunk_size}"] = { + "chunk_size": chunk_size, + "num_steps": num_steps, + "be_cuda_graph_ms": be_cuda_graph_time, + "be_no_cuda_graph_ms": be_no_cuda_graph_time, + "sg_cuda_graph_ms": sg_cuda_graph_time, + "sg_no_cuda_graph_ms": sg_no_cuda_graph_time, + "speedup_be_over_sg_cuda_graph": sg_cuda_graph_time / be_cuda_graph_time if be_cuda_graph_time > 0 else 0, + } + + # Cleanup + del be_graph, sg_graph + torch.cuda.empty_cache() + + # ================================================================ + # Results summary + # ================================================================ + print(f"\n{'='*90}") + print(f"Results Summary (CUDA Graph Version)") + print(f"{'='*90}") + + print(f"\n{'chunk':>8} | {'steps':>6} | {'BE(cuda_graph)':>10} | {'BE(No)':>10} | {'SG(cuda_graph)':>10} | {'SG(No)':>10} | {'BE/SG':>8}") + print(f"{'-'*8}-+-{'-'*6}-+-{'-'*10}-+-{'-'*10}-+-{'-'*10}-+-{'-'*10}-+-{'-'*8}") + + for key in sorted(results.keys(), key=lambda k: results[k]["chunk_size"]): + r = results[key] + print(f"{r['chunk_size']:>8} | {r['num_steps']:>6} | {r['be_cuda_graph_ms']:>8.3f}ms | {r['be_no_cuda_graph_ms']:>8.3f}ms | " + f"{r['sg_cuda_graph_ms']:>8.3f}ms | {r['sg_no_cuda_graph_ms']:>8.3f}ms | {r['speedup_be_over_sg_cuda_graph']:>7.2f}x") + + print(f"\nNotes:") + print(f" - BE(cuda_graph): batch_block_extend Wrapper.run() with CUDA Graph") + print(f" - BE(No): batch_block_extend Wrapper.run() without CUDA Graph") + print(f" - SG(cuda_graph): sglang_style Wrapper.run() with CUDA Graph") + print(f" - SG(No): sglang_style Wrapper.run() without CUDA Graph") + print(f" - BE/SG: Speed ratio of BE vs SG (>1 means BE is faster)") + print(f" - CUDA Graph optimization: pre-plan(), only capture run()") + + return results + + +def test_heterogeneous_prefix_batch( + verbose: bool = True, + backend: str = "fa2", +): + """ + Heterogeneous prefix test: different requests have different prefix lengths + + Scenario reproduction: + - Req 0: Already prefilled, has prefix (kv_len=128, q_offset=64) + - Req 1: New request, no prefix (kv_len=32, q_offset=0) + - Both requests concatenated together for batch block-extend attention operator + + Test purposes: + 1. Verify whether operator supports heterogeneous kv_len input + 2. Check for out-of-bounds memory access issues + 3. Verify precision correctness + + Reference implementation: Each request computed independently with custom_mask, then concatenated + """ + from flashinfer.dllm import BatchBlockExtendRaggedOffsetWrapper + from flashinfer.prefill import single_prefill_with_kv_cache + + device = torch.device("cuda:0") + dtype = torch.float16 + tol = 1e-2 + + print(f"\n{'='*100}") + print(f"Heterogeneous Prefix Test: Different requests have different prefix lengths") + print(f"{'='*100}") + print(f"Test backend: {backend}") + print(f"Precision tolerance: {tol}") + + # Test configuration: Heterogeneous prefix scenarios + # Each config is a list of requests, each request has different qo_len, kv_len, q_offset + test_configs = [ + # Scenario 1: One with prefix, one without + { + "name": "Req0(has_prefix) + Req1(no_prefix)", + "dllm_block_size": 32, + "num_heads": 32, + "num_kv_heads": 8, + "head_dim": 128, + "requests": [ + {"qo_len": 64, "kv_len": 128, "q_offset": 64}, # Req0: step_2, already has 64 tokens prefix + {"qo_len": 32, "kv_len": 32, "q_offset": 0}, # Req1: step_0, no prefix + ], + }, + # Scenario 2: Three requests, different steps + { + "name": "Req0(step_3) + Req1(step_0) + Req2(step_1)", + "dllm_block_size": 32, + "num_heads": 32, + "num_kv_heads": 8, + "head_dim": 128, + "requests": [ + {"qo_len": 32, "kv_len": 128, "q_offset": 96}, # Req0: step_3 + {"qo_len": 32, "kv_len": 32, "q_offset": 0}, # Req1: step_0 + {"qo_len": 32, "kv_len": 64, "q_offset": 32}, # Req2: step_1 + ], + }, + # Scenario 3: Two requests, larger kv_len difference + { + "name": "Req0(kv=256) + Req1(kv=32)", + "dllm_block_size": 32, + "num_heads": 32, + "num_kv_heads": 8, + "head_dim": 128, + "requests": [ + {"qo_len": 32, "kv_len": 256, "q_offset": 224}, # Req0: very long prefix + {"qo_len": 32, "kv_len": 32, "q_offset": 0}, # Req1: no prefix + ], + }, + # Scenario 4: Two requests, qo_len also different + { + "name": "Req0(qo=64,kv=128) + Req1(qo=32,kv=32)", + "dllm_block_size": 32, + "num_heads": 32, + "num_kv_heads": 8, + "head_dim": 128, + "requests": [ + {"qo_len": 64, "kv_len": 128, "q_offset": 64}, # Req0 + {"qo_len": 32, "kv_len": 32, "q_offset": 0}, # Req1 + ], + }, + # Scenario 5: Four requests, mixed scenario + { + "name": "4_requests_mixed_scenario", + "dllm_block_size": 64, + "num_heads": 32, + "num_kv_heads": 8, + "head_dim": 128, + "requests": [ + {"qo_len": 64, "kv_len": 256, "q_offset": 192}, # Req0: step_3 + {"qo_len": 64, "kv_len": 64, "q_offset": 0}, # Req1: step_0 + {"qo_len": 64, "kv_len": 128, "q_offset": 64}, # Req2: step_1 + {"qo_len": 64, "kv_len": 192, "q_offset": 128}, # Req3: step_2 + ], + }, + # Scenario 6: Similar to SGLang batch inference - long prompt + short prompt + { + "name": "Long_short_prompt_mix(512_vs_32)", + "dllm_block_size": 32, + "num_heads": 32, + "num_kv_heads": 8, + "head_dim": 128, + "requests": [ + {"qo_len": 512, "kv_len": 512, "q_offset": 0}, # Req0: long prompt (e.g., math problem) + {"qo_len": 32, "kv_len": 32, "q_offset": 0}, # Req1: short prompt (e.g., "Say hello") + ], + }, + # Scenario 7: Extreme difference - super long vs super short + { + "name": "Extreme_difference(1024_vs_16)", + "dllm_block_size": 32, + "num_heads": 32, + "num_kv_heads": 8, + "head_dim": 128, + "requests": [ + {"qo_len": 1024, "kv_len": 1024, "q_offset": 0}, # Req0: super long prompt + {"qo_len": 16, "kv_len": 16, "q_offset": 0}, # Req1: super short prompt + ], + }, + # Scenario 8: Three requests, increasing lengths + { + "name": "Three_requests_increasing_length(64,256,512)", + "dllm_block_size": 64, + "num_heads": 32, + "num_kv_heads": 8, + "head_dim": 128, + "requests": [ + {"qo_len": 64, "kv_len": 64, "q_offset": 0}, # Req0: short + {"qo_len": 256, "kv_len": 256, "q_offset": 0}, # Req1: medium + {"qo_len": 512, "kv_len": 512, "q_offset": 0}, # Req2: long + ], + }, + # Scenario 9: Mixed prefill stages + different prompt lengths + { + "name": "Mixed_prefill_stages+different_prompt_lengths", + "dllm_block_size": 32, + "num_heads": 32, + "num_kv_heads": 8, + "head_dim": 128, + "requests": [ + {"qo_len": 32, "kv_len": 512, "q_offset": 480}, # Req0: long prompt, step_15 + {"qo_len": 32, "kv_len": 32, "q_offset": 0}, # Req1: short prompt, step_0 + {"qo_len": 32, "kv_len": 128, "q_offset": 96}, # Req2: medium prompt, step_3 + ], + }, + ] + + results = [] + all_pass = True + + for cfg_idx, cfg in enumerate(test_configs): + dllm_block_size = cfg["dllm_block_size"] + num_heads = cfg["num_heads"] + num_kv_heads = cfg["num_kv_heads"] + head_dim = cfg["head_dim"] + requests = cfg["requests"] + num_requests = len(requests) + sm_scale = 1.0 / math.sqrt(head_dim) + + print(f"\n [Test {cfg_idx:02d}] {cfg['name']}") + print(f" B={dllm_block_size}, heads={num_heads}/{num_kv_heads}, dim={head_dim}") + + # Generate data for each request + qs = [] + ks = [] + vs = [] + for req in requests: + qo_len = req["qo_len"] + kv_len = req["kv_len"] + qs.append(torch.randn(qo_len, num_heads, head_dim, dtype=dtype, device=device)) + ks.append(torch.randn(kv_len, num_kv_heads, head_dim, dtype=dtype, device=device)) + vs.append(torch.randn(kv_len, num_kv_heads, head_dim, dtype=dtype, device=device)) + + # Reference implementation: compute each request independently + ref_outputs = [] + for req_idx, req in enumerate(requests): + qo_len = req["qo_len"] + kv_len = req["kv_len"] + q_offset = req["q_offset"] + + # Build custom_mask + q_pos = torch.arange(qo_len, device=device) + q_offset + k_pos = torch.arange(kv_len, device=device) + q_block = q_pos.unsqueeze(1) // dllm_block_size + k_block = k_pos.unsqueeze(0) // dllm_block_size + mask_2d = (q_block >= k_block).to(torch.uint8) + + ref_output = single_prefill_with_kv_cache( + qs[req_idx], ks[req_idx], vs[req_idx], + custom_mask=mask_2d, + sm_scale=sm_scale, + backend="fa2", + ) + ref_outputs.append(ref_output) + + # Concatenate reference outputs + ref_output_cat = torch.cat(ref_outputs, dim=0) + + # Build batch input + q_cat = torch.cat(qs, dim=0) + k_cat = torch.cat(ks, dim=0) + v_cat = torch.cat(vs, dim=0) + + # Build indptr + qo_lens = [req["qo_len"] for req in requests] + kv_lens = [req["kv_len"] for req in requests] + q_offsets_list = [req["q_offset"] for req in requests] + + qo_indptr = torch.tensor([0] + list(torch.cumsum(torch.tensor(qo_lens), dim=0).numpy()), dtype=torch.int32, device=device) + kv_indptr = torch.tensor([0] + list(torch.cumsum(torch.tensor(kv_lens), dim=0).numpy()), dtype=torch.int32, device=device) + q_offsets = torch.tensor(q_offsets_list, dtype=torch.int32, device=device) + + if verbose: + print(f" qo_indptr: {qo_indptr.tolist()}") + print(f" kv_indptr: {kv_indptr.tolist()}") + print(f" q_offsets: {q_offsets.tolist()}") + + # Object under test: BatchBlockExtendRaggedOffsetWrapper + try: + workspace = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device) + wrapper = BatchBlockExtendRaggedOffsetWrapper( + workspace, kv_layout="NHD", dllm_block_size=dllm_block_size, backend=backend + ) + wrapper.plan( + qo_indptr=qo_indptr, + kv_indptr=kv_indptr, + num_qo_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + q_data_type=dtype, + sm_scale=sm_scale, + q_offsets=q_offsets, + ) + bbe_output = wrapper.run(q_cat, k_cat, v_cat) + + # Calculate precision differences + max_diff = (bbe_output - ref_output_cat).abs().max().item() + mean_diff = (bbe_output - ref_output_cat).abs().mean().item() + passed = max_diff < tol + + status = "PASS" if passed else "FAIL" + print(f" BBE-{backend.upper()}: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f} [{status}]") + + if not passed: + all_pass = False + # Print detailed diff for each request + start_idx = 0 + for req_idx, req in enumerate(requests): + qo_len = req["qo_len"] + end_idx = start_idx + qo_len + req_diff = (bbe_output[start_idx:end_idx] - ref_outputs[req_idx]).abs().max().item() + print(f" Req{req_idx} (qo={qo_len}, kv={req['kv_len']}, q_off={req['q_offset']}): max_diff={req_diff:.6f}") + start_idx = end_idx + + results.append({ + "config_idx": cfg_idx, + "name": cfg["name"], + "max_diff": max_diff, + "passed": passed, + "error": None, + }) + + del workspace, wrapper + + except Exception as e: + print(f" BBE-{backend.upper()}: ERROR - {str(e)}") + results.append({ + "config_idx": cfg_idx, + "name": cfg["name"], + "max_diff": float('inf'), + "passed": False, + "error": str(e), + }) + all_pass = False + + torch.cuda.empty_cache() + + # Summary + print(f"\n{'='*100}") + print(f"Heterogeneous Prefix Test Summary") + print(f"{'='*100}") + + passed_count = sum(1 for r in results if r["passed"]) + total_count = len(results) + print(f" Passed: {passed_count}/{total_count}") + + if not all_pass: + print(f"\n Failed tests:") + for r in results: + if not r["passed"]: + if r["error"]: + print(f" - [{r['config_idx']:02d}] {r['name']}: ERROR - {r['error']}") + else: + print(f" - [{r['config_idx']:02d}] {r['name']}: max_diff={r['max_diff']:.6f}") + + return results + + +def test_cascade_current_chunk_batch( + verbose: bool = True, + backend: str = "fa2", +): + """ + Test complete three-stage Cascade Attention (simulating SGLang DLLM flow) + + Three-stage flow: + Stage 1 (prefix): BatchBlockExtendRaggedOffsetWrapper computes prefix + - K/V: [0, prefix_len) + - kv_offset = 0 + Stage 2 (current chunk): BatchBlockExtendRaggedOffsetWrapper computes current chunk + - K/V: [prefix_len, prefix_len + chunk_len) (only current chunk) + - kv_offset = prefix_len + Stage 3 (merge): merge_state(o1, s1, o2, s2) + + Mask rule: mask[q, k] = (q_global // B) >= (k_global // B) + + Key points: + - Stage 2's K/V doesn't start from 0, needs kv_offset + - Uses blockwise extend mask (not causal mask) + """ + device = "cuda" + dtype = torch.bfloat16 + tol = 0.01 if backend == "fa3" else 0.01 + + print(f"\n{'='*100}") + print(f"Three-stage Cascade Attention Test (prefix + current_chunk + merge)") + print(f"{'='*100}") + print(f"Test backend: {backend}") + print(f"Precision tolerance: {tol}") + + # Test configuration: each request has prefix_len and chunk_len + test_configs = [ + # Scenario 1: Two requests, one with prefix, one without + { + "name": "Req0(has_prefix) + Req1(no_prefix)", + "dllm_block_size": 32, + "num_heads": 32, + "num_kv_heads": 8, + "head_dim": 128, + "requests": [ + {"prefix_len": 64, "chunk_len": 32}, # Req0: step_2, prefix [0,64), chunk [64,96) + {"prefix_len": 0, "chunk_len": 32}, # Req1: step_0, no prefix, chunk [0,32) + ], + }, + # Scenario 2: Three requests, different steps + { + "name": "Req0(step_3) + Req1(step_0) + Req2(step_1)", + "dllm_block_size": 32, + "num_heads": 32, + "num_kv_heads": 8, + "head_dim": 128, + "requests": [ + {"prefix_len": 96, "chunk_len": 32}, # Req0: step_3 + {"prefix_len": 0, "chunk_len": 32}, # Req1: step_0 + {"prefix_len": 32, "chunk_len": 32}, # Req2: step_1 + ], + }, + # Scenario 3: Large prefix + { + "name": "Req0(large_prefix=256) + Req1(no_prefix)", + "dllm_block_size": 32, + "num_heads": 32, + "num_kv_heads": 8, + "head_dim": 128, + "requests": [ + {"prefix_len": 256, "chunk_len": 32}, # Req0: step_8 + {"prefix_len": 0, "chunk_len": 32}, # Req1: step_0 + ], + }, + # Scenario 4: chunk_len != block_size + { + "name": "Req0(chunk=64) + Req1(chunk=32)", + "dllm_block_size": 32, + "num_heads": 32, + "num_kv_heads": 8, + "head_dim": 128, + "requests": [ + {"prefix_len": 64, "chunk_len": 64}, # Req0: 2 blocks chunk + {"prefix_len": 0, "chunk_len": 32}, # Req1: 1 block chunk + ], + }, + # Scenario 5: Four requests mixed + { + "name": "4_requests_mixed(step_0,1,2,3)", + "dllm_block_size": 64, + "num_heads": 32, + "num_kv_heads": 8, + "head_dim": 128, + "requests": [ + {"prefix_len": 0, "chunk_len": 64}, # Req0: step_0 + {"prefix_len": 64, "chunk_len": 64}, # Req1: step_1 + {"prefix_len": 128, "chunk_len": 64}, # Req2: step_2 + {"prefix_len": 192, "chunk_len": 64}, # Req3: step_3 + ], + }, + ] + + results = [] + all_pass = True + + for cfg_idx, cfg in enumerate(test_configs): + dllm_block_size = cfg["dllm_block_size"] + num_heads = cfg["num_heads"] + num_kv_heads = cfg["num_kv_heads"] + head_dim = cfg["head_dim"] + requests = cfg["requests"] + sm_scale = 1.0 / math.sqrt(head_dim) + + print(f"\n [Test {cfg_idx:02d}] {cfg['name']}") + print(f" B={dllm_block_size}, heads={num_heads}/{num_kv_heads}, dim={head_dim}") + + # Generate data for each request + # Each request has: Q(chunk_len), K_prefix(prefix_len), V_prefix(prefix_len), K_chunk(chunk_len), V_chunk(chunk_len) + qs = [] # Q: current chunk's query + k_prefixes = [] # K_prefix: K of prefix part + v_prefixes = [] # V_prefix: V of prefix part + k_chunks = [] # K_chunk: K of current chunk + v_chunks = [] # V_chunk: V of current chunk + + for req in requests: + prefix_len = req["prefix_len"] + chunk_len = req["chunk_len"] + qs.append(torch.randn(chunk_len, num_heads, head_dim, dtype=dtype, device=device)) + if prefix_len > 0: + k_prefixes.append(torch.randn(prefix_len, num_kv_heads, head_dim, dtype=dtype, device=device)) + v_prefixes.append(torch.randn(prefix_len, num_kv_heads, head_dim, dtype=dtype, device=device)) + else: + k_prefixes.append(None) + v_prefixes.append(None) + k_chunks.append(torch.randn(chunk_len, num_kv_heads, head_dim, dtype=dtype, device=device)) + v_chunks.append(torch.randn(chunk_len, num_kv_heads, head_dim, dtype=dtype, device=device)) + + # Reference implementation: compute each request independently (full KV + blockwise mask) + ref_outputs = [] + for req_idx, req in enumerate(requests): + prefix_len = req["prefix_len"] + chunk_len = req["chunk_len"] + total_kv_len = prefix_len + chunk_len + + # Concatenate full K/V + if prefix_len > 0: + k_full = torch.cat([k_prefixes[req_idx], k_chunks[req_idx]], dim=0) + v_full = torch.cat([v_prefixes[req_idx], v_chunks[req_idx]], dim=0) + else: + k_full = k_chunks[req_idx] + v_full = v_chunks[req_idx] + + # Construct blockwise extend mask + # Q: [prefix_len, prefix_len + chunk_len) + # K: [0, prefix_len + chunk_len) + q_offset = prefix_len + q_pos = torch.arange(chunk_len, device=device) + q_offset + k_pos = torch.arange(total_kv_len, device=device) # K starts from 0 + q_block = q_pos.unsqueeze(1) // dllm_block_size + k_block = k_pos.unsqueeze(0) // dllm_block_size + mask_2d = (q_block >= k_block).to(torch.uint8) + + ref_output = single_prefill_with_kv_cache( + qs[req_idx], k_full, v_full, + custom_mask=mask_2d, + sm_scale=sm_scale, + backend="fa2", + ) + ref_outputs.append(ref_output) + + ref_output_cat = torch.cat(ref_outputs, dim=0) + + if verbose: + for req_idx, req in enumerate(requests): + prefix_len = req["prefix_len"] + chunk_len = req["chunk_len"] + print(f" Req{req_idx}: prefix_len={prefix_len}, chunk_len={chunk_len}, q_offset={prefix_len}, kv_offset={prefix_len}") + + # Target under test: three-stage Cascade Attention + try: + cascade_outputs = [] + workspace = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device) + + for req_idx, req in enumerate(requests): + prefix_len = req["prefix_len"] + chunk_len = req["chunk_len"] + q = qs[req_idx] + + # Stage 1: prefix (if exists) + if prefix_len > 0: + # Construct prefix indptr + prefix_qo_indptr = torch.tensor([0, chunk_len], dtype=torch.int32, device=device) + prefix_kv_indptr = torch.tensor([0, prefix_len], dtype=torch.int32, device=device) + prefix_q_offsets = torch.tensor([prefix_len], dtype=torch.int32, device=device) # Q's global position + prefix_kv_offsets = torch.tensor([0], dtype=torch.int32, device=device) # prefix K/V starts from 0 + + prefix_wrapper = BatchBlockExtendRaggedOffsetWrapper( + workspace, kv_layout="NHD", dllm_block_size=dllm_block_size, backend=backend + ) + prefix_wrapper.plan( + qo_indptr=prefix_qo_indptr, + kv_indptr=prefix_kv_indptr, + num_qo_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + q_data_type=dtype, + sm_scale=sm_scale, + q_offsets=prefix_q_offsets, + kv_offsets=prefix_kv_offsets, + ) + o1, s1 = prefix_wrapper.run(q, k_prefixes[req_idx], v_prefixes[req_idx], return_lse=True) + del prefix_wrapper + else: + o1 = None + s1 = None + + # Stage 2: current chunk + chunk_qo_indptr = torch.tensor([0, chunk_len], dtype=torch.int32, device=device) + chunk_kv_indptr = torch.tensor([0, chunk_len], dtype=torch.int32, device=device) + chunk_q_offsets = torch.tensor([prefix_len], dtype=torch.int32, device=device) # Q's global position + chunk_kv_offsets = torch.tensor([prefix_len], dtype=torch.int32, device=device) # chunk K/V's global position + + chunk_wrapper = BatchBlockExtendRaggedOffsetWrapper( + workspace, kv_layout="NHD", dllm_block_size=dllm_block_size, backend=backend + ) + chunk_wrapper.plan( + qo_indptr=chunk_qo_indptr, + kv_indptr=chunk_kv_indptr, + num_qo_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + q_data_type=dtype, + sm_scale=sm_scale, + q_offsets=chunk_q_offsets, + kv_offsets=chunk_kv_offsets, + ) + o2, s2 = chunk_wrapper.run(q, k_chunks[req_idx], v_chunks[req_idx], return_lse=True) + del chunk_wrapper + + # Stage 3: merge + if o1 is not None: + o_merged, _ = merge_state(o1, s1, o2, s2) + cascade_outputs.append(o_merged) + else: + cascade_outputs.append(o2) + + cascade_output_cat = torch.cat(cascade_outputs, dim=0) + + # Compute precision difference + max_diff = (cascade_output_cat - ref_output_cat).abs().max().item() + mean_diff = (cascade_output_cat - ref_output_cat).abs().mean().item() + passed = max_diff < tol + + status = "PASS" if passed else "FAIL" + print(f" Cascade-{backend.upper()}: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f} [{status}]") + + if not passed: + all_pass = False + + results.append({ + "config_idx": cfg_idx, + "name": cfg["name"], + "max_diff": max_diff, + "passed": passed, + "error": None, + }) + + del workspace + + except Exception as e: + import traceback + print(f" Cascade-{backend.upper()}: ERROR - {str(e)}") + if verbose: + traceback.print_exc() + results.append({ + "config_idx": cfg_idx, + "name": cfg["name"], + "max_diff": float('inf'), + "passed": False, + "error": str(e), + }) + all_pass = False + + torch.cuda.empty_cache() + + # Summary + print(f"\n{'='*100}") + print(f"Three-stage Cascade Attention Test Summary") + print(f"{'='*100}") + + passed_count = sum(1 for r in results if r["passed"]) + total_count = len(results) + print(f" Passed: {passed_count}/{total_count}") + + if not all_pass: + print(f"\n Failed tests:") + for r in results: + if not r["passed"]: + if r["error"]: + print(f" - [{r['config_idx']:02d}] {r['name']}: ERROR - {r['error']}") + else: + print(f" - [{r['config_idx']:02d}] {r['name']}: max_diff={r['max_diff']:.6f}") + + return results + + +def test_cascade_precision_alignment( + verbose: bool = True, +): + """ + Test block_extend_cascade precision alignment (single request version) + + Comparison targets: + 1. block_extend_cascade: uses block_extend_attention_with_offset (Current Chunk) + 2. Reference implementation: uses single_prefill_with_kv_cache(causal=True) (Current Chunk) + + When chunk_size == dllm_block_size: + - Causal mask ≡ Block Extend mask + - The outputs of both implementations should be completely consistent + + Test coverage: + 1. Different dllm_block_size: [32, 64, 128] + 2. Different num_steps (with/without prefix) + 3. Different head configurations (MHA, GQA, MQA) + 4. Different head_dim + """ + from flashinfer.dllm import block_extend_cascade + from flashinfer.cascade import merge_state_in_place + from flashinfer.prefill import single_prefill_with_kv_cache + + device = torch.device("cuda:0") + dtype = torch.float16 + tol = 1e-2 # Precision tolerance + + print(f"\n{'='*100}") + print(f"Single Request Cascade Precision Alignment Test: block_extend_cascade vs custom_mask Reference Implementation") + print(f"{'='*100}") + print(f"Test condition: chunk_size == dllm_block_size") + print(f"Note: Block Extend mask != Causal mask (all visible within block, not lower triangular)") + print(f"Precision tolerance: {tol}") + + # Test configurations + test_configs = [ + # Basic tests: different dllm_block_size + {"dllm_block_size": 32, "num_steps": 4, "num_heads": 32, "num_kv_heads": 8, "head_dim": 128}, + {"dllm_block_size": 64, "num_steps": 4, "num_heads": 32, "num_kv_heads": 8, "head_dim": 128}, + {"dllm_block_size": 128, "num_steps": 2, "num_heads": 32, "num_kv_heads": 8, "head_dim": 128}, + + # More steps (longer sequences) + {"dllm_block_size": 32, "num_steps": 8, "num_heads": 32, "num_kv_heads": 8, "head_dim": 128}, + {"dllm_block_size": 64, "num_steps": 8, "num_heads": 32, "num_kv_heads": 8, "head_dim": 128}, + + # Different head configurations (MHA, GQA, MQA) + {"dllm_block_size": 32, "num_steps": 4, "num_heads": 32, "num_kv_heads": 32, "head_dim": 128}, # MHA + {"dllm_block_size": 32, "num_steps": 4, "num_heads": 32, "num_kv_heads": 4, "head_dim": 128}, # GQA-8 + {"dllm_block_size": 32, "num_steps": 4, "num_heads": 32, "num_kv_heads": 1, "head_dim": 128}, # MQA + + # Different head_dim + {"dllm_block_size": 32, "num_steps": 4, "num_heads": 32, "num_kv_heads": 8, "head_dim": 64}, + {"dllm_block_size": 32, "num_steps": 4, "num_heads": 16, "num_kv_heads": 4, "head_dim": 256}, + + # Boundary tests: single step (no prefix) + {"dllm_block_size": 32, "num_steps": 1, "num_heads": 32, "num_kv_heads": 8, "head_dim": 128}, + {"dllm_block_size": 64, "num_steps": 1, "num_heads": 32, "num_kv_heads": 8, "head_dim": 128}, + ] + + results = [] + + for cfg_idx, cfg in enumerate(test_configs): + dllm_block_size = cfg["dllm_block_size"] + num_steps = cfg["num_steps"] + num_heads = cfg["num_heads"] + num_kv_heads = cfg["num_kv_heads"] + head_dim = cfg["head_dim"] + + chunk_size = dllm_block_size # Key: chunk_size == dllm_block_size + tokens_per_request = num_steps * chunk_size + sm_scale = 1.0 / math.sqrt(head_dim) + + # Generate test data for full sequence + q_full = torch.randn(tokens_per_request, num_heads, head_dim, dtype=dtype, device=device) + k_full = torch.randn(tokens_per_request, num_kv_heads, head_dim, dtype=dtype, device=device) + v_full = torch.randn(tokens_per_request, num_kv_heads, head_dim, dtype=dtype, device=device) + + def split_chunks(tensor, chunk_size): + return [tensor[i*chunk_size:(i+1)*chunk_size] for i in range(tensor.shape[0] // chunk_size)] + + qs_chunks = split_chunks(q_full, chunk_size) + ks_chunks = split_chunks(k_full, chunk_size) + vs_chunks = split_chunks(v_full, chunk_size) + + max_diff = 0.0 + step_diffs = [] + + for step_idx in range(num_steps): + q_current = qs_chunks[step_idx] + k_current = ks_chunks[step_idx] + v_current = vs_chunks[step_idx] + + # Prefix (KV from all previous chunks) + if step_idx > 0: + k_prefix = k_full[:step_idx * chunk_size] + v_prefix = v_full[:step_idx * chunk_size] + else: + k_prefix = None + v_prefix = None + + # Target under test: block_extend_cascade + be_out = block_extend_cascade( + q=q_current, + k_current=k_current, + v_current=v_current, + k_prefix=k_prefix, + v_prefix=v_prefix, + dllm_block_size=dllm_block_size, + sm_scale=sm_scale, + return_lse=False, + backend="fa2", + ) + + # Reference: compute Block Extend Attention using custom_mask + # Note: Block Extend mask != Causal mask + # Block Extend: mask[q,k] = ((q+offset)//B) >= (k//B) + # When chunk_size == dllm_block_size, all positions within chunk are visible (not lower triangular) + prefix_len = step_idx * chunk_size + if k_prefix is None: + # No prefix, directly use Block Extend mask + ref_out = compute_block_extend_reference( + q_current, k_current, v_current, + dllm_block_size=dllm_block_size, + q_offset=prefix_len, + sm_scale=sm_scale, + ) + else: + # With prefix: Current Chunk (Block Extend) + Prefix (fully visible) + merge + # Current chunk: q_offset = prefix_len, kv starts from prefix_len + + # Construct Block Extend mask for current chunk + qo_len = q_current.shape[0] + kv_len = k_current.shape[0] + q_pos = torch.arange(qo_len, device=q_current.device) + prefix_len + k_pos = torch.arange(kv_len, device=q_current.device) + prefix_len # kv also starts from prefix_len + q_block = q_pos.unsqueeze(1) // dllm_block_size + k_block = k_pos.unsqueeze(0) // dllm_block_size + mask_current = (q_block >= k_block).to(torch.uint8) + + o1, s1 = single_prefill_with_kv_cache( + q_current, k_current, v_current, + custom_mask=mask_current, + sm_scale=sm_scale, + return_lse=True, + ) + + # Prefix: q_offset = prefix_len, kv starts from 0 (fully visible) + o2, s2 = single_prefill_with_kv_cache( + q_current, k_prefix, v_prefix, + causal=False, + sm_scale=sm_scale, + return_lse=True, + ) + merge_state_in_place(o1, s1, o2, s2) + ref_out = o1 + + step_diff = (be_out - ref_out).abs().max().item() + step_diffs.append(step_diff) + max_diff = max(max_diff, step_diff) + + test_pass = max_diff < tol + + result = { + "config_idx": cfg_idx, + "dllm_block_size": dllm_block_size, + "num_steps": num_steps, + "num_heads": num_heads, + "num_kv_heads": num_kv_heads, + "head_dim": head_dim, + "max_diff": max_diff, + "step_diffs": step_diffs, + "pass": test_pass, + } + results.append(result) + + if verbose: + status = "PASS" if test_pass else "FAIL" + print(f"\n [Test {cfg_idx:02d}] B={dllm_block_size:3d}, steps={num_steps}, " + f"heads={num_heads}/{num_kv_heads}, dim={head_dim}") + print(f" max_diff={max_diff:.6f} [{status}]") + if not test_pass: + print(f" step_diffs: {[f'{d:.6f}' for d in step_diffs]}") + + torch.cuda.empty_cache() + + print(f"\n{'='*100}") + print(f"Precision Alignment Test Summary") + print(f"{'='*100}") + + total_tests = len(results) + pass_count = sum(1 for r in results if r["pass"]) + + print(f"\n Total tests: {total_tests}") + print(f" Passed: {pass_count}/{total_tests} PASS") + print(f" max_diff (all tests): {max(r['max_diff'] for r in results):.6f}") + + failed = [r for r in results if not r["pass"]] + if failed: + print(f"\n [Failed Test Details]") + for r in failed: + print(f" Test {r['config_idx']:02d}: B={r['dllm_block_size']}, steps={r['num_steps']}, " + f"heads={r['num_heads']}/{r['num_kv_heads']}, dim={r['head_dim']}, max_diff={r['max_diff']:.6f}") + + overall_pass = all(r["pass"] for r in results) + overall_status = "ALL TESTS PASSED" if overall_pass else "SOME TESTS FAILED" + print(f"\n Overall Result: {overall_status}") + + return { + "results": results, + "overall_pass": overall_pass, + } + + +def test_sglang_vs_block_extend_cascade( + num_steps: int = 4, + dllm_block_size: int = 64, + num_heads: int = 32, + num_kv_heads: int = 8, + head_dim: int = 128, + page_size: int = 16, + warmup_iters: int = 10, + bench_iters: int = 100, + verbose: bool = True, + backend: str = "fa2", +): + """ + sglang_style_cascade_attention vs block_extend_cascade precision and performance comparison + + Key Design: + ═══════════════════════════════════════════════════════════════════════════════ + Ensure completely identical inputs: + 1. Use the same Q, K_current, V_current and K_prefix, V_prefix data + 2. sglang_style_cascade_attention: K_prefix/V_prefix converted to Paged KV Cache format + 3. block_extend_cascade: K_prefix/V_prefix uses contiguous storage format + + Comparison targets: + ═══════════════════════════════════════════════════════════════════════════════ + - sglang_style_cascade_attention (batch version): + * Current Chunk: BatchPrefillWithRaggedKVCacheWrapper (causal=False) + * Prefix: BatchPrefillWithPagedKVCacheWrapper (causal=False) + * Uses Paged KV Cache to store prefix + + - block_extend_cascade (single request version): + * Current Chunk: block_extend_attention_with_offset (Block Extend mask) + * Prefix: single_prefill_with_kv_cache (causal=False) + * Uses contiguous memory to store prefix + + Applicable conditions: + ═══════════════════════════════════════════════════════════════════════════════ + chunk_size == dllm_block_size (when causal mask = block_extend mask) + """ + from flashinfer.dllm import ( + sglang_style_cascade_attention, + block_extend_cascade, + ) + import time as time_module + + device = torch.device("cuda:0") + dtype = torch.float16 + tol = 1e-2 # Precision tolerance + + chunk_size = dllm_block_size # Key: chunk_size == dllm_block_size + tokens_per_request = num_steps * chunk_size + sm_scale = 1.0 / math.sqrt(head_dim) + + print(f"\n{'='*100}") + print(f"sglang_style_cascade_attention vs block_extend_cascade Precision and Performance Comparison") + print(f"{'='*100}") + print(f"Configuration:") + print(f" num_steps = {num_steps}") + print(f" dllm_block_size = {dllm_block_size}") + print(f" chunk_size = {chunk_size} (= dllm_block_size)") + print(f" tokens_per_request = {tokens_per_request}") + print(f" num_heads = {num_heads}") + print(f" num_kv_heads = {num_kv_heads}") + print(f" head_dim = {head_dim}") + print(f" page_size = {page_size}") + print(f" backend = {backend}") + print(f" Precision tolerance = {tol}") + print(f"\nComparison implementations:") + print(f" sglang_style_cascade_attention: Ragged (causal=False) + Paged (causal=False) + merge_state") + print(f" block_extend_cascade: BlockExtend (with offset) + single_prefill (causal=False) + merge_state") + + q_full = torch.randn(tokens_per_request, num_heads, head_dim, dtype=dtype, device=device) + k_full = torch.randn(tokens_per_request, num_kv_heads, head_dim, dtype=dtype, device=device) + v_full = torch.randn(tokens_per_request, num_kv_heads, head_dim, dtype=dtype, device=device) + + def split_chunks(tensor, chunk_size): + return [tensor[i*chunk_size:(i+1)*chunk_size] for i in range(tensor.shape[0] // chunk_size)] + + qs_chunks = split_chunks(q_full, chunk_size) + ks_chunks = split_chunks(k_full, chunk_size) + vs_chunks = split_chunks(v_full, chunk_size) + + # Workspace buffer for sglang_style + workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device) + + def create_paged_kv_cache_from_prefix(k_prefix, v_prefix, page_size): + """ + Convert contiguous K/V prefix to Paged KV Cache format + + Args: + k_prefix: [prefix_len, num_kv_heads, head_dim] + v_prefix: [prefix_len, num_kv_heads, head_dim] + page_size: Page size + + Returns: + paged_kv_cache: [num_pages, 2, page_size, num_kv_heads, head_dim] + paged_kv_indptr: [2] - [0, num_pages] + paged_kv_indices: [num_pages] - [0, 1, ..., num_pages-1] + paged_kv_last_page_len: [1] - Valid length of last page + """ + prefix_len = k_prefix.size(0) + num_kv_heads = k_prefix.size(1) + head_dim = k_prefix.size(2) + device = k_prefix.device + dtype = k_prefix.dtype + + # Calculate number of pages needed + num_pages = (prefix_len + page_size - 1) // page_size + last_page_len = prefix_len - (num_pages - 1) * page_size if num_pages > 0 else 0 + + # Create Paged KV Cache: [num_pages, 2, page_size, num_kv_heads, head_dim] + paged_kv_cache = torch.zeros( + num_pages, 2, page_size, num_kv_heads, head_dim, + dtype=dtype, device=device + ) + + # Fill data + for page_idx in range(num_pages): + start = page_idx * page_size + end = min(start + page_size, prefix_len) + actual_len = end - start + + # K: paged_kv_cache[page_idx, 0, :actual_len, :, :] + paged_kv_cache[page_idx, 0, :actual_len, :, :] = k_prefix[start:end] + # V: paged_kv_cache[page_idx, 1, :actual_len, :, :] + paged_kv_cache[page_idx, 1, :actual_len, :, :] = v_prefix[start:end] + + # indptr, indices, last_page_len + paged_kv_indptr = torch.tensor([0, num_pages], dtype=torch.int32, device=device) + paged_kv_indices = torch.arange(num_pages, dtype=torch.int32, device=device) + paged_kv_last_page_len = torch.tensor([last_page_len], dtype=torch.int32, device=device) + + return paged_kv_cache, paged_kv_indptr, paged_kv_indices, paged_kv_last_page_len + + print(f"\n{'='*100}") + print(f"Precision Comparison (Step by Step)") + print(f"{'='*100}") + + max_diff_all_steps = 0.0 + step_diffs = [] + + for step_idx in range(num_steps): + q_current = qs_chunks[step_idx] + k_current = ks_chunks[step_idx] + v_current = vs_chunks[step_idx] + + # Prefix (KV from all previous chunks) + if step_idx > 0: + k_prefix = k_full[:step_idx * chunk_size] + v_prefix = v_full[:step_idx * chunk_size] + else: + k_prefix = None + v_prefix = None + + be_out = block_extend_cascade( + q=q_current, + k_current=k_current, + v_current=v_current, + k_prefix=k_prefix, + v_prefix=v_prefix, + dllm_block_size=dllm_block_size, + sm_scale=sm_scale, + return_lse=False, + backend=backend, + ) + + # Construct batch_size=1 batch parameters + qo_indptr = torch.tensor([0, chunk_size], dtype=torch.int32, device=device) + kv_curr_indptr = torch.tensor([0, chunk_size], dtype=torch.int32, device=device) + + if k_prefix is not None: + # Convert contiguous prefix to Paged KV Cache format + paged_kv_cache, paged_kv_indptr, paged_kv_indices, paged_kv_last_page_len = \ + create_paged_kv_cache_from_prefix(k_prefix, v_prefix, page_size) + + sg_out = sglang_style_cascade_attention( + q=q_current, + k_current=k_current, + v_current=v_current, + qo_indptr=qo_indptr, + kv_curr_indptr=kv_curr_indptr, + paged_kv_cache=paged_kv_cache, + paged_kv_indptr=paged_kv_indptr, + paged_kv_indices=paged_kv_indices, + paged_kv_last_page_len=paged_kv_last_page_len, + page_size=page_size, + workspace_buffer=workspace_buffer, + sm_scale=sm_scale, + return_lse=False, + backend=backend, + ) + else: + # No prefix (first chunk) + sg_out = sglang_style_cascade_attention( + q=q_current, + k_current=k_current, + v_current=v_current, + qo_indptr=qo_indptr, + kv_curr_indptr=kv_curr_indptr, + paged_kv_cache=None, + paged_kv_indptr=None, + paged_kv_indices=None, + paged_kv_last_page_len=None, + page_size=page_size, + workspace_buffer=workspace_buffer, + sm_scale=sm_scale, + return_lse=False, + backend=backend, + ) + + # Compute difference + step_diff = (be_out - sg_out).abs().max().item() + step_diffs.append(step_diff) + max_diff_all_steps = max(max_diff_all_steps, step_diff) + + prefix_len = step_idx * chunk_size + if verbose: + print(f" Step {step_idx}: prefix_len={prefix_len:4d}, curr_len={chunk_size}, max_diff={step_diff:.6f}") + + precision_ok = max_diff_all_steps < tol + status = "PASS" if precision_ok else "FAIL" + print(f"\n [Precision Summary] max_diff (all steps) = {max_diff_all_steps:.6f} [{status}]") + print(f"\n{'='*100}") + print(f"Performance Comparison (measuring last Step: step={num_steps-1}, prefix_len={(num_steps-1)*chunk_size})") + print(f"{'='*100}") + + # Use last step for performance testing (longest prefix) + test_step = num_steps - 1 + q_current = qs_chunks[test_step] + k_current = ks_chunks[test_step] + v_current = vs_chunks[test_step] + k_prefix = k_full[:test_step * chunk_size] + v_prefix = v_full[:test_step * chunk_size] + prefix_len = test_step * chunk_size + + # Prepare Paged KV Cache (create only once) + paged_kv_cache, paged_kv_indptr, paged_kv_indices, paged_kv_last_page_len = \ + create_paged_kv_cache_from_prefix(k_prefix, v_prefix, page_size) + qo_indptr = torch.tensor([0, chunk_size], dtype=torch.int32, device=device) + kv_curr_indptr = torch.tensor([0, chunk_size], dtype=torch.int32, device=device) + + def run_be_cascade(): + return block_extend_cascade( + q=q_current, + k_current=k_current, + v_current=v_current, + k_prefix=k_prefix, + v_prefix=v_prefix, + dllm_block_size=dllm_block_size, + sm_scale=sm_scale, + return_lse=False, + backend=backend, + ) + + # Warmup + for _ in range(warmup_iters): + run_be_cascade() + torch.cuda.synchronize() + + # Benchmark + start = time_module.perf_counter() + for _ in range(bench_iters): + run_be_cascade() + torch.cuda.synchronize() + be_time = (time_module.perf_counter() - start) / bench_iters * 1000 + + def run_sg_cascade(): + return sglang_style_cascade_attention( + q=q_current, + k_current=k_current, + v_current=v_current, + qo_indptr=qo_indptr, + kv_curr_indptr=kv_curr_indptr, + paged_kv_cache=paged_kv_cache, + paged_kv_indptr=paged_kv_indptr, + paged_kv_indices=paged_kv_indices, + paged_kv_last_page_len=paged_kv_last_page_len, + page_size=page_size, + workspace_buffer=workspace_buffer, + sm_scale=sm_scale, + return_lse=False, + backend=backend, + ) + + # Warmup + for _ in range(warmup_iters): + run_sg_cascade() + torch.cuda.synchronize() + + # Benchmark + start = time_module.perf_counter() + for _ in range(bench_iters): + run_sg_cascade() + torch.cuda.synchronize() + sg_time = (time_module.perf_counter() - start) / bench_iters * 1000 + + print(f"\n Test parameters: chunk_size={chunk_size}, prefix_len={prefix_len}") + print(f" block_extend_cascade: {be_time:.3f} ms") + print(f" sglang_style_cascade_attention: {sg_time:.3f} ms") + + if be_time < sg_time: + speedup = sg_time / be_time + print(f" block_extend_cascade faster by {speedup:.2f}x") + else: + speedup = be_time / sg_time + print(f" sglang_style_cascade_attention faster by {speedup:.2f}x") + + + return { + "precision_ok": precision_ok, + "max_diff": max_diff_all_steps, + "step_diffs": step_diffs, + "be_time_ms": be_time, + "sg_time_ms": sg_time, + } + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Cascade vs Batch Comparison Test") + parser.add_argument("--batch-prefill", action="store_true", help="Multi Request Batch Prefill comparison (variable chunk_size)") + parser.add_argument("--step-by-step", action="store_true", help="Real step-by-step execution comparison (simulating pipeline dependencies)") + parser.add_argument("--step-by-step-cuda_graph", action="store_true", help="Real step-by-step execution comparison + CUDA Graph") + parser.add_argument("--single-req-cuda_graph", action="store_true", help="Single request step-by-step execution + CUDA Graph (pipeline dependencies)") + parser.add_argument("--fa2-fa3-be", action="store_true", help="FA2 vs FA3 BlockExtend vs Causal performance comparison") + parser.add_argument("--cascade-perf", action="store_true", help="Cascade interface performance comparison (batch_block_extend_cascade vs sglang_style)") + parser.add_argument("--cascade-perf-cuda_graph", action="store_true", help="Cascade interface performance comparison + CUDA Graph (pre-create Wrapper, only capture run)") + parser.add_argument("--cascade-precision", action="store_true", help="Cascade interface precision alignment test (sglang_style vs block_extend)") + parser.add_argument("--sglang-vs-be", action="store_true", help="sglang_style_cascade vs block_extend_cascade precision and performance comparison (equal inputs)") + parser.add_argument("--heterogeneous-prefix", action="store_true", help="Heterogeneous prefix test: different requests have different prefix lengths") + parser.add_argument("--cascade-chunk", action="store_true", help="Cascade Current Chunk test: K/V only has current block, requires kv_offset") + parser.add_argument("--tvm-ffi-slice-bug", action="store_true", help="TVM FFI slice tensor bug reproduction test") + parser.add_argument("--cuda_graph-reuse-bug", action="store_true", help="Test CUDA Graph mode wrapper reuse bug (exposes q_offsets address change issue)") + parser.add_argument("--precision-test", action="store_true", help="DLLM component precision test (compare with native Custom Mask FA2)") + parser.add_argument("--dtype", type=str, default="fp16", choices=["fp16", "bf16", "all"], + help="Data type for precision test: fp16, bf16, or all (test both)") + parser.add_argument("--verbose", "-v", action="store_true", help="Show detailed Step flow preview information") + parser.add_argument("--num_chunks", type=int, default=8) + parser.add_argument("--chunk_len", type=int, default=32, help="Chunk length (prefill seq_len)") + parser.add_argument("--dllm_block_size", type=int, default=None, help="DLLM block size (default = chunk_len)") + parser.add_argument("--num_heads", type=int, default=32) + parser.add_argument("--num_kv_heads", type=int, default=8) + parser.add_argument("--head_dim", type=int, default=128) + parser.add_argument("--total_tokens", type=int, default=256, help="Total tokens (for fair/width mode)") + parser.add_argument("--batch_size", type=int, default=8, help="Batch size (for --multi mode)") + parser.add_argument("--num_requests", type=int, default=4, help="Number of requests") + parser.add_argument("--tokens_per_request", type=int, default=256, help="Tokens per request") + parser.add_argument("--kv_len", type=int, default=2048, help="Total tokens (for --fa2-fa3-be mode, i.e., tokens_per_request)") + parser.add_argument("--chunk_sizes", type=str, default="32,64,128,256,512", help="Chunk sizes list, comma separated") + parser.add_argument("--backend", type=str, default="fa2", choices=["auto", "fa2", "fa3"], help="Backend implementation: auto/fa2/fa3") + args = parser.parse_args() + + dllm_bs = args.dllm_block_size if args.dllm_block_size is not None else args.chunk_len + + if args.fa2_fa3_be: + # FA2 vs FA3 BlockExtend vs Causal performance comparison (incremental Prefill scenario) + chunk_sizes = [int(x) for x in args.chunk_sizes.split(",")] + test_fa2_fa3_block_extending_vs_causal( + num_requests=args.num_requests, + chunk_sizes=chunk_sizes, + tokens_per_request=args.kv_len, # Use kv_len parameter as tokens_per_request + dllm_block_size=dllm_bs, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + head_dim=args.head_dim, + verbose=args.verbose, + ) + elif args.step_by_step_cuda_graph: + # Multi-request real step-by-step execution comparison + CUDA Graph + test_incremental_batchprefill_step_by_step_with_cuda_graph( + num_requests=args.num_requests, + tokens_per_request=args.tokens_per_request, + dllm_block_size=dllm_bs, + num_heads=args.num_heads, + head_dim=args.head_dim, + verbose=args.verbose, + ) + elif args.single_req_cuda_graph: + # Single request step-by-step execution + CUDA Graph (pipeline dependencies) + test_incremental_singlereq_prefill_step_by_step_with_cuda_graph( + tokens_per_request=args.tokens_per_request, + dllm_block_size=dllm_bs, + num_heads=args.num_heads, + head_dim=args.head_dim, + verbose=args.verbose, + ) + elif args.precision_test: + if args.dtype == "all": + print("\n" + "="*120) + print("Running FP16 precision test...") + print("="*120) + test_dllm_precision_vs_custom_mask_fa2( + verbose=args.verbose, + test_dtypes=[torch.float16], + ) + print("\n" + "="*120) + print("Running BF16 precision test...") + print("="*120) + test_dllm_precision_vs_custom_mask_fa2( + verbose=args.verbose, + test_dtypes=[torch.bfloat16], + ) + else: + if args.dtype == "fp16": + test_dtypes = [torch.float16] + else: # bf16 + test_dtypes = [torch.bfloat16] + test_dllm_precision_vs_custom_mask_fa2( + verbose=args.verbose, + test_dtypes=test_dtypes, + ) + elif args.cascade_perf: + # Multi-request end-to-end performance test, baseline is sglang_style_cascade_attention + chunk_sizes = [int(x) for x in args.chunk_sizes.split(",")] + test_cascade_interfaces_perf( + num_requests=args.num_requests, + tokens_per_request=args.tokens_per_request, + dllm_block_size=dllm_bs, + chunk_sizes=chunk_sizes, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + head_dim=args.head_dim, + verbose=args.verbose, + backend=args.backend, + ) + elif args.cascade_perf_cuda_graph: + # Multi-request CUDA Graph end-to-end performance test, baseline is sglang_style_cascade_attention CUDA Graph + chunk_sizes = [int(x) for x in args.chunk_sizes.split(",")] + test_cascade_interfaces_perf_with_cuda_graph( + num_requests=args.num_requests, + tokens_per_request=args.tokens_per_request, + dllm_block_size=dllm_bs, + chunk_sizes=chunk_sizes, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + head_dim=args.head_dim, + verbose=args.verbose, + backend=args.backend, + ) + elif args.cascade_precision: + # Pure precision test, verify numerical consistency between block_extend_cascade and sglang_style_cascade_attention + test_cascade_precision_alignment( + verbose=args.verbose, + ) + elif args.sglang_vs_be: + # Single request cascade performance test, baseline is sglang_style_cascade_attention + num_steps = args.tokens_per_request // dllm_bs + test_sglang_vs_block_extend_cascade( + num_steps=num_steps, + dllm_block_size=dllm_bs, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + head_dim=args.head_dim, + verbose=args.verbose, + backend=args.backend, + ) + elif args.heterogeneous_prefix: + # Heterogeneous prefix test: different requests have different prefix lengths + for be in ["fa2", "fa3"]: + test_heterogeneous_prefix_batch( + verbose=args.verbose, + backend=be, + ) + + elif args.cascade_chunk: + # Cascade Current Chunk test: K/V only has current block, requires kv_offset + for be in ["fa2", "fa3"]: + test_cascade_current_chunk_batch( + verbose=args.verbose, + backend=be, + ) \ No newline at end of file diff --git a/tests/attention/test_dllm_vs_flex_attention.py b/tests/attention/test_dllm_vs_flex_attention.py new file mode 100644 index 0000000000..8d80fc759d --- /dev/null +++ b/tests/attention/test_dllm_vs_flex_attention.py @@ -0,0 +1,1262 @@ +"""FlashInfer Block Extend vs PyTorch Flex Attention Performance Comparison + +Compares three implementations: +1. BatchBlockExtendRaggedOffsetWrapper (FlashInfer, ragged KV) +2. BatchBlockExtendPagedOffsetWrapper (FlashInfer, paged KV) +3. torch.nn.attention.flex_attention + create_block_mask (PyTorch native) + +Mask rule (Block Extend): + q_global = q_offset + q_idx + kv_global = kv_offset + kv_idx + mask[q, k] = (q_global // dllm_block_size) >= (kv_global // dllm_block_size) + +Flex Attention KV Cache format: + Q: (B, Hq, L, E) - dense BHSD + K: (B, Hkv, S, E) - dense BHSD + V: (B, Hkv, S, Ev) - dense BHSD + No paged KV cache, just regular dense tensors +""" + +import torch +import time +import math +import sys + +# ============================================================ +# FlashInfer imports +# ============================================================ +try: + from flashinfer import single_prefill_with_kv_cache + from flashinfer.dllm import ( + BatchBlockExtendPagedOffsetWrapper, + BatchBlockExtendRaggedOffsetWrapper, + ) + HAS_FLASHINFER = True +except ImportError as e: + HAS_FLASHINFER = False + print(f"[WARN] flashinfer not available: {e}") + print(" Will skip FlashInfer benchmarks") +except Exception as e: + HAS_FLASHINFER = False + print(f"[ERROR] flashinfer import failed with unexpected error: {e}") + print(" Will skip FlashInfer benchmarks") + +# ============================================================ +# Flex Attention imports (requires PyTorch >= 2.5) +# ============================================================ +try: + from torch.nn.attention.flex_attention import ( + flex_attention, + create_block_mask, + ) + HAS_FLEX_ATTENTION = True +except ImportError: + HAS_FLEX_ATTENTION = False + print("[WARN] flex_attention not available (requires PyTorch >= 2.5)") + + +# ============================================================ +# Reference implementation +# ============================================================ +def compute_block_extend_reference( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + dllm_block_size: int, + q_offset: int = 0, + sm_scale: float = None, +) -> torch.Tensor: + """Reference: single_prefill_with_kv_cache + custom_mask""" + qo_len = q.shape[0] + kv_len = k.shape[0] + head_dim = q.shape[-1] + device = q.device + + if sm_scale is None: + sm_scale = 1.0 / math.sqrt(head_dim) + + q_pos = torch.arange(qo_len, device=device) + q_offset + k_pos = torch.arange(kv_len, device=device) + q_block = q_pos.unsqueeze(1) // dllm_block_size + k_block = k_pos.unsqueeze(0) // dllm_block_size + mask_2d = (q_block >= k_block).to(torch.uint8) + + return single_prefill_with_kv_cache( + q, k, v, custom_mask=mask_2d, sm_scale=sm_scale, + ) + + +# ============================================================ +# Flex Attention helper: build block_extend mask_mod +# ============================================================ +def make_block_extend_mask_mod(dllm_block_size: int, q_offset: int = 0): + """ + Returns the mask_mod function used by flex_attention + + mask_mod(b, h, q_idx, kv_idx) -> bool + True = allow attend, False = mask out + """ + def block_extend_mask(b, h, q_idx, kv_idx): + q_global = q_idx + q_offset + q_blk = q_global // dllm_block_size + kv_blk = kv_idx // dllm_block_size + return q_blk >= kv_blk + return block_extend_mask + + +# ============================================================ +# Memory utility +# ============================================================ +def get_memory_stats(device=None): + """Get current GPU memory stats in MB.""" + if device is None: + device = torch.device("cuda:0") + torch.cuda.synchronize(device) + allocated = torch.cuda.memory_allocated(device) / 1024**2 # MB + reserved = torch.cuda.memory_reserved(device) / 1024**2 # MB + max_allocated = torch.cuda.max_memory_allocated(device) / 1024**2 # MB + return { + "allocated_mb": allocated, + "reserved_mb": reserved, + "max_allocated_mb": max_allocated, + } + +def reset_peak_memory(device=None): + """Reset peak memory stats.""" + if device is None: + device = torch.device("cuda:0") + torch.cuda.reset_peak_memory_stats(device) + +def measure_memory_fn(fn, warmup_iters=5): + """Measure peak memory usage of a callable. + + Returns: + dict with keys: + - peak_allocated_mb: Peak allocated memory during execution + - peak_reserved_mb: Peak reserved memory + - baseline_allocated_mb: Memory before execution + """ + # Warmup + for _ in range(warmup_iters): + fn() + torch.cuda.synchronize() + torch.cuda.empty_cache() + + # Measure baseline + baseline = get_memory_stats() + reset_peak_memory() + + # Execute and measure peak + fn() + torch.cuda.synchronize() + + peak = get_memory_stats() + + return { + "baseline_allocated_mb": baseline["allocated_mb"], + "peak_allocated_mb": peak["max_allocated_mb"], + "peak_reserved_mb": peak["reserved_mb"], + "memory_increase_mb": peak["max_allocated_mb"] - baseline["allocated_mb"], + } + + +# ============================================================ +# Benchmark utility +# ============================================================ +def benchmark_fn(fn, warmup_iters=20, bench_iters=100, label=""): + """Benchmark a callable, return average time in ms.""" + for _ in range(warmup_iters): + fn() + torch.cuda.synchronize() + + start = time.perf_counter() + for _ in range(bench_iters): + fn() + torch.cuda.synchronize() + elapsed_ms = (time.perf_counter() - start) / bench_iters * 1000 + return elapsed_ms + + +def benchmark_with_cuda_graph(fn, warmup_iters=20, bench_iters=100, label=""): + """Benchmark with CUDA Graph capture, return average time in ms.""" + # warmup + for _ in range(warmup_iters): + fn() + torch.cuda.synchronize() + + # capture + stream = torch.cuda.Stream() + with torch.cuda.stream(stream): + fn() + stream.synchronize() + + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph, stream=stream): + fn() + + # warmup cuda_graph + for _ in range(warmup_iters): + graph.replay() + torch.cuda.synchronize() + + # bench + start = time.perf_counter() + for _ in range(bench_iters): + graph.replay() + torch.cuda.synchronize() + elapsed_ms = (time.perf_counter() - start) / bench_iters * 1000 + + del graph + return elapsed_ms + + +# ============================================================ +# Main benchmark +# ============================================================ +def bench_flashinfer_vs_flex_attention( + num_requests: int = 4, + total_kv_len: int = 2048, + qo_len: int = 256, + dllm_block_size: int = 32, + num_heads: int = 32, + num_kv_heads: int = 8, + head_dim: int = 128, + page_size: int = 16, + dtype: torch.dtype = torch.float16, + warmup_iters: int = 20, + bench_iters: int = 100, + verify: bool = True, +): + """ + Main benchmark: FlashInfer Block Extend (Ragged + Paged) vs Flex Attention + + Scenario: Each request has Q length = qo_len, KV length = total_kv_len + q_offset = total_kv_len - qo_len (simulates the last step of incremental prefill) + """ + device = torch.device("cuda:0") + sm_scale = 1.0 / math.sqrt(head_dim) + q_offset = total_kv_len - qo_len + + print(f"\n{'='*80}") + print(f"FlashInfer Block Extend vs PyTorch Flex Attention") + print(f"{'='*80}") + print(f" num_requests = {num_requests}") + print(f" total_kv_len = {total_kv_len}") + print(f" qo_len = {qo_len}") + print(f" q_offset = {q_offset}") + print(f" dllm_block_size = {dllm_block_size}") + print(f" num_heads = {num_heads}") + print(f" num_kv_heads = {num_kv_heads}") + print(f" head_dim = {head_dim}") + print(f" page_size = {page_size}") + print(f" dtype = {dtype}") + print() + + results = {} + + # =========================================================== + # Data Preparation + # =========================================================== + # Per-request tensors + all_q = [torch.randn(qo_len, num_heads, head_dim, dtype=dtype, device=device) + for _ in range(num_requests)] + all_k = [torch.randn(total_kv_len, num_kv_heads, head_dim, dtype=dtype, device=device) + for _ in range(num_requests)] + all_v = [torch.randn(total_kv_len, num_kv_heads, head_dim, dtype=dtype, device=device) + for _ in range(num_requests)] + + # Ragged layout: concat along token dim (NHD format) + q_ragged = torch.cat(all_q, dim=0) # [B*qo_len, H, D] + k_ragged = torch.cat(all_k, dim=0) # [B*kv_len, Hkv, D] + v_ragged = torch.cat(all_v, dim=0) + + # Flex Attention layout: BHSD + # Q: (B, Hq, qo_len, D) K/V: (B, Hkv, kv_len, D) + q_bhsd = torch.stack(all_q, dim=0).permute(0, 2, 1, 3).contiguous() + k_bhsd = torch.stack(all_k, dim=0).permute(0, 2, 1, 3).contiguous() + v_bhsd = torch.stack(all_v, dim=0).permute(0, 2, 1, 3).contiguous() + # q_bhsd: [B, Hq, qo_len, D], k_bhsd: [B, Hkv, kv_len, D] + + # Paged KV cache preparation + num_pages_per_req = (total_kv_len + page_size - 1) // page_size + total_pages = num_pages_per_req * num_requests + # kv_layout="NHD" -> paged_kv_cache: (total_pages, 2, page_size, num_kv_heads, head_dim) + paged_kv_cache = torch.zeros( + total_pages, 2, page_size, num_kv_heads, head_dim, + dtype=dtype, device=device, + ) + # Fill pages from all_k/all_v + paged_kv_indices_list = [] + paged_kv_last_page_lens = [] + for req_idx in range(num_requests): + for page_idx in range(num_pages_per_req): + global_page = req_idx * num_pages_per_req + page_idx + start = page_idx * page_size + end = min(start + page_size, total_kv_len) + length = end - start + # all_k shape: (kv_len, num_kv_heads, head_dim) — already NHD + # page slot: (page_size, num_kv_heads, head_dim) — NHD + paged_kv_cache[global_page, 0, :length, :, :] = all_k[req_idx][start:end] + paged_kv_cache[global_page, 1, :length, :, :] = all_v[req_idx][start:end] + paged_kv_indices_list.append(global_page) + last_page_len = total_kv_len - (num_pages_per_req - 1) * page_size + paged_kv_last_page_lens.append(last_page_len) + + paged_kv_indices = torch.tensor(paged_kv_indices_list, dtype=torch.int32, device=device) + paged_kv_indptr = torch.tensor( + [i * num_pages_per_req for i in range(num_requests + 1)], + dtype=torch.int32, device=device, + ) + paged_kv_last_page_len = torch.tensor(paged_kv_last_page_lens, dtype=torch.int32, device=device) + + # indptr for ragged + qo_indptr = torch.tensor( + [i * qo_len for i in range(num_requests + 1)], + dtype=torch.int32, device=device, + ) + kv_indptr = torch.tensor( + [i * total_kv_len for i in range(num_requests + 1)], + dtype=torch.int32, device=device, + ) + q_offsets = torch.full((num_requests,), q_offset, dtype=torch.int32, device=device) + + # =========================================================== + # 1. Correctness Validation (single request, against reference) + # =========================================================== + if verify and HAS_FLASHINFER: + print(f"{'='*60}") + print(f"Correctness Validation (request 0)") + print(f"{'='*60}") + + ref_out = compute_block_extend_reference( + all_q[0], all_k[0], all_v[0], + dllm_block_size=dllm_block_size, + q_offset=q_offset, + sm_scale=sm_scale, + ) + + tol = 1e-2 + + # Ragged Offset + ws = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device) + single_qo_indptr = torch.tensor([0, qo_len], dtype=torch.int32, device=device) + single_kv_indptr = torch.tensor([0, total_kv_len], dtype=torch.int32, device=device) + single_q_offsets = torch.tensor([q_offset], dtype=torch.int32, device=device) + + ragged_wrapper = BatchBlockExtendRaggedOffsetWrapper( + ws, kv_layout="NHD", dllm_block_size=dllm_block_size, + ) + ragged_wrapper.plan( + qo_indptr=single_qo_indptr, kv_indptr=single_kv_indptr, + num_qo_heads=num_heads, num_kv_heads=num_kv_heads, head_dim=head_dim, + q_data_type=dtype, sm_scale=sm_scale, q_offsets=single_q_offsets, + ) + ragged_out = ragged_wrapper.run(all_q[0], all_k[0], all_v[0]) + ragged_diff = (ragged_out - ref_out).abs().max().item() + ragged_pass = ragged_diff < tol + print(f" [Ragged Offset] max_diff={ragged_diff:.6f} {'PASS' if ragged_pass else 'FAIL'}") + assert ragged_pass, f"Ragged Offset max_diff={ragged_diff} exceeds tolerance {tol}" + + # Paged Offset + single_paged_indptr = torch.tensor([0, num_pages_per_req], dtype=torch.int32, device=device) + single_paged_indices = torch.arange(num_pages_per_req, dtype=torch.int32, device=device) + single_paged_last = torch.tensor([paged_kv_last_page_lens[0]], dtype=torch.int32, device=device) + + paged_wrapper = BatchBlockExtendPagedOffsetWrapper( + ws, kv_layout="NHD", dllm_block_size=dllm_block_size, + ) + paged_wrapper.plan( + qo_indptr=single_qo_indptr, paged_kv_indptr=single_paged_indptr, + paged_kv_indices=single_paged_indices, paged_kv_last_page_len=single_paged_last, + num_qo_heads=num_heads, num_kv_heads=num_kv_heads, head_dim=head_dim, + page_size=page_size, q_data_type=dtype, sm_scale=sm_scale, + q_offsets=single_q_offsets, + ) + paged_out = paged_wrapper.run(all_q[0], paged_kv_cache) + paged_diff = (paged_out - ref_out).abs().max().item() + paged_pass = paged_diff < tol + print(f" [Paged Offset] max_diff={paged_diff:.6f} {'PASS' if paged_pass else 'FAIL'}") + assert paged_pass, f"Paged Offset max_diff={paged_diff} exceeds tolerance {tol}" + + # Flex Attention + if HAS_FLEX_ATTENTION: + mask_mod = make_block_extend_mask_mod(dllm_block_size, q_offset) + block_mask = create_block_mask( + mask_mod, B=1, H=1, Q_LEN=qo_len, KV_LEN=total_kv_len, device=device, + ) + # single request, BHSD format + q_single = all_q[0].unsqueeze(0).permute(0, 2, 1, 3).contiguous() + k_single = all_k[0].unsqueeze(0).permute(0, 2, 1, 3).contiguous() + v_single = all_v[0].unsqueeze(0).permute(0, 2, 1, 3).contiguous() + + flex_out_bhsd = flex_attention( + q_single, k_single, v_single, + block_mask=block_mask, scale=sm_scale, + enable_gqa=(num_heads != num_kv_heads), + ) + # convert back: (1, H, L, D) -> (L, H, D) + flex_out = flex_out_bhsd.squeeze(0).permute(1, 0, 2).contiguous() + flex_diff = (flex_out - ref_out).abs().max().item() + flex_pass = flex_diff < tol + print(f" [Flex Attention] max_diff={flex_diff:.6f} {'PASS' if flex_pass else 'FAIL'}") + assert flex_pass, f"Flex Attention max_diff={flex_diff} exceeds tolerance {tol}" + + del ws, ragged_wrapper, paged_wrapper + torch.cuda.empty_cache() + print() + + # =========================================================== + # 2. FlashInfer Ragged Offset Benchmark + # =========================================================== + if HAS_FLASHINFER: + print(f"{'='*60}") + print(f"[Bench] FlashInfer Ragged Offset (batch={num_requests})") + print(f"{'='*60}") + + ws_ragged = torch.empty(256 * 1024 * 1024, dtype=torch.uint8, device=device) + ragged_bench_wrapper = BatchBlockExtendRaggedOffsetWrapper( + ws_ragged, kv_layout="NHD", dllm_block_size=dllm_block_size, + ) + ragged_bench_wrapper.plan( + qo_indptr=qo_indptr, kv_indptr=kv_indptr, + num_qo_heads=num_heads, num_kv_heads=num_kv_heads, head_dim=head_dim, + q_data_type=dtype, sm_scale=sm_scale, q_offsets=q_offsets, + ) + + ragged_out_buf = torch.empty( + num_requests * qo_len, num_heads, head_dim, dtype=dtype, device=device, + ) + + def run_ragged(): + ragged_out_buf.copy_( + ragged_bench_wrapper.run(q_ragged, k_ragged, v_ragged) + ) + + # no-cuda_graph benchmark + t_ragged = benchmark_fn(run_ragged, warmup_iters, bench_iters) + print(f" No cuda_graph: {t_ragged:.3f} ms") + + # cuda_graph benchmark + t_ragged_cuda_graph = benchmark_with_cuda_graph(run_ragged, warmup_iters, bench_iters) + print(f" With cuda_graph: {t_ragged_cuda_graph:.3f} ms") + + # Memory measurement + mem_ragged = measure_memory_fn(run_ragged, warmup_iters=5) + print(f" Memory: Peak={mem_ragged['peak_allocated_mb']:.1f} MB, " + f"Increase={mem_ragged['memory_increase_mb']:.1f} MB") + + results["ragged_offset"] = { + "no_cuda_graph_ms": t_ragged, + "cuda_graph_ms": t_ragged_cuda_graph, + "memory": mem_ragged, + } + del ws_ragged, ragged_bench_wrapper + torch.cuda.empty_cache() + print() + + # =========================================================== + # 3. FlashInfer Paged Offset Benchmark + # =========================================================== + if HAS_FLASHINFER: + print(f"{'='*60}") + print(f"[Bench] FlashInfer Paged Offset (batch={num_requests})") + print(f"{'='*60}") + + ws_paged = torch.empty(256 * 1024 * 1024, dtype=torch.uint8, device=device) + paged_bench_wrapper = BatchBlockExtendPagedOffsetWrapper( + ws_paged, kv_layout="NHD", dllm_block_size=dllm_block_size, + ) + paged_bench_wrapper.plan( + qo_indptr=qo_indptr, paged_kv_indptr=paged_kv_indptr, + paged_kv_indices=paged_kv_indices, + paged_kv_last_page_len=paged_kv_last_page_len, + num_qo_heads=num_heads, num_kv_heads=num_kv_heads, head_dim=head_dim, + page_size=page_size, q_data_type=dtype, sm_scale=sm_scale, + q_offsets=q_offsets, + ) + + paged_out_buf = torch.empty( + num_requests * qo_len, num_heads, head_dim, dtype=dtype, device=device, + ) + + def run_paged(): + paged_out_buf.copy_( + paged_bench_wrapper.run(q_ragged, paged_kv_cache) + ) + + t_paged = benchmark_fn(run_paged, warmup_iters, bench_iters) + print(f" No cuda_graph: {t_paged:.3f} ms") + + t_paged_cuda_graph = benchmark_with_cuda_graph(run_paged, warmup_iters, bench_iters) + print(f" With cuda_graph: {t_paged_cuda_graph:.3f} ms") + + # Memory measurement + mem_paged = measure_memory_fn(run_paged, warmup_iters=5) + print(f" Memory: Peak={mem_paged['peak_allocated_mb']:.1f} MB, " + f"Increase={mem_paged['memory_increase_mb']:.1f} MB") + + results["paged_offset"] = { + "no_cuda_graph_ms": t_paged, + "cuda_graph_ms": t_paged_cuda_graph, + "memory": mem_paged, + } + del ws_paged, paged_bench_wrapper + torch.cuda.empty_cache() + print() + + # =========================================================== + # 4. Flex Attention Benchmark + # =========================================================== + if HAS_FLEX_ATTENTION: + print(f"{'='*60}") + print(f"[Bench] PyTorch Flex Attention (batch={num_requests})") + print(f"{'='*60}") + print(f" KV format: dense BHSD Q({num_requests},{num_heads},{qo_len},{head_dim})") + print(f" K({num_requests},{num_kv_heads},{total_kv_len},{head_dim})") + print(f" V({num_requests},{num_kv_heads},{total_kv_len},{head_dim})") + + # Create block_mask - all requests share the same mask pattern + mask_mod = make_block_extend_mask_mod(dllm_block_size, q_offset) + block_mask = create_block_mask( + mask_mod, B=num_requests, H=1, + Q_LEN=qo_len, KV_LEN=total_kv_len, device=device, + ) + + use_gqa = (num_heads != num_kv_heads) + flex_out_buf = torch.empty( + num_requests, num_heads, qo_len, head_dim, dtype=dtype, device=device, + ) + + # ---------- flex_attention (no compile) ---------- + # Skip no-compile for large sequences (materializes full QxKV score matrix - OOM) + if qo_len * total_kv_len <= 4096 * 4096: + def run_flex_no_compile(): + flex_out_buf.copy_( + flex_attention( + q_bhsd, k_bhsd, v_bhsd, + block_mask=block_mask, scale=sm_scale, + enable_gqa=use_gqa, + ) + ) + + t_flex_no_compile = benchmark_fn( + run_flex_no_compile, warmup_iters, bench_iters, + ) + print(f" No compile: {t_flex_no_compile:.3f} ms") + + # Memory measurement + mem_flex_no_compile = measure_memory_fn(run_flex_no_compile, warmup_iters=5) + print(f" Memory (no compile): Peak={mem_flex_no_compile['peak_allocated_mb']:.1f} MB, " + f"Increase={mem_flex_no_compile['memory_increase_mb']:.1f} MB") + + results["flex_no_compile"] = { + "no_cuda_graph_ms": t_flex_no_compile, + "memory": mem_flex_no_compile, + } + else: + print(f" No compile: SKIPPED (seq too large, would OOM)") + + # ---------- flex_attention (compiled) ---------- + # Recreate compile instance + reset dynamo cache each time to avoid cross-tier shape cumulative recompilation + torch._dynamo.reset() + _flex_attention_compiled = torch.compile(flex_attention, dynamic=False) + + def run_flex_compiled(): + flex_out_buf.copy_( + _flex_attention_compiled( + q_bhsd, k_bhsd, v_bhsd, + block_mask=block_mask, scale=sm_scale, + enable_gqa=use_gqa, + ) + ) + + t_flex_compiled = benchmark_fn( + run_flex_compiled, warmup_iters, bench_iters, + ) + print(f" Compiled: {t_flex_compiled:.3f} ms") + + # Memory measurement + mem_flex = measure_memory_fn(run_flex_compiled, warmup_iters=5) + print(f" Memory (compiled): Peak={mem_flex['peak_allocated_mb']:.1f} MB, " + f"Increase={mem_flex['memory_increase_mb']:.1f} MB") + + results["flex_compiled"] = { + "no_cuda_graph_ms": t_flex_compiled, + "memory": mem_flex, + } + + # ---------- flex_attention (compiled + reduce-overhead / internal cuda_graph) ---------- + torch._dynamo.reset() + _flex_attention_reduce_overhead = torch.compile(flex_attention, dynamic=False, mode="reduce-overhead") + + def run_flex_reduce_overhead(): + flex_out_buf.copy_( + _flex_attention_reduce_overhead( + q_bhsd, k_bhsd, v_bhsd, + block_mask=block_mask, scale=sm_scale, + enable_gqa=use_gqa, + ) + ) + + t_flex_reduce = benchmark_fn( + run_flex_reduce_overhead, warmup_iters, bench_iters, + ) + print(f" Compiled (reduce-overhead): {t_flex_reduce:.3f} ms") + results["flex_reduce_overhead"] = {"no_cuda_graph_ms": t_flex_reduce} + + # ---------- flex_attention (compiled + manual CUDA Graph) ---------- + try: + t_flex_cuda_graph = benchmark_with_cuda_graph( + run_flex_compiled, warmup_iters, bench_iters, + ) + print(f" Compiled + CUDA Graph: {t_flex_cuda_graph:.3f} ms") + results["flex_compiled"]["cuda_graph_ms"] = t_flex_cuda_graph + except Exception as e: + print(f" Compiled + CUDA Graph: FAILED ({e})") + + # =========================================================== + # 5. Summary + # =========================================================== + print(f"\n{'='*80}") + print(f"Summary (batch={num_requests}, qo_len={qo_len}, kv_len={total_kv_len}, " + f"block_size={dllm_block_size})") + print(f"{'='*80}") + print(f" {'Method':<40} | {'No cuda_graph (ms)':>12} | {'With cuda_graph (ms)':>12} | {'Mem Incr (MB)':>14}") + print(f" {'-'*40}-+-{'-'*12}-+-{'-'*12}-+-{'-'*14}") + + for key, label in [ + ("ragged_offset", "FlashInfer Ragged Offset"), + ("paged_offset", "FlashInfer Paged Offset"), + ("flex_no_compile", "Flex Attention (no compile)"), + ("flex_compiled", "Flex Attention (compiled)"), + ("flex_reduce_overhead", "Flex Attention (reduce-overhead)"), + ]: + if key in results: + r = results[key] + no_cuda_graph = f"{r.get('no_cuda_graph_ms', 0):.3f}" if 'no_cuda_graph_ms' in r else "N/A" + cuda_graph = f"{r.get('cuda_graph_ms', 0):.3f}" if 'cuda_graph_ms' in r else "N/A" + mem = r.get('memory', {}) + mem_incr = f"{mem['memory_increase_mb']:.1f}" if mem and 'memory_increase_mb' in mem else "N/A" + print(f" {label:<40} | {no_cuda_graph:>12} | {cuda_graph:>12} | {mem_incr:>14}") + + # Speedup vs flex_compiled + if "flex_compiled" in results and "ragged_offset" in results: + flex_t = results["flex_compiled"]["no_cuda_graph_ms"] + ragged_t = results["ragged_offset"]["no_cuda_graph_ms"] + paged_t = results.get("paged_offset", {}).get("no_cuda_graph_ms", 0) + print(f"\n Speedup (vs Flex compiled, no cuda_graph):") + print(f" Ragged Offset: {flex_t / ragged_t:.2f}x") + if paged_t > 0: + print(f" Paged Offset: {flex_t / paged_t:.2f}x") + + if "flex_compiled" in results and "ragged_offset" in results and "cuda_graph_ms" in results["ragged_offset"]: + ragged_cuda_graph = results["ragged_offset"]["cuda_graph_ms"] + paged_cuda_graph = results.get("paged_offset", {}).get("cuda_graph_ms", 0) + flex_t = results["flex_compiled"]["no_cuda_graph_ms"] + print(f"\n Speedup (FlashInfer cuda_graph vs Flex compiled):") + print(f" Ragged Offset cuda_graph: {flex_t / ragged_cuda_graph:.2f}x") + if paged_cuda_graph > 0: + print(f" Paged Offset cuda_graph: {flex_t / paged_cuda_graph:.2f}x") + + return results + + +# ============================================================ +# Sweep across different sequence lengths +# ============================================================ +def bench_sweep_seq_lengths( + num_requests: int = 4, + dllm_block_size: int = 32, + num_heads: int = 32, + num_kv_heads: int = 8, + head_dim: int = 128, + page_size: int = 16, + dtype: torch.dtype = torch.float16, +): + """ + Seven-tier context length test: + 1K / 2K / 4K / 8K / 16K / 24K / 32K + + Each tier uses fixed chunk_size=256, tests per-chunk latency for the last chunk + (i.e., q_offset = kv_len - 256, simulating the last step of incremental prefill) + """ + chunk_size = 256 + configs = [ + # (total_kv_len, qo_len, batch, label) + (1024, chunk_size, num_requests, "1K"), + (2048, chunk_size, num_requests, "2K"), + (4096, chunk_size, num_requests, "4K"), + (8192, chunk_size, num_requests, "8K"), + (16384, chunk_size, num_requests, "16K"), + (24576, chunk_size, min(num_requests, 2), "24K"), + (32768, chunk_size, 1, "32K"), + ] + + all_results = {} + for total_kv_len, qo_len, batch, desc in configs: + num_chunks = total_kv_len // qo_len + q_offset = total_kv_len - qo_len + tag = f"kv{total_kv_len}_q{qo_len}" + print(f"\n{'#'*80}") + print(f"# Tier: {desc}") + print(f"# batch={batch}, kv_len={total_kv_len}, chunk_size={qo_len}, " + f"num_chunks={num_chunks}, q_offset={q_offset}") + print(f"{'#'*80}") + r = bench_flashinfer_vs_flex_attention( + num_requests=batch, + total_kv_len=total_kv_len, + qo_len=qo_len, + dllm_block_size=dllm_block_size, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + page_size=page_size, + dtype=dtype, + warmup_iters=10, + bench_iters=50, + verify=True, + ) + all_results[tag] = r + + # Final comparison table + print(f"\n\n{'='*120}") + print(f"Context Length Comparison (chunk={chunk_size}, block_size={dllm_block_size})") + print(f"{'='*120}") + header = (f" {'Tier':<20} | {'KV Len':>8} | {'Batch':>5} | {'Chunks':>6} | " + f"{'Ragged(ms)':>12} | {'Ragged cuda_graph':>12} | {'Paged(ms)':>12} | " + f"{'Flex(ms)':>12} | {'Speedup':>10}") + print(header) + print(f" {'-'*20}-+-{'-'*8}-+-{'-'*5}-+-{'-'*6}-+-{'-'*12}-+-{'-'*12}-+-{'-'*12}-+-{'-'*12}-+-{'-'*10}") + + for total_kv_len, qo_len, batch, desc in configs: + tag = f"kv{total_kv_len}_q{qo_len}" + r = all_results.get(tag, {}) + num_chunks = total_kv_len // qo_len + ragged = r.get("ragged_offset", {}).get("no_cuda_graph_ms", 0) + ragged_cuda_graph = r.get("ragged_offset", {}).get("cuda_graph_ms", 0) + paged = r.get("paged_offset", {}).get("no_cuda_graph_ms", 0) + flex = r.get("flex_compiled", {}).get("no_cuda_graph_ms", 0) + speedup = f"{flex / ragged_cuda_graph:.2f}x" if ragged_cuda_graph > 0 and flex > 0 else "N/A" + print(f" {desc:<20} | {total_kv_len:>8} | {batch:>5} | {num_chunks:>6} | " + f"{ragged:>12.3f} | {ragged_cuda_graph:>12.3f} | {paged:>12.3f} | " + f"{flex:>12.3f} | {speedup:>10}") + + # Note about full prefill timing + print(f"\n Note: Above shows per-chunk latency for the last chunk (max q_offset)") + print(f" In actual full prefill, earlier chunks have shorter kv_len, thus lower latency") + print(f" 24K/32K tiers use reduced batch to avoid OOM") + + return all_results + + +# ============================================================ +# Full Prefill Four-tier Context Length Test +# ============================================================ +def bench_full_prefill_four_tiers( + dllm_block_size: int = 32, + num_heads: int = 32, + num_kv_heads: int = 8, + head_dim: int = 128, + page_size: int = 16, + dtype: torch.dtype = torch.float16, +): + """ + Full Prefill scenario: qo_len = kv_len, q_offset = 0 + Process the entire sequence at once, simulating initial prefill + + Seven tiers: 1K / 2K / 4K / 8K / 16K / 24K / 32K + Long sequences use batch=1 to avoid OOM + """ + configs = [ + # (seq_len, batch, label) + (1024, 4, "1K"), + (2048, 4, "2K"), + (4096, 4, "4K"), + (8192, 4, "8K"), + (16384, 1, "16K"), + (24576, 1, "24K"), + (32768, 1, "32K"), + ] + + all_results = {} + for seq_len, batch, desc in configs: + tag = f"seq{seq_len}_b{batch}" + print(f"\n{'#'*80}") + print(f"# Full Prefill | {desc}") + print(f"# batch={batch}, seq_len={seq_len} (qo_len=kv_len={seq_len}, q_offset=0)") + print(f"{'#'*80}") + r = bench_flashinfer_vs_flex_attention( + num_requests=batch, + total_kv_len=seq_len, + qo_len=seq_len, # Full prefill: Q and KV have same length + dllm_block_size=dllm_block_size, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + page_size=page_size, + dtype=dtype, + warmup_iters=5, + bench_iters=20, + verify=(seq_len <= 4096), # Skip verify for long sequences (custom_mask too large) + ) + all_results[tag] = r + + # Summary table + print(f"\n\n{'='*110}") + print(f"Full Prefill Context Comparison (block_size={dllm_block_size})") + print(f"{'='*110}") + print(f" {'Tier':<18} | {'Seq Len':>8} | {'Batch':>5} | " + f"{'Ragged':>10} | {'Ragged cuda_graph':>10} | {'Paged':>10} | " + f"{'Flex compiled':>14} | {'Speedup':>10}") + print(f" {'-'*18}-+-{'-'*8}-+-{'-'*5}-+-" + f"{'-'*10}-+-{'-'*10}-+-{'-'*10}-+-" + f"{'-'*14}-+-{'-'*10}") + + for seq_len, batch, desc in configs: + tag = f"seq{seq_len}_b{batch}" + r = all_results.get(tag, {}) + ragged = r.get("ragged_offset", {}).get("no_cuda_graph_ms", 0) + ragged_cuda_graph = r.get("ragged_offset", {}).get("cuda_graph_ms", 0) + paged = r.get("paged_offset", {}).get("no_cuda_graph_ms", 0) + flex = r.get("flex_compiled", {}).get("no_cuda_graph_ms", 0) + speedup = f"{flex / ragged:.2f}x" if ragged > 0 and flex > 0 else "N/A" + print(f" {desc:<18} | {seq_len:>8} | {batch:>5} | " + f"{ragged:>10.3f} | {ragged_cuda_graph:>10.3f} | {paged:>10.3f} | " + f"{flex:>14.3f} | {speedup:>10}") + + print(f"\n Note: Long/extra-long contexts use batch=1 (to avoid OOM)") + + return all_results + + +# ============================================================ +# Block size alignment effect test +# ============================================================ +def bench_block_size_sweep( + num_requests: int = 4, + total_kv_len: int = 4096, + qo_len: int = 512, + num_heads: int = 32, + num_kv_heads: int = 8, + head_dim: int = 128, + page_size: int = 16, + dtype: torch.dtype = torch.float16, +): + """ + Test the effect of different dllm_block_size on performance + + Flex Attention's Triton tile is fixed at 128x128: + - block_size=128: Perfect alignment, each tile is either all FULL or all SKIP + - block_size<128: Diagonal tiles become PARTIAL, requiring per-element mask check + - block_size>128: Tile granularity is finer than mask granularity, also produces PARTIAL + + FlashInfer's block extend kernel internally skips by dllm_block_size granularity, + not constrained by the 128 tile size. + """ + block_sizes = [32, 64, 128, 256] + q_offset = total_kv_len - qo_len + + print(f"\n{'='*100}") + print(f"Block Size Alignment Effect Test") + print(f" kv_len={total_kv_len}, qo_len={qo_len}, q_offset={q_offset}") + print(f" Flex Attention Triton tile = 128x128 (hardcoded)") + print(f"{'='*100}") + + # Precompute tile distribution for each block_size + num_q_tiles = (qo_len + 127) // 128 + num_kv_tiles = (total_kv_len + 127) // 128 + total_tiles = num_q_tiles * num_kv_tiles + + all_results = {} + for bs in block_sizes: + # Count tile type distribution + full, skip, partial = 0, 0, 0 + for qi in range(num_q_tiles): + for ki in range(num_kv_tiles): + q_start = qi * 128 + q_offset + q_end = min(q_start + 128, q_offset + qo_len) + k_start = ki * 128 + k_end = min(k_start + 128, total_kv_len) + # Check if mask in tile is all True / all False + q_blk_min = q_start // bs + q_blk_max = (q_end - 1) // bs + k_blk_min = k_start // bs + k_blk_max = (k_end - 1) // bs + if q_blk_min >= k_blk_max: # Q min block >= KV max block - all True + full += 1 + elif q_blk_max < k_blk_min: # Q max block < KV min block - all False + skip += 1 + else: + partial += 1 + + print(f"\n{'#'*80}") + print(f"# dllm_block_size = {bs}") + print(f"# 128x128 tile distribution: FULL={full}, SKIP={skip}, PARTIAL={partial} (total {total_tiles})") + print(f"# PARTIAL ratio: {partial/total_tiles*100:.1f}%") + print(f"{'#'*80}") + + r = bench_flashinfer_vs_flex_attention( + num_requests=num_requests, + total_kv_len=total_kv_len, + qo_len=qo_len, + dllm_block_size=bs, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + page_size=page_size, + dtype=dtype, + warmup_iters=10, + bench_iters=50, + verify=(bs <= 128), + ) + all_results[bs] = r + + # Summary table + print(f"\n\n{'='*130}") + print(f"Block Size Alignment Effect Summary (kv_len={total_kv_len}, qo_len={qo_len}, batch={num_requests})") + print(f"{'='*130}") + print(f" {'block_size':>10} | {'PARTIAL%':>8} | " + f"{'Ragged(ms)':>10} | {'Ragged CG':>10} | {'Paged(ms)':>10} | " + f"{'Flex compiled':>14} | {'Ragged Mem':>10} | {'Flex Mem':>10} | " + f"{'Speedup':>10}") + print(f" {'-'*10}-+-{'-'*8}-+-" + f"{'-'*10}-+-{'-'*10}-+-{'-'*10}-+-" + f"{'-'*14}-+-{'-'*10}-+-{'-'*10}-+-{'-'*10}") + + for bs in block_sizes: + r = all_results.get(bs, {}) + ragged = r.get("ragged_offset", {}).get("no_cuda_graph_ms", 0) + ragged_cg = r.get("ragged_offset", {}).get("cuda_graph_ms", 0) + paged = r.get("paged_offset", {}).get("no_cuda_graph_ms", 0) + flex = r.get("flex_compiled", {}).get("no_cuda_graph_ms", 0) + ragged_mem = r.get("ragged_offset", {}).get("memory", {}).get("memory_increase_mb", 0) + flex_mem = r.get("flex_compiled", {}).get("memory", {}).get("memory_increase_mb", 0) + speedup = f"{flex / ragged:.2f}x" if ragged > 0 and flex > 0 else "N/A" + + # Calculate partial ratio + partial_count = 0 + for qi in range(num_q_tiles): + for ki in range(num_kv_tiles): + q_start = qi * 128 + q_offset + q_end = min(q_start + 128, q_offset + qo_len) + k_start = ki * 128 + k_end = min(k_start + 128, total_kv_len) + q_blk_min = q_start // bs + q_blk_max = (q_end - 1) // bs + k_blk_min = k_start // bs + k_blk_max = (k_end - 1) // bs + if not (q_blk_min >= k_blk_max or q_blk_max < k_blk_min): + partial_count += 1 + pct = f"{partial_count/total_tiles*100:.1f}%" + + print(f" {bs:>10} | {pct:>8} | " + f"{ragged:>10.3f} | {ragged_cg:>10.3f} | {paged:>10.3f} | " + f"{flex:>14.3f} | {ragged_mem:>10.1f} | {flex_mem:>10.1f} | {speedup:>10}") + + return all_results + + +# ============================================================ +# Full Prefill Total Memory + Performance Comparison: Sweep context length x dllm_block_size +# ============================================================ +def bench_total_memory_comparison( + num_requests: int = 4, + qo_len: int = 512, + dllm_block_size: int = 32, + num_heads: int = 32, + num_kv_heads: int = 8, + head_dim: int = 128, + page_size: int = 16, + dtype: torch.dtype = torch.float16, +): + """ + Full Prefill comprehensive comparison: qo_len = kv_len, q_offset = 0 + Sweep 6 context lengths x 4 dllm_block_sizes + Measure for each combination: FlashInfer and Flex Attention latency(ms) + peak total memory(MB) + Flex Attention BLOCK_SIZE stays default, only mask logic granularity varies + Long sequences automatically reduce batch to avoid OOM + """ + device = torch.device("cuda:0") + sm_scale = 1.0 / math.sqrt(head_dim) + + # (seq_len, batch) - reduce batch for long sequences to avoid OOM + configs = [ + (2048, num_requests), + (4096, num_requests), + (8192, num_requests), + (16384, 1), + (24576, 1), + (32768, 1), + ] + dllm_block_sizes = [32, 64, 128, 256] + warmup_iters, bench_iters = 10, 50 + + print(f"\n{'='*140}") + print(f"Full Prefill: FlashInfer Ragged vs Flex compiled Comprehensive Comparison") + print(f"{'='*140}") + print(f" Scenario: qo_len = kv_len (full prefill), q_offset = 0") + print(f" heads={num_heads}/{num_kv_heads}, head_dim={head_dim}") + print(f" seq_lens: {[c[0] for c in configs]}") + print(f" batch: {[c[1] for c in configs]}") + print(f" dllm_block_sizes: {dllm_block_sizes}") + print(f" Flex BLOCK_SIZE: default (kernel decides)") + print() + + # Probe minimum workspace (max seq_len, batch=1) + min_ws_mb = 256 + if HAS_FLASHINFER: + max_seq = configs[-1][0] + probe_batch = configs[-1][1] + print(f" Probing minimum workspace (seq_len={max_seq}, batch={probe_batch}) ...") + _q = torch.randn(probe_batch * max_seq, num_heads, head_dim, dtype=dtype, device=device) + _k = torch.randn(probe_batch * max_seq, num_kv_heads, head_dim, dtype=dtype, device=device) + _v = torch.randn(probe_batch * max_seq, num_kv_heads, head_dim, dtype=dtype, device=device) + _qo = torch.tensor([i * max_seq for i in range(probe_batch + 1)], dtype=torch.int32, device=device) + _kv = torch.tensor([i * max_seq for i in range(probe_batch + 1)], dtype=torch.int32, device=device) + _qoff = torch.zeros(probe_batch, dtype=torch.int32, device=device) + for try_mb in [1, 2, 4, 8, 16, 32, 64, 128, 256]: + try: + _ws = torch.empty(try_mb * 1024 * 1024, dtype=torch.uint8, device=device) + _w = BatchBlockExtendRaggedOffsetWrapper(_ws, kv_layout="NHD", dllm_block_size=dllm_block_sizes[0]) + _w.plan(qo_indptr=_qo, kv_indptr=_kv, + num_qo_heads=num_heads, num_kv_heads=num_kv_heads, head_dim=head_dim, + q_data_type=dtype, sm_scale=sm_scale, q_offsets=_qoff) + _w.run(_q, _k, _v) + torch.cuda.synchronize(device) + min_ws_mb = try_mb + del _ws, _w + break + except Exception: + torch.cuda.empty_cache() + continue + del _q, _k, _v, _qo, _kv, _qoff + torch.cuda.empty_cache() + print(f" Minimum usable workspace: {min_ws_mb} MB") + print() + + def _reset(): + torch.cuda.empty_cache() + torch.cuda.synchronize(device) + torch.cuda.reset_peak_memory_stats(device) + return torch.cuda.memory_allocated(device) / 1024**2 + + # {(seq_len, batch, dllm_bs): {"fi_ms", "fi_peak", "flex_ms", "flex_peak"}} + all_results = {} + + for seq_len, batch in configs: + # Full prefill: qo_len = kv_len = seq_len, q_offset = 0 + q_offset = 0 + print(f" --- seq_len={seq_len}, batch={batch} (full prefill, q_offset=0) ---") + + for dbs in dllm_block_sizes: + key = (seq_len, batch, dbs) + entry = {"batch": batch} + + # ===== FlashInfer Ragged ===== + if HAS_FLASHINFER: + base = _reset() + q = torch.randn(batch * seq_len, num_heads, head_dim, dtype=dtype, device=device) + k = torch.randn(batch * seq_len, num_kv_heads, head_dim, dtype=dtype, device=device) + v = torch.randn(batch * seq_len, num_kv_heads, head_dim, dtype=dtype, device=device) + qo_indptr = torch.tensor([i * seq_len for i in range(batch + 1)], dtype=torch.int32, device=device) + kv_indptr = torch.tensor([i * seq_len for i in range(batch + 1)], dtype=torch.int32, device=device) + q_offsets = torch.zeros(batch, dtype=torch.int32, device=device) + ws = torch.empty(min_ws_mb * 1024 * 1024, dtype=torch.uint8, device=device) + + wrapper = BatchBlockExtendRaggedOffsetWrapper(ws, kv_layout="NHD", dllm_block_size=dbs) + wrapper.plan( + qo_indptr=qo_indptr, kv_indptr=kv_indptr, + num_qo_heads=num_heads, num_kv_heads=num_kv_heads, head_dim=head_dim, + q_data_type=dtype, sm_scale=sm_scale, q_offsets=q_offsets, + ) + + def _run_fi(): + wrapper.run(q, k, v) + + for _ in range(warmup_iters): + _run_fi() + torch.cuda.synchronize(device) + fi_peak = torch.cuda.max_memory_allocated(device) / 1024**2 - base + + # no cuda_graph latency + t0 = time.perf_counter() + for _ in range(bench_iters): + _run_fi() + torch.cuda.synchronize(device) + fi_ms = (time.perf_counter() - t0) / bench_iters * 1000 + + # cuda_graph latency + try: + fi_cg_ms = benchmark_with_cuda_graph(_run_fi, warmup_iters, bench_iters) + entry["fi_cg_ms"] = fi_cg_ms + except Exception as e: + print(f" [CUDA Graph capture failed] {e}") + + entry["fi_ms"] = fi_ms + entry["fi_peak"] = fi_peak + del q, k, v, ws, wrapper, qo_indptr, kv_indptr, q_offsets + + # ===== Flex Attention (compiled, default BLOCK_SIZE) ===== + if HAS_FLEX_ATTENTION: + base = _reset() + q = torch.randn(batch, num_heads, seq_len, head_dim, dtype=dtype, device=device) + k = torch.randn(batch, num_kv_heads, seq_len, head_dim, dtype=dtype, device=device) + v = torch.randn(batch, num_kv_heads, seq_len, head_dim, dtype=dtype, device=device) + + mask_mod = make_block_extend_mask_mod(dbs, q_offset) + block_mask = create_block_mask( + mask_mod, B=batch, H=1, + Q_LEN=seq_len, KV_LEN=seq_len, device=device, + ) + use_gqa = (num_heads != num_kv_heads) + + torch._dynamo.reset() + _compiled = torch.compile(flex_attention, dynamic=False) + + try: + for _ in range(warmup_iters): + _compiled(q, k, v, block_mask=block_mask, scale=sm_scale, enable_gqa=use_gqa) + torch.cuda.synchronize(device) + flex_peak = torch.cuda.max_memory_allocated(device) / 1024**2 - base + + t0 = time.perf_counter() + for _ in range(bench_iters): + _compiled(q, k, v, block_mask=block_mask, scale=sm_scale, enable_gqa=use_gqa) + torch.cuda.synchronize(device) + flex_ms = (time.perf_counter() - t0) / bench_iters * 1000 + + entry["flex_ms"] = flex_ms + entry["flex_peak"] = flex_peak + except Exception as e: + print(f" [dllm_bs={dbs}, seq={seq_len}] Flex failed: {e}") + + del q, k, v, block_mask, _compiled + + all_results[key] = entry + + # Progress line + fi_cg = f"CG={entry['fi_cg_ms']:.3f}ms" if "fi_cg_ms" in entry else "CG=N/A" + fi_s = f"{entry.get('fi_ms', 0):.3f}ms/{fi_cg}/{entry.get('fi_peak', 0):.0f}MB" + fx_s = f"{entry.get('flex_ms', 0):.3f}ms/{entry.get('flex_peak', 0):.0f}MB" if "flex_ms" in entry else "N/A" + print(f" dllm_bs={dbs:<3} FI={fi_s:<32} Flex={fx_s}") + + # ===== Summary table ===== + print(f"\n{'='*170}") + print(f"Full Prefill Summary: FlashInfer Ragged vs Flex compiled") + print(f" Scenario: qo_len = kv_len, q_offset = 0, FI workspace={min_ws_mb}MB, Flex BLOCK_SIZE=default(kernel decides)") + print(f"{'='*170}") + print(f" {'seq_len':>8} | {'batch':>5} | {'dllm_bs':>7} | {'FI(ms)':>8} | {'FI CG(ms)':>10} | {'FI peak(MB)':>12} | {'Flex(ms)':>9} | {'Flex peak(MB)':>14} | {'Speedup':>8} | {'CG Speedup':>9} | {'Mem Save':>8}") + print(f" {'-'*8}-+-{'-'*5}-+-{'-'*7}-+-{'-'*8}-+-{'-'*10}-+-{'-'*12}-+-{'-'*9}-+-{'-'*14}-+-{'-'*8}-+-{'-'*9}-+-{'-'*8}") + + fi_wins_perf, fi_wins_cg, fi_wins_mem, total_cmp = 0, 0, 0, 0 + + for seq_len, batch in configs: + for dbs in dllm_block_sizes: + e = all_results.get((seq_len, batch, dbs), {}) + fi_ms = e.get("fi_ms", 0) + fi_cg = e.get("fi_cg_ms", 0) + fi_pk = e.get("fi_peak", 0) + fx_ms = e.get("flex_ms", 0) + fx_pk = e.get("flex_peak", 0) + + if fi_ms > 0 and fx_ms > 0: + ratio = f"{fx_ms / fi_ms:.2f}x" + cg_ratio = f"{fx_ms / fi_cg:.2f}x" if fi_cg > 0 else "N/A" + mem_save = f"{(1 - fi_pk / fx_pk) * 100:+.0f}%" if fx_pk > 0 else "N/A" + total_cmp += 1 + if fi_ms < fx_ms: + fi_wins_perf += 1 + if fi_cg > 0 and fi_cg < fx_ms: + fi_wins_cg += 1 + if fi_pk < fx_pk: + fi_wins_mem += 1 + else: + ratio = "N/A" + cg_ratio = "N/A" + mem_save = "N/A" + + fi_cg_s = f"{fi_cg:>10.3f}" if fi_cg > 0 else f"{'N/A':>10}" + fx_ms_s = f"{fx_ms:>9.3f}" if fx_ms > 0 else f"{'N/A':>9}" + fx_pk_s = f"{fx_pk:>14.1f}" if fx_pk > 0 else f"{'N/A':>14}" + + print(f" {seq_len:>8} | {batch:>5} | {dbs:>7} | {fi_ms:>8.3f} | {fi_cg_s} | {fi_pk:>12.1f} | {fx_ms_s} | {fx_pk_s} | {ratio:>8} | {cg_ratio:>9} | {mem_save:>8}") + + # Statistics + print(f"\n Statistics ({total_cmp} valid comparisons):") + if total_cmp > 0: + print(f" Performance (no CG): FlashInfer wins {fi_wins_perf}/{total_cmp} cases") + print(f" Performance (FI CG): FlashInfer wins {fi_wins_cg}/{total_cmp} cases") + print(f" Memory: FlashInfer wins {fi_wins_mem}/{total_cmp} cases") + print(f"\n Note: 16K+ sequences use batch=1 to avoid OOM") + + return all_results + + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser(description="FlashInfer vs Flex Attention Benchmark") + parser.add_argument("--sweep", action="store_true", help="Run sweep across seq lengths (multi-chunk)") + parser.add_argument("--single-chunk", action="store_true", help="Single chunk, four context length tiers") + parser.add_argument("--full-prefill", action="store_true", help="Full prefill (qo=kv), four context length tiers") + parser.add_argument("--block-size-sweep", action="store_true", help="Sweep dllm_block_size (32/64/128/256) alignment effect") + parser.add_argument("--memory-compare", action="store_true", help="Total GPU memory comparison from clean state") + parser.add_argument("--num-requests", type=int, default=4) + parser.add_argument("--kv-len", type=int, default=4096) + parser.add_argument("--qo-len", type=int, default=512) + parser.add_argument("--block-size", type=int, default=32) + parser.add_argument("--num-heads", type=int, default=32) + parser.add_argument("--num-kv-heads", type=int, default=8) + parser.add_argument("--head-dim", type=int, default=128) + parser.add_argument("--page-size", type=int, default=16) + parser.add_argument("--no-verify", action="store_true") + args = parser.parse_args() + + if args.sweep: + bench_sweep_seq_lengths( + num_requests=args.num_requests, + dllm_block_size=args.block_size, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + head_dim=args.head_dim, + page_size=args.page_size, + ) + elif args.full_prefill: + bench_full_prefill_four_tiers( + dllm_block_size=args.block_size, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + head_dim=args.head_dim, + page_size=args.page_size, + ) + elif args.block_size_sweep: + bench_block_size_sweep( + num_requests=args.num_requests, + total_kv_len=args.kv_len, + qo_len=args.qo_len, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + head_dim=args.head_dim, + page_size=args.page_size, + ) + elif args.memory_compare: + bench_total_memory_comparison( + num_requests=args.num_requests, + qo_len=args.qo_len, + dllm_block_size=args.block_size, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + head_dim=args.head_dim, + page_size=args.page_size, + ) + else: + bench_flashinfer_vs_flex_attention( + num_requests=args.num_requests, + total_kv_len=args.kv_len, + qo_len=args.qo_len, + dllm_block_size=args.block_size, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + head_dim=args.head_dim, + page_size=args.page_size, + verify=not args.no_verify, + )