Skip to content

Commit 6c56ac8

Browse files
PerkzZhengclaude
andcommitted
Sync trtllm FMHA: mSparseMla -> mSparseAttn for new cubin struct
Minimal header changes to match the new trtllm-gen FMHA cubin MetaInfo struct layout: - TllmGenFmhaKernelMetaInfo: renamed mSparseMla (bool) -> mSparseAttn (int). Callers convert to bool via `!= 0`. - KernelParams (GPU-side struct): renamed mSparseMlaTopK -> mSparseAttnTopK and moved immediately after mSkipSoftmaxThresholdScaleFactor to match the layout expected by the new kernels. The K/V dtype split (mDataTypeKv -> mDataTypeK/V) and SageAttention block size fields present in the new struct are layout-compatible but not used, so no code changes are needed for those -- existing references to mDataTypeKv still compile since the cubin-supplied struct keeps that field alongside the new mDataTypeK/V. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
1 parent ffa04bb commit 6c56ac8

3 files changed

Lines changed: 7 additions & 6 deletions

File tree

csrc/fmhaReduction.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ __global__ void __launch_bounds__(NumThreadsPerCta, 2)
8181
seqLenKv = seqLenKv - ((params.mMaxSeqLenQ - 1) - ctaIdxQ);
8282
// Consider sparseMlaTopK.
8383
if (sparseMla) {
84-
seqLenKv = min(seqLenKv, params.mSparseMlaTopK);
84+
seqLenKv = min(seqLenKv, params.mSparseAttnTopK);
8585
}
8686
// The actual number of CtasKv (TileSizeKv is always 128 for now).
8787
int32_t numCtasKv{min((seqLenKv + 127) / 128, params.mMaxNumCtasKv)};
@@ -361,7 +361,7 @@ void runFmhaReduction(TllmGenFmhaKernelMetaInfo const& kernelMeta, KernelParams
361361
}
362362

363363
// Launch the kernel.
364-
cudaLaunchKernelEx(&config, kernel, params, kernelMeta.mSparseMla, numCtasForReduction,
364+
cudaLaunchKernelEx(&config, kernel, params, kernelMeta.mSparseAttn != 0, numCtasForReduction,
365365
numCtasForAllHeads, numHeadDimCtasV);
366366
cudaError_t err = cudaGetLastError();
367367
FLASHINFER_CHECK(err == cudaSuccess, "Failed to launch kernel: ", cudaGetErrorString(err));

include/flashinfer/trtllm/fmha/fmhaKernels.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ class TllmGenFmhaKernel {
191191
kernelMeta.mTileScheduler, kernelMeta.mMultiCtasKvMode,
192192
kernelMeta.mHeadDimPerCtaV, kernelMeta.mHeadDimQk, kernelMeta.mHeadDimV,
193193
kernelMeta.mTileSizeQ, kernelMeta.mTileSizeKv, kernelMeta.mNumTokensPerPage,
194-
kernelMeta.mReuseSmemKForV, kernelMeta.m2CtaMma, kernelMeta.mSparseMla,
194+
kernelMeta.mReuseSmemKForV, kernelMeta.m2CtaMma, kernelMeta.mSparseAttn != 0,
195195
kernelMeta.mSkipsSoftmaxWhenPossible);
196196
}
197197

include/flashinfer/trtllm/fmha/kernelParams.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -194,13 +194,14 @@ struct KernelParams {
194194
float mScaleSfO;
195195
// Threshold to decide whether warp skips softmax ops
196196
float mSkipSoftmaxThresholdScaleFactor;
197+
// The sparse attention topK value. Must immediately follow mSkipSoftmaxThresholdScaleFactor
198+
// to match the GPU struct layout expected by trtllm-gen kernels.
199+
int32_t mSparseAttnTopK;
197200
// The start token index in SF tensor. Used for FP4 SF offset calculation in generation phase
198201
// kernel when inflight batching is enabled in TRT-LLM.
199202
int32_t mStartTokenIdxSfO;
200203
// The sum of sequence lengths for Q and K/V.
201204
int32_t mSumOfSeqLensQ, mSumOfSeqLensKv;
202-
// The sparseMla topK value.
203-
int32_t mSparseMlaTopK;
204205
// The flag to use block sparse attention.
205206
bool mUseBlockSparseAttention;
206207
// Whether the indices for K & V pages are shared as unified index.
@@ -879,7 +880,7 @@ struct KernelParams {
879880
// indices.
880881
FLASHINFER_CHECK(!options.mSparseMla || (options.mSparseMlaTopK % 4) == 0,
881882
"SparseMlaTopK must be a multiple of 4");
882-
params.mSparseMlaTopK = options.mSparseMlaTopK;
883+
params.mSparseAttnTopK = options.mSparseMlaTopK;
883884
// TODO: Integrate trtllm block-sparse attention kernels when needed.
884885
params.mUseBlockSparseAttention = false;
885886
// Whether the indices for K & V pages are shared as unified index (vLLM/FlashInfer).

0 commit comments

Comments
 (0)