Skip to content

Commit de6200d

Browse files
authored
[None][revert] Revert "[TRTLLM-11119][feat] Blackwell SageAttention, Integrate into AttentionOp API (NVIDIA#11718)" (NVIDIA#12679)
Signed-off-by: yunruis <205571022+yunruis@users.noreply.github.com>
1 parent c60615a commit de6200d

File tree

119 files changed

+107
-3038
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

119 files changed

+107
-3038
lines changed

cpp/tensorrt_llm/common/attentionOp.cpp

Lines changed: 6 additions & 149 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
#include "tensorrt_llm/common/envUtils.h"
2121
#include "tensorrt_llm/common/logger.h"
2222
#include "tensorrt_llm/common/memoryUtils.h"
23-
#include "tensorrt_llm/common/sageQuant.h"
2423
#include "tensorrt_llm/kernels/decoderMaskedMultiheadAttention.h"
2524
#include "tensorrt_llm/kernels/flashMLA/flash_mla.h"
2625
#include "tensorrt_llm/kernels/gptKernels.h"
@@ -779,10 +778,6 @@ size_t AttentionOp::getWorkspaceSizeForContext(nvinfer1::DataType type, int32_t
779778
= mNumAttnHeads * dim_k_per_head; // Assuming effective num_kv_heads = head_num for layout
780779
int const total_v_dim_all_heads
781780
= mNumAttnHeads * dim_v_per_head; // Assuming effective num_kv_heads = head_num for layout
782-
bool const useSageAttnSeparateQkv = mEnableContextFMHA && !mIsMLAEnabled && mFmhaDispatcher->isSeparateQAndKvInput()
783-
&& (mSageAttnNumEltsPerBlkQ > 0 || mSageAttnNumEltsPerBlkK > 0 || mSageAttnNumEltsPerBlkV > 0)
784-
&& mFP8ContextFMHA;
785-
786781
// Packed fp8 qkv buffer size for normal fp8 context FMHA
787782
size_t fp8_qkv_buffer_size = mFP8ContextFMHA && mEnableContextFMHA && !mFmhaDispatcher->isSeparateQAndKvInput()
788783
? max_num_tokens * size_t(local_hidden_units_qo + 2 * local_hidden_units_kv)
@@ -808,22 +803,6 @@ size_t AttentionOp::getWorkspaceSizeForContext(nvinfer1::DataType type, int32_t
808803
fp8_v_buf_size = mChunkPrefillBufferBatchSize * max_num_tokens * static_cast<size_t>(total_v_dim_all_heads);
809804
}
810805
}
811-
else if (useSageAttnSeparateQkv)
812-
{
813-
fp8_q_buf_size = max_num_tokens * static_cast<size_t>(local_hidden_units_qo);
814-
fp8_k_buf_size = max_num_tokens * static_cast<size_t>(local_hidden_units_kv);
815-
fp8_v_buf_size = max_num_tokens * static_cast<size_t>(local_hidden_units_kv);
816-
}
817-
818-
int32_t const q_max_n_blk = mSageAttnNumEltsPerBlkQ > 0 ? tc::divUp(input_seq_length, mSageAttnNumEltsPerBlkQ) : 0;
819-
int32_t const k_max_n_blk = mSageAttnNumEltsPerBlkK > 0 ? tc::divUp(input_seq_length, mSageAttnNumEltsPerBlkK) : 0;
820-
size_t const sage_q_sfs_buffer_size
821-
= useSageAttnSeparateQkv ? sizeof(float) * mNumAttnHeads * batch_size * static_cast<size_t>(q_max_n_blk) : 0;
822-
size_t const sage_k_sfs_buffer_size
823-
= useSageAttnSeparateQkv ? sizeof(float) * mNumAttnKVHeads * batch_size * static_cast<size_t>(k_max_n_blk) : 0;
824-
size_t const sage_v_sfs_buffer_size = useSageAttnSeparateQkv
825-
? sizeof(float) * tc::divUp(local_hidden_units_kv, std::max(1, mSageAttnNumEltsPerBlkV))
826-
: 0;
827806

828807
size_t const padding_offset_size = mEnableContextFMHA ? 0 : sizeof(int) * max_num_tokens;
829808
size_t const encoder_padding_offset_size = mEnableContextFMHA ? 0 : sizeof(int) * max_num_tokens;
@@ -839,7 +818,7 @@ size_t AttentionOp::getWorkspaceSizeForContext(nvinfer1::DataType type, int32_t
839818
? 0
840819
: (2 * size * cpMaxPaddedSequenceLength * getHeadSize() * (mNumHeads + 2 * mNumKVHeads) + cu_seqlens_size);
841820

842-
int const NUM_BUFFERS = 26;
821+
int const NUM_BUFFERS = 23;
843822
size_t workspaces[NUM_BUFFERS];
844823
workspaces[0] = CUBLAS_WORKSPACE_SIZE;
845824
workspaces[1] = attention_mask_size;
@@ -863,10 +842,7 @@ size_t AttentionOp::getWorkspaceSizeForContext(nvinfer1::DataType type, int32_t
863842
workspaces[19] = fmha_scheduler_counter;
864843
workspaces[20] = fmha_bmm1_scale_size;
865844
workspaces[21] = fmha_bmm2_scale_size;
866-
workspaces[22] = sage_q_sfs_buffer_size;
867-
workspaces[23] = sage_k_sfs_buffer_size;
868-
workspaces[24] = sage_v_sfs_buffer_size;
869-
workspaces[25] = cpWorkspaceSize;
845+
workspaces[22] = cpWorkspaceSize;
870846
context_workspace_size = tc::calculateTotalWorkspaceSize(workspaces, NUM_BUFFERS);
871847

872848
return context_workspace_size;
@@ -1442,10 +1418,6 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
14421418
= mNumAttnHeads * dim_k_per_head; // Assuming effective num_kv_heads = head_num for layout
14431419
int const total_v_dim_all_heads
14441420
= mNumAttnHeads * dim_v_per_head; // Assuming effective num_kv_heads = head_num for layout
1445-
bool const useSageAttnSeparateQkv = mEnableContextFMHA && !mIsMLAEnabled && mFmhaDispatcher->isSeparateQAndKvInput()
1446-
&& (mSageAttnNumEltsPerBlkQ > 0 || mSageAttnNumEltsPerBlkK > 0 || mSageAttnNumEltsPerBlkV > 0)
1447-
&& mFP8ContextFMHA;
1448-
14491421
// Packed fp8 qkv buffer size for normal fp8 context FMHA
14501422
size_t fp8_qkv_buffer_size = mEnableContextFMHA && mFP8ContextFMHA && !mFmhaDispatcher->isSeparateQAndKvInput()
14511423
? params.num_tokens * (local_hidden_units_qo + 2 * local_hidden_units_kv)
@@ -1471,26 +1443,6 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
14711443
fp8_v_buf_size = params.total_kv_len * static_cast<size_t>(total_v_dim_all_heads);
14721444
}
14731445
}
1474-
else if (useSageAttnSeparateQkv)
1475-
{
1476-
fp8_q_buf_size = params.num_tokens * static_cast<size_t>(local_hidden_units_qo);
1477-
fp8_k_buf_size = params.total_kv_len * static_cast<size_t>(local_hidden_units_kv);
1478-
fp8_v_buf_size = params.total_kv_len * static_cast<size_t>(local_hidden_units_kv);
1479-
}
1480-
int32_t const q_max_n_blk
1481-
= mSageAttnNumEltsPerBlkQ > 0 ? tc::divUp(params.input_seq_length, mSageAttnNumEltsPerBlkQ) : 0;
1482-
int32_t const k_max_n_blk
1483-
= mSageAttnNumEltsPerBlkK > 0 ? tc::divUp(params.input_seq_length, mSageAttnNumEltsPerBlkK) : 0;
1484-
// SageAttention V scales are shared across tokens and partitioned on the flattened hidden dimension (H * D).
1485-
int32_t const v_max_n_blk
1486-
= mSageAttnNumEltsPerBlkV > 0 ? tc::divUp(local_hidden_units_kv, mSageAttnNumEltsPerBlkV) : 0;
1487-
size_t const sage_q_sfs_buffer_size = useSageAttnSeparateQkv
1488-
? sizeof(float) * mNumAttnHeads * params.batch_size * static_cast<size_t>(q_max_n_blk)
1489-
: 0;
1490-
size_t const sage_k_sfs_buffer_size = useSageAttnSeparateQkv
1491-
? sizeof(float) * mNumAttnKVHeads * params.batch_size * static_cast<size_t>(k_max_n_blk)
1492-
: 0;
1493-
size_t const sage_v_sfs_buffer_size = useSageAttnSeparateQkv ? sizeof(float) * static_cast<size_t>(v_max_n_blk) : 0;
14941446
size_t const padding_offset_size
14951447
= mEnableContextFMHA ? 0 : sizeof(int) * params.batch_size * params.input_seq_length;
14961448
size_t const encoder_padding_offset_size
@@ -1545,12 +1497,6 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
15451497
= reinterpret_cast<float*>(nextWorkspacePtr(workspace_byte_ptr, offset, fmha_bmm1_scale_size));
15461498
float* fmha_bmm2_scale_ptr
15471499
= reinterpret_cast<float*>(nextWorkspacePtr(workspace_byte_ptr, offset, fmha_bmm2_scale_size));
1548-
float* sage_q_sfs_buf
1549-
= reinterpret_cast<float*>(nextWorkspacePtr(workspace_byte_ptr, offset, sage_q_sfs_buffer_size));
1550-
float* sage_k_sfs_buf
1551-
= reinterpret_cast<float*>(nextWorkspacePtr(workspace_byte_ptr, offset, sage_k_sfs_buffer_size));
1552-
float* sage_v_sfs_buf
1553-
= reinterpret_cast<float*>(nextWorkspacePtr(workspace_byte_ptr, offset, sage_v_sfs_buffer_size));
15541500

15551501
T* gatherInBuffer = reinterpret_cast<T*>(nextWorkspacePtr(workspace_byte_ptr, offset, cpWorkspaceSize));
15561502
T* gatherOutBuffer = gatherInBuffer + cpMaxPadedSequenceLength * getHeadSize() * (mNumHeads + 2 * mNumKVHeads);
@@ -1792,69 +1738,7 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
17921738
}
17931739
else
17941740
{
1795-
if (useSageAttnSeparateQkv)
1796-
{
1797-
TLLM_CHECK_WITH_INFO(mFP8ContextFMHA, "SageAttention kernel runs under mFP8ContextFMHA option.");
1798-
TLLM_CHECK_WITH_INFO(
1799-
mFmhaDispatcher->isSupported(), "SageAttention has no unfused fallback implemented.");
1800-
TLLM_CHECK_WITH_INFO(
1801-
mSageAttnNumEltsPerBlkQ > 0 && mSageAttnNumEltsPerBlkK > 0 && mSageAttnNumEltsPerBlkV == 1,
1802-
"SageQuant requires positive block sizes for Q and K while the block size for V must be 1.");
1803-
TLLM_CHECK_WITH_INFO(!params.kv_scale_quant_orig,
1804-
"SageAttention disregards the configured params.kv_scale_quant_orig, invalidating the result.");
1805-
check_cuda_error(cudaMemsetAsync(sage_v_sfs_buf, 0, sage_v_sfs_buffer_size, stream));
1806-
1807-
tc::SageQuantParams qkParams{};
1808-
qkParams.headDim = getHeadSize();
1809-
qkParams.inputType = std::is_same_v<T, __nv_bfloat16> ? DATA_TYPE_BF16 : DATA_TYPE_FP16;
1810-
qkParams.quantType = mSageAttnQkInt8 ? DATA_TYPE_INT8 : DATA_TYPE_E4M3;
1811-
qkParams.vStage = 0;
1812-
qkParams.sumSeqLensV = params.total_kv_len;
1813-
qkParams.numHeadsV = mNumAttnKVHeads;
1814-
qkParams.ptrV = params.v_ptr;
1815-
qkParams.ptrVQuant = fp8_v_buf;
1816-
qkParams.ptrVScale = sage_v_sfs_buf;
1817-
qkParams.smCount = mMultiProcessorCount;
1818-
qkParams.stream = stream;
1819-
1820-
// Quantize into Fp8Q, SfsQ, SfsV
1821-
if (mSageAttnNumEltsPerBlkQ > 0)
1822-
{
1823-
qkParams.sumSeqLensQk = params.num_tokens;
1824-
qkParams.numHeads = mNumAttnHeads;
1825-
qkParams.tokenBlockSize = mSageAttnNumEltsPerBlkQ;
1826-
qkParams.ptrQk = attention_input;
1827-
qkParams.ptrQkQuant = fp8_q_buf;
1828-
qkParams.ptrQkScale = sage_q_sfs_buf;
1829-
qkParams.vStage = 1;
1830-
tc::invokeSageQuant(qkParams);
1831-
}
1832-
else
1833-
{
1834-
invokeCudaCast(fp8_q_buf, attention_input, params.num_tokens * local_hidden_units_qo, stream);
1835-
}
1836-
1837-
// Quantize into Fp8K, SfsK, Fp8V
1838-
if (mSageAttnNumEltsPerBlkK > 0)
1839-
{
1840-
qkParams.sumSeqLensQk = params.total_kv_len;
1841-
qkParams.numHeads = mNumAttnKVHeads;
1842-
qkParams.tokenBlockSize = mSageAttnNumEltsPerBlkK;
1843-
qkParams.ptrQk = params.k_ptr;
1844-
qkParams.ptrQkQuant = fp8_k_buf;
1845-
qkParams.ptrQkScale = sage_k_sfs_buf;
1846-
qkParams.vStage = 2;
1847-
tc::invokeSageQuant(qkParams);
1848-
}
1849-
else
1850-
{
1851-
invokeCudaCast(fp8_k_buf, params.k_ptr, params.total_kv_len * local_hidden_units_kv, stream);
1852-
}
1853-
}
1854-
else
1855-
{
1856-
invokeQKVPreprocessing(preprocessingParams, stream);
1857-
}
1741+
invokeQKVPreprocessing(preprocessingParams, stream);
18581742
}
18591743
sync_check_cuda_error(stream);
18601744
{
@@ -1934,23 +1818,9 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
19341818
}
19351819
else
19361820
{
1937-
if (useSageAttnSeparateQkv)
1938-
{
1939-
fmhaParams.qkvPtr = nullptr;
1940-
fmhaParams.qPtr = reinterpret_cast<void const*>(fp8_q_buf);
1941-
fmhaParams.kPtr = reinterpret_cast<void const*>(fp8_k_buf);
1942-
fmhaParams.vPtr = reinterpret_cast<void const*>(fp8_v_buf);
1943-
1944-
fmhaParams.qScalePtr = sage_q_sfs_buf;
1945-
fmhaParams.kScalePtr = sage_k_sfs_buf;
1946-
fmhaParams.vScalePtr = sage_v_sfs_buf;
1947-
}
1948-
else
1949-
{
1950-
fmhaParams.qkvPtr = mFP8ContextFMHA ? reinterpret_cast<void const*>(fp8_qkv_buffer)
1951-
: reinterpret_cast<void const*>(attention_input);
1952-
fmhaParams.qPtr = reinterpret_cast<void const*>(q_buf_2_);
1953-
}
1821+
fmhaParams.qkvPtr = mFP8ContextFMHA ? reinterpret_cast<void const*>(fp8_qkv_buffer)
1822+
: reinterpret_cast<void const*>(attention_input);
1823+
fmhaParams.qPtr = reinterpret_cast<void const*>(q_buf_2_);
19541824
}
19551825
// TODO: add contiguous kv buffer (cross-attention).
19561826
fmhaParams.kvPtr = nullptr;
@@ -2877,22 +2747,13 @@ int AttentionOp::initialize() noexcept
28772747
fmhaParams.attentionInputLayout = (mPagedKVCache && mPagedContextFMHA) ? AttentionInputLayout::Q_PAGED_KV
28782748
: AttentionInputLayout::PACKED_QKV;
28792749
}
2880-
if (!mIsMLAEnabled && mFP8ContextFMHA
2881-
&& (mSageAttnNumEltsPerBlkQ > 0 || mSageAttnNumEltsPerBlkK > 0 || mSageAttnNumEltsPerBlkV > 0))
2882-
{
2883-
fmhaParams.attentionInputLayout = AttentionInputLayout::SEPARATE_Q_K_V;
2884-
}
28852750
fmhaParams.isSPadded = !mRemovePadding;
28862751
fmhaParams.numQHeads = mNumAttnHeads;
28872752
fmhaParams.numKvHeads = mNumAttnKVHeads;
28882753
fmhaParams.numTokensPerBlock = mTokensPerBlock;
28892754
fmhaParams.headSize = mHeadSize;
28902755
fmhaParams.headSizeV = mHeadSize;
28912756
fmhaParams.qScaling = mQScaling;
2892-
fmhaParams.sageBlockSizeQ = mSageAttnNumEltsPerBlkQ;
2893-
fmhaParams.sageBlockSizeK = mSageAttnNumEltsPerBlkK;
2894-
fmhaParams.sageBlockSizeV = mSageAttnNumEltsPerBlkV;
2895-
fmhaParams.dataTypeQkReinterpret = mSageAttnQkInt8 ? DATA_TYPE_INT8 : DATA_TYPE_E4M3;
28962757

