Skip to content

[Fmha] update trtllm-gen FMHA cubins and sync headers for context SWA fix#3089

Merged
saltyminty merged 3 commits intoflashinfer-ai:mainfrom
PerkzZheng:perkzz/update-trtllm-gen-fmha-cubins
Apr 23, 2026
Merged

[Fmha] update trtllm-gen FMHA cubins and sync headers for context SWA fix#3089
saltyminty merged 3 commits intoflashinfer-ai:mainfrom
PerkzZheng:perkzz/update-trtllm-gen-fmha-cubins

Conversation

@PerkzZheng
Copy link
Copy Markdown
Contributor

@PerkzZheng PerkzZheng commented Apr 16, 2026

📌 Description

The branch has 2 commits:

  1. Update trtllm-gen FMHA cubins to fix context SWA page-skip — updates artifacts.py path + checksum
  2. Sync trtllm FMHA headers with latest trtllm-gen (from PR feat: Add DiT-oriented kernels where Qk (Bmm1) type can be reinterpreted into Int8 or BFloat16 #2711) — cherry-picks header changes to match the new cubin MetaInfo struct

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • Bug Fixes

    • Fixed sparse-attention truncation so sequence-length top-K is applied correctly when sparse-attention is enabled.
  • Improvements

    • Standardized sparse-attention parameter naming and selection logic to make behavior more consistent across launches and kernel choices.
    • Skip incompatible kernel variants during runtime kernel loading to avoid incorrect selections.
  • Chores

    • Updated FMHA runtime artifact paths and their checksums for validation and downloads.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 16, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Renamed sparse Top-K field (mSparseMlaTopKmSparseAttnTopK and updated uses), switched kernel sparse checks to (kernelMeta.mSparseAttn != 0), added detection to skip Sage kernels, and updated the FMHA artifact subpath and its checksum.

Changes

Cohort / File(s) Summary
Sparse parameter & kernel logic
csrc/fmhaReduction.cu, include/flashinfer/trtllm/fmha/kernelParams.h, include/flashinfer/trtllm/fmha/fmhaKernels.cuh
Replaced mSparseMlaTopK with mSparseAttnTopK in KernelParams and updated population; kernel clamps seqLenKv against params.mSparseAttnTopK; kernel metadata/hash and launch-selection now use (kernelMeta.mSparseAttn != 0) and fmhaKernels.cuh skips Sage kernels via std::strstr(..., "Sage").
FMHA artifact path & checksum
flashinfer/artifacts.py
Updated ArtifactPath.TRTLLM_GEN_FMHA subdirectory hash and CheckSumHash.TRTLLM_GEN_FMHA SHA256 to point to new FMHA binary location.

Sequence Diagram(s)

(omitted)

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Possibly related PRs

Suggested labels

run-ci, op: attention

Suggested reviewers

  • sricketts
  • aleozlx
  • yongwww
  • yzh119
  • cyx-6
  • samuellees
  • bkryu
  • yyihuang
  • kahyunnam
  • jimmyzho
  • nv-yunzheq

Poem

🐰 I hopped through headers, kernels, and art,

Renamed Top‑K and gave Sage a part,
Kernels now check the sparse-attn sign,
New checksums fetched, the binaries align,
A joyful hop — the code looks fine! 🥕

🚥 Pre-merge checks | ✅ 5
✅ Passed checks (5 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately summarizes the main changes: updating trtllm-gen FMHA cubins and syncing headers for a context SWA fix, which aligns with the commit objectives.
Description check ✅ Passed The description covers the main objective (two commits for cubin updates and header sync) and includes pre-commit and test completion checkboxes marked as done, satisfying the template requirements.
Docstring Coverage ✅ Passed Docstring coverage is 80.00% which is sufficient. The required threshold is 80.00%.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@PerkzZheng
Copy link
Copy Markdown
Contributor Author

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !560 has been created, and the CI pipeline #48699694 is currently running. I'll report back once the pipeline job completes.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request enables support for separate K and V data types and integrates SageAttention scaling factors and block sizes into the TRT-LLM FMHA runner and kernels. Key updates include modifying the kernel cache key, extending interfaces for distinct K/V types, and adding INT8 support for query and key tensors. Feedback recommends optimizing the integer log2 calculation in the kernel factory via bit manipulation and extending INT8 support to the value tensor check for consistency.

Comment thread include/flashinfer/trtllm/fmha/fmhaKernels.cuh Outdated
Comment thread include/flashinfer/trtllm/fmha/fmhaRunner.cuh Outdated
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
include/flashinfer/trtllm/fmha/kernelParams.h (1)

701-770: ⚠️ Potential issue | 🟠 Major

Finish the K/V dtype split in the FP4/TMEM branches.

This block still makes storeTransformedKvInTmem, numEltsDivisor, and the FP4-SF path depend on kernelMeta.mDataTypeKv. That keeps the old “K and V always share a dtype” assumption alive, so mixed cases like K=FP16, V=E2M1 or K=E2M1, V=FP16 will build the wrong descriptor/SF state.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@include/flashinfer/trtllm/fmha/kernelParams.h` around lines 701 - 770, The
code assumes K and V share dtype by using kernelMeta.mDataTypeKv in places that
must be per-tensor; split the logic into K- and V-specific variants: e.g.,
introduce storeTransformedKInTmem and storeTransformedVInTmem (instead of
storeTransformedKvInTmem), numEltsDivisorK/V (instead of numEltsDivisor), and
compute reshapeFactor/shape/stride/tileShape for K and V by calling
makeTmaShapeStrideKv with kernelMeta.mDataTypeK and kernelMeta.mDataTypeV
respectively; ensure the FP4 dequant/scaling path and the buildNdTmaDescriptor
unpack4b/swizzled flags use the corresponding per-tensor booleans, and adjust
the tileShapeV conversion to use get_size_in_bits(kernelMeta.mDataTypeK) vs
get_size_in_bits(kernelMeta.mDataTypeV) as currently done but guarded by the new
per-tensor storeTransformed flags; update any downstream FP4-SF handling to
reference kernelMeta.mDataTypeK and kernelMeta.mDataTypeV separately.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@include/flashinfer/trtllm/fmha/fmhaKernels.cuh`:
- Around line 1028-1049: The current sageEncode() inside hashID() compresses
block sizes with log2f, causing different positive sizes (e.g., 8 and 12) to
collide; replace that encoding with the actual size (clamped to 4 bits) so the
4-bit fields uniquely represent the block count. Concretely, change the lambda
used in hashID (sageEncode) to return n > 0 ? static_cast<uint64_t>(n) & 0xF : 0
(or otherwise clamp n to [0,15]) and use those values in the same shift
positions so the fields for numEltsPerSageAttnBlkQ/K/P/V store the raw (clamped)
sizes instead of log2(size)+1.
- Around line 248-261: The SageAttention override values (ptrSageAttnSfsQ/K/P/V
and mLogNumEltsPerSageAttnBlk*) are applied to kernelParams via the local
sageParamEncode lambda but are lost when KernelParams::setKernelParams() is
later called for the CGA→GMEM fallback; reapply the same overrides immediately
after any call to KernelParams::setKernelParams() (or move the Sage override
logic inside KernelParams::setKernelParams()) so that
kernelParams.ptrSageAttnSfsQ/K/P/V and
kernelParams.mLogNumEltsPerSageAttnBlkQ/K/P/V are set regardless of fallback
path (reference: kernelParams, sageParamEncode, KernelParams::setKernelParams).

In `@include/flashinfer/trtllm/fmha/kernelParams.h`:
- Line 663: The code calls getDevicePtrs(options,
get_size_in_bits(kernelMeta.mDataTypeK)) but getDevicePtrs applies a single
bitsPerElt to Q, K, and V, producing incorrect PackedQkv/ContiguousKv offsets
for mixed dtypes; change the implementation to derive offsets using each
tensor's own element size (use get_size_in_bits(kernelMeta.mDataTypeQ),
get_size_in_bits(kernelMeta.mDataTypeK),
get_size_in_bits(kernelMeta.mDataTypeV)) or update getDevicePtrs to accept
per-tensor bit sizes (or reject mixed-width layouts early), and then compute
PackedQkv and ContiguousKv offsets using those per-tensor sizes so Q, K, V
offsets are correct for mixed-width cases.

---

Outside diff comments:
In `@include/flashinfer/trtllm/fmha/kernelParams.h`:
- Around line 701-770: The code assumes K and V share dtype by using
kernelMeta.mDataTypeKv in places that must be per-tensor; split the logic into
K- and V-specific variants: e.g., introduce storeTransformedKInTmem and
storeTransformedVInTmem (instead of storeTransformedKvInTmem), numEltsDivisorK/V
(instead of numEltsDivisor), and compute reshapeFactor/shape/stride/tileShape
for K and V by calling makeTmaShapeStrideKv with kernelMeta.mDataTypeK and
kernelMeta.mDataTypeV respectively; ensure the FP4 dequant/scaling path and the
buildNdTmaDescriptor unpack4b/swizzled flags use the corresponding per-tensor
booleans, and adjust the tileShapeV conversion to use
get_size_in_bits(kernelMeta.mDataTypeK) vs
get_size_in_bits(kernelMeta.mDataTypeV) as currently done but guarded by the new
per-tensor storeTransformed flags; update any downstream FP4-SF handling to
reference kernelMeta.mDataTypeK and kernelMeta.mDataTypeV separately.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: b4f429f9-18fd-426c-a934-a2a51554bbd1

📥 Commits

Reviewing files that changed from the base of the PR and between a99ee72 and 16dbc0b.

📒 Files selected for processing (7)
  • csrc/fmhaReduction.cu
  • csrc/trtllm_fmha_kernel_launcher.cu
  • flashinfer/artifacts.py
  • include/flashinfer/trtllm/fmha/fmhaKernels.cuh
  • include/flashinfer/trtllm/fmha/fmhaRunner.cuh
  • include/flashinfer/trtllm/fmha/fmhaRunnerParams.h
  • include/flashinfer/trtllm/fmha/kernelParams.h

Comment thread include/flashinfer/trtllm/fmha/fmhaKernels.cuh Outdated
Comment thread include/flashinfer/trtllm/fmha/fmhaKernels.cuh Outdated
Comment thread include/flashinfer/trtllm/fmha/kernelParams.h Outdated
@PerkzZheng PerkzZheng force-pushed the perkzz/update-trtllm-gen-fmha-cubins branch from 16dbc0b to 212e310 Compare April 17, 2026 01:21
Comment thread flashinfer/artifacts.py Outdated
@PerkzZheng
Copy link
Copy Markdown
Contributor Author

/bot run

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

♻️ Duplicate comments (1)
include/flashinfer/trtllm/fmha/fmhaKernels.cuh (1)

289-309: ⚠️ Potential issue | 🟠 Major

SageAttention overrides are lost on the CGA→GMEM fallback re-build.

setKernelParams is called again on Line 304-305 and the resulting kernelParams overwrites the SageAttention pointers/log block sizes set at Lines 254-261, so any request that hits this fallback will launch with null ptrSageAttnSfs* and zero mLogNumEltsPerSageAttnBlk*. Please either move the override block into a helper and re-invoke it here, or push the override logic into KernelParams::setKernelParams so it is applied unconditionally.

🔧 Proposed fix
-        kernelParams = KernelParams::setKernelParams(
-            params, kernelMeta, ctaLaunchParams.mMaxNumCtasQ, ctaLaunchParams.mMaxNumCtasKv);
+        kernelParams = KernelParams::setKernelParams(
+            params, kernelMeta, ctaLaunchParams.mMaxNumCtasQ, ctaLaunchParams.mMaxNumCtasKv);
+        kernelParams.ptrSageAttnSfsQ = params.ptrSageAttnSfsQ;
+        kernelParams.ptrSageAttnSfsK = params.ptrSageAttnSfsK;
+        kernelParams.ptrSageAttnSfsP = params.ptrSageAttnSfsP;
+        kernelParams.ptrSageAttnSfsV = params.ptrSageAttnSfsV;
+        kernelParams.mLogNumEltsPerSageAttnBlkQ = sageParamEncode(kernelMeta.mNumEltsPerSageAttnBlkQ);
+        kernelParams.mLogNumEltsPerSageAttnBlkK = sageParamEncode(kernelMeta.mNumEltsPerSageAttnBlkK);
+        kernelParams.mLogNumEltsPerSageAttnBlkP = sageParamEncode(kernelMeta.mNumEltsPerSageAttnBlkP);
+        kernelParams.mLogNumEltsPerSageAttnBlkV = sageParamEncode(kernelMeta.mNumEltsPerSageAttnBlkV);
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@include/flashinfer/trtllm/fmha/fmhaKernels.cuh` around lines 289 - 309, The
SageAttention overrides (ptrSageAttnSfs* and mLogNumEltsPerSageAttnBlk* stored
in kernelParams) are lost when kernelParams is rebuilt by
KernelParams::setKernelParams during the CGA→GMEM fallback; fix by reapplying
the SageAttention override logic immediately after kernelParams is reassigned
here (i.e., call the same helper that set the Sage pointers/log-block sizes
earlier) or refactor KernelParams::setKernelParams to carry over or initialize
the SageAttention fields unconditionally so they are preserved after
setKernelParams; locate the reassign site of kernelParams, the original override
block that sets ptrSageAttnSfs*/mLogNumEltsPerSageAttnBlk*, and either invoke
that helper after setKernelParams or update KernelParams::setKernelParams to
perform the override.
🧹 Nitpick comments (3)
include/flashinfer/trtllm/fmha/kernelParams.h (1)

717-726: reshapeFactorKv is derived from K's bit width only.

When K and V have different element sizes, this factor is applied identically to both the K and V TMA shapes (via the reshapeFactor argument of makeTmaShapeStrideKv). If V has a different byte width, the resulting V shape may not land on the 128B box width this heuristic targets, even though tileShapeV[0] is rescaled later. Worth a one-line comment documenting that K drives the reshape and V byte-parity is restored only through the tileShape rescale at Lines 753-759.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@include/flashinfer/trtllm/fmha/kernelParams.h` around lines 717 - 726, Add a
one-line comment next to the reshapeFactorKv calculation (symbols:
reshapeFactorKv, canReshapeTmaKv, get_size_in_bits, kernelMeta.mDataTypeK)
explaining that the reshape factor is computed from K's element bit-width only
and will be applied to both K and V TMA shapes via makeTmaShapeStrideKv, and
that any byte-width mismatch for V is corrected later by rescaling tileShapeV
(tileShapeV rescale logic around makeTmaShapeStrideKv usage) to restore 128B box
alignment.
include/flashinfer/trtllm/fmha/fmhaRunner.cuh (2)

38-47: Datatype asymmetry between K and V is intentional — worth documenting.

K now accepts INT8 while V does not. This is consistent with the available cubins (V kernels aren't built for INT8), but the asymmetry is surprising to readers and will bite anyone who writes TllmGenFmhaRunner(q, DATA_TYPE_INT8, o) via the 3-arg convenience ctor (Line 57) — it will hit the V validation with DATA_TYPE_INT8 and throw despite INT8 being "supported for K". A brief comment here, or an earlier guard in the convenience ctor, would make this much less confusing.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@include/flashinfer/trtllm/fmha/fmhaRunner.cuh` around lines 38 - 47, The K-v
datatype asymmetry (mDtypeK allowing DATA_TYPE_INT8 while mDtypeV does not) is
intentional but surprising; add a brief clarifying comment next to the
FLASHINFER_CHECKs referencing mDtypeK and mDtypeV explaining that V kernels are
not built for INT8, and also add an explicit early-guard in the 3-arg
convenience constructor TllmGenFmhaRunner(...) (or the function that forwards
those args) that either converts an INT8 V request to a supported V type or
throws a clear error before reaching the mDtypeV validation so callers like
TllmGenFmhaRunner(q, DATA_TYPE_INT8, o) fail with a descriptive message about V
not supporting INT8.

56-58: Convenience constructor drops SageAttention block sizes.

The 3-argument form forwards to the primary constructor with all four numEltsPerSageAttnBlk* defaulted to 0, i.e. SageAttention-off. That's fine for the existing call site in csrc/trtllm_fmha_kernel_launcher.cu, but if a future caller needs SageAttention they must use the 5-arg form — worth a comment so nobody assumes the 3-arg ctor quietly enables Sage when the runtime params carry Sage pointers.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@include/flashinfer/trtllm/fmha/fmhaRunner.cuh` around lines 56 - 58, The
3-argument convenience constructor TllmGenFmhaRunner(dtypeQ,dtypeKv,dtypeOut)
silently disables SageAttention by forwarding to the main ctor with all
numEltsPerSageAttnBlk* parameters set to 0; add a brief comment next to this
constructor (or update its docstring) stating that it intentionally defaults all
numEltsPerSageAttnBlk* to 0 and therefore SageAttention is off, and that callers
who need SageAttention must use the 5-argument constructor that accepts explicit
numEltsPerSageAttnBlk* values and Sage pointers.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@include/flashinfer/trtllm/fmha/fmhaKernels.cuh`:
- Around line 249-253: The lambda sageParamEncode uses GCC/Clang-only
__builtin_ctz which breaks MSVC builds; modify sageParamEncode to detect
compiler and use a portable fallback: keep the existing ternary guard (blockSize
== 0) but replace the direct __builtin_ctz call with a small wrapper function
(e.g., count_trailing_zeros) that at compile-time selects __builtin_ctz for
GCC/Clang, _BitScanForward/_tzcnt/_lzcnt for MSVC, and a simple loop fallback
otherwise; ensure you still call FLASHINFER_CHECK for power-of-two validation
and return 0 for blockSize==0 to preserve behavior.

In `@include/flashinfer/trtllm/fmha/kernelParams.h`:
- Around line 750-759: The if-block adjusting tileShapeV has inconsistent
indentation: the body lines inside the if (the comment and the tileShapeV[0]
assignment) and the closing brace are indented with 4 spaces instead of the
file's 2-space style. Fix by reindenting the comment, the tileShapeV[0]
calculation, and the closing brace to use 2-space indentation so the block
around symbols tileShapeK, tileShapeV, storeTransformedKvInTmem, and
kernelMeta.mDataTypeK / kernelMeta.mDataTypeV matches the file's formatting
standard (or run pre-commit/clang-format to apply the same change).
- Around line 753-759: Add a divisibility check before scaling tileShapeV[0] in
the block guarded by storeTransformedKvInTmem and kernelMeta.mDataTypeK !=
kernelMeta.mDataTypeV to prevent silent integer truncation: compute numerator =
int64(tileShapeV[0]) * get_size_in_bits(kernelMeta.mDataTypeK) and denominator =
get_size_in_bits(kernelMeta.mDataTypeV), assert (numerator % denominator == 0)
using FLASHINFER_CHECK with a clear error message, then assign tileShapeV[0] =
static_cast<uint32_t>(numerator / denominator); this ensures exact divisibility
when adjusting tileShapeV[0] for differing K/V bit widths.

---

Duplicate comments:
In `@include/flashinfer/trtllm/fmha/fmhaKernels.cuh`:
- Around line 289-309: The SageAttention overrides (ptrSageAttnSfs* and
mLogNumEltsPerSageAttnBlk* stored in kernelParams) are lost when kernelParams is
rebuilt by KernelParams::setKernelParams during the CGA→GMEM fallback; fix by
reapplying the SageAttention override logic immediately after kernelParams is
reassigned here (i.e., call the same helper that set the Sage pointers/log-block
sizes earlier) or refactor KernelParams::setKernelParams to carry over or
initialize the SageAttention fields unconditionally so they are preserved after
setKernelParams; locate the reassign site of kernelParams, the original override
block that sets ptrSageAttnSfs*/mLogNumEltsPerSageAttnBlk*, and either invoke
that helper after setKernelParams or update KernelParams::setKernelParams to
perform the override.

---

Nitpick comments:
In `@include/flashinfer/trtllm/fmha/fmhaRunner.cuh`:
- Around line 38-47: The K-v datatype asymmetry (mDtypeK allowing DATA_TYPE_INT8
while mDtypeV does not) is intentional but surprising; add a brief clarifying
comment next to the FLASHINFER_CHECKs referencing mDtypeK and mDtypeV explaining
that V kernels are not built for INT8, and also add an explicit early-guard in
the 3-arg convenience constructor TllmGenFmhaRunner(...) (or the function that
forwards those args) that either converts an INT8 V request to a supported V
type or throws a clear error before reaching the mDtypeV validation so callers
like TllmGenFmhaRunner(q, DATA_TYPE_INT8, o) fail with a descriptive message
about V not supporting INT8.
- Around line 56-58: The 3-argument convenience constructor
TllmGenFmhaRunner(dtypeQ,dtypeKv,dtypeOut) silently disables SageAttention by
forwarding to the main ctor with all numEltsPerSageAttnBlk* parameters set to 0;
add a brief comment next to this constructor (or update its docstring) stating
that it intentionally defaults all numEltsPerSageAttnBlk* to 0 and therefore
SageAttention is off, and that callers who need SageAttention must use the
5-argument constructor that accepts explicit numEltsPerSageAttnBlk* values and
Sage pointers.

In `@include/flashinfer/trtllm/fmha/kernelParams.h`:
- Around line 717-726: Add a one-line comment next to the reshapeFactorKv
calculation (symbols: reshapeFactorKv, canReshapeTmaKv, get_size_in_bits,
kernelMeta.mDataTypeK) explaining that the reshape factor is computed from K's
element bit-width only and will be applied to both K and V TMA shapes via
makeTmaShapeStrideKv, and that any byte-width mismatch for V is corrected later
by rescaling tileShapeV (tileShapeV rescale logic around makeTmaShapeStrideKv
usage) to restore 128B box alignment.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: f8be8e90-4a85-4d2b-91eb-15dad98a33c3

📥 Commits

Reviewing files that changed from the base of the PR and between 16dbc0b and 212e310.

📒 Files selected for processing (5)
  • csrc/fmhaReduction.cu
  • include/flashinfer/trtllm/fmha/fmhaKernels.cuh
  • include/flashinfer/trtllm/fmha/fmhaRunner.cuh
  • include/flashinfer/trtllm/fmha/fmhaRunnerParams.h
  • include/flashinfer/trtllm/fmha/kernelParams.h
🚧 Files skipped from review as they are similar to previous changes (2)
  • csrc/fmhaReduction.cu
  • include/flashinfer/trtllm/fmha/fmhaRunnerParams.h

Comment thread include/flashinfer/trtllm/fmha/fmhaKernels.cuh Outdated
Comment thread include/flashinfer/trtllm/fmha/kernelParams.h Outdated
Comment thread include/flashinfer/trtllm/fmha/kernelParams.h Outdated
@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !560 has been updated with latest changes, and the CI pipeline #48744929 is currently running. I'll report back once the pipeline job completes.

Comment thread include/flashinfer/trtllm/fmha/kernelParams.h
Comment thread include/flashinfer/trtllm/fmha/fmhaKernels.cuh Outdated
@PerkzZheng PerkzZheng force-pushed the perkzz/update-trtllm-gen-fmha-cubins branch from 212e310 to 6c56ac8 Compare April 20, 2026 02:39
@PerkzZheng
Copy link
Copy Markdown
Contributor Author

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !560 has been updated with latest changes, and the CI pipeline #48947668 is currently running. I'll report back once the pipeline job completes.

@PerkzZheng PerkzZheng force-pushed the perkzz/update-trtllm-gen-fmha-cubins branch from 6c56ac8 to 31d2b46 Compare April 20, 2026 13:19
@PerkzZheng
Copy link
Copy Markdown
Contributor Author

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !560 has been updated with latest changes, and the CI pipeline #48988940 is currently running. I'll report back once the pipeline job completes.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (2)
include/flashinfer/trtllm/fmha/fmhaKernels.cuh (2)

197-204: hashID parameter name still sparseMla while metadata field is now mSparseAttn.

Line 202 correctly casts the new int mSparseAttn to a bool via != 0, but the underlying hashID(..., bool sparseMla, ...) signature and the bit-55 comment at line 181 still reference the old sparseMla naming. Functionally correct (bit semantics unchanged: any non-zero sparse-attn value collapses to the same bit), just a naming consistency nit if you want the header to match the renamed field.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@include/flashinfer/trtllm/fmha/fmhaKernels.cuh` around lines 197 - 204, The
hashID function still uses the old parameter name sparseMla and the bit-55
comment refers to sparseMla even though KernelMeta renamed the field to
mSparseAttn; update the hashID signature/parameter name to reflect mSparseAttn
(or sparseAttn) and adjust the bit-55 comment to mention mSparseAttn (or
sparseAttn) so naming is consistent with KernelMeta and the cast (mSparseAttn !=
0) remains functionally identical.

116-122: Sage kernel substring filter: Consider tightening to word-boundary match for robustness.

The substring match std::strstr(mFuncName, "Sage") will skip any kernel whose symbol name contains the literal "Sage" anywhere in it. The comment correctly notes that SageAttention kernels share the same hashID as their non-Sage counterparts (block sizes are intentionally excluded from the hash), so filtering is necessary to avoid false conflicts.

However, since kernel names come from externally-generated TensorRT LLM metadata, consider using a more specific match pattern (e.g., "_Sage" or "Sage_" prefix/word boundary) to reduce the risk of accidentally filtering a non-SageAttention kernel if future naming conventions introduce the substring elsewhere.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@include/flashinfer/trtllm/fmha/fmhaKernels.cuh` around lines 116 - 122, The
current filter uses std::strstr(kernelMeta.mFuncName, "Sage") which may match
unintended symbols; tighten the check in the loop that examines
kernelMeta.mFuncName to only exclude true SageAttention kernels by matching a
stricter pattern (e.g., check for substrings like "_Sage", "Sage_", or the full
token "SageAttention") instead of a raw "Sage" substring. Update the condition
around kernelMeta.mFuncName to use these more specific tests (or convert to
std::string and use find/regex for word-boundary matching) so only intended Sage
kernels are skipped.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@include/flashinfer/trtllm/fmha/fmhaKernels.cuh`:
- Around line 197-204: The hashID function still uses the old parameter name
sparseMla and the bit-55 comment refers to sparseMla even though KernelMeta
renamed the field to mSparseAttn; update the hashID signature/parameter name to
reflect mSparseAttn (or sparseAttn) and adjust the bit-55 comment to mention
mSparseAttn (or sparseAttn) so naming is consistent with KernelMeta and the cast
(mSparseAttn != 0) remains functionally identical.
- Around line 116-122: The current filter uses std::strstr(kernelMeta.mFuncName,
"Sage") which may match unintended symbols; tighten the check in the loop that
examines kernelMeta.mFuncName to only exclude true SageAttention kernels by
matching a stricter pattern (e.g., check for substrings like "_Sage", "Sage_",
or the full token "SageAttention") instead of a raw "Sage" substring. Update the
condition around kernelMeta.mFuncName to use these more specific tests (or
convert to std::string and use find/regex for word-boundary matching) so only
intended Sage kernels are skipped.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 2b93f353-a0cd-4fff-b5e9-c06aaf817f34

📥 Commits

Reviewing files that changed from the base of the PR and between 6c56ac8 and 31d2b46.

📒 Files selected for processing (3)
  • csrc/fmhaReduction.cu
  • include/flashinfer/trtllm/fmha/fmhaKernels.cuh
  • include/flashinfer/trtllm/fmha/kernelParams.h
🚧 Files skipped from review as they are similar to previous changes (1)
  • csrc/fmhaReduction.cu

@PerkzZheng
Copy link
Copy Markdown
Contributor Author

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !560 has been updated with latest changes, and the CI pipeline #49157734 is currently running. I'll report back once the pipeline job completes.

@PerkzZheng
Copy link
Copy Markdown
Contributor Author

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !560 has been created, and the CI pipeline #49167948 is currently running. I'll report back once the pipeline job completes.

@PerkzZheng
Copy link
Copy Markdown
Contributor Author

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !560 has been updated with latest changes, and the CI pipeline #49179138 is currently running. I'll report back once the pipeline job completes.

PerkzZheng and others added 3 commits April 23, 2026 03:34
Minimal header changes to match the new trtllm-gen FMHA cubin MetaInfo
struct layout:

- TllmGenFmhaKernelMetaInfo: renamed mSparseMla (bool) -> mSparseAttn
  (int). Callers convert to bool via `!= 0`.
- KernelParams (GPU-side struct): renamed mSparseMlaTopK -> mSparseAttnTopK
  and moved immediately after mSkipSoftmaxThresholdScaleFactor to match
  the layout expected by the new kernels.

The K/V dtype split (mDataTypeKv -> mDataTypeK/V) and SageAttention block
size fields present in the new struct are layout-compatible but not used,
so no code changes are needed for those -- existing references to
mDataTypeKv still compile since the cubin-supplied struct keeps that
field alongside the new mDataTypeK/V.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
New cubins from cubin_publishing pipeline 49098275 fix the
gpt-oss-120b B200 TP=2 prefix caching corruption (NVBug 5922676).

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
The trtllm-gen cubin at commit 134850621 ships zero SwapsMmaAb
CgaSmemReduction kernels at headDimV >= 512. This covers both:

- MLA decode (headDimQk=576, headDimV=512): the prior cubin shipped 234
  MLA CGA kernels; 134850621 strips them due to an over-broad
  `headDim >= 512 && SwapsMmaAb && CGA -> skip` rule in trtllm-gen's
  ExportCubin.
- Non-MLA headDim=512 (head_dim_512 tests added in flashinfer-ai#2959): same skip
  predicate drops these too.

Promoting to CgaSmemReduction in either regime misses the kernel map,
so gate the promotion on `params.mHeadDimV < 512`. Additionally, for
tileSizeQ >= 32 the CGA variant would exceed the 232KB Blackwell smem
limit, so the guard is correct on the hardware side as well.

This guard can be narrowed once trtllm-gen ships a cubin with the
tileSizeQ>=32 + headDimPerCtaV>=512 skip predicate (trtllm-gen MR !928).

Also moves the kernel-not-found check ahead of the iterator dereference
in loadKernel to avoid UB on miss.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
@PerkzZheng PerkzZheng force-pushed the perkzz/update-trtllm-gen-fmha-cubins branch from 8b003a4 to 365a9ad Compare April 23, 2026 03:50
@PerkzZheng
Copy link
Copy Markdown
Contributor Author

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !560 has been updated with latest changes, and the CI pipeline #49256243 is currently running. I'll report back once the pipeline job completes.

@saltyminty saltyminty merged commit b809821 into flashinfer-ai:main Apr 23, 2026
42 of 43 checks passed
@kmrao-nv kmrao-nv added the v0.6.10 release blocker label for 0.6.10 label Apr 24, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

run-ci v0.6.10 release blocker label for 0.6.10

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants