Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions cpp/kernels/xqa/defines.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,18 @@ static_assert(SPEC_DEC, "SPEC_Q_SEQ_LEN should only be used when SPEC_DEC is ena
#define SLIDING_WINDOW 0
#endif

#ifndef SKIP_SOFTMAX_ATTN
#define SKIP_SOFTMAX_ATTN 0
#endif

#ifndef SKIP_SOFTMAX_ATTN_BLOCK_STATS
#define SKIP_SOFTMAX_ATTN_BLOCK_STATS 0
#endif

#ifndef SKIP_SOFTMAX_ATTN_FIX_THRESHOLD_GREATER_THAN_ONE
#define SKIP_SOFTMAX_ATTN_FIX_THRESHOLD_GREATER_THAN_ONE 1
#endif

// 0 - no PDL
// 1 - naive PDL
// 2 - aggressive PDL (implemented only in mha_sm90.cu for now)
Expand Down
1 change: 1 addition & 0 deletions cpp/kernels/xqa/gmma.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ __device__ inline MatDesc makeMatDesc(void const* data, uint32_t dimKByteOffset,
asm volatile("trap;\n");
return 0;
}();
assert(__cvta_generic_to_shared(data) % baseAlign == 0);
uint32_t const baseOffset = ((patternAddr % baseAlign == 0) ? 0U : ((patternAddr >> 0x7) & 0x7));
return MatDesc{
/*addr=*/MatDesc::encode(__cvta_generic_to_shared(data)),
Expand Down
45 changes: 27 additions & 18 deletions cpp/kernels/xqa/mha.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2734,6 +2734,25 @@ static constexpr auto kernel_mha = kernel_mha_impl;
#endif

#ifndef GENERATE_CUBIN
uint32_t computeNbSubSeqPerSeqMHA(cudaDeviceProp const& prop, uint32_t batchSize, uint32_t nbKHeads, uint32_t maxSeqLen)
{
if (!allowMultiBlockMode)
{
return 1;
}
auto const env = std::getenv("XQA_NB_SUB_SEQ");
if (env != nullptr)
{
int32_t const val = std::stoi(env);
if (val > 0)
{
return val;
}
}
return std::min<uint32_t>(
std::max<uint32_t>(1U, prop.multiProcessorCount / (batchSize * nbKHeads)), divUp(maxSeqLen, ctaTile.x));
}

void launchMHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
#if SLIDING_WINDOW
uint32_t slidingWinSize,
Expand Down Expand Up @@ -2771,6 +2790,13 @@ void launchMHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
// int8/fp8 KV cache.
#if SPEC_DEC
SpecDecParams const& specDecParams,
#endif
#if SKIP_SOFTMAX_ATTN
float const skipSoftmaxThresholdScaleFactor, // for compatibility with mha_sm90.cu only
#if SKIP_SOFTMAX_ATTN_BLOCK_STATS
uint32_t* __restrict__ skippedBlockCount, // for compatibility with mha_sm90.cu only
uint32_t* __restrict__ totalBlockCount, // for compatibility with mha_sm90.cu only
#endif
#endif
uint32_t* semaphores, void* scratch, cudaStream_t stream)
{
Expand All @@ -2793,24 +2819,7 @@ void launchMHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
uint32_t const nbQHeads = nbKHeads * headGrpSize;

// const uint32_t nbSubSeqPerSeq = allowMultiBlockMode ? DBG_NB_CTAS_PER_SEQ : 1;
uint32_t const nbSubSeqPerSeq = [&]() -> uint32_t
{
if (!allowMultiBlockMode)
{
return 1;
}
auto const env = std::getenv("XQA_NB_SUB_SEQ");
if (env != nullptr)
{
int32_t const val = std::stoi(env);
if (val > 0)
{
return val;
}
}
return std::min<uint32_t>(
std::max<uint32_t>(1U, prop.multiProcessorCount / (batchSize * nbKHeads)), divUp(maxSeqLen, ctaTile.x));
}();
uint32_t const nbSubSeqPerSeq = computeNbSubSeqPerSeqMHA(prop, batchSize, nbKHeads, maxSeqLen);
// gridDim.z == batchSize && gridDim.y == nbKHeads && gridDim.x == nbSubSeqPerSeq
#if SPEC_DEC
const uint32_t nbTokenBlocksPerGrp = divUp(qSeqLen * headGrpSize, rowsPerBlock);
Expand Down
18 changes: 18 additions & 0 deletions cpp/kernels/xqa/mha.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@ struct BeamSearchParams
// match trt-llm API.
};

uint32_t computeNbSubSeqPerSeqMHA(
cudaDeviceProp const& prop, uint32_t batchSize, uint32_t nbKHeads, uint32_t maxSeqLen);

void launchMHA(cudaDeviceProp const& prop, uint32_t const nbKHeads,
#if SLIDING_WINDOW
uint32_t slidingWinSize,
Expand Down Expand Up @@ -127,9 +130,18 @@ void launchMHA(cudaDeviceProp const& prop, uint32_t const nbKHeads,
// int8/fp8 KV cache.
#if SPEC_DEC
SpecDecParams const& specDecParams,
#endif
#if SKIP_SOFTMAX_ATTN
float const skipSoftmaxThresholdScaleFactor,
#if SKIP_SOFTMAX_ATTN_BLOCK_STATS
uint32_t* __restrict__ skippedBlockCount, uint32_t* __restrict__ totalBlockCount,
#endif
#endif
uint32_t* semaphores, void* scratch, cudaStream_t stream);

uint32_t computeNbSubSeqPerSeqHopperF8MHA(
cudaDeviceProp const& prop, uint32_t batchSize, uint32_t nbKHeads, uint32_t maxSeqLen);

void launchHopperF8MHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
#if SLIDING_WINDOW
uint32_t slidingWinSize,
Expand Down Expand Up @@ -167,6 +179,12 @@ void launchHopperF8MHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
// int8/fp8 KV cache.
#if SPEC_DEC
SpecDecParams const& specDecParams,
#endif
#if SKIP_SOFTMAX_ATTN
float const skipSoftmaxThresholdScaleFactor,
#if SKIP_SOFTMAX_ATTN_BLOCK_STATS
uint32_t* __restrict__ skippedBlockCount, uint32_t* __restrict__ totalBlockCount,
#endif
#endif
uint32_t* semaphores, void* scratch, cudaStream_t stream);

Expand Down
Loading
Loading