28972758
// mFmhaDispatcher is not used for generation MLA, but we still need to modify these values to avoid selecting
28982759
// the wrong kernel, no matter mIsGenerationMLA is true or false
@@ -3199,10 +3060,6 @@ std::string AttentionOp::toString() const
31993060
ss << "mPosShiftEnabled: " << std::boolalpha << mPosShiftEnabled << std::endl;
32003061
ss << "mPagedContextFMHA: " << std::boolalpha << mPagedContextFMHA << std::endl;
32013062
ss << "mFP8ContextFMHA: " << std::boolalpha << mFP8ContextFMHA << std::endl;
3202-
ss << "mSageAttnNumEltsPerBlkQ: " << mSageAttnNumEltsPerBlkQ << std::endl;
3203-
ss << "mSageAttnNumEltsPerBlkK: " << mSageAttnNumEltsPerBlkK << std::endl;
3204-
ss << "mSageAttnNumEltsPerBlkV: " << mSageAttnNumEltsPerBlkV << std::endl;
3205-
ss << "mSageAttnQkInt8: " << std::boolalpha << mSageAttnQkInt8 << std::endl;
32063063
ss << "mFP8AttenOutput: " << std::boolalpha << mFP8AttenOutput << std::endl;
32073064
ss << "mFP8ContextMLA: " << std::boolalpha << mFP8ContextMLA << std::endl;
32083065
ss << "mDenseContextFMHA: " << std::boolalpha << mDenseContextFMHA << std::endl;

