Support Sigmoid (sigmoid+topk) routing function#2869
Support Sigmoid (sigmoid+topk) routing function#2869EdalatiAli wants to merge 1 commit intoflashinfer-ai:mainfrom
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds a new Changes
Sequence Diagram(s)sequenceDiagram
rect rgba(200,200,255,0.5)
participant Client
end
rect rgba(200,255,200,0.5)
participant Runner
end
rect rgba(255,200,200,0.5)
participant CUDA_Kernel
end
rect rgba(255,255,200,0.5)
participant Postprocess
end
Client->>Runner: request MoE inference (routing=Sigmoid)
Runner->>Runner: configure Preprocess=Sigmoid, Postprocess=SumNormalize, normTopkProb=false
Runner->>CUDA_Kernel: launch kernel with TopK routing
CUDA_Kernel->>Postprocess: emit expert indices + logits
Postprocess->>Client: return final routed output
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the Mixture-of-Experts (MoE) routing capabilities by introducing new sigmoid-based routing functions and a robust, policy-driven architecture. The changes aim to provide greater flexibility and support for diverse MoE model designs, ensuring efficient and accurate expert selection under various activation and normalization schemes. The refactoring also streamlines the codebase by consolidating common routing utilities and optimizing kernel instantiations. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. Footnotes
|
There was a problem hiding this comment.
Code Review
The pull request refactors MoE routing kernels by introducing a new policy-based dispatch system, consolidating common routing logic into trtllm_fused_moe_routing_common.cu and trtllm_fused_moe_routing_custom.cu, and removing the trtllm_fused_moe_routing_renormalize.cu file. It also adds new SigmoidRenorm and Sigmoid routing methods, and introduces a norm_topk_prob parameter to control top-K probability normalization. Review comments highlight concerns about the removal of specific dtype checks for routing_logits in trtllm_fused_moe_kernel_launcher.cu, suggesting that specific checks for each routing method should be maintained. There are also suggestions to use data_ptr() directly instead of static_cast<float*> for routing_logits to prevent type casting issues, and to use loadScalar for routing_bias to handle types correctly. Additionally, feedback points out that checks for routing_bias dimensions and shape should be conditional if not all methods use it, and that hardcoded expert tier values in trtllm_fused_moe_routing_deepseek.cu are not scalable and should be configurable. Minor comment inaccuracies regarding mUsePdl and mPdlOverlapWithNext in trtllm_fused_moe_routing_custom.cu were also noted.
I am having trouble creating individual review comments. Click here to see my feedback.
csrc/trtllm_fused_moe_kernel_launcher.cu (1799-1804)
The removal of the specific dtype checks for routing_logits based on RoutingMethodType is concerning. It's important to ensure that all routing methods now correctly handle both dl_float32 and dl_bfloat16 for routing_logits. If certain routing methods still require a specific dtype, this change could introduce errors. It's better to have specific checks for each routing method.
csrc/trtllm_fused_moe_kernel_launcher.cu (393-397)
The addition of mRoutingLogitsDtype and norm_topk_prob as arguments to the routing_runner.run function call is correct, but it's crucial to ensure that these parameters are correctly passed and handled in all subsequent calls to this function throughout the codebase. It's important to verify that the data types and values are consistent with the expected behavior of the routing kernel.
csrc/trtllm_fused_moe_kernel_launcher.cu (904-907)
The static_cast<float*> here is concerning. It's better to use data_ptr() directly and let the kernel handle the type. This is especially important if routing_logits is not always a float*. This could lead to incorrect memory access or type casting issues.
args->routing_logits = static_cast<float*>(routing_logits.value().data_ptr());
csrc/trtllm_fused_moe_routing_deepseek.cu (162-165)
Using loadScalar is better than static_cast<float>(params.mPtrRoutingBias[threadExpert]) as it handles the type correctly. This avoids potential type casting issues.
? static_cast<OutputT>(
loadScalar(params.mPtrRoutingBias, threadExpert, params.mDtypeBias))
: invalidScore;
csrc/trtllm_fused_moe_kernel_launcher.cu (1735-1736)
The type validation for routing_logits is now more general, allowing both dl_float32 and dl_bfloat16. This is good for flexibility, but it's crucial to ensure that the kernel implementation correctly handles both data types. Add a comment to explain why the type check is now more permissive.
TVM_FFI_ICHECK(routing_logits.dtype() == dl_float32 || routing_logits.dtype() == dl_bfloat16)
<< "FP8 per-tensor MoE: routing_logits must be float or bfloat16."; // Allow both float32 and bfloat16
csrc/trtllm_fused_moe_kernel_launcher.cu (929-930)
Using routing_logits.value().data_ptr() is better than static_cast<float*>(routing_logits.value().data_ptr()) as it avoids potential type casting issues. The kernel should handle the type appropriately.
args->routing_logits = routing_logits.value().data_ptr();
csrc/trtllm_fused_moe_kernel_launcher.cu (1840-1846)
The addition of checks for routing_bias dimensions and shape is good for robustness. However, it's important to ensure that these checks are consistent with the expected behavior of all routing methods that use routing_bias. If some methods don't use it, the check should be conditional.
csrc/trtllm_fused_moe_routing_deepseek.cu (67-77)
The use of hardcoded values for expert tiers is not scalable. These values should be configurable or derived from the input data to allow for flexibility in different MoE configurations. Consider using a function or a lookup table to determine the appropriate expert tier based on the number of experts.
csrc/trtllm_fused_moe_kernel_launcher.cu (881-892)
Adding RoutingMethodType::Sigmoid to this conditional block is correct for enabling the new routing method. However, it's important to ensure that the logic within this block is appropriate for all routing methods included, and that the comment accurately reflects the supported top_k values for all methods in the group. Consider updating the comment to be more general, or adding separate checks with distinct comments for each routing method if their requirements diverge.
TVM_FFI_ICHECK(args->top_k <= 10 && args->top_k > 0)
<< "Current routing kernel (no groups) only supports top_k<=10 && top_k>0.";
csrc/trtllm_fused_moe_routing_custom.cu (48-50)
When MaxNumExperts > 1024, the comment says the code caps the actual thread count at 1024. However, the code uses NumThreadsBlock which is assigned to MaxNumExperts <= 1024 ? MaxNumExperts : 1024. This means that the code is indeed capping the thread count at 1024. However, the code also says that each thread handles multiple experts. This needs to be verified to ensure that the thread is indeed handling multiple experts.
csrc/trtllm_fused_moe_routing_custom.cu (82-84)
Using params.mUsePdl is better than KernelParams::UsePdl as it uses the runtime value instead of the compile time value.
if (params.mUsePdl) {
csrc/trtllm_fused_moe_routing_custom.cu (311-312)
Using params.mUsePdl is better than KernelParams::UsePdl as it uses the runtime value instead of the compile time value.
if (params.mUsePdl) {
csrc/trtllm_fused_moe_routing_custom.cu (571)
The comment is not accurate. The code is checking for data.mUsePdl not KernelParams::UsePdl.
csrc/trtllm_fused_moe_routing_custom.cu (600)
The comment is not accurate. The code is checking for mutableData.mPdlOverlapWithNext not data.mPdlOverlapWithNext.
csrc/trtllm_fused_moe_routing_custom.cu (634)
The comment is not accurate. The code is checking for mutableData.mPdlOverlapWithNext not data.mPdlOverlapWithNext.
csrc/trtllm_fused_moe_routing_custom.cu (640)
The comment is not accurate. The code is checking for mutableData.mPdlOverlapWithNext not data.mPdlOverlapWithNext.
5f2751e to
756c10f
Compare
There was a problem hiding this comment.
Actionable comments posted: 4
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
flashinfer/fused_moe/core.py (1)
2301-2308:⚠️ Potential issue | 🟡 MinorThe routing-method docs now skip
TopK = 5.Each list jumps from
RenormalizeNaive = 4straight toSigmoidRenorm = 6, so the public docs no longer matchRoutingMethodType. This is also the natural place to explain how the newnorm_topk_probknob differs fromSigmoidRenorm.Also applies to: 2398-2405, 2835-2842, 2972-2979, 3089-3096
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/fused_moe/core.py` around lines 2301 - 2308, The routing-method docs skip the enum value TopK = 5 and don't explain how the new norm_topk_prob differs from SigmoidRenorm; update every affected docstring/comment block (mentions near RoutingMethodType and the doc sections around lines referenced) to include the missing "5: TopK" entry and add a short sentence comparing norm_topk_prob to SigmoidRenorm (what behavior changes and when to use it), ensuring references to RoutingMethodType and norm_topk_prob are accurate and consistent across all occurrences (around the blocks you noted).
🧹 Nitpick comments (3)
tests/moe/test_trtllm_gen_fused_moe.py (1)
3757-3879: Add one custom-routing config to the dtype-flexibility matrix.This matrix only exercises
RenormalizeandDeepSeekV3. The newDefault/Sigmoid/SigmoidRenormbranch incsrc/trtllm_fused_moe_runner.cunever runs here withrouting_logits_dtype=torch.float32, so the newly addeddtypeLogitsplumbing for the custom routing implementation is still uncovered.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/moe/test_trtllm_gen_fused_moe.py` around lines 3757 - 3879, The test matrix in test_routing_dtype_flexibility never covers the new Default/Sigmoid/SigmoidRenorm custom-routing path, so add an additional pytest.param entry in the "routing_config" param list that uses routing_method_type=RoutingMethodType.Default (or RoutingMethodType.Sigmoid/SigmoidRenorm), set appropriate fields (e.g., num_experts, top_k, padding, n_groups/top_k_groups as needed), include has_routing_bias True/False as relevant, and add the same compatible_moe_impls and compatible_intermediate_size entries; this ensures run_moe_test(...) exercises the csrc/trtllm_fused_moe_runner.cu branch that uses dtypeLogits so routing_logits_dtype=torch.float32 is covered.include/flashinfer/trtllm/fused_moe/RoutingKernel.cuh (1)
49-62: Silent fallback inloadScalarfor unsupported dtypes.The function returns
0.ffor dtypes other than Fp32 and Bfloat16. This could silently produce incorrect routing behavior if a caller accidentally passes an unsupported dtype (e.g., Fp16).Consider adding an assertion or handling Fp16 explicitly if it's a valid use case:
♻️ Suggested improvement
__forceinline__ __device__ float loadScalar(void const* ptr, int idx, batchedGemm::trtllm::gen::Dtype dtype) { namespace tg = batchedGemm::trtllm::gen; switch (dtype) { case tg::Dtype::Fp32: return static_cast<float const*>(ptr)[idx]; case tg::Dtype::Bfloat16: return static_cast<float>(static_cast<__nv_bfloat16 const*>(ptr)[idx]); + case tg::Dtype::Fp16: + return static_cast<float>(static_cast<__half const*>(ptr)[idx]); default: + assert(false && "Unsupported dtype in loadScalar"); return 0.f; } }🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@include/flashinfer/trtllm/fused_moe/RoutingKernel.cuh` around lines 49 - 62, The switch in loadScalar silently returns 0.f for unsupported tg::Dtype values, which can mask bugs; update loadScalar to explicitly handle tg::Dtype::Fp16 by reading __half values and converting to float, and change the default case to a device-side assertion (e.g., assert(false)) followed by a safe return to satisfy the compiler, so callers and future dtype additions fail loudly instead of silently producing 0.0; target the function loadScalar and the tg::Dtype enum when making this change.csrc/trtllm_fused_moe_routing_custom.cu (1)
446-479: Coop kernel launch wrapper covers reasonable expert tiers.The
launchCoopKernelfunction dispatches based on expert count tiers up to 576. However, the coop kernel constraint inrunPostTopKPipeline(line 80 in routing_common.cu) requiresmNumExperts <= 1024, so expert counts between 576 and 1024 would fall through to the warning at line 477-478.Consider adding tiers for 640, 768, 896, and 1024 experts, or updating the warning message to indicate that only experts ≤ 576 are supported in the coop path.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@csrc/trtllm_fused_moe_routing_custom.cu` around lines 446 - 479, launchCoopKernel currently handles expert tiers up to NumExperts576Experts but runPostTopKPipeline allows mNumExperts <= 1024, so experts in (576,1024] fall through to the warning; extend the dispatch in launchCoopKernel by adding additional else-if branches for the missing tiers (e.g., NumExperts640Experts, NumExperts768Experts, NumExperts896Experts, NumExperts1024Experts) using the same LAUNCH_ROUTING_WITH_POLICIES invocation pattern (coopLaunch=true, routingIndicesCoopKernel, numBlocksCoop, numThreadsHist, NoOpPreprocess, NoOpPostprocess, appropriate NumExpertsXXXExperts and NumTop8Experts), or alternatively update the final FLASHINFER_WARN message to explicitly state that only up to 576 experts are supported by the coop path; ensure the change keeps behavior consistent with runPostTopKPipeline and references the same kernel policy macro (LAUNCH_ROUTING_WITH_POLICIES) and function name launchCoopKernel.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@csrc/trtllm_fused_moe_routing_deepseek.cu`:
- Around line 474-484: The fast-path for precomputed TopK (the branch that
checks data.mPtrTopKIds and data.mPtrTopKPacked/mPtrScores) runs
runPostTopKPipeline(data, numThreadsHist, stream) after calling
getMaxNumExperts(data.mNumExperts) which can return 0 for unsupported expert
counts; move or replicate the expert-count guard (the check against
MaxSupportedExpertCount for data.mNumExperts) before taking the precomputed-TopK
fast path so you never call getMaxNumExperts/runPostTopKPipeline with an
unsupported expert count, or explicitly check data.mNumExperts <=
MaxSupportedExpertCount and error/return before computing numThreadsHist and
calling runPostTopKPipeline.
- Around line 162-165: The code reads routing bias unconditionally and casts it
to OutputT which can dereference a null params.mPtrRoutingBias and lose
precision; change the logic around biasVal (and the similar block at the other
occurrence) to load the bias into a float via loadScalar only if
params.mPtrRoutingBias is non-nullptr, otherwise treat it as 0.0f, perform all
bias math and ranking in float, and only cast to OutputT at the final assignment
if required; update uses of expertSelected, loadScalar, params.mDtypeBias,
OutputT, and invalidScore accordingly so null pointers are guarded and BF16
paths keep full float precision during comparisons.
In `@flashinfer/fused_moe/core.py`:
- Around line 1347-1350: The fake/meta op signatures registered via
register_fake_op must match the real op signatures; update the parameter lists
of _fake_trtllm_bf16_moe, _fake_trtllm_fp8_per_tensor_scale_moe,
_fake_trtllm_fp8_block_scale_moe, _fake_trtllm_fp4_block_scale_moe, and
_fake_trtllm_mxint4_block_scale_moe to include the new norm_topk_prob: bool =
True kwarg (and mirror any default) so they exactly match the real op signature
that now accepts norm_topk_prob; ensure the register_fake_op-decorated functions
and any callers/tracing paths use the updated parameter name and default to
avoid binding failures during tracing/compile.
In `@include/flashinfer/trtllm/common/cudaUtils.h`:
- Around line 276-291: getSMVersion and getMultiProcessorCount call
cudaGetDevice and cudaDeviceGetAttribute without checking their return values;
update both functions to check the cudaError_t results from cudaGetDevice and
each cudaDeviceGetAttribute call and surface failures (do not silently return
derived 0/-1). On error, produce a clear failure path: throw a
std::runtime_error (or return a sentinel and log, but prefer throwing)
containing the CUDA error string from cudaGetErrorString and context (function
name, device and which attribute failed). Reference getSMVersion and
getMultiProcessorCount and the CUDA calls (cudaGetDevice,
cudaDeviceGetAttribute, cudaGetErrorString) when locating and fixing the code.
---
Outside diff comments:
In `@flashinfer/fused_moe/core.py`:
- Around line 2301-2308: The routing-method docs skip the enum value TopK = 5
and don't explain how the new norm_topk_prob differs from SigmoidRenorm; update
every affected docstring/comment block (mentions near RoutingMethodType and the
doc sections around lines referenced) to include the missing "5: TopK" entry and
add a short sentence comparing norm_topk_prob to SigmoidRenorm (what behavior
changes and when to use it), ensuring references to RoutingMethodType and
norm_topk_prob are accurate and consistent across all occurrences (around the
blocks you noted).
---
Nitpick comments:
In `@csrc/trtllm_fused_moe_routing_custom.cu`:
- Around line 446-479: launchCoopKernel currently handles expert tiers up to
NumExperts576Experts but runPostTopKPipeline allows mNumExperts <= 1024, so
experts in (576,1024] fall through to the warning; extend the dispatch in
launchCoopKernel by adding additional else-if branches for the missing tiers
(e.g., NumExperts640Experts, NumExperts768Experts, NumExperts896Experts,
NumExperts1024Experts) using the same LAUNCH_ROUTING_WITH_POLICIES invocation
pattern (coopLaunch=true, routingIndicesCoopKernel, numBlocksCoop,
numThreadsHist, NoOpPreprocess, NoOpPostprocess, appropriate
NumExpertsXXXExperts and NumTop8Experts), or alternatively update the final
FLASHINFER_WARN message to explicitly state that only up to 576 experts are
supported by the coop path; ensure the change keeps behavior consistent with
runPostTopKPipeline and references the same kernel policy macro
(LAUNCH_ROUTING_WITH_POLICIES) and function name launchCoopKernel.
In `@include/flashinfer/trtllm/fused_moe/RoutingKernel.cuh`:
- Around line 49-62: The switch in loadScalar silently returns 0.f for
unsupported tg::Dtype values, which can mask bugs; update loadScalar to
explicitly handle tg::Dtype::Fp16 by reading __half values and converting to
float, and change the default case to a device-side assertion (e.g.,
assert(false)) followed by a safe return to satisfy the compiler, so callers and
future dtype additions fail loudly instead of silently producing 0.0; target the
function loadScalar and the tg::Dtype enum when making this change.
In `@tests/moe/test_trtllm_gen_fused_moe.py`:
- Around line 3757-3879: The test matrix in test_routing_dtype_flexibility never
covers the new Default/Sigmoid/SigmoidRenorm custom-routing path, so add an
additional pytest.param entry in the "routing_config" param list that uses
routing_method_type=RoutingMethodType.Default (or
RoutingMethodType.Sigmoid/SigmoidRenorm), set appropriate fields (e.g.,
num_experts, top_k, padding, n_groups/top_k_groups as needed), include
has_routing_bias True/False as relevant, and add the same compatible_moe_impls
and compatible_intermediate_size entries; this ensures run_moe_test(...)
exercises the csrc/trtllm_fused_moe_runner.cu branch that uses dtypeLogits so
routing_logits_dtype=torch.float32 is covered.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: e38a2b19-3d03-4fcf-93ef-f790f5322327
📒 Files selected for processing (18)
csrc/trtllm_fused_moe_kernel_launcher.cucsrc/trtllm_fused_moe_routing_common.cucsrc/trtllm_fused_moe_routing_custom.cucsrc/trtllm_fused_moe_routing_deepseek.cucsrc/trtllm_fused_moe_routing_llama4.cucsrc/trtllm_fused_moe_routing_renormalize.cucsrc/trtllm_fused_moe_runner.cuflashinfer/fused_moe/core.pyflashinfer/jit/fused_moe.pyinclude/flashinfer/trtllm/common/cudaUtils.hinclude/flashinfer/trtllm/fused_moe/DevKernel.hinclude/flashinfer/trtllm/fused_moe/RoutingCustomPolicy.cuhinclude/flashinfer/trtllm/fused_moe/RoutingDevKernel.hinclude/flashinfer/trtllm/fused_moe/RoutingKernel.cuhinclude/flashinfer/trtllm/fused_moe/RoutingKernel.hinclude/flashinfer/trtllm/fused_moe/RoutingKernelTopK.cuhinclude/flashinfer/trtllm/fused_moe/runner.htests/moe/test_trtllm_gen_fused_moe.py
💤 Files with no reviewable changes (1)
- csrc/trtllm_fused_moe_routing_renormalize.cu
| auto biasVal = expertSelected | ||
| ? static_cast<OutputT>( | ||
| loadScalar(params.mPtrRoutingBias, threadExpert, params.mDtypeBias)) | ||
| : invalidScore; |
There was a problem hiding this comment.
Handle routing_bias == nullptr and keep bias math in float.
routing_bias is still optional at the launcher/API layer, so this unconditional loadScalar(params.mPtrRoutingBias, ...) can read through nullptr. Casting a float32 bias down to OutputT before ranking also drops precision on the BF16 path and can perturb expert selection. Load into float and treat a missing bias as 0.f.
Proposed fix
- auto biasVal = expertSelected
- ? static_cast<OutputT>(
- loadScalar(params.mPtrRoutingBias, threadExpert, params.mDtypeBias))
- : invalidScore;
+ float biasVal = invalidScoreFloat;
+ if (expertSelected) {
+ biasVal = params.mPtrRoutingBias != nullptr
+ ? static_cast<float>(
+ loadScalar(params.mPtrRoutingBias, threadExpert, params.mDtypeBias))
+ : 0.f;
+ }
@@
- auto scoreBias = float{scoreSigmoid + float{biasVal}};
+ auto scoreBias = scoreSigmoid + biasVal;Also applies to: 196-198
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@csrc/trtllm_fused_moe_routing_deepseek.cu` around lines 162 - 165, The code
reads routing bias unconditionally and casts it to OutputT which can dereference
a null params.mPtrRoutingBias and lose precision; change the logic around
biasVal (and the similar block at the other occurrence) to load the bias into a
float via loadScalar only if params.mPtrRoutingBias is non-nullptr, otherwise
treat it as 0.0f, perform all bias math and ranking in float, and only cast to
OutputT at the final assignment if required; update uses of expertSelected,
loadScalar, params.mDtypeBias, OutputT, and invalidScore accordingly so null
pointers are guarded and BF16 paths keep full float precision during
comparisons.
| if (data.mPtrTopKIds != nullptr || | ||
| (data.mPtrTopKPacked != nullptr && data.mPtrScores == nullptr)) { | ||
| if (data.mPtrTopKIds != nullptr) { | ||
| FLASHINFER_CHECK( | ||
| data.mPtrTopKWeights != nullptr, | ||
| "When mPtrTopKIds is provided, mPtrTopKWeights must also be provided for " | ||
| "DeepSeek routing."); | ||
| } | ||
| int const numThreadsHist = getMaxNumExperts(data.mNumExperts); | ||
| runPostTopKPipeline(data, numThreadsHist, stream); | ||
| return; |
There was a problem hiding this comment.
Keep the expert-count guard before the precomputed-topK fast path.
This branch returns before the later mNumExperts <= MaxSupportedExpertCount check. For unsupported expert counts, getMaxNumExperts() only warns and returns 0, so runPostTopKPipeline(data, 0, stream) gets an invalid thread count.
Proposed fix
if (data.mPtrTopKIds != nullptr ||
(data.mPtrTopKPacked != nullptr && data.mPtrScores == nullptr)) {
+ FLASHINFER_CHECK(data.mNumExperts <= MaxSupportedExpertCount,
+ "Routing kernel expects `#experts` %d <= `#threads` %d", data.mNumExperts,
+ MaxSupportedExpertCount);
if (data.mPtrTopKIds != nullptr) {
FLASHINFER_CHECK(
data.mPtrTopKWeights != nullptr,
"When mPtrTopKIds is provided, mPtrTopKWeights must also be provided for "
"DeepSeek routing.");
}
int const numThreadsHist = getMaxNumExperts(data.mNumExperts);
+ FLASHINFER_CHECK(numThreadsHist > 0, "Unsupported numExperts %d", data.mNumExperts);
runPostTopKPipeline(data, numThreadsHist, stream);
return;
}🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@csrc/trtllm_fused_moe_routing_deepseek.cu` around lines 474 - 484, The
fast-path for precomputed TopK (the branch that checks data.mPtrTopKIds and
data.mPtrTopKPacked/mPtrScores) runs runPostTopKPipeline(data, numThreadsHist,
stream) after calling getMaxNumExperts(data.mNumExperts) which can return 0 for
unsupported expert counts; move or replicate the expert-count guard (the check
against MaxSupportedExpertCount for data.mNumExperts) before taking the
precomputed-TopK fast path so you never call
getMaxNumExperts/runPostTopKPipeline with an unsupported expert count, or
explicitly check data.mNumExperts <= MaxSupportedExpertCount and error/return
before computing numThreadsHist and calling runPostTopKPipeline.
| inline int getSMVersion() { | ||
| int device{-1}; | ||
| cudaGetDevice(&device); | ||
| int sm_major = 0; | ||
| int sm_minor = 0; | ||
| cudaDeviceGetAttribute(&sm_major, cudaDevAttrComputeCapabilityMajor, device); | ||
| cudaDeviceGetAttribute(&sm_minor, cudaDevAttrComputeCapabilityMinor, device); | ||
| return sm_major * 10 + sm_minor; | ||
| } | ||
|
|
||
| inline int getMultiProcessorCount() { | ||
| int device{-1}; | ||
| cudaGetDevice(&device); | ||
| int count = 0; | ||
| cudaDeviceGetAttribute(&count, cudaDevAttrMultiProcessorCount, device); | ||
| return count; |
There was a problem hiding this comment.
Check the CUDA query results before returning device metadata.
Both helpers ignore cudaGetDevice / cudaDeviceGetAttribute failures. If any of those calls fail, the caller gets 0/-1-derived values and may silently pick the wrong routing path instead of surfacing the CUDA error.
🔧 Proposed fix
inline int getSMVersion() {
int device{-1};
- cudaGetDevice(&device);
+ FLASHINFER_CHECK(cudaGetDevice(&device) == cudaSuccess, "CUDA error in cudaGetDevice");
int sm_major = 0;
int sm_minor = 0;
- cudaDeviceGetAttribute(&sm_major, cudaDevAttrComputeCapabilityMajor, device);
- cudaDeviceGetAttribute(&sm_minor, cudaDevAttrComputeCapabilityMinor, device);
+ FLASHINFER_CHECK(cudaDeviceGetAttribute(&sm_major, cudaDevAttrComputeCapabilityMajor, device) ==
+ cudaSuccess,
+ "CUDA error querying compute capability major");
+ FLASHINFER_CHECK(cudaDeviceGetAttribute(&sm_minor, cudaDevAttrComputeCapabilityMinor, device) ==
+ cudaSuccess,
+ "CUDA error querying compute capability minor");
return sm_major * 10 + sm_minor;
}
inline int getMultiProcessorCount() {
int device{-1};
- cudaGetDevice(&device);
+ FLASHINFER_CHECK(cudaGetDevice(&device) == cudaSuccess, "CUDA error in cudaGetDevice");
int count = 0;
- cudaDeviceGetAttribute(&count, cudaDevAttrMultiProcessorCount, device);
+ FLASHINFER_CHECK(cudaDeviceGetAttribute(&count, cudaDevAttrMultiProcessorCount, device) ==
+ cudaSuccess,
+ "CUDA error querying multiprocessor count");
return count;
}📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| inline int getSMVersion() { | |
| int device{-1}; | |
| cudaGetDevice(&device); | |
| int sm_major = 0; | |
| int sm_minor = 0; | |
| cudaDeviceGetAttribute(&sm_major, cudaDevAttrComputeCapabilityMajor, device); | |
| cudaDeviceGetAttribute(&sm_minor, cudaDevAttrComputeCapabilityMinor, device); | |
| return sm_major * 10 + sm_minor; | |
| } | |
| inline int getMultiProcessorCount() { | |
| int device{-1}; | |
| cudaGetDevice(&device); | |
| int count = 0; | |
| cudaDeviceGetAttribute(&count, cudaDevAttrMultiProcessorCount, device); | |
| return count; | |
| inline int getSMVersion() { | |
| int device{-1}; | |
| FLASHINFER_CHECK(cudaGetDevice(&device) == cudaSuccess, "CUDA error in cudaGetDevice"); | |
| int sm_major = 0; | |
| int sm_minor = 0; | |
| FLASHINFER_CHECK(cudaDeviceGetAttribute(&sm_major, cudaDevAttrComputeCapabilityMajor, device) == | |
| cudaSuccess, | |
| "CUDA error querying compute capability major"); | |
| FLASHINFER_CHECK(cudaDeviceGetAttribute(&sm_minor, cudaDevAttrComputeCapabilityMinor, device) == | |
| cudaSuccess, | |
| "CUDA error querying compute capability minor"); | |
| return sm_major * 10 + sm_minor; | |
| } | |
| inline int getMultiProcessorCount() { | |
| int device{-1}; | |
| FLASHINFER_CHECK(cudaGetDevice(&device) == cudaSuccess, "CUDA error in cudaGetDevice"); | |
| int count = 0; | |
| FLASHINFER_CHECK(cudaDeviceGetAttribute(&count, cudaDevAttrMultiProcessorCount, device) == | |
| cudaSuccess, | |
| "CUDA error querying multiprocessor count"); | |
| return count; | |
| } |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@include/flashinfer/trtllm/common/cudaUtils.h` around lines 276 - 291,
getSMVersion and getMultiProcessorCount call cudaGetDevice and
cudaDeviceGetAttribute without checking their return values; update both
functions to check the cudaError_t results from cudaGetDevice and each
cudaDeviceGetAttribute call and surface failures (do not silently return derived
0/-1). On error, produce a clear failure path: throw a std::runtime_error (or
return a sentinel and log, but prefer throwing) containing the CUDA error string
from cudaGetErrorString and context (function name, device and which attribute
failed). Reference getSMVersion and getMultiProcessorCount and the CUDA calls
(cudaGetDevice, cudaDeviceGetAttribute, cudaGetErrorString) when locating and
fixing the code.
756c10f to
b2115af
Compare
|
@ChristinaZ may you help review? |
|
/bot run |
|
[FAILED] Pipeline #47872154: 10/20 passed |
| // Sigmoid -> TopK (no renormalization) | ||
| routingData.mPreprocessType = RoutingPreprocessType::Sigmoid; | ||
| routingData.mPostprocessType = RoutingPostprocessType::SumNormalize; | ||
| routingData.mNormTopkProb = false; |
There was a problem hiding this comment.
Hi, thanks for your work. Please correct me if I am wrong. Since I added a norm_topk_prob parameter to the Python APIs like trtllm_fp8_block_scale_moe_op, trtllm_fp8_per_tensor_scale_moe_op, trtllm_fp4_block_scale_moe_op, trtllm_mxint4_block_scale_moe_op, and trtllm_bf16_moe_op in file flashinfer/fused_moe/core.py. I think we can use RoutingMethodType::SigmoidRenorm with routingData.mNormTopkProb set to false" (And this parameter is passed to function Runner::run in file csrc/trtllm_fused_moe_runner.cu ) to behave the same as RoutingMethodType::Sigmoid.
There was a problem hiding this comment.
And one more question regarding our requirements: what are the maximum number of experts and top‑K value we need to support? I asked this question to limit the instantiation cases. I set a tier list for different routing methods. For now, the maximum expert number is 256, and the maximum K value is 8 for SigmoidRenorm. If we need to support a larger value, we need to add one line here https://github.com/flashinfer-ai/flashinfer/blob/main/include/flashinfer/trtllm/fused_moe/RoutingCustomPolicy.cuh#L463. For example:
/// Sigmoid + SumNormalize (SigmoidRenorm: Sigmoid -> TopK -> Renormalize).
/// NOTE: Currently only covers ≤256 experts. If a model requires more, add a larger Tier here.
template <>
struct PolicyTraits<SigmoidPreprocess, SumNormalizePostprocess> {
using Pairs = TierList<Tier<128, 8>, // Small expert counts (≤128 experts)
Tier<256, 8>, // Medium expert counts (≤256 experts)
Tier<1024, 32> // Large expert counts (≤1024 experts) and K values
>;
};
There was a problem hiding this comment.
Hi @ChristinaZ , thank you for the comments!
-
it wasn't clear to me that we can set
norm_topk_prob=Falseand useRoutingMethodType::SigmoidRenormto get the same behaviour asRoutingMethodType::Sigmoid. Ifnorm_topk_probis used only forRoutingMethodType::SigmoidRenorm, I'd recommend removingnorm_topk_proband addingRoutingMethodType::Sigmoidto avoid confusion and simplifies integration into other frameworks since adding an extra param to the Python APIs won't be needed. What do you think? -
the current requirements should be fine but we can add more tiers if needed
There was a problem hiding this comment.
So sorry that I missed your reply. Thanks for your advice. I think your suggestion is great. I can then open a new PR to remove the parameter norm_topk_prob. We can merge this PR first.
There was a problem hiding this comment.
No worries @ChristinaZ , thanks for your feedback, I updated the PR to fix the conflicts
|
pls resolve merge conflicts thanks for the contrib! |
Signed-off-by: EdalatiAli <aliedalati@cohere.com>
7c1efbf to
4e3a1e3
Compare
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@csrc/trtllm_fused_moe_kernel_launcher.cu`:
- Around line 993-1001: This branch validates only args->top_k but must also
enforce the documented expert-count ceiling for the Sigmoid/SigmoidRenorm
policy; update the validation in the branch that checks
static_cast<RoutingMethodType>(routing_method_type) for
RoutingMethodType::Renormalize, RenormalizeNaive, SigmoidRenorm, and Sigmoid to
also assert args->num_experts <= 256 (and > 0) using TVM_FFI_ICHECK (either by
extending the existing ICHECK condition or adding a separate ICHECK) and provide
a clear error message mentioning the expert limit so launcher rejects
num_experts > 256 upfront.
In `@include/flashinfer/trtllm/fused_moe/runner.h`:
- Around line 55-58: Restore the original numeric ordinal for Unspecified so
existing serialized/ABI consumers remain stable: keep Unspecified = 8 and assign
Sigmoid a new, explicit value (not 8) to avoid overlap; update the enum entries
for Sigmoid and Unspecified in runner.h (the Sigmoid and Unspecified enum
constants) so Unspecified retains 8 and Sigmoid moves to the next unused value,
ensuring no other enum entries change.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 938dc975-0000-43ac-ab40-b02d458de428
📒 Files selected for processing (6)
csrc/trtllm_fused_moe_kernel_launcher.cucsrc/trtllm_fused_moe_runner.cuflashinfer/fused_moe/core.pyflashinfer/tllm_enums.pyinclude/flashinfer/trtllm/fused_moe/RoutingCustomPolicy.cuhinclude/flashinfer/trtllm/fused_moe/runner.h
✅ Files skipped from review due to trivial changes (2)
- include/flashinfer/trtllm/fused_moe/RoutingCustomPolicy.cuh
- flashinfer/fused_moe/core.py
🚧 Files skipped from review as they are similar to previous changes (1)
- flashinfer/tllm_enums.py
|
Thank you @aleozlx! I updated the PR to fix the conflicts |
📌 Description
Depends on #2803 .
This PR adds
RoutingMethodType.Sigmoidto support a routing function that applies sigmoid before topk (without renormalization) to be used by MoE layers that use this routing function.🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Behavior Change
Documentation
Tests