Skip to content

Commit b809821

Browse files
PerkzZhengclaude
andauthored
[Fmha] update trtllm-gen FMHA cubins and sync headers for context SWA fix (#3089)
<!-- .github/pull_request_template.md --> ## 📌 Description The branch has 2 commits: 1. Update trtllm-gen FMHA cubins to fix context SWA page-skip — updates artifacts.py path + checksum 2. Sync trtllm FMHA headers with latest trtllm-gen (from PR #2711) — cherry-picks header changes to match the new cubin MetaInfo struct ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Bug Fixes** * Fixed sparse-attention truncation so sequence-length top-K is applied correctly when sparse-attention is enabled. * **Improvements** * Standardized sparse-attention parameter naming and selection logic to make behavior more consistent across launches and kernel choices. * Skip incompatible kernel variants during runtime kernel loading to avoid incorrect selections. * **Chores** * Updated FMHA runtime artifact paths and their checksums for validation and downloads. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: Claude Opus 4.7 <noreply@anthropic.com>
1 parent c9eb3cd commit b809821

2 files changed

Lines changed: 19 additions & 8 deletions

File tree

flashinfer/artifacts.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ class ArtifactPath:
135135
When compiling new cubins for backend directories, update the corresponding path.
136136
"""
137137

138-
TRTLLM_GEN_FMHA: str = "82f4c77d9cf83e3fcf105feda4ce3445100ab491/fmha/trtllm-gen/"
138+
TRTLLM_GEN_FMHA: str = "134850621dbbd55ed6b0c3fa7c29b968136c05ef/fmha/trtllm-gen/"
139139
TRTLLM_GEN_BMM: str = (
140140
"39a9d28268f43475a757d5700af135e1e58c9849/batched_gemm-5ee61af-2b9855b/"
141141
)
@@ -155,7 +155,7 @@ class CheckSumHash:
155155
"""
156156

157157
TRTLLM_GEN_FMHA: str = (
158-
"56c95fbe5d1b5d0d9ded7706e1c0b7ebf0582d9cfd2f9382acd878b6b9d58c89"
158+
"2be32ce1949ab0b1e637c27f128b77c41d6753a36cb9c0e1a97acb2d3d44ae5f"
159159
)
160160
TRTLLM_GEN_BMM: str = (
161161
"db06db7f36a2a9395a2041ff6ac016fe664874074413a2ed90797f91ef17e0f6"

include/flashinfer/trtllm/fmha/fmhaKernels.cuh

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
#include <cfloat>
2222
#include <cstdint>
23+
#include <cstring>
2324
#include <cuda/std/cfloat>
2425
#include <iterator>
2526
#include <memory>
@@ -112,6 +113,13 @@ class TllmGenFmhaKernel {
112113
for (unsigned int i = 0; i < mKernelMetaCount; ++i) {
113114
auto const& kernelMeta = mKernelMeta[i];
114115
IKL_LOG_DEBUG("Checking tllmgen attention kernel %s", kernelMeta.mFuncName);
116+
// Skip SageAttention kernels: they share the same hashID as their non-sage
117+
// counterparts (sage block sizes are not part of the hash), which causes
118+
// false "hash conflict" failures. SageAttention is not exposed through the
119+
// flashinfer interface, so dropping these entries is safe.
120+
if (kernelMeta.mFuncName != nullptr && std::strstr(kernelMeta.mFuncName, "Sage") != nullptr) {
121+
continue;
122+
}
115123
if (isSMCompatible(mSM, kernelMeta.mSM) && kernelMeta.mDataTypeQ == mDtypeQ &&
116124
kernelMeta.mDataTypeKv == mDtypeKv && kernelMeta.mDataTypeO == mDtypeOut) {
117125
// Store metadata for later use.
@@ -443,13 +451,15 @@ class TllmGenFmhaKernel {
443451

444452
// Enable the CgaSmemReduction if the numCtasPerSeqKv <= 16 as the maximum cluster dimension
445453
// is 16. Only the swapsMmaAbForGeneration kernel supports the CgaSmemReduction for now.
446-
// CgaSmemReduction exceeds the shared memory limit for MLA decode with tileSizeQ >= 32
447-
// (headDimQk=576 requires more smem than the device allows for that tile size).
454+
// headDimV >= 512 is excluded: the current trtllm-gen cubin ships no SwapsMmaAb
455+
// CgaSmemReduction kernels at headDimV >= 512 (covers both MLA headDimQk=576/V=512 and
456+
// non-MLA H=512), and for tileSizeQ >= 32 the CGA variant also exceeds the device smem
457+
// limit. This guard can be narrowed once trtllm-gen ships a cubin with the
458+
// tileSizeQ>=32 + headDimPerCtaV>=512 skip predicate.
448459
if (!isDsv3MinLatencyMode && numCtasPerSeqKv > 1 && numCtasPerSeqKv <= 16 &&
449460
isSwapsMmaAbForGenerationKernel(selectKernelParams.mKernelType) &&
450461
isGmemReduction(selectKernelParams.mMultiCtasKvMode) &&
451-
!selectKernelParams.mForceGmemReduction &&
452-
(!isMlaGenKernel(params) || selectKernelParams.mTileSizeQ < 32)) {
462+
!selectKernelParams.mForceGmemReduction && params.mHeadDimV < 512) {
453463
selectKernelParams.mMultiCtasKvMode = MultiCtasKvMode::CgaSmemReduction;
454464
// Need to select a different kernel.
455465
selectKernelParams.mSelectNewKernel = true;
@@ -864,12 +874,13 @@ class TllmGenFmhaKernel {
864874
// Hash the runner params.
865875
auto [hashId, info] = hashFromRunnerParams(params, selectKernelParams);
866876
auto const findMetaIter = mKernelMetaMap.find(hashId);
867-
// The meta index.
868-
auto const metaIndex = findMetaIter->second;
869877

870878
// Add debug info when kernels are not found.
871879
FLASHINFER_CHECK(findMetaIter != mKernelMetaMap.end(), "Trtllm-gen kernels not found: " + info);
872880

881+
// The meta index.
882+
auto const metaIndex = findMetaIter->second;
883+
873884
// Load the function if not found.
874885
if (mFunctions.find(hashId) == mFunctions.end()) {
875886
// Load the kernel on-demand.

0 commit comments

Comments
 (0)