cpp/tensorrt_llm/common/attentionOp.h

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -119,10 +119,6 @@ class AttentionOp
119119
// this is a buffer of size [num_tokens, num_heads_q] with each element
120120
// representing the max and LSE/denominator of the softmax values
121121
float2* softmax_stats = nullptr;
122-
// Optional SageAttention scaling factors.
123-
float const* sage_attn_sfs_q = nullptr;
124-
float const* sage_attn_sfs_k = nullptr;
125-
float const* sage_attn_sfs_v = nullptr;
126122
};
127123

128124
template <typename T>
@@ -523,12 +519,6 @@ class AttentionOp
523519
// Skip softmax threshold scale factor.
524520
float mSkipSoftmaxThresholdScaleFactorPrefill = 0;
525521
float mSkipSoftmaxThresholdScaleFactorDecode = 0;
526-
// Optional SageAttention block sizes.
527-
// Currently, these are only consumed by the TllmGen backend path.
528-
int mSageAttnNumEltsPerBlkQ = 0;
529-
int mSageAttnNumEltsPerBlkK = 0;
530-
int mSageAttnNumEltsPerBlkV = 0;
531-
bool mSageAttnQkInt8 = false;
532522
#ifdef SKIP_SOFTMAX_STAT
533523
uint32_t* mSkipSoftmaxTotalBlocks;
534524
uint32_t* mSkipSoftmaxSkippedBlocks;
@@ -551,8 +541,7 @@ class AttentionOp
551541
mAttnTpSize, mAttnTpRank, mAttnCpSize, mAttnCpRank, mUlyssesMQABroadcast, mEnableContextFMHA,
552542
mFMHAForceFP32Acc, mMultiBlockMode, mEnableXQA, mUseKVCache, mSkipAttn, mFuseFp4Quant,
553543
mNbMultiBlockSemaphores, mAttentionChunkSize.value_or(-1), mSkipSoftmaxThresholdScaleFactorPrefill,
554-
mSkipSoftmaxThresholdScaleFactorDecode, mSageAttnNumEltsPerBlkQ, mSageAttnNumEltsPerBlkK,
555-
mSageAttnNumEltsPerBlkV, mSageAttnQkInt8);
544+
mSkipSoftmaxThresholdScaleFactorDecode);
556545
};
557546

558547
private:

0 commit comments

Comments
 (0)