Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ enum class ActType
//
// GatedSilu is a special case of SwiGlu where the alpha is 1.0 and the beta is 0.0.
SwiGlu,
Relu2
Relu2,
Silu
};

// Type of the element-wise activation to apply after the Gemm
Expand All @@ -59,6 +60,10 @@ enum class EltwiseActType
// act = relu(x0) ^ 2
// where x0 is the output of the Gemm.
Relu2,
// Silu is defined as the following operation:
// act = x0 * sigmoid(x0)
// where x0 is the output of the Gemm.
Silu
};

struct TrtllmGenBatchedGemmRunnerOptions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,10 +141,10 @@ struct BatchedGemmData
// The rightmost dimension is contiguous in memory.
//
// If DeepSeek FP8 recipe is not used, but for MxFp{4,8}, MxInt4 and NvFp4 formats:
// The layout of scaling factors for A is always R128c4
// If the layout is R128c4,
// M must be a multiple of 128.
// K must be a multiple of 64.
// The "logical" shape is: [paddedM, K / P], where P is the scaling block size.
// K must be a multiple of 4 * P, where P is the scaling block size.
// The "logical" shape is: [paddedM, K / P].
// The R128c4 layout is: [paddedM / 128, K / P / 4, 512].
// The shape we use for TMA is: [paddedM / 128, K / P / 4, 2, 256].
// Where paddedM is M if (routeAct == true && batchM), or
Expand Down Expand Up @@ -302,7 +302,7 @@ struct BatchedGemmData

// The pre-activation scaling factor (typically dequantA * dequantB) for non-gated non-linear
// activation.
// Only used when non-linear activation is applied (e.g., GELU, Relu2).
// Only used when non-linear activation is applied (e.g., GELU, Relu2, Silu).
// When used, scaleC should be quantScaleC only, and this scale is applied before the
// activation. Shape is [B].
float const* mPtrScaleAct{nullptr};
Expand Down Expand Up @@ -786,7 +786,7 @@ class BatchedGemmInterface
{
numCtasBatch += batchM
? gemm::divUp(options.mBatchedM[bi], options.mTileM * options.mClusterDimX) * options.mClusterDimX
: gemm::divUp(options.mBatchedN[bi], options.mTileN);
: gemm::divUp(options.mBatchedN[bi], options.mTileN * options.mClusterDimY) * options.mClusterDimY;
}
}
// For MoE, mNumTokens != 0 and the number of CTAs is known only at runtime.
Expand Down Expand Up @@ -923,19 +923,21 @@ class BatchedGemmInterface
{
totalNumPaddedTokens += batchM
? gemm::divUpMul(options.mBatchedM[bi], options.mTileM * options.mClusterDimX)
: gemm::divUpMul(options.mBatchedN[bi], options.mTileN);
: gemm::divUpMul(options.mBatchedN[bi], options.mTileN * options.mClusterDimY);
}
}
else
{
// Get tile in token dim.
auto tileTokensDim = batchM ? options.mTileM * options.mClusterDimX : options.mTileN;
auto tileTokensDim
= batchM ? options.mTileM * options.mClusterDimX : options.mTileN * options.mClusterDimY;
totalNumPaddedTokens = data.mProblemDimensions.mMaxNumCtasInTokenDim * tileTokensDim;
}
// Get options from config.
auto& options = config.mOptions;

int const tokenTile = batchM ? options.mTileM * options.mClusterDimX : options.mTileN;
int const tokenTile
= batchM ? options.mTileM * options.mClusterDimX : options.mTileN * options.mClusterDimY;

