Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
The diff you're trying to view is too large. We only load the first 3000 changed files.
16 changes: 9 additions & 7 deletions cpp/tensorrt_llm/common/attentionOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ bool AttentionOp::convertMMHAParamsToXQAParams(tensorrt_llm::kernels::XQAParams&
xqaParams.start_token_idx_sf = generationsParams.start_token_idx_sf;
// Parameters for sparse attention
xqaParams.sparse_params = mRuntimeSparseAttentionParams;
xqaParams.use_sparse_attention = useTllmGenSparseAttention();
xqaParams.use_sparse_attention_gen_paged = useTllmGenSparseAttentionPaged();
// Skip softmax threshold.
xqaParams.skip_softmax_threshold_scale_factor = mSkipSoftmaxThresholdScaleFactorDecode;
#ifdef SKIP_SOFTMAX_STAT
Expand Down Expand Up @@ -948,7 +948,7 @@ size_t AttentionOp::getWorkspaceSizeForGeneration(nvinfer1::DataType type, int32
size_t const cu_kv_seqlens_size = sizeof(int) * (batch_beam + 1);
size_t const rotary_inv_freq_size = sizeof(float) * batch_beam * mRotaryEmbeddingDim / 2;
// Two workspaces for sparse attention. One for the sequence lengths, and one for kv block offsets.
size_t const sparse_attn_cache_size = useTllmGenSparseAttention()
size_t const sparse_attn_cache_size = useTllmGenSparseAttentionPaged()
? sizeof(int) * (batch_beam + batch_beam * 2 * max_blocks_per_sequence) * mNumKVHeads
: 0;
xqa_workspaces[0] = cu_seqlens_size;
Expand Down Expand Up @@ -1120,14 +1120,14 @@ int AttentionOp::mlaGeneration(
= reinterpret_cast<float const*>(params.bmm1_scale) + bmm1_scale_offset;
}

// Set the following parameters if sparseMLA is used.
// Set the following parameters if sparseAttention is used.
if (useSparseMLA())
{
tllmRunnerParams.mSparseMla = true;
tllmRunnerParams.mSparseMlaTopK = mRuntimeSparseAttentionParams.sparse_mla_topk;
tllmRunnerParams.mSparseAttention = SparseType::StaticTokenSparse;
tllmRunnerParams.mSparseTopK = mRuntimeSparseAttentionParams.sparse_topk;
tllmRunnerParams.kvPageIdxPtr = reinterpret_cast<KVCacheIndex::UnderlyingType const*>(
mRuntimeSparseAttentionParams.sparse_attn_indices);
tllmRunnerParams.kvPtr = mRuntimeSparseAttentionParams.sparse_mla_kv_cache_pool;
tllmRunnerParams.kvPtr = mRuntimeSparseAttentionParams.sparse_kv_cache_pool;
}

mTllmGenFMHARunner->run(tllmRunnerParams);
Expand Down Expand Up @@ -1851,7 +1851,7 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
fmhaParams.softmaxStatsPtr = params.softmax_stats;

// Sparse attention parameters
if (useSparseMLA())
if (useTllmGenSparseAttention())
{
fmhaParams.sparse_params = mRuntimeSparseAttentionParams;
}
Expand Down Expand Up @@ -2790,6 +2790,7 @@ int AttentionOp::initialize() noexcept
fmhaParams.hasAlibi = isALiBi();
fmhaParams.scaleAlibi = isAliBiWithScale();
fmhaParams.useSparseMLA = useSparseMLA();
fmhaParams.useTllmGenSparseAttention = useTllmGenSparseAttention();

// Load kernels from the pre-compiled cubins.
mFmhaDispatcher.reset(new FmhaDispatcher(fmhaParams));
Expand Down Expand Up @@ -2960,6 +2961,7 @@ int AttentionOp::initialize() noexcept
fixedParams.isPagedKv = mPagedKVCache;
fixedParams.isSpecDecoding = mIsSpecDecodingEnabled;
fixedParams.hasAlibi = isALiBi();
fixedParams.useTllmGenSparseAttention = useTllmGenSparseAttention();

mXqaDispatcher.reset(new XqaDispatcher(fixedParams));

Expand Down
18 changes: 12 additions & 6 deletions cpp/tensorrt_llm/common/attentionOp.h
Original file line number Diff line number Diff line change
Expand Up @@ -389,16 +389,21 @@ class AttentionOp
return mUseSparseAttention && mPagedKVCache && mEnableXQA;
}

[[nodiscard]] bool useTllmGenSparseAttention() const
[[nodiscard]] bool useTllmGenSparseAttentionPaged() const
{
return mUseTllmGenSparseAttention && useSparseAttention();
return mUseTllmGenSparseAttentionPaged && useSparseAttention();
}

[[nodiscard]] bool useSparseMLA() const
{
return mUseSparseAttention && mUseTllmGen && mIsMLAEnabled;
}

[[nodiscard]] bool useTllmGenSparseAttention() const
{
return useSparseMLA() || (mUseSparseAttention && mUseTllmGen && mUseTllmGenSparseAttention);
}

[[nodiscard]] int smVersion() const
{
return mSM;
Expand Down Expand Up @@ -479,6 +484,7 @@ class AttentionOp
bool mIsGenerationMLA = false;
bool mUseGenFlashMLA = false;
bool mUseSparseAttention = false;
bool mUseTllmGenSparseAttentionPaged = false;
bool mUseTllmGenSparseAttention = false;
tensorrt_llm::kernels::MlaMetaParams mMLAParams;
int mCpSize = 1;
Expand Down Expand Up @@ -536,10 +542,10 @@ class AttentionOp
mPosShiftEnabled, mPagedContextFMHA, mFP8ContextFMHA, mFP8AttenOutput, mFP8ContextMLA, mFP8GenerationMLA,
mChunkPrefillBufferBatchSize, mDenseContextFMHA, mHasFullAttentionMask, mIsSpecDecodingEnabled,
mUseSpecDecoding, mIsSpecDecTree, mSpecDecodingIsGenerationLengthVariable, mSpecDecodingMaxGenerationLength,
mIsMLAEnabled, mIsGenerationMLA, mUseGenFlashMLA, mUseSparseAttention, mUseTllmGenSparseAttention,
mMLAParams.data(), mCpSize, mCpRank, mCpGroup, mNumAttnHeads, mNumAttnKVHeads, mNumKVHeadsOrigin,
mAttnTpSize, mAttnTpRank, mAttnCpSize, mAttnCpRank, mUlyssesMQABroadcast, mEnableContextFMHA,
mFMHAForceFP32Acc, mMultiBlockMode, mEnableXQA, mUseKVCache, mSkipAttn, mFuseFp4Quant,
mIsMLAEnabled, mIsGenerationMLA, mUseGenFlashMLA, mUseSparseAttention, mUseTllmGenSparseAttentionPaged,
mUseTllmGenSparseAttention, mMLAParams.data(), mCpSize, mCpRank, mCpGroup, mNumAttnHeads, mNumAttnKVHeads,
mNumKVHeadsOrigin, mAttnTpSize, mAttnTpRank, mAttnCpSize, mAttnCpRank, mUlyssesMQABroadcast,
mEnableContextFMHA, mFMHAForceFP32Acc, mMultiBlockMode, mEnableXQA, mUseKVCache, mSkipAttn, mFuseFp4Quant,
mNbMultiBlockSemaphores, mAttentionChunkSize.value_or(-1), mSkipSoftmaxThresholdScaleFactorPrefill,
mSkipSoftmaxThresholdScaleFactorDecode);
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,8 @@ struct MHARunnerFixedParams
int sageBlockSizeV = 0;
// Use sparse MLA ?
bool useSparseMLA = false;
// Use sparse attention in trtllm-gen ?
bool useTllmGenSparseAttention = false;

// Convert to string for debug.
std::string convertToStrOutput()
Expand Down Expand Up @@ -195,6 +197,8 @@ struct MHARunnerFixedParams
output += ", sageBlockSizeQ = " + std::to_string(sageBlockSizeQ);
output += ", sageBlockSizeK = " + std::to_string(sageBlockSizeK);
output += ", sageBlockSizeV = " + std::to_string(sageBlockSizeV);
output += ", useSparseMLA = " + std::string(useSparseMLA ? "true" : "false");
output += ", useTllmGenSparseAttention = " + std::string(useTllmGenSparseAttention ? "true" : "false");

return output;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ void buildXQALaunchParams(XQALaunchParam<KVCacheBuffer>& launchParams, void*& in
launchParams.bmm2_scale_ptr = reinterpret_cast<float*>(workspace);
workspace = tensorrt_llm::common::nextWorkspacePtrWithAlignment(workspace, bmm2_scale_size);
// Used for block sparse attention
if (params.use_sparse_attention)
if (params.use_sparse_attention_gen_paged)
{
launchParams.sparse_kv_block_offsets = reinterpret_cast<void*>(workspace);
workspace = tensorrt_llm::common::nextWorkspacePtrWithAlignment(workspace, kv_block_offsets_size);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ struct XQAParams

// sparse attention parameters
SparseAttentionParams sparse_params;
bool use_sparse_attention = false;
bool use_sparse_attention_gen_paged = false;

// Skip softmax threshold.
float skip_softmax_threshold_scale_factor = 0;
Expand Down Expand Up @@ -210,7 +210,7 @@ struct XQAParams
<< "fp8_out_scale :" << fp8_out_scale << std ::endl
<< "encoder_input_lengths: " << encoder_input_lengths << std::endl
<< "sparse_params: " << sparse_params.toString() << std::endl
<< "use_sparse_attention :" << (use_sparse_attention ? "true" : "false") << std ::endl
<< "use_sparse_attention_gen_paged :" << (use_sparse_attention_gen_paged ? "true" : "false") << std ::endl
<< "skip_softmax_threshold_scale_factor :" << skip_softmax_threshold_scale_factor << std ::endl
#ifdef SKIP_SOFTMAX_STAT
<< "skip_softmax_total_blocks :" << skip_softmax_total_blocks << std ::endl
Expand Down
18 changes: 10 additions & 8 deletions cpp/tensorrt_llm/kernels/fmhaDispatcher.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,10 @@ bool FmhaDispatcher::isSupported()
// the kernel is supported.
tllmRunnerParams.mChunkedAttentionSize = INT_MAX;
tllmRunnerParams.mAttentionWindowSize = INT_MAX;
// Set the kernel type and mask type if sparseMLA is used.
if (mFixedParams.useSparseMLA)
// Sparse context attention uses a generation-style kernel with per-token sparse indices.
if (mFixedParams.useTllmGenSparseAttention)
{
tllmRunnerParams.mSparseMla = true;
tllmRunnerParams.mSparseAttention = SparseType::StaticTokenSparse;
tllmRunnerParams.mKernelType = FmhaKernelType::Generation;
tllmRunnerParams.mMaskType = TrtllmGenAttentionMaskType::Causal;
}
Expand Down Expand Up @@ -243,16 +243,18 @@ void FmhaDispatcher::run(MHARunnerParams runnerParams)
// For skip softmax
tllmRunnerParams.mSkipSoftmaxThresholdScaleFactor = runnerParams.skipSoftmaxThresholdScaleFactor;
tllmRunnerParams.stream = runnerParams.stream;
// Set the sparse attention parameters if sparseMLA is used.
if (mFixedParams.useSparseMLA)
// Sparse context attention: reuse the generation-style kernel with per-token sparse indices.
// Same approach as sparse MLA: keep original batch structure, tileSizeQ only groups heads.
// The kernel iterates over tokens via numCtasPerSeqQ when maxSeqLenQ > 1.
if (mFixedParams.useTllmGenSparseAttention)
{
tllmRunnerParams.mSparseMla = true;
tllmRunnerParams.mSparseMlaTopK = runnerParams.sparse_params.sparse_mla_topk;
tllmRunnerParams.mSparseAttention = SparseType::StaticTokenSparse;
tllmRunnerParams.mSparseTopK = runnerParams.sparse_params.sparse_topk;
tllmRunnerParams.mKernelType = FmhaKernelType::Generation;
tllmRunnerParams.mMaskType = TrtllmGenAttentionMaskType::Causal;
tllmRunnerParams.kvPageIdxPtr
= reinterpret_cast<int const*>(runnerParams.sparse_params.sparse_attn_indices);
tllmRunnerParams.kvPtr = runnerParams.sparse_params.sparse_mla_kv_cache_pool;
tllmRunnerParams.kvPtr = runnerParams.sparse_params.sparse_kv_cache_pool;
}

mTllmGenFMHARunner->run(tllmRunnerParams);
Expand Down
16 changes: 8 additions & 8 deletions cpp/tensorrt_llm/kernels/sparseAttentionKernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,12 @@ namespace kernels

struct SparseAttentionParams
{
int32_t* sparse_kv_indices{nullptr}; // [num_kv_heads, num_sparse_kv_indices]
int32_t* sparse_attn_indices{nullptr}; // [num_kv_heads, num_sparse_attn_indices]
int32_t* sparse_kv_offsets{nullptr}; // [num_contexts + 1]
int32_t* sparse_attn_offsets{nullptr}; // [num_generations + 1]
int32_t sparse_mla_topk{0}; // for DSA attention
void* sparse_mla_kv_cache_pool{nullptr}; // for DSA attention
int32_t* sparse_kv_indices{nullptr}; // [num_kv_heads, num_sparse_kv_indices]
int32_t* sparse_attn_indices{nullptr}; // [num_kv_heads, num_sparse_attn_indices]
int32_t* sparse_kv_offsets{nullptr}; // [num_contexts + 1]
int32_t* sparse_attn_offsets{nullptr}; // [num_generations + 1]
int32_t sparse_topk{0};
void* sparse_kv_cache_pool{nullptr};

int32_t sparse_attn_indices_block_size{1};
int32_t sparse_attn_indices_stride{0};
Expand All @@ -46,8 +46,8 @@ struct SparseAttentionParams
<< "sparse_attn_indices: " << this->sparse_attn_indices << std::endl
<< "sparse_kv_offsets: " << this->sparse_kv_offsets << std::endl
<< "sparse_attn_offsets: " << this->sparse_attn_offsets << std::endl
<< "sparse_mla_topk: " << this->sparse_mla_topk << std::endl
<< "sparse_mla_kv_cache_pool: " << this->sparse_mla_kv_cache_pool << std::endl
<< "sparse_topk: " << this->sparse_topk << std::endl
<< "sparse_kv_cache_pool: " << this->sparse_kv_cache_pool << std::endl
<< "sparse_attn_indices_block_size: " << this->sparse_attn_indices_block_size << std::endl
<< "sparse_attn_indices_stride: " << this->sparse_attn_indices_stride << std::endl;
return ss.str();
Expand Down
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Loading
Loading