Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ void run(Data const& data, void* stream)
TLLM_CHECK_WITH_INFO(
data.mNumExperts % 4 == 0, "Routing kernel expects #experts %d to be a multiple of 4.", data.mNumExperts);

bool const useSingleBlock = data.mNumTokens <= BlockKernelMaxNumTokens;
bool const useSingleBlock = data.mNumTokens <= BlockKernelMaxNumTokens
|| (data.mNumTokens <= DynBlockKernelMaxNumTokens && data.mNumExperts <= DynBlockKernelMaxNumExperts);

bool const useSingleCluster = data.mNumTokens <= ((data.mPtrScores != nullptr || data.mPtrTopKIds != nullptr)
? MaxNumTokensSingleClusterScores
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ static constexpr int MaxNumTokensSingleCluster = NumBlocksPerCluster * NumThread
static constexpr int MaxNumTokensSingleClusterScores = NumBlocksPerCluster * NumWarps;

static constexpr int BlockKernelMaxNumTokens = 4;
static constexpr int DynBlockKernelMaxNumTokens = 16;
static constexpr int DynBlockKernelMaxNumExperts = 512;

template <typename DataType, typename InputType, int VecSize, int K, bool DoSoftmaxBeforeTopK>
__forceinline__ __device__ void routingTopKExperts(cg::thread_block_tile<WarpSize> const& warp,
Expand Down

Large diffs are not rendered by default.

66 changes: 66 additions & 0 deletions cpp/tests/unit_tests/kernels/routing/routingRenormalizeTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,72 @@ TYPED_TEST(RoutingRenormalizeKernelTest, DeviceLevelParallelizationTop4)
this->runTest(param);
};

TYPED_TEST(RoutingRenormalizeKernelTest, DynBlockBasic)
{
RoutingKernelTestParam param(RoutingMethodType::Renormalize, /*numTokens=*/8,
/*numExperts=*/128, /*topK=*/8,
/*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256,
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false,
/*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9);
this->runTest(param);
};

TYPED_TEST(RoutingRenormalizeKernelTest, DynBlockMaxTokens)
{
RoutingKernelTestParam param(RoutingMethodType::Renormalize, /*numTokens=*/16,
/*numExperts=*/512, /*topK=*/8,
/*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256,
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false,
/*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9);
this->runTest(param);
};

TYPED_TEST(RoutingRenormalizeKernelTest, DynBlockWithExpertParallelization)
{
RoutingKernelTestParam param(RoutingMethodType::Renormalize, /*numTokens=*/12,
/*numExperts=*/512, /*topK=*/8,
/*expertParallelization=*/2, /*expertParallelizationId=*/1, /*tileTokensDim=*/192,
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false,
/*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9);
this->runTest(param);
};

TYPED_TEST(RoutingRenormalizeKernelTest, DynBlockWithTopKAsInput)
{
RoutingKernelTestParam param(RoutingMethodType::Renormalize, /*numTokens=*/8,
/*numExperts=*/128, /*topK=*/8,
/*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256,
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/true, /*hasInvalidTopKInput=*/false,
/*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9);
this->runTest(param);
};

TYPED_TEST(RoutingRenormalizeKernelTest, DynBlockWithInvalidTopKInput)
{
RoutingKernelTestParam param(RoutingMethodType::Renormalize, /*numTokens=*/10,
/*numExperts=*/512, /*topK=*/8,
/*expertParallelization=*/2, /*expertParallelizationId=*/0, /*tileTokensDim=*/256,
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/true, /*hasInvalidTopKInput=*/true,
/*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9);
this->runTest(param);
};

TYPED_TEST(RoutingRenormalizeKernelTest, DynBlockWithRenormalizeNaive)
{
RoutingKernelTestParam param(RoutingMethodType::RenormalizeNaive, /*numTokens=*/16,
/*numExperts=*/512, /*topK=*/8,
/*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256,
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false,
/*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9);
this->runTest(param);
};

TYPED_TEST(RoutingRenormalizeKernelTest, BlockLevelParallelizationLargeN)
{
RoutingKernelTestParam param(RoutingMethodType::Renormalize, /*numTokens=*/4,
Expand Down
Loading
Loading