auto const numTokens = totalNumPaddedTokens;
auto const intermediateDim = batchM ? options.mN : options.mM;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,18 +100,18 @@ struct BatchedGemmOptions : public gemmGatedAct::GemmGatedActOptions
tg::Dtype dtypeA, tg::Dtype dtypeB, tg::Dtype dtypeC, tg::Dtype dtypeMmaA, tg::Dtype dtypeMmaB,
gemm::EltwiseActType eltwiseActType, bool enablesEarlyExit, bool enablesDelayedEarlyExit,
bool enablesGlobalPtxKnobs, int epilogueLdtmDps, int epilogueLdtmBits, int epilogueTileM, int epilogueTileN,
bool fuseUtccpWithUtcmma, bool gridTriggerSecondaryA, bool gridTriggerSecondaryB,
bool gridWaitForPrimaryEarlyExit, bool gridWaitForPrimaryA, bool gridWaitForPrimaryB, bool hoistLoadTaskInit,
bool hoistMmaTaskTryWaits, int k, gemm::KernelTraits kernelTraits, gemm::MatrixLayout layoutA,
gemm::MatrixLayout layoutB, int m, int mmaK, tg::MmaKind mmaKind, int mmaM, int mmaN, bool mockAllReduce, int n,
int numEpilogueWarps, int numRegsCastAWarps, int numRegsCopySfLdsSttm, int numRegsCopySparsityInfo,
int numRegsPerThreadEpilogueWarp, int numRegsPerThreadNonEpilogueWarp, int numSlicesForSplitK,
int numSlicesForSliceK, int numStages, int numStagesMma, int numStagesMmaWithinWorkTile,
int numStagesMmaAcrossWorkTile, int numStagesWorkId, bool outputDebugTensors, bool patchF2fp,
int32_t sfBlockSizeA, int32_t sfBlockSizeB, int32_t sfBlockSizeC, tg::SfLayout sfLayoutA,
tg::SfLayout sfLayoutB, tg::SfLayout sfLayoutC, int32_t sfReshapeFactor, bool sliceK, tg::Sparsity sparsityA,
gemm::SplitK splitK, int tileK, int tileM, int tileN, gemm::TileScheduler tileScheduler,
bool transposeMmaOutput, bool useCustomMmaSchedule, bool useDeepSeekFp8,
int fallbackClusterDimX, int fallbackClusterDimY, int fallbackClusterDimZ, bool fuseUtccpWithUtcmma,
bool gridTriggerSecondaryA, bool gridTriggerSecondaryB, bool gridWaitForPrimaryEarlyExit,
bool gridWaitForPrimaryA, bool gridWaitForPrimaryB, bool hoistLoadTaskInit, bool hoistMmaTaskTryWaits, int k,
gemm::KernelTraits kernelTraits, gemm::MatrixLayout layoutA, gemm::MatrixLayout layoutB, int m, int mmaK,
tg::MmaKind mmaKind, int mmaM, int mmaN, bool mockAllReduce, int n, int numEpilogueWarps, int numRegsCastAWarps,
int numRegsCopySfLdsSttm, int numRegsCopySparsityInfo, int numRegsPerThreadEpilogueWarp,
int numRegsPerThreadNonEpilogueWarp, int numSlicesForSplitK, int numSlicesForSliceK, int numStages,
int numStagesMma, int numStagesMmaWithinWorkTile, int numStagesMmaAcrossWorkTile, int numStagesWorkId,
bool outputDebugTensors, bool patchF2fp, int32_t sfBlockSizeA, int32_t sfBlockSizeB, int32_t sfBlockSizeC,
tg::SfLayout sfLayoutA, tg::SfLayout sfLayoutB, tg::SfLayout sfLayoutC, int32_t sfReshapeFactor, bool sliceK,
tg::Sparsity sparsityA, gemm::SplitK splitK, int tileK, int tileM, int tileN, gemm::TileScheduler tileScheduler,
bool transposeMmaOutput, bool useCustomMmaSchedule, bool useDeepSeekFp8, bool useFlexibleClusterDims,
bool useHoistTryWaitForCustomMmaSchedule, bool useMaxTmemOverlap, bool usePerTokenSfA, bool usePerTokenSfB,
bool useShuffledMatrix, bool useTmaStore, bool useTwoTmaLoadWarps, bool useTwoMmaWarps,
bool useUnrollLoop2xForMma, int validM, int validN, int validK, int worldSize,
Expand All @@ -127,17 +127,18 @@ struct BatchedGemmOptions : public gemmGatedAct::GemmGatedActOptions
gemm::GemmOptions(allReduceAlgo, biasType, blockK, clcFastDrain, clusterDimX, clusterDimY, clusterDimZ,
ctaSwizzleType, dtypeAcc, dtypeA, dtypeB, dtypeC, dtypeMmaA, dtypeMmaB, eltwiseActType,
enablesEarlyExit, enablesDelayedEarlyExit, enablesGlobalPtxKnobs, epilogueLdtmDps, epilogueLdtmBits,
epilogueTileM, epilogueTileN, fuseUtccpWithUtcmma, gridTriggerSecondaryA, gridTriggerSecondaryB,
gridWaitForPrimaryEarlyExit, gridWaitForPrimaryA, gridWaitForPrimaryB, hoistLoadTaskInit,
hoistMmaTaskTryWaits, k, kernelTraits, layoutA, layoutB, m, mmaK, mmaKind, mmaM, mmaN, mockAllReduce, n,
numEpilogueWarps, numRegsCastAWarps, numRegsCopySfLdsSttm, numRegsCopySparsityInfo,
numRegsPerThreadEpilogueWarp, numRegsPerThreadNonEpilogueWarp, numSlicesForSplitK, numSlicesForSliceK,
numStages, numStagesMma, numStagesMmaWithinWorkTile, numStagesMmaAcrossWorkTile, numStagesWorkId,
outputDebugTensors, patchF2fp, sfBlockSizeA, sfBlockSizeB, sfBlockSizeC, sfLayoutA, sfLayoutB,
sfLayoutC, sfReshapeFactor, sliceK, sparsityA, splitK, tileK, tileM, tileN, tileScheduler,
transposeMmaOutput, useCustomMmaSchedule, useDeepSeekFp8, useHoistTryWaitForCustomMmaSchedule,
useMaxTmemOverlap, usePerTokenSfA, usePerTokenSfB, useShuffledMatrix, useTmaStore, useTwoTmaLoadWarps,
useTwoMmaWarps, useUnrollLoop2xForMma, validM, validN, validK, worldSize),
epilogueTileM, epilogueTileN, fallbackClusterDimX, fallbackClusterDimY, fallbackClusterDimZ,
fuseUtccpWithUtcmma, gridTriggerSecondaryA, gridTriggerSecondaryB, gridWaitForPrimaryEarlyExit,
gridWaitForPrimaryA, gridWaitForPrimaryB, hoistLoadTaskInit, hoistMmaTaskTryWaits, k, kernelTraits,
layoutA, layoutB, m, mmaK, mmaKind, mmaM, mmaN, mockAllReduce, n, numEpilogueWarps, numRegsCastAWarps,
numRegsCopySfLdsSttm, numRegsCopySparsityInfo, numRegsPerThreadEpilogueWarp,
numRegsPerThreadNonEpilogueWarp, numSlicesForSplitK, numSlicesForSliceK, numStages, numStagesMma,
numStagesMmaWithinWorkTile, numStagesMmaAcrossWorkTile, numStagesWorkId, outputDebugTensors, patchF2fp,
sfBlockSizeA, sfBlockSizeB, sfBlockSizeC, sfLayoutA, sfLayoutB, sfLayoutC, sfReshapeFactor, sliceK,
sparsityA, splitK, tileK, tileM, tileN, tileScheduler, transposeMmaOutput, useCustomMmaSchedule,
useDeepSeekFp8, useFlexibleClusterDims, useHoistTryWaitForCustomMmaSchedule, useMaxTmemOverlap,
usePerTokenSfA, usePerTokenSfB, useShuffledMatrix, useTmaStore, useTwoTmaLoadWarps, useTwoMmaWarps,
useUnrollLoop2xForMma, validM, validN, validK, worldSize),
actType, clampBeforeAct)
, mBatchedM(batchedM)
, mBatchedN(batchedN)
Expand Down Expand Up @@ -310,7 +311,7 @@ inline bool checkAndUpdateBatchedGemmOptions(
TLLM_CHECK_ERROR((options.mRouteSfsImpl.value() == RouteImpl::Ldgsts
|| options.mRouteSfsImpl.value() == RouteImpl::LdgPlusSts)
&& options.mRouteImpl == RouteImpl::Tma,
"RouteSfsImpl must be equal to RouteImpl, or Ldgsts/LdgPlusSts, when RouteImpl is Tma");
"RouteSfsImpl must be equal to RouteImpl, or Ldgsts/LdgPlusSts when RouteImpl is Tma");
}
else if (!options.mRouteSfsImpl.has_value())
{
Expand Down Expand Up @@ -379,8 +380,6 @@ inline bool checkAndUpdateBatchedGemmOptions(

if (doesRouteImplUseTma(options.mRouteSfsImpl.value()))
{
TLLM_CHECK_ERROR(!batchM, "UTMALDG.GATHER4 only supported for batch N.");

if (tg::mmaKindIsBlockFmt(options.mMmaKind))
{
int const numEltsPerSfRoute = batchM ? options.mSfBlockSizeA : options.mSfBlockSizeB;
Expand All @@ -392,8 +391,9 @@ inline bool checkAndUpdateBatchedGemmOptions(

if (!batchM || doesRouteImplUseNoRoute(options.mRouteImpl))
{
TLLM_CHECK_ERROR(options.mSfLayoutA == tg::SfLayout::R128c4,
"options.mSfLayoutA has to be tg::SfLayout::R128c4 when not being routed");
bool isSupportedSfLayoutA = options.mSfLayoutA == tg::SfLayout::R128c4;
TLLM_CHECK_ERROR(isSupportedSfLayoutA, "options.mSfLayoutA has to be R128cX when not batch M or not routed",
tg::sfLayoutToString(options.mSfLayoutA));
}
}

Expand Down Expand Up @@ -422,12 +422,6 @@ inline bool checkAndUpdateBatchedGemmOptions(
options.mK % options.mTileK == 0, "K must be a multiple of tileK when using Ldg based SF routing");
}

if (options.mClusterDimX > 1 && batchM && options.mRouteSfsImpl.has_value())
{
TLLM_CHECK_ERROR(options.mRouteSfsImpl.value() != RouteImpl::Tma,
"2CTA BatchedGemm does not support routing Sf along M dimension with TMA.");
}

// Check if all elements in mBatchedM or mBatchedN are the same (uniform tokens per batch) and
// set mIsUniformNumTokensPerBatch and mBatchStride.
if (options.mIsUniformNumTokensPerBatch)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,10 @@ enum class EltwiseActType
// act = relu(x0) ^ 2
// where x0 is the output of the Gemm.
Relu2,
// Silu is defined as the following operation:
// act = x0 * sigmoid(x0)
// where x0 is the output of the Gemm.
Silu,
};

////////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down
Loading