Skip to content
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions csrc/batch_prefill_customize_config.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,22 @@ struct RaggedParams {
__host__ __device__ __forceinline__ uint32_t get_kv_len(uint32_t batch_idx) const {
return kv_indptr[batch_idx + 1] - kv_indptr[batch_idx];
}

__host__ __device__ __forceinline__ uint32_t get_q_block_expanding_offset(uint32_t batch_idx) const {
{% if 'maybe_q_block_expanding_offset' in additional_params_decl %}
return (maybe_q_block_expanding_offset != nullptr) ? maybe_q_block_expanding_offset[batch_idx] : 0;
{% else %}
return 0;
{% endif %}
}

__host__ __device__ __forceinline__ uint32_t get_kv_block_expanding_offset(uint32_t batch_idx) const {
{% if 'maybe_kv_block_expanding_offset' in additional_params_decl %}
return (maybe_kv_block_expanding_offset != nullptr) ? maybe_kv_block_expanding_offset[batch_idx] : 0;
{% else %}
return 0;
{% endif %}
}
};

struct PagedParams {
Expand Down Expand Up @@ -116,6 +132,22 @@ struct PagedParams {
__host__ __device__ __forceinline__ uint32_t get_kv_len(uint32_t batch_idx) const {
return paged_kv.get_length(batch_idx);
}

__host__ __device__ __forceinline__ uint32_t get_q_block_expanding_offset(uint32_t batch_idx) const {
{% if 'maybe_q_block_expanding_offset' in additional_params_decl %}
return (maybe_q_block_expanding_offset != nullptr) ? maybe_q_block_expanding_offset[batch_idx] : 0;
{% else %}
return 0;
{% endif %}
}

__host__ __device__ __forceinline__ uint32_t get_kv_block_expanding_offset(uint32_t batch_idx) const {
{% if 'maybe_kv_block_expanding_offset' in additional_params_decl %}
return (maybe_kv_block_expanding_offset != nullptr) ? maybe_kv_block_expanding_offset[batch_idx] : 0;
{% else %}
return 0;
{% endif %}
}
};

{{ variant_decl }}
50 changes: 50 additions & 0 deletions csrc/batch_prefill_sm90_customize_config.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,31 @@ struct RaggedParams {
int window_left;

bool causal;

// Block Expanding support
__host__ __device__ __forceinline__ uint32_t get_qo_len(uint32_t batch_idx) const {
return qo_lens[batch_idx];
}

__host__ __device__ __forceinline__ uint32_t get_kv_len(uint32_t batch_idx) const {
return kv_lens[batch_idx];
}

__host__ __device__ __forceinline__ uint32_t get_q_block_expanding_offset(uint32_t batch_idx) const {
{% if 'maybe_q_block_expanding_offset' in additional_params_decl %}
return (additional_params.maybe_q_block_expanding_offset != nullptr) ? additional_params.maybe_q_block_expanding_offset[batch_idx] : 0;
{% else %}
return 0;
{% endif %}
}

__host__ __device__ __forceinline__ uint32_t get_kv_block_expanding_offset(uint32_t batch_idx) const {
{% if 'maybe_kv_block_expanding_offset' in additional_params_decl %}
return (additional_params.maybe_kv_block_expanding_offset != nullptr) ? additional_params.maybe_kv_block_expanding_offset[batch_idx] : 0;
{% else %}
return 0;
{% endif %}
}
};

struct PagedParams {
Expand Down Expand Up @@ -117,6 +142,31 @@ struct PagedParams {
int window_left;

bool causal;

// Block Expanding support
__host__ __device__ __forceinline__ uint32_t get_qo_len(uint32_t batch_idx) const {
return qo_lens[batch_idx];
}

__host__ __device__ __forceinline__ uint32_t get_kv_len(uint32_t batch_idx) const {
return kv_lens[batch_idx];
}

__host__ __device__ __forceinline__ uint32_t get_q_block_expanding_offset(uint32_t batch_idx) const {
{% if 'maybe_q_block_expanding_offset' in additional_params_decl %}
return (additional_params.maybe_q_block_expanding_offset != nullptr) ? additional_params.maybe_q_block_expanding_offset[batch_idx] : 0;
{% else %}
return 0;
{% endif %}
}

__host__ __device__ __forceinline__ uint32_t get_kv_block_expanding_offset(uint32_t batch_idx) const {
{% if 'maybe_kv_block_expanding_offset' in additional_params_decl %}
return (additional_params.maybe_kv_block_expanding_offset != nullptr) ? additional_params.maybe_kv_block_expanding_offset[batch_idx] : 0;
{% else %}
return 0;
{% endif %}
}
};

{{ variant_decl }}
20 changes: 20 additions & 0 deletions csrc/single_prefill_customize_config.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,26 @@ struct Params {
__host__ __device__ __forceinline__ uint32_t get_kv_len(uint32_t batch_idx) const {
return kv_len;
}

// SinglePrefill: q_block_expanding_offset support
// If q_block_expanding_offset parameter is provided, use it; otherwise return 0
__host__ __device__ __forceinline__ uint32_t get_q_block_expanding_offset(uint32_t batch_idx) const {
{% if has_q_block_expanding_offset %}
return static_cast<uint32_t>(q_block_expanding_offset);
{% else %}
return 0;
{% endif %}
}

// SinglePrefill: kv_block_expanding_offset support (for Cascade Current Chunk)
// If kv_block_expanding_offset parameter is provided, use it; otherwise return 0
__host__ __device__ __forceinline__ uint32_t get_kv_block_expanding_offset(uint32_t batch_idx) const {
{% if has_kv_block_expanding_offset %}
return static_cast<uint32_t>(kv_block_expanding_offset);
{% else %}
return 0;
{% endif %}
}
};

{{ variant_decl }}
29 changes: 28 additions & 1 deletion csrc/single_prefill_sm90_customize_config.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,33 @@ struct Params {
int window_left;

bool causal;

// Block Expanding support
__host__ __device__ __forceinline__ uint32_t get_qo_len(uint32_t batch_idx) const {
return qo_len;
}

__host__ __device__ __forceinline__ uint32_t get_kv_len(uint32_t batch_idx) const {
return kv_len;
}

// SinglePrefill: q_block_expanding_offset support
__host__ __device__ __forceinline__ uint32_t get_q_block_expanding_offset(uint32_t batch_idx) const {
{% if has_q_block_expanding_offset %}
return static_cast<uint32_t>(additional_params.q_block_expanding_offset);
{% else %}
return 0;
{% endif %}
}

// SinglePrefill: kv_block_expanding_offset support (for Cascade Current Chunk)
__host__ __device__ __forceinline__ uint32_t get_kv_block_expanding_offset(uint32_t batch_idx) const {
{% if has_kv_block_expanding_offset %}
return static_cast<uint32_t>(additional_params.kv_block_expanding_offset);
{% else %}
return 0;
{% endif %}
}
};

{{ variant_decl }}
{{ variant_decl }}
1 change: 1 addition & 0 deletions flashinfer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from .version import __version__ as __version__
from .version import __git_version__ as __git_version__

from . import dllm as dllm

from . import jit as jit
from .activation import gelu_and_mul as gelu_and_mul
Expand Down
37 changes: 37 additions & 0 deletions flashinfer/dllm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from .block_extend import (
block_extend_attention_with_offset,
block_extend_cascade,
get_block_extend_module_with_offset,
BLOCK_EXTEND_V2_WITH_OFFSET_VARIANT_DECL,
BLOCK_EXTEND_V3_WITH_OFFSET_VARIANT_DECL,
)

from .batch_block_extend import (
BatchBlockExtendPagedOffsetWrapper,
BatchBlockExtendRaggedOffsetWrapper,
batch_block_extend_cascade,
sglang_style_cascade_attention,
_BATCH_BE_OFFSET_VARIANT_DECL,
_BATCH_BE_OFFSET_VARIANT_DECL_FA3,
_check_batch_be_aot_available,
_get_batch_be_aot_path,
_get_batch_be_module_uri,
)

__all__ = [
# Single Prefill with offset (FA2/FA3 auto-select)
"block_extend_attention_with_offset",
"get_block_extend_module_with_offset",
"BLOCK_EXTEND_V2_WITH_OFFSET_VARIANT_DECL",
"BLOCK_EXTEND_V3_WITH_OFFSET_VARIANT_DECL",
# Cascade + block extend (SGLang 风格: causal + merge_state)
"block_extend_cascade",
"batch_block_extend_cascade",
"sglang_style_cascade_attention",
# Batch Prefill with offset versions
"BatchBlockExtendPagedOffsetWrapper",
"BatchBlockExtendRaggedOffsetWrapper",
# Batch Offset variant declarations
"_BATCH_BE_OFFSET_VARIANT_DECL",
"_BATCH_BE_OFFSET_VARIANT_DECL_FA3",
]
Comment on lines +21 to +37
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Sort __all__ to satisfy Ruff.

Ruff flags this list with RUF022; please apply the project’s __all__ sorting convention or the lint step may fail.

🛠️ Proposed fix
 __all__ = [
-    # Single Prefill with offset (FA2/FA3 auto-select)
-    "block_extend_attention_with_offset",
-    "get_block_extend_module_with_offset",
     "BLOCK_EXTEND_V2_WITH_OFFSET_VARIANT_DECL",
     "BLOCK_EXTEND_V3_WITH_OFFSET_VARIANT_DECL",
-    # Cascade + block extend (SGLang style: causal + merge_state)
-    "block_extend_cascade",
-    "batch_block_extend_cascade",
-    "sglang_style_cascade_attention",
-    # Batch Prefill with offset versions
     "BatchBlockExtendPagedOffsetWrapper",
     "BatchBlockExtendRaggedOffsetWrapper",
-    # Batch Offset variant declarations
     "_BATCH_BE_OFFSET_VARIANT_DECL",
     "_BATCH_BE_OFFSET_VARIANT_DECL_FA3",
+    "batch_block_extend_cascade",
+    "block_extend_attention_with_offset",
+    "block_extend_cascade",
+    "get_block_extend_module_with_offset",
+    "sglang_style_cascade_attention",
 ]
🧰 Tools
🪛 Ruff (0.15.10)

[warning] 21-37: __all__ is not sorted

Apply an isort-style sorting to __all__

(RUF022)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/dllm/__init__.py` around lines 21 - 37, Ruff flags the __all__
list ordering (RUF022); sort the entries in the __all__ list alphabetically
according to the project's convention so the exported names (e.g.,
"BatchBlockExtendPagedOffsetWrapper", "BatchBlockExtendRaggedOffsetWrapper",
"BLOCK_EXTEND_V2_WITH_OFFSET_VARIANT_DECL",
"BLOCK_EXTEND_V3_WITH_OFFSET_VARIANT_DECL",
"BLOCK_EXTEND_V3_WITH_OFFSET_VARIANT_DECL",
"BLOCK_EXTEND_V2_WITH_OFFSET_VARIANT_DECL",
"BLOCK_EXTEND_V3_WITH_OFFSET_VARIANT_DECL",
"BLOCK_EXTEND_V2_WITH_OFFSET_VARIANT_DECL",
"BatchBlockExtendPagedOffsetWrapper", "block_extend_attention_with_offset",
"block_extend_cascade", "batch_block_extend_cascade",
"get_block_extend_module_with_offset", "sglang_style_cascade_attention",
"_BATCH_BE_OFFSET_VARIANT_DECL", "_BATCH_BE_OFFSET_VARIANT_DECL_FA3") are in the
required sorted order; update the __all__ declaration in
flashinfer/dllm/__init__.py to the sorted list so the linter RUF022 is
satisfied.

Loading