2020#include " tensorrt_llm/common/envUtils.h"
2121#include " tensorrt_llm/common/logger.h"
2222#include " tensorrt_llm/common/memoryUtils.h"
23- #include " tensorrt_llm/common/sageQuant.h"
2423#include " tensorrt_llm/kernels/decoderMaskedMultiheadAttention.h"
2524#include " tensorrt_llm/kernels/flashMLA/flash_mla.h"
2625#include " tensorrt_llm/kernels/gptKernels.h"
@@ -779,10 +778,6 @@ size_t AttentionOp::getWorkspaceSizeForContext(nvinfer1::DataType type, int32_t
779778 = mNumAttnHeads * dim_k_per_head; // Assuming effective num_kv_heads = head_num for layout
780779 int const total_v_dim_all_heads
781780 = mNumAttnHeads * dim_v_per_head; // Assuming effective num_kv_heads = head_num for layout
782- bool const useSageAttnSeparateQkv = mEnableContextFMHA && !mIsMLAEnabled && mFmhaDispatcher ->isSeparateQAndKvInput ()
783- && (mSageAttnNumEltsPerBlkQ > 0 || mSageAttnNumEltsPerBlkK > 0 || mSageAttnNumEltsPerBlkV > 0 )
784- && mFP8ContextFMHA ;
785-
786781 // Packed fp8 qkv buffer size for normal fp8 context FMHA
787782 size_t fp8_qkv_buffer_size = mFP8ContextFMHA && mEnableContextFMHA && !mFmhaDispatcher ->isSeparateQAndKvInput ()
788783 ? max_num_tokens * size_t (local_hidden_units_qo + 2 * local_hidden_units_kv)
@@ -808,22 +803,6 @@ size_t AttentionOp::getWorkspaceSizeForContext(nvinfer1::DataType type, int32_t
808803 fp8_v_buf_size = mChunkPrefillBufferBatchSize * max_num_tokens * static_cast <size_t >(total_v_dim_all_heads);
809804 }
810805 }
811- else if (useSageAttnSeparateQkv)
812- {
813- fp8_q_buf_size = max_num_tokens * static_cast <size_t >(local_hidden_units_qo);
814- fp8_k_buf_size = max_num_tokens * static_cast <size_t >(local_hidden_units_kv);
815- fp8_v_buf_size = max_num_tokens * static_cast <size_t >(local_hidden_units_kv);
816- }
817-
818- int32_t const q_max_n_blk = mSageAttnNumEltsPerBlkQ > 0 ? tc::divUp (input_seq_length, mSageAttnNumEltsPerBlkQ ) : 0 ;
819- int32_t const k_max_n_blk = mSageAttnNumEltsPerBlkK > 0 ? tc::divUp (input_seq_length, mSageAttnNumEltsPerBlkK ) : 0 ;
820- size_t const sage_q_sfs_buffer_size
821- = useSageAttnSeparateQkv ? sizeof (float ) * mNumAttnHeads * batch_size * static_cast <size_t >(q_max_n_blk) : 0 ;
822- size_t const sage_k_sfs_buffer_size
823- = useSageAttnSeparateQkv ? sizeof (float ) * mNumAttnKVHeads * batch_size * static_cast <size_t >(k_max_n_blk) : 0 ;
824- size_t const sage_v_sfs_buffer_size = useSageAttnSeparateQkv
825- ? sizeof (float ) * tc::divUp (local_hidden_units_kv, std::max (1 , mSageAttnNumEltsPerBlkV ))
826- : 0 ;
827806
828807 size_t const padding_offset_size = mEnableContextFMHA ? 0 : sizeof (int ) * max_num_tokens;
829808 size_t const encoder_padding_offset_size = mEnableContextFMHA ? 0 : sizeof (int ) * max_num_tokens;
@@ -839,7 +818,7 @@ size_t AttentionOp::getWorkspaceSizeForContext(nvinfer1::DataType type, int32_t
839818 ? 0
840819 : (2 * size * cpMaxPaddedSequenceLength * getHeadSize () * (mNumHeads + 2 * mNumKVHeads ) + cu_seqlens_size);
841820
842- int const NUM_BUFFERS = 26 ;
821+ int const NUM_BUFFERS = 23 ;
843822 size_t workspaces[NUM_BUFFERS];
844823 workspaces[0 ] = CUBLAS_WORKSPACE_SIZE;
845824 workspaces[1 ] = attention_mask_size;
@@ -863,10 +842,7 @@ size_t AttentionOp::getWorkspaceSizeForContext(nvinfer1::DataType type, int32_t
863842 workspaces[19 ] = fmha_scheduler_counter;
864843 workspaces[20 ] = fmha_bmm1_scale_size;
865844 workspaces[21 ] = fmha_bmm2_scale_size;
866- workspaces[22 ] = sage_q_sfs_buffer_size;
867- workspaces[23 ] = sage_k_sfs_buffer_size;
868- workspaces[24 ] = sage_v_sfs_buffer_size;
869- workspaces[25 ] = cpWorkspaceSize;
845+ workspaces[22 ] = cpWorkspaceSize;
870846 context_workspace_size = tc::calculateTotalWorkspaceSize (workspaces, NUM_BUFFERS);
871847
872848 return context_workspace_size;
@@ -1442,10 +1418,6 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
14421418 = mNumAttnHeads * dim_k_per_head; // Assuming effective num_kv_heads = head_num for layout
14431419 int const total_v_dim_all_heads
14441420 = mNumAttnHeads * dim_v_per_head; // Assuming effective num_kv_heads = head_num for layout
1445- bool const useSageAttnSeparateQkv = mEnableContextFMHA && !mIsMLAEnabled && mFmhaDispatcher ->isSeparateQAndKvInput ()
1446- && (mSageAttnNumEltsPerBlkQ > 0 || mSageAttnNumEltsPerBlkK > 0 || mSageAttnNumEltsPerBlkV > 0 )
1447- && mFP8ContextFMHA ;
1448-
14491421 // Packed fp8 qkv buffer size for normal fp8 context FMHA
14501422 size_t fp8_qkv_buffer_size = mEnableContextFMHA && mFP8ContextFMHA && !mFmhaDispatcher ->isSeparateQAndKvInput ()
14511423 ? params.num_tokens * (local_hidden_units_qo + 2 * local_hidden_units_kv)
@@ -1471,26 +1443,6 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
14711443 fp8_v_buf_size = params.total_kv_len * static_cast <size_t >(total_v_dim_all_heads);
14721444 }
14731445 }
1474- else if (useSageAttnSeparateQkv)
1475- {
1476- fp8_q_buf_size = params.num_tokens * static_cast <size_t >(local_hidden_units_qo);
1477- fp8_k_buf_size = params.total_kv_len * static_cast <size_t >(local_hidden_units_kv);
1478- fp8_v_buf_size = params.total_kv_len * static_cast <size_t >(local_hidden_units_kv);
1479- }
1480- int32_t const q_max_n_blk
1481- = mSageAttnNumEltsPerBlkQ > 0 ? tc::divUp (params.input_seq_length , mSageAttnNumEltsPerBlkQ ) : 0 ;
1482- int32_t const k_max_n_blk
1483- = mSageAttnNumEltsPerBlkK > 0 ? tc::divUp (params.input_seq_length , mSageAttnNumEltsPerBlkK ) : 0 ;
1484- // SageAttention V scales are shared across tokens and partitioned on the flattened hidden dimension (H * D).
1485- int32_t const v_max_n_blk
1486- = mSageAttnNumEltsPerBlkV > 0 ? tc::divUp (local_hidden_units_kv, mSageAttnNumEltsPerBlkV ) : 0 ;
1487- size_t const sage_q_sfs_buffer_size = useSageAttnSeparateQkv
1488- ? sizeof (float ) * mNumAttnHeads * params.batch_size * static_cast <size_t >(q_max_n_blk)
1489- : 0 ;
1490- size_t const sage_k_sfs_buffer_size = useSageAttnSeparateQkv
1491- ? sizeof (float ) * mNumAttnKVHeads * params.batch_size * static_cast <size_t >(k_max_n_blk)
1492- : 0 ;
1493- size_t const sage_v_sfs_buffer_size = useSageAttnSeparateQkv ? sizeof (float ) * static_cast <size_t >(v_max_n_blk) : 0 ;
14941446 size_t const padding_offset_size
14951447 = mEnableContextFMHA ? 0 : sizeof (int ) * params.batch_size * params.input_seq_length ;
14961448 size_t const encoder_padding_offset_size
@@ -1545,12 +1497,6 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
15451497 = reinterpret_cast <float *>(nextWorkspacePtr (workspace_byte_ptr, offset, fmha_bmm1_scale_size));
15461498 float * fmha_bmm2_scale_ptr
15471499 = reinterpret_cast <float *>(nextWorkspacePtr (workspace_byte_ptr, offset, fmha_bmm2_scale_size));
1548- float * sage_q_sfs_buf
1549- = reinterpret_cast <float *>(nextWorkspacePtr (workspace_byte_ptr, offset, sage_q_sfs_buffer_size));
1550- float * sage_k_sfs_buf
1551- = reinterpret_cast <float *>(nextWorkspacePtr (workspace_byte_ptr, offset, sage_k_sfs_buffer_size));
1552- float * sage_v_sfs_buf
1553- = reinterpret_cast <float *>(nextWorkspacePtr (workspace_byte_ptr, offset, sage_v_sfs_buffer_size));
15541500
15551501 T* gatherInBuffer = reinterpret_cast <T*>(nextWorkspacePtr (workspace_byte_ptr, offset, cpWorkspaceSize));
15561502 T* gatherOutBuffer = gatherInBuffer + cpMaxPadedSequenceLength * getHeadSize () * (mNumHeads + 2 * mNumKVHeads );
@@ -1792,69 +1738,7 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
17921738 }
17931739 else
17941740 {
1795- if (useSageAttnSeparateQkv)
1796- {
1797- TLLM_CHECK_WITH_INFO (mFP8ContextFMHA , " SageAttention kernel runs under mFP8ContextFMHA option." );
1798- TLLM_CHECK_WITH_INFO (
1799- mFmhaDispatcher ->isSupported (), " SageAttention has no unfused fallback implemented." );
1800- TLLM_CHECK_WITH_INFO (
1801- mSageAttnNumEltsPerBlkQ > 0 && mSageAttnNumEltsPerBlkK > 0 && mSageAttnNumEltsPerBlkV == 1 ,
1802- " SageQuant requires positive block sizes for Q and K while the block size for V must be 1." );
1803- TLLM_CHECK_WITH_INFO (!params.kv_scale_quant_orig ,
1804- " SageAttention disregards the configured params.kv_scale_quant_orig, invalidating the result." );
1805- check_cuda_error (cudaMemsetAsync (sage_v_sfs_buf, 0 , sage_v_sfs_buffer_size, stream));
1806-
1807- tc::SageQuantParams qkParams{};
1808- qkParams.headDim = getHeadSize ();
1809- qkParams.inputType = std::is_same_v<T, __nv_bfloat16> ? DATA_TYPE_BF16 : DATA_TYPE_FP16;
1810- qkParams.quantType = mSageAttnQkInt8 ? DATA_TYPE_INT8 : DATA_TYPE_E4M3;
1811- qkParams.vStage = 0 ;
1812- qkParams.sumSeqLensV = params.total_kv_len ;
1813- qkParams.numHeadsV = mNumAttnKVHeads ;
1814- qkParams.ptrV = params.v_ptr ;
1815- qkParams.ptrVQuant = fp8_v_buf;
1816- qkParams.ptrVScale = sage_v_sfs_buf;
1817- qkParams.smCount = mMultiProcessorCount ;
1818- qkParams.stream = stream;
1819-
1820- // Quantize into Fp8Q, SfsQ, SfsV
1821- if (mSageAttnNumEltsPerBlkQ > 0 )
1822- {
1823- qkParams.sumSeqLensQk = params.num_tokens ;
1824- qkParams.numHeads = mNumAttnHeads ;
1825- qkParams.tokenBlockSize = mSageAttnNumEltsPerBlkQ ;
1826- qkParams.ptrQk = attention_input;
1827- qkParams.ptrQkQuant = fp8_q_buf;
1828- qkParams.ptrQkScale = sage_q_sfs_buf;
1829- qkParams.vStage = 1 ;
1830- tc::invokeSageQuant (qkParams);
1831- }
1832- else
1833- {
1834- invokeCudaCast (fp8_q_buf, attention_input, params.num_tokens * local_hidden_units_qo, stream);
1835- }
1836-
1837- // Quantize into Fp8K, SfsK, Fp8V
1838- if (mSageAttnNumEltsPerBlkK > 0 )
1839- {
1840- qkParams.sumSeqLensQk = params.total_kv_len ;
1841- qkParams.numHeads = mNumAttnKVHeads ;
1842- qkParams.tokenBlockSize = mSageAttnNumEltsPerBlkK ;
1843- qkParams.ptrQk = params.k_ptr ;
1844- qkParams.ptrQkQuant = fp8_k_buf;
1845- qkParams.ptrQkScale = sage_k_sfs_buf;
1846- qkParams.vStage = 2 ;
1847- tc::invokeSageQuant (qkParams);
1848- }
1849- else
1850- {
1851- invokeCudaCast (fp8_k_buf, params.k_ptr , params.total_kv_len * local_hidden_units_kv, stream);
1852- }
1853- }
1854- else
1855- {
1856- invokeQKVPreprocessing (preprocessingParams, stream);
1857- }
1741+ invokeQKVPreprocessing (preprocessingParams, stream);
18581742 }
18591743 sync_check_cuda_error (stream);
18601744 {
@@ -1934,23 +1818,9 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
19341818 }
19351819 else
19361820 {
1937- if (useSageAttnSeparateQkv)
1938- {
1939- fmhaParams.qkvPtr = nullptr ;
1940- fmhaParams.qPtr = reinterpret_cast <void const *>(fp8_q_buf);
1941- fmhaParams.kPtr = reinterpret_cast <void const *>(fp8_k_buf);
1942- fmhaParams.vPtr = reinterpret_cast <void const *>(fp8_v_buf);
1943-
1944- fmhaParams.qScalePtr = sage_q_sfs_buf;
1945- fmhaParams.kScalePtr = sage_k_sfs_buf;
1946- fmhaParams.vScalePtr = sage_v_sfs_buf;
1947- }
1948- else
1949- {
1950- fmhaParams.qkvPtr = mFP8ContextFMHA ? reinterpret_cast <void const *>(fp8_qkv_buffer)
1951- : reinterpret_cast <void const *>(attention_input);
1952- fmhaParams.qPtr = reinterpret_cast <void const *>(q_buf_2_);
1953- }
1821+ fmhaParams.qkvPtr = mFP8ContextFMHA ? reinterpret_cast <void const *>(fp8_qkv_buffer)
1822+ : reinterpret_cast <void const *>(attention_input);
1823+ fmhaParams.qPtr = reinterpret_cast <void const *>(q_buf_2_);
19541824 }
19551825 // TODO: add contiguous kv buffer (cross-attention).
19561826 fmhaParams.kvPtr = nullptr ;
@@ -2877,22 +2747,13 @@ int AttentionOp::initialize() noexcept
28772747 fmhaParams.attentionInputLayout = (mPagedKVCache && mPagedContextFMHA ) ? AttentionInputLayout::Q_PAGED_KV
28782748 : AttentionInputLayout::PACKED_QKV;
28792749 }
2880- if (!mIsMLAEnabled && mFP8ContextFMHA
2881- && (mSageAttnNumEltsPerBlkQ > 0 || mSageAttnNumEltsPerBlkK > 0 || mSageAttnNumEltsPerBlkV > 0 ))
2882- {
2883- fmhaParams.attentionInputLayout = AttentionInputLayout::SEPARATE_Q_K_V;
2884- }
28852750 fmhaParams.isSPadded = !mRemovePadding ;
28862751 fmhaParams.numQHeads = mNumAttnHeads ;
28872752 fmhaParams.numKvHeads = mNumAttnKVHeads ;
28882753 fmhaParams.numTokensPerBlock = mTokensPerBlock ;
28892754 fmhaParams.headSize = mHeadSize ;
28902755 fmhaParams.headSizeV = mHeadSize ;
28912756 fmhaParams.qScaling = mQScaling ;
2892- fmhaParams.sageBlockSizeQ = mSageAttnNumEltsPerBlkQ ;
2893- fmhaParams.sageBlockSizeK = mSageAttnNumEltsPerBlkK ;
2894- fmhaParams.sageBlockSizeV = mSageAttnNumEltsPerBlkV ;
2895- fmhaParams.dataTypeQkReinterpret = mSageAttnQkInt8 ? DATA_TYPE_INT8 : DATA_TYPE_E4M3;
28962757
28972758 // mFmhaDispatcher is not used for generation MLA, but we still need to modify these values to avoid selecting
28982759 // the wrong kernel, no matter mIsGenerationMLA is true or false
@@ -3199,10 +3060,6 @@ std::string AttentionOp::toString() const
31993060 ss << " mPosShiftEnabled: " << std::boolalpha << mPosShiftEnabled << std::endl;
32003061 ss << " mPagedContextFMHA: " << std::boolalpha << mPagedContextFMHA << std::endl;
32013062 ss << " mFP8ContextFMHA: " << std::boolalpha << mFP8ContextFMHA << std::endl;
3202- ss << " mSageAttnNumEltsPerBlkQ: " << mSageAttnNumEltsPerBlkQ << std::endl;
3203- ss << " mSageAttnNumEltsPerBlkK: " << mSageAttnNumEltsPerBlkK << std::endl;
3204- ss << " mSageAttnNumEltsPerBlkV: " << mSageAttnNumEltsPerBlkV << std::endl;
3205- ss << " mSageAttnQkInt8: " << std::boolalpha << mSageAttnQkInt8 << std::endl;
32063063 ss << " mFP8AttenOutput: " << std::boolalpha << mFP8AttenOutput << std::endl;
32073064 ss << " mFP8ContextMLA: " << std::boolalpha << mFP8ContextMLA << std::endl;
32083065 ss << " mDenseContextFMHA: " << std::boolalpha << mDenseContextFMHA << std::endl;
0 commit comments