[Fmha] update trtllm-gen FMHA cubins and sync headers for context SWA fix#3089
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughRenamed sparse Top-K field ( Changes
Sequence Diagram(s)(omitted) Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 5✅ Passed checks (5 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
/bot run |
There was a problem hiding this comment.
Code Review
This pull request enables support for separate K and V data types and integrates SageAttention scaling factors and block sizes into the TRT-LLM FMHA runner and kernels. Key updates include modifying the kernel cache key, extending interfaces for distinct K/V types, and adding INT8 support for query and key tensors. Feedback recommends optimizing the integer log2 calculation in the kernel factory via bit manipulation and extending INT8 support to the value tensor check for consistency.
There was a problem hiding this comment.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
include/flashinfer/trtllm/fmha/kernelParams.h (1)
701-770:⚠️ Potential issue | 🟠 MajorFinish the K/V dtype split in the FP4/TMEM branches.
This block still makes
storeTransformedKvInTmem,numEltsDivisor, and the FP4-SF path depend onkernelMeta.mDataTypeKv. That keeps the old “K and V always share a dtype” assumption alive, so mixed cases likeK=FP16, V=E2M1orK=E2M1, V=FP16will build the wrong descriptor/SF state.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@include/flashinfer/trtllm/fmha/kernelParams.h` around lines 701 - 770, The code assumes K and V share dtype by using kernelMeta.mDataTypeKv in places that must be per-tensor; split the logic into K- and V-specific variants: e.g., introduce storeTransformedKInTmem and storeTransformedVInTmem (instead of storeTransformedKvInTmem), numEltsDivisorK/V (instead of numEltsDivisor), and compute reshapeFactor/shape/stride/tileShape for K and V by calling makeTmaShapeStrideKv with kernelMeta.mDataTypeK and kernelMeta.mDataTypeV respectively; ensure the FP4 dequant/scaling path and the buildNdTmaDescriptor unpack4b/swizzled flags use the corresponding per-tensor booleans, and adjust the tileShapeV conversion to use get_size_in_bits(kernelMeta.mDataTypeK) vs get_size_in_bits(kernelMeta.mDataTypeV) as currently done but guarded by the new per-tensor storeTransformed flags; update any downstream FP4-SF handling to reference kernelMeta.mDataTypeK and kernelMeta.mDataTypeV separately.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@include/flashinfer/trtllm/fmha/fmhaKernels.cuh`:
- Around line 1028-1049: The current sageEncode() inside hashID() compresses
block sizes with log2f, causing different positive sizes (e.g., 8 and 12) to
collide; replace that encoding with the actual size (clamped to 4 bits) so the
4-bit fields uniquely represent the block count. Concretely, change the lambda
used in hashID (sageEncode) to return n > 0 ? static_cast<uint64_t>(n) & 0xF : 0
(or otherwise clamp n to [0,15]) and use those values in the same shift
positions so the fields for numEltsPerSageAttnBlkQ/K/P/V store the raw (clamped)
sizes instead of log2(size)+1.
- Around line 248-261: The SageAttention override values (ptrSageAttnSfsQ/K/P/V
and mLogNumEltsPerSageAttnBlk*) are applied to kernelParams via the local
sageParamEncode lambda but are lost when KernelParams::setKernelParams() is
later called for the CGA→GMEM fallback; reapply the same overrides immediately
after any call to KernelParams::setKernelParams() (or move the Sage override
logic inside KernelParams::setKernelParams()) so that
kernelParams.ptrSageAttnSfsQ/K/P/V and
kernelParams.mLogNumEltsPerSageAttnBlkQ/K/P/V are set regardless of fallback
path (reference: kernelParams, sageParamEncode, KernelParams::setKernelParams).
In `@include/flashinfer/trtllm/fmha/kernelParams.h`:
- Line 663: The code calls getDevicePtrs(options,
get_size_in_bits(kernelMeta.mDataTypeK)) but getDevicePtrs applies a single
bitsPerElt to Q, K, and V, producing incorrect PackedQkv/ContiguousKv offsets
for mixed dtypes; change the implementation to derive offsets using each
tensor's own element size (use get_size_in_bits(kernelMeta.mDataTypeQ),
get_size_in_bits(kernelMeta.mDataTypeK),
get_size_in_bits(kernelMeta.mDataTypeV)) or update getDevicePtrs to accept
per-tensor bit sizes (or reject mixed-width layouts early), and then compute
PackedQkv and ContiguousKv offsets using those per-tensor sizes so Q, K, V
offsets are correct for mixed-width cases.
---
Outside diff comments:
In `@include/flashinfer/trtllm/fmha/kernelParams.h`:
- Around line 701-770: The code assumes K and V share dtype by using
kernelMeta.mDataTypeKv in places that must be per-tensor; split the logic into
K- and V-specific variants: e.g., introduce storeTransformedKInTmem and
storeTransformedVInTmem (instead of storeTransformedKvInTmem), numEltsDivisorK/V
(instead of numEltsDivisor), and compute reshapeFactor/shape/stride/tileShape
for K and V by calling makeTmaShapeStrideKv with kernelMeta.mDataTypeK and
kernelMeta.mDataTypeV respectively; ensure the FP4 dequant/scaling path and the
buildNdTmaDescriptor unpack4b/swizzled flags use the corresponding per-tensor
booleans, and adjust the tileShapeV conversion to use
get_size_in_bits(kernelMeta.mDataTypeK) vs
get_size_in_bits(kernelMeta.mDataTypeV) as currently done but guarded by the new
per-tensor storeTransformed flags; update any downstream FP4-SF handling to
reference kernelMeta.mDataTypeK and kernelMeta.mDataTypeV separately.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: b4f429f9-18fd-426c-a934-a2a51554bbd1
📒 Files selected for processing (7)
csrc/fmhaReduction.cucsrc/trtllm_fmha_kernel_launcher.cuflashinfer/artifacts.pyinclude/flashinfer/trtllm/fmha/fmhaKernels.cuhinclude/flashinfer/trtllm/fmha/fmhaRunner.cuhinclude/flashinfer/trtllm/fmha/fmhaRunnerParams.hinclude/flashinfer/trtllm/fmha/kernelParams.h
16dbc0b to
212e310
Compare
|
/bot run |
There was a problem hiding this comment.
Actionable comments posted: 3
♻️ Duplicate comments (1)
include/flashinfer/trtllm/fmha/fmhaKernels.cuh (1)
289-309:⚠️ Potential issue | 🟠 MajorSageAttention overrides are lost on the CGA→GMEM fallback re-build.
setKernelParamsis called again on Line 304-305 and the resultingkernelParamsoverwrites the SageAttention pointers/log block sizes set at Lines 254-261, so any request that hits this fallback will launch with nullptrSageAttnSfs*and zeromLogNumEltsPerSageAttnBlk*. Please either move the override block into a helper and re-invoke it here, or push the override logic intoKernelParams::setKernelParamsso it is applied unconditionally.🔧 Proposed fix
- kernelParams = KernelParams::setKernelParams( - params, kernelMeta, ctaLaunchParams.mMaxNumCtasQ, ctaLaunchParams.mMaxNumCtasKv); + kernelParams = KernelParams::setKernelParams( + params, kernelMeta, ctaLaunchParams.mMaxNumCtasQ, ctaLaunchParams.mMaxNumCtasKv); + kernelParams.ptrSageAttnSfsQ = params.ptrSageAttnSfsQ; + kernelParams.ptrSageAttnSfsK = params.ptrSageAttnSfsK; + kernelParams.ptrSageAttnSfsP = params.ptrSageAttnSfsP; + kernelParams.ptrSageAttnSfsV = params.ptrSageAttnSfsV; + kernelParams.mLogNumEltsPerSageAttnBlkQ = sageParamEncode(kernelMeta.mNumEltsPerSageAttnBlkQ); + kernelParams.mLogNumEltsPerSageAttnBlkK = sageParamEncode(kernelMeta.mNumEltsPerSageAttnBlkK); + kernelParams.mLogNumEltsPerSageAttnBlkP = sageParamEncode(kernelMeta.mNumEltsPerSageAttnBlkP); + kernelParams.mLogNumEltsPerSageAttnBlkV = sageParamEncode(kernelMeta.mNumEltsPerSageAttnBlkV);🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@include/flashinfer/trtllm/fmha/fmhaKernels.cuh` around lines 289 - 309, The SageAttention overrides (ptrSageAttnSfs* and mLogNumEltsPerSageAttnBlk* stored in kernelParams) are lost when kernelParams is rebuilt by KernelParams::setKernelParams during the CGA→GMEM fallback; fix by reapplying the SageAttention override logic immediately after kernelParams is reassigned here (i.e., call the same helper that set the Sage pointers/log-block sizes earlier) or refactor KernelParams::setKernelParams to carry over or initialize the SageAttention fields unconditionally so they are preserved after setKernelParams; locate the reassign site of kernelParams, the original override block that sets ptrSageAttnSfs*/mLogNumEltsPerSageAttnBlk*, and either invoke that helper after setKernelParams or update KernelParams::setKernelParams to perform the override.
🧹 Nitpick comments (3)
include/flashinfer/trtllm/fmha/kernelParams.h (1)
717-726:reshapeFactorKvis derived from K's bit width only.When K and V have different element sizes, this factor is applied identically to both the K and V TMA shapes (via the
reshapeFactorargument ofmakeTmaShapeStrideKv). If V has a different byte width, the resulting V shape may not land on the 128B box width this heuristic targets, even thoughtileShapeV[0]is rescaled later. Worth a one-line comment documenting that K drives the reshape and V byte-parity is restored only through the tileShape rescale at Lines 753-759.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@include/flashinfer/trtllm/fmha/kernelParams.h` around lines 717 - 726, Add a one-line comment next to the reshapeFactorKv calculation (symbols: reshapeFactorKv, canReshapeTmaKv, get_size_in_bits, kernelMeta.mDataTypeK) explaining that the reshape factor is computed from K's element bit-width only and will be applied to both K and V TMA shapes via makeTmaShapeStrideKv, and that any byte-width mismatch for V is corrected later by rescaling tileShapeV (tileShapeV rescale logic around makeTmaShapeStrideKv usage) to restore 128B box alignment.include/flashinfer/trtllm/fmha/fmhaRunner.cuh (2)
38-47: Datatype asymmetry between K and V is intentional — worth documenting.K now accepts
INT8while V does not. This is consistent with the available cubins (V kernels aren't built forINT8), but the asymmetry is surprising to readers and will bite anyone who writesTllmGenFmhaRunner(q, DATA_TYPE_INT8, o)via the 3-arg convenience ctor (Line 57) — it will hit the V validation withDATA_TYPE_INT8and throw despite INT8 being "supported for K". A brief comment here, or an earlier guard in the convenience ctor, would make this much less confusing.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@include/flashinfer/trtllm/fmha/fmhaRunner.cuh` around lines 38 - 47, The K-v datatype asymmetry (mDtypeK allowing DATA_TYPE_INT8 while mDtypeV does not) is intentional but surprising; add a brief clarifying comment next to the FLASHINFER_CHECKs referencing mDtypeK and mDtypeV explaining that V kernels are not built for INT8, and also add an explicit early-guard in the 3-arg convenience constructor TllmGenFmhaRunner(...) (or the function that forwards those args) that either converts an INT8 V request to a supported V type or throws a clear error before reaching the mDtypeV validation so callers like TllmGenFmhaRunner(q, DATA_TYPE_INT8, o) fail with a descriptive message about V not supporting INT8.
56-58: Convenience constructor drops SageAttention block sizes.The 3-argument form forwards to the primary constructor with all four
numEltsPerSageAttnBlk*defaulted to 0, i.e. SageAttention-off. That's fine for the existing call site incsrc/trtllm_fmha_kernel_launcher.cu, but if a future caller needs SageAttention they must use the 5-arg form — worth a comment so nobody assumes the 3-arg ctor quietly enables Sage when the runtime params carry Sage pointers.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@include/flashinfer/trtllm/fmha/fmhaRunner.cuh` around lines 56 - 58, The 3-argument convenience constructor TllmGenFmhaRunner(dtypeQ,dtypeKv,dtypeOut) silently disables SageAttention by forwarding to the main ctor with all numEltsPerSageAttnBlk* parameters set to 0; add a brief comment next to this constructor (or update its docstring) stating that it intentionally defaults all numEltsPerSageAttnBlk* to 0 and therefore SageAttention is off, and that callers who need SageAttention must use the 5-argument constructor that accepts explicit numEltsPerSageAttnBlk* values and Sage pointers.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@include/flashinfer/trtllm/fmha/fmhaKernels.cuh`:
- Around line 249-253: The lambda sageParamEncode uses GCC/Clang-only
__builtin_ctz which breaks MSVC builds; modify sageParamEncode to detect
compiler and use a portable fallback: keep the existing ternary guard (blockSize
== 0) but replace the direct __builtin_ctz call with a small wrapper function
(e.g., count_trailing_zeros) that at compile-time selects __builtin_ctz for
GCC/Clang, _BitScanForward/_tzcnt/_lzcnt for MSVC, and a simple loop fallback
otherwise; ensure you still call FLASHINFER_CHECK for power-of-two validation
and return 0 for blockSize==0 to preserve behavior.
In `@include/flashinfer/trtllm/fmha/kernelParams.h`:
- Around line 750-759: The if-block adjusting tileShapeV has inconsistent
indentation: the body lines inside the if (the comment and the tileShapeV[0]
assignment) and the closing brace are indented with 4 spaces instead of the
file's 2-space style. Fix by reindenting the comment, the tileShapeV[0]
calculation, and the closing brace to use 2-space indentation so the block
around symbols tileShapeK, tileShapeV, storeTransformedKvInTmem, and
kernelMeta.mDataTypeK / kernelMeta.mDataTypeV matches the file's formatting
standard (or run pre-commit/clang-format to apply the same change).
- Around line 753-759: Add a divisibility check before scaling tileShapeV[0] in
the block guarded by storeTransformedKvInTmem and kernelMeta.mDataTypeK !=
kernelMeta.mDataTypeV to prevent silent integer truncation: compute numerator =
int64(tileShapeV[0]) * get_size_in_bits(kernelMeta.mDataTypeK) and denominator =
get_size_in_bits(kernelMeta.mDataTypeV), assert (numerator % denominator == 0)
using FLASHINFER_CHECK with a clear error message, then assign tileShapeV[0] =
static_cast<uint32_t>(numerator / denominator); this ensures exact divisibility
when adjusting tileShapeV[0] for differing K/V bit widths.
---
Duplicate comments:
In `@include/flashinfer/trtllm/fmha/fmhaKernels.cuh`:
- Around line 289-309: The SageAttention overrides (ptrSageAttnSfs* and
mLogNumEltsPerSageAttnBlk* stored in kernelParams) are lost when kernelParams is
rebuilt by KernelParams::setKernelParams during the CGA→GMEM fallback; fix by
reapplying the SageAttention override logic immediately after kernelParams is
reassigned here (i.e., call the same helper that set the Sage pointers/log-block
sizes earlier) or refactor KernelParams::setKernelParams to carry over or
initialize the SageAttention fields unconditionally so they are preserved after
setKernelParams; locate the reassign site of kernelParams, the original override
block that sets ptrSageAttnSfs*/mLogNumEltsPerSageAttnBlk*, and either invoke
that helper after setKernelParams or update KernelParams::setKernelParams to
perform the override.
---
Nitpick comments:
In `@include/flashinfer/trtllm/fmha/fmhaRunner.cuh`:
- Around line 38-47: The K-v datatype asymmetry (mDtypeK allowing DATA_TYPE_INT8
while mDtypeV does not) is intentional but surprising; add a brief clarifying
comment next to the FLASHINFER_CHECKs referencing mDtypeK and mDtypeV explaining
that V kernels are not built for INT8, and also add an explicit early-guard in
the 3-arg convenience constructor TllmGenFmhaRunner(...) (or the function that
forwards those args) that either converts an INT8 V request to a supported V
type or throws a clear error before reaching the mDtypeV validation so callers
like TllmGenFmhaRunner(q, DATA_TYPE_INT8, o) fail with a descriptive message
about V not supporting INT8.
- Around line 56-58: The 3-argument convenience constructor
TllmGenFmhaRunner(dtypeQ,dtypeKv,dtypeOut) silently disables SageAttention by
forwarding to the main ctor with all numEltsPerSageAttnBlk* parameters set to 0;
add a brief comment next to this constructor (or update its docstring) stating
that it intentionally defaults all numEltsPerSageAttnBlk* to 0 and therefore
SageAttention is off, and that callers who need SageAttention must use the
5-argument constructor that accepts explicit numEltsPerSageAttnBlk* values and
Sage pointers.
In `@include/flashinfer/trtllm/fmha/kernelParams.h`:
- Around line 717-726: Add a one-line comment next to the reshapeFactorKv
calculation (symbols: reshapeFactorKv, canReshapeTmaKv, get_size_in_bits,
kernelMeta.mDataTypeK) explaining that the reshape factor is computed from K's
element bit-width only and will be applied to both K and V TMA shapes via
makeTmaShapeStrideKv, and that any byte-width mismatch for V is corrected later
by rescaling tileShapeV (tileShapeV rescale logic around makeTmaShapeStrideKv
usage) to restore 128B box alignment.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: f8be8e90-4a85-4d2b-91eb-15dad98a33c3
📒 Files selected for processing (5)
csrc/fmhaReduction.cuinclude/flashinfer/trtllm/fmha/fmhaKernels.cuhinclude/flashinfer/trtllm/fmha/fmhaRunner.cuhinclude/flashinfer/trtllm/fmha/fmhaRunnerParams.hinclude/flashinfer/trtllm/fmha/kernelParams.h
🚧 Files skipped from review as they are similar to previous changes (2)
- csrc/fmhaReduction.cu
- include/flashinfer/trtllm/fmha/fmhaRunnerParams.h
212e310 to
6c56ac8
Compare
|
/bot run |
6c56ac8 to
31d2b46
Compare
|
/bot run |
There was a problem hiding this comment.
🧹 Nitpick comments (2)
include/flashinfer/trtllm/fmha/fmhaKernels.cuh (2)
197-204: hashID parameter name stillsparseMlawhile metadata field is nowmSparseAttn.Line 202 correctly casts the new
int mSparseAttnto a bool via!= 0, but the underlyinghashID(..., bool sparseMla, ...)signature and the bit-55 comment at line 181 still reference the oldsparseMlanaming. Functionally correct (bit semantics unchanged: any non-zero sparse-attn value collapses to the same bit), just a naming consistency nit if you want the header to match the renamed field.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@include/flashinfer/trtllm/fmha/fmhaKernels.cuh` around lines 197 - 204, The hashID function still uses the old parameter name sparseMla and the bit-55 comment refers to sparseMla even though KernelMeta renamed the field to mSparseAttn; update the hashID signature/parameter name to reflect mSparseAttn (or sparseAttn) and adjust the bit-55 comment to mention mSparseAttn (or sparseAttn) so naming is consistent with KernelMeta and the cast (mSparseAttn != 0) remains functionally identical.
116-122: Sage kernel substring filter: Consider tightening to word-boundary match for robustness.The substring match
std::strstr(mFuncName, "Sage")will skip any kernel whose symbol name contains the literal"Sage"anywhere in it. The comment correctly notes that SageAttention kernels share the same hashID as their non-Sage counterparts (block sizes are intentionally excluded from the hash), so filtering is necessary to avoid false conflicts.However, since kernel names come from externally-generated TensorRT LLM metadata, consider using a more specific match pattern (e.g.,
"_Sage"or"Sage_"prefix/word boundary) to reduce the risk of accidentally filtering a non-SageAttention kernel if future naming conventions introduce the substring elsewhere.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@include/flashinfer/trtllm/fmha/fmhaKernels.cuh` around lines 116 - 122, The current filter uses std::strstr(kernelMeta.mFuncName, "Sage") which may match unintended symbols; tighten the check in the loop that examines kernelMeta.mFuncName to only exclude true SageAttention kernels by matching a stricter pattern (e.g., check for substrings like "_Sage", "Sage_", or the full token "SageAttention") instead of a raw "Sage" substring. Update the condition around kernelMeta.mFuncName to use these more specific tests (or convert to std::string and use find/regex for word-boundary matching) so only intended Sage kernels are skipped.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@include/flashinfer/trtllm/fmha/fmhaKernels.cuh`:
- Around line 197-204: The hashID function still uses the old parameter name
sparseMla and the bit-55 comment refers to sparseMla even though KernelMeta
renamed the field to mSparseAttn; update the hashID signature/parameter name to
reflect mSparseAttn (or sparseAttn) and adjust the bit-55 comment to mention
mSparseAttn (or sparseAttn) so naming is consistent with KernelMeta and the cast
(mSparseAttn != 0) remains functionally identical.
- Around line 116-122: The current filter uses std::strstr(kernelMeta.mFuncName,
"Sage") which may match unintended symbols; tighten the check in the loop that
examines kernelMeta.mFuncName to only exclude true SageAttention kernels by
matching a stricter pattern (e.g., check for substrings like "_Sage", "Sage_",
or the full token "SageAttention") instead of a raw "Sage" substring. Update the
condition around kernelMeta.mFuncName to use these more specific tests (or
convert to std::string and use find/regex for word-boundary matching) so only
intended Sage kernels are skipped.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 2b93f353-a0cd-4fff-b5e9-c06aaf817f34
📒 Files selected for processing (3)
csrc/fmhaReduction.cuinclude/flashinfer/trtllm/fmha/fmhaKernels.cuhinclude/flashinfer/trtllm/fmha/kernelParams.h
🚧 Files skipped from review as they are similar to previous changes (1)
- csrc/fmhaReduction.cu
|
/bot run |
|
/bot run |
|
/bot run |
Minimal header changes to match the new trtllm-gen FMHA cubin MetaInfo struct layout: - TllmGenFmhaKernelMetaInfo: renamed mSparseMla (bool) -> mSparseAttn (int). Callers convert to bool via `!= 0`. - KernelParams (GPU-side struct): renamed mSparseMlaTopK -> mSparseAttnTopK and moved immediately after mSkipSoftmaxThresholdScaleFactor to match the layout expected by the new kernels. The K/V dtype split (mDataTypeKv -> mDataTypeK/V) and SageAttention block size fields present in the new struct are layout-compatible but not used, so no code changes are needed for those -- existing references to mDataTypeKv still compile since the cubin-supplied struct keeps that field alongside the new mDataTypeK/V. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
New cubins from cubin_publishing pipeline 49098275 fix the gpt-oss-120b B200 TP=2 prefix caching corruption (NVBug 5922676). Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
The trtllm-gen cubin at commit 134850621 ships zero SwapsMmaAb CgaSmemReduction kernels at headDimV >= 512. This covers both: - MLA decode (headDimQk=576, headDimV=512): the prior cubin shipped 234 MLA CGA kernels; 134850621 strips them due to an over-broad `headDim >= 512 && SwapsMmaAb && CGA -> skip` rule in trtllm-gen's ExportCubin. - Non-MLA headDim=512 (head_dim_512 tests added in flashinfer-ai#2959): same skip predicate drops these too. Promoting to CgaSmemReduction in either regime misses the kernel map, so gate the promotion on `params.mHeadDimV < 512`. Additionally, for tileSizeQ >= 32 the CGA variant would exceed the 232KB Blackwell smem limit, so the guard is correct on the hardware side as well. This guard can be narrowed once trtllm-gen ships a cubin with the tileSizeQ>=32 + headDimPerCtaV>=512 skip predicate (trtllm-gen MR !928). Also moves the kernel-not-found check ahead of the iterator dereference in loadKernel to avoid UB on miss. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
8b003a4 to
365a9ad
Compare
|
/bot run |
📌 Description
The branch has 2 commits:
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
Bug Fixes
Improvements
Chores