Skip to content

feat: Fuse shared experts into trtllm_gen moe (fp8)#2625

Open
nv-yunzheq wants to merge 9 commits intoflashinfer-ai:mainfrom
nv-yunzheq:DSR1_shared_expert_fusion
Open

feat: Fuse shared experts into trtllm_gen moe (fp8)#2625
nv-yunzheq wants to merge 9 commits intoflashinfer-ai:mainfrom
nv-yunzheq:DSR1_shared_expert_fusion

Conversation

@nv-yunzheq
Copy link
Copy Markdown
Collaborator

@nv-yunzheq nv-yunzheq commented Feb 23, 2026

For #2551
Integrating NVIDIA/TensorRT-LLM#11143

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Support for fused shared experts in MoE runtime and kernels, expanding per-token expert counts and routing behavior.
  • Bug Fixes

    • Improved runtime validations and sizing for fused-expert configurations; clearer error messages for weight/scale sizing.
  • Tests

    • Added and extended tests covering fused shared expert scenarios and routing variants.
  • Documentation

    • Clarified num_experts meaning in FP8 block-scale MoE docs.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @nv-yunzheq, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request integrates the concept of 'fused shared experts' into the existing TensorRT-LLM MoE framework, particularly for FP8 operations. The primary goal is to optimize the handling of experts that are shared across multiple tokens, by incorporating them directly into the MoE kernel's routing and execution logic. This change impacts how memory is allocated, how routing decisions are made, and how the overall MoE computation is performed, leading to a more streamlined and potentially faster processing of MoE layers with shared components.

Highlights

  • Fused Shared Experts Integration: Introduced support for fusing shared experts into the TensorRT-LLM Mixture-of-Experts (MoE) kernel, specifically for FP8 quantization, enhancing efficiency and flexibility.
  • Updated MoE Kernel Logic: Modified the MoE kernel launchers and runners to correctly account for and process num_fused_shared_experts in tensor allocations, routing calculations, and workspace sizing.
  • DeepSeekV3 Routing Enhancement: The DeepSeekV3 routing method was updated to properly handle the routing and weighting of both regular and fused shared experts.
  • Routing Method Compatibility Checks: Added explicit checks to ensure that Llama4 and Renormalize routing methods do not support the fused shared expert functionality, preventing misuse.
  • Python API and Test Coverage: Extended the Python API to expose the num_fused_shared_experts parameter and added new test cases to validate the correctness of the fused shared experts feature.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • csrc/trtllm_fused_moe_kernel_launcher.cu
    • Adjusted tensor allocation sizes for num_tokens_per_expert, expanded_idx_to_permuted_idx, expert_indexes, and expert_count_histogram to include num_fused_shared_experts.
    • Updated getMaxPermutedPaddedCount and getMaxNumCtasInBatchDim calls to use totalExpertsPerToken and totalNumExperts.
    • Modified moe_runner->getDefaultValidConfigIndex and getValidConfigIndices to use effectiveTopK and effectiveLocalExperts.
    • Passed num_fused_shared_experts to the routing_runner.run method.
  • csrc/trtllm_fused_moe_routing_deepseek.cu
    • Updated idxTopK and idxShared calculations to use mTotalExpertsPerToken.
    • Added logic to write packed scores and weights for fused shared experts.
    • Introduced numExperts, topK, and numThreadsHist variables to reflect total experts including fused shared ones.
    • Added FLASHINFER_CHECK for num_fused_shared_experts being less than warp size.
    • Modified useSingleCluster condition to include data.mNumTokens * topK.
    • Adjusted maxTokensCoop calculation to use topK.
    • Incremented data.mNumExperts, data.mTopK, and data.mNumLocalExperts by data.mNumFusedSharedExperts if shared experts are present.
    • Updated expandedIdxSize calculation to use topK.
  • csrc/trtllm_fused_moe_runner.cu
    • Added numFusedSharedExpert parameter to the Runner::run method signature.
    • Introduced totalExpertsPerToken variable in run method for DeepSeekV3 routing.
    • Added mNumFusedSharedExperts and mTotalExpertsPerToken to routingData for DeepSeekV3.
    • Implemented logic to calculate mSharedExpertTokenOffset and mSharedExpertNumTokens for shared experts distribution.
    • Added FLASHINFER_CHECK to disallow numFusedSharedExpert for Llama4 and Renormalize routing methods.
    • Introduced totalNumExperts and totalExpertsPerToken in setOpsData for activation and finalize data setup.
    • Updated activationData.topK and finalizeData.numExperts, finalizeData.topK to use total expert counts.
    • Modified getWorkspaceSizeInBytes to use totalLocalExperts and totalExpertsPerToken for workspace calculations.
    • Updated mPermuteGemm1.run and mGemm2.run calls to use totalLocalExperts and totalExpertsPerToken.
  • flashinfer/fused_moe/core.py
    • Added num_fused_shared_experts as an optional argument to trtllm_fp8_block_scale_moe_op and trtllm_fp8_block_scale_moe functions.
    • Passed num_fused_shared_experts to the C++ backend calls.
    • Updated docstrings to describe the new num_fused_shared_experts parameter and its implications on weight tensors.
    • Modified the forward method to pass num_fused_shared_experts to the underlying C++ function.
  • include/flashinfer/trtllm/fused_moe/RoutingKernel.h
    • Added mNumFusedSharedExperts, mSharedExpertTokenOffset, mSharedExpertNumTokens, and mTotalExpertsPerToken members to DataBase and KernelParamsBase structs.
  • include/flashinfer/trtllm/fused_moe/runner.h
    • Added numFusedSharedExpert parameter to the Runner::run method declaration.
    • Added num_fused_shared_experts member to the MoERunnerArgs struct.
  • tests/moe/test_trtllm_gen_fused_moe.py
    • Introduced num_routed_experts calculation in call_moe.
    • Adjusted num_experts and top_k parameters passed to trtllm_fp8_block_scale_moe to account for fused shared experts.
    • Modified routing_reference to accept num_fused_shared_experts and include shared experts in its logic.
    • Updated routing_reference_no_aux to pass num_fused_shared_experts to routing_reference.
    • Modified _compute_moe_actual_unified to pass num_fused_shared_experts to call_moe.
    • Adjusted run_moe_test to calculate total_experts and pass num_fused_shared_experts to reference implementations and moe_args.
    • Updated gemm1_weights and gemm2_weights tensor shapes in run_moe_test to accommodate total_experts.
    • Added new pytest.param entries for DSv3_fused_shared_1 and DSv3_fused_shared_2 to test the new functionality.
Activity
  • The pull request aims to integrate a feature from TensorRT-LLM PR 11143.
  • It addresses issue Shared expert fusion integration #2551, likely related to MoE functionality.
  • The author, nv-yunzheq, has included standard pre-commit checks and test requirements in the PR description.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Feb 23, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

This PR adds support for fused shared experts by threading num_fused_shared_experts through routing kernels, runner APIs, CUDA launchers, Python bindings, and tests; workspace sizing, tensor shapes, routing outputs, and GEMM/workspace calculations are updated to account for fused shared experts.

Changes

Cohort / File(s) Summary
Routing Kernel & Params
csrc/fused_moe/trtllm_backend/trtllm_fused_moe_routing_deepseek.cu, include/flashinfer/trtllm/fused_moe/RoutingKernel.h
Add fused-shared fields (mNumFusedSharedExperts, mSharedExpertTokenOffset, mSharedExpertNumTokens, mTotalExpertsPerToken) and expand routing kernel output/indexing and validations to write and account for fused shared expert entries (expanded topK/numExperts).
Runner API & Logic
csrc/trtllm_fused_moe_runner.cu, include/flashinfer/trtllm/fused_moe/runner.h
Runner::run signature updated to accept numFusedSharedExpert; MoERunnerArgs gains num_fused_shared_experts. Activation/finalize, GEMM workspace sizing and GEMM launches use fused-aware totals (totalExpertsPerToken, totalLocalExperts). Added routing guards for unsupported modes with fused experts.
Kernel Launcher & FP8 Entrypoints
csrc/trtllm_fused_moe_kernel_launcher.cu, flashinfer/fused_moe/core.py
Thread num_fused_shared_experts into launcher and Python entrypoints; compute totalExpertsPerToken/totalLocalExperts; adjust FP8 block-scale weight tensor allocations and scale validations; update exported FP8 entry signature to accept optional fused count.
Bindings & Utilities
csrc/moe_utils_binding.cu
Initialize fused-shared routing fields in moe_sort (prevent uninitialized routingData fields on this code path).
Tests
tests/moe/test_trtllm_gen_fused_moe.py
Extend routing reference & test harness to accept and propagate num_fused_shared_experts; update total_experts/top_k in test scaffolding; add DeepSeek fused-shared test cases and adjust kernel kwargs.

Sequence Diagram

sequenceDiagram
    participant Client as Client
    participant Py as Python Binding
    participant Launcher as Kernel Launcher
    participant Routing as Routing Kernel
    participant Runner as MoE Runner
    participant GEMM as GEMM Kernels

    Client->>Py: call trtllm_fp8_block_scale_moe(..., num_fused_shared_experts=N)
    Py->>Launcher: forward args including N
    Launcher->>Launcher: compute totalExpertsPerToken = top_k + N, totalLocalExperts = local_num_experts + N, allocate workspaces
    Launcher->>Routing: routing_runner.run(..., numFusedSharedExpert=N, ...)
    Routing->>Routing: emit routed experts + fused shared expert indices/weights (expanded topK)
    Routing-->>Launcher: routing outputs (expanded indices, counts)
    Launcher->>Runner: moe_runner.run(..., topK=top_k+N, localNumExperts=local+N, ...)
    Runner->>GEMM: launch PermuteGemm1/Gemm2 with fused-aware dimensions
    GEMM-->>Runner: results
    Runner-->>Launcher: aggregated MoE output
    Launcher-->>Py: return result
    Py-->>Client: deliver output
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related issues

Possibly related PRs

Suggested reviewers

  • djmmoss
  • yongwww
  • cyx-6
  • yzh119
  • bkryu
  • jimmyzho
  • sricketts
  • joker-eph
  • samuellees

Poem

🐰 Hopping through kernels, I tally each guest,

Shared experts join the MoE fest,
TopK grows by those fused and true,
Routing and runner now count the new crew,
A joyful hop — the tests pass too!

🚥 Pre-merge checks | ✅ 1 | ❌ 2

❌ Failed checks (1 warning, 1 inconclusive)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 33.33% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Description check ❓ Inconclusive The PR description references related issues and provides context about integrating upstream changes, but largely contains an unchecked template rather than substantive implementation details or a meaningful description of the changes. Fill in the Description section with details about what the PR does and why it's needed; ensure pre-commit checks and tests are addressed in the checklist before merging.
✅ Passed checks (1 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and concisely describes the main change: adding shared expert fusion to the trtllm_gen MoE implementation with FP8 support.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@nv-yunzheq
Copy link
Copy Markdown
Collaborator Author

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !344 has been created, and the CI pipeline #44669282 is currently running. I'll report back once the pipeline job completes.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the fusion of shared experts into the trtllm_gen MoE implementation, specifically for FP8. The changes cover the routing kernel, the launcher, and the Python API. While the integration logic for shared experts is mostly sound, there are a few critical issues regarding histogram initialization and template dispatching in the routing kernel that could lead to undefined behavior or incorrect results in multi-GPU or large-token scenarios.

Comment on lines +668 to +672
if (data.mNumFusedSharedExperts > 0) {
data.mNumExperts += data.mNumFusedSharedExperts;
data.mTopK += data.mNumFusedSharedExperts;
data.mNumLocalExperts += data.mNumFusedSharedExperts;
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Updating data.mNumExperts and data.mTopK after the first kernel launch (line 656 or 662) leads to several issues:

  1. numThreadsMain (line 655) and the histogram initialization inside routingMainKernel (line 85) use the original routed expert count, meaning the histogram entries for shared experts are never initialized to zero. This can cause garbage values to be used as offsets in subsequent permutation kernels.
  2. The dispatching macro LAUNCH_ROUTING_DEEPSEEK uses data.mNumExperts to select the MaxNumExperts template parameter. If the total expert count (routed + shared) crosses a threshold (e.g., 256 to 257), the first and second launches will use different template instantiations, which is inconsistent.

You should calculate the total expert count and top-k at the beginning of runImpl and ensure that initialization kernels use the total count, while routingMainKernel receives the routed count for its indexing logic.

Comment on lines +617 to +619
FLASHINFER_CHECK(data.mNumFusedSharedExperts <= WarpSize,
"Number of fused shared experts (%d) must be less than warp size.",
data.mNumFusedSharedExperts);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The check for mNumFusedSharedExperts <= WarpSize is currently placed inside the if (data.mNumExpertGroups > 1) block. However, routingMainKernel always assumes that shared experts can be handled by a single warp (using laneIdx), regardless of whether expert groups are used. This check should be moved outside the conditional block to ensure it is always enforced.

weight_layout=weight_layout,
do_finalize=do_finalize,
enable_pdl=enable_pdl,
num_fused_shared_experts=num_fused_shared_experts,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The num_fused_shared_experts parameter should be included in the instance_key used by the MoERunner (around line 1045). Since the kernel's performance and configuration depend on the total number of experts (routed + shared), omitting this from the key might lead to the autotuner returning a suboptimal tactic if multiple calls with different shared expert counts are made.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
csrc/trtllm_fused_moe_kernel_launcher.cu (1)

810-821: ⚠️ Potential issue | 🟠 Major

Shape validation for precomputed routing with fused shared experts is inconsistent.

The shape check at line 818 validates expert_indices.size(1) == args->top_k, but when fused shared experts are enabled, precomputed indices should account for the additional fused entries. At line 892, totalExpertsPerToken is calculated as args->top_k + args->num_fused_shared_experts, and the expert_weights tensor is allocated with this dimension (line 897). If precomputed routing is used alongside fused shared experts, the shape validation should check expert_indices.size(1) == totalExpertsPerToken instead of just args->top_k to ensure consistency with the routing output tensors.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/trtllm_fused_moe_kernel_launcher.cu` around lines 810 - 821, In
check_routing(), the validation of expert_indices.dim(1) only compares to
args->top_k but must account for fused shared experts; compute an expected width
like int expectedPerToken = args->top_k + args->num_fused_shared_experts (or
just use args->top_k when num_fused_shared_experts is zero) and replace the
existing TVM_FFI_ICHECK_EQ(expert_indices.size(1), args->top_k) with a check
against expectedPerToken so precomputed routing matches the allocation for
totalExpertsPerToken and expert_weights.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@csrc/trtllm_fused_moe_routing_deepseek.cu`:
- Around line 616-619: The check ensuring data.mNumFusedSharedExperts <=
WarpSize must be unconditional because fused shared-expert writes use laneIdx <
mNumFusedSharedExperts regardless of expert group count; move the
FLASHINFER_CHECK(data.mNumFusedSharedExperts <= WarpSize, ...) out of the if
(data.mNumExpertGroups > 1) block so it always runs, ensuring
data.mNumFusedSharedExperts is validated before any code paths that use
mNumFusedSharedExperts/mNumFusedSharedExperts-induced lane comparisons or writes
(references: data.mNumFusedSharedExperts, WarpSize, mNumFusedSharedExperts,
mNumExpertGroups).
- Around line 571-574: routingInitExpertCounts currently initializes only 2 *
data.mNumExperts using the pre-fusion value then data.mNumExperts is incremented
to include mNumFusedSharedExperts, leaving histogram slots for fused-shared
experts uninitialized; fix by making the initialization cover the full fused
range (initialize 2 * (data.mNumExperts + data.mNumFusedSharedExperts)) or by
moving the mutation of data.mNumExperts (add mNumFusedSharedExperts) before
calling routingInitExpertCounts so the kernel initializes the correct size, and
ensure subsequent kernels that atomicAdd into expert-count slots will see zeros
for indices [original_mNumExperts, original_mNumExperts +
mNumFusedSharedExperts). Also move the check data.mNumFusedSharedExperts <=
WarpSize out of the if (data.mNumExpertGroups > 1) block so the
fused-shared-expert write logic (the unconditional write at the fused shared
expert site) consistently validates the WarpSize constraint regardless of group
count.

In `@include/flashinfer/trtllm/fused_moe/RoutingKernel.h`:
- Around line 102-107: The new fused-shared expert members
(mNumFusedSharedExperts, mSharedExpertTokenOffset, mSharedExpertNumTokens,
mTotalExpertsPerToken) are uninitialized; initialize them to the same safe
defaults used in KernelParamsBase (e.g., zero) by adding default member
initializers or setting them in the DataBase constructor so callers that don't
set them won't propagate garbage into kernel params and cause routing/OOB
errors.

---

Outside diff comments:
In `@csrc/trtllm_fused_moe_kernel_launcher.cu`:
- Around line 810-821: In check_routing(), the validation of
expert_indices.dim(1) only compares to args->top_k but must account for fused
shared experts; compute an expected width like int expectedPerToken =
args->top_k + args->num_fused_shared_experts (or just use args->top_k when
num_fused_shared_experts is zero) and replace the existing
TVM_FFI_ICHECK_EQ(expert_indices.size(1), args->top_k) with a check against
expectedPerToken so precomputed routing matches the allocation for
totalExpertsPerToken and expert_weights.

ℹ️ Review info

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 26ef055 and 259d279.

📒 Files selected for processing (7)
  • csrc/trtllm_fused_moe_kernel_launcher.cu
  • csrc/trtllm_fused_moe_routing_deepseek.cu
  • csrc/trtllm_fused_moe_runner.cu
  • flashinfer/fused_moe/core.py
  • include/flashinfer/trtllm/fused_moe/RoutingKernel.h
  • include/flashinfer/trtllm/fused_moe/runner.h
  • tests/moe/test_trtllm_gen_fused_moe.py

Comment thread csrc/trtllm_fused_moe_routing_deepseek.cu Outdated
Comment on lines +616 to +619

FLASHINFER_CHECK(data.mNumFusedSharedExperts <= WarpSize,
"Number of fused shared experts (%d) must be less than warp size.",
data.mNumFusedSharedExperts);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

fusedSharedExperts <= WarpSize check should be unconditional.

This validation is guarded by if (data.mNumExpertGroups > 1) (line 605), but the fused shared expert writes at lines 261-265 and 272-274 use laneIdx < mNumFusedSharedExperts regardless of expert groups. If mNumExpertGroups <= 1 and mNumFusedSharedExperts > WarpSize, the writes would silently skip some fused experts.

Suggested fix

Move the check out of the if (data.mNumExpertGroups > 1) block:

+  FLASHINFER_CHECK(data.mNumFusedSharedExperts <= WarpSize,
+                   "Number of fused shared experts (%d) must be less than warp size.",
+                   data.mNumFusedSharedExperts);
+
   if (data.mNumExpertGroups > 1) {
     FLASHINFER_CHECK(data.mNumExpertGroups <= MaxNumGroups,
                      ...);
     ...
-
-    FLASHINFER_CHECK(data.mNumFusedSharedExperts <= WarpSize,
-                     "Number of fused shared experts (%d) must be less than warp size.",
-                     data.mNumFusedSharedExperts);
   }
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/trtllm_fused_moe_routing_deepseek.cu` around lines 616 - 619, The check
ensuring data.mNumFusedSharedExperts <= WarpSize must be unconditional because
fused shared-expert writes use laneIdx < mNumFusedSharedExperts regardless of
expert group count; move the FLASHINFER_CHECK(data.mNumFusedSharedExperts <=
WarpSize, ...) out of the if (data.mNumExpertGroups > 1) block so it always
runs, ensuring data.mNumFusedSharedExperts is validated before any code paths
that use mNumFusedSharedExperts/mNumFusedSharedExperts-induced lane comparisons
or writes (references: data.mNumFusedSharedExperts, WarpSize,
mNumFusedSharedExperts, mNumExpertGroups).

Comment on lines +102 to +107

/// For fused shared expert
int32_t mNumFusedSharedExperts;
int32_t mSharedExpertTokenOffset;
int32_t mSharedExpertNumTokens;
int32_t mTotalExpertsPerToken;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Initialize fused-shared expert metadata fields to safe defaults.

These new DataBase members are currently uninitialized, so any caller that forgets to set them will propagate garbage into kernel params and risk incorrect routing / OOB indexing. Mirror the KernelParamsBase defaults.

🛠️ Suggested fix
-  int32_t mNumFusedSharedExperts;
-  int32_t mSharedExpertTokenOffset;
-  int32_t mSharedExpertNumTokens;
-  int32_t mTotalExpertsPerToken;
+  int32_t mNumFusedSharedExperts{0};
+  int32_t mSharedExpertTokenOffset{0};
+  int32_t mSharedExpertNumTokens{0};
+  int32_t mTotalExpertsPerToken{0};
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
/// For fused shared expert
int32_t mNumFusedSharedExperts;
int32_t mSharedExpertTokenOffset;
int32_t mSharedExpertNumTokens;
int32_t mTotalExpertsPerToken;
/// For fused shared expert
int32_t mNumFusedSharedExperts{0};
int32_t mSharedExpertTokenOffset{0};
int32_t mSharedExpertNumTokens{0};
int32_t mTotalExpertsPerToken{0};
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@include/flashinfer/trtllm/fused_moe/RoutingKernel.h` around lines 102 - 107,
The new fused-shared expert members (mNumFusedSharedExperts,
mSharedExpertTokenOffset, mSharedExpertNumTokens, mTotalExpertsPerToken) are
uninitialized; initialize them to the same safe defaults used in
KernelParamsBase (e.g., zero) by adding default member initializers or setting
them in the DataBase constructor so callers that don't set them won't propagate
garbage into kernel params and cause routing/OOB errors.

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[FAILED] Pipeline #44669282: 13/20 passed

@nv-yunzheq nv-yunzheq force-pushed the DSR1_shared_expert_fusion branch from 259d279 to 2255bca Compare March 9, 2026 17:53
@nv-yunzheq
Copy link
Copy Markdown
Collaborator Author

/bot run

@nv-yunzheq nv-yunzheq marked this pull request as ready for review March 9, 2026 18:45
@nv-yunzheq nv-yunzheq requested a review from bkryu as a code owner March 9, 2026 18:45
@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !344 has been updated with latest changes, and the CI pipeline #45731067 is currently running. I'll report back once the pipeline job completes.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (3)
flashinfer/fused_moe/core.py (1)

1804-1836: ⚠️ Potential issue | 🟠 Major

Fake op signature mismatch: missing num_fused_shared_experts parameter.

The _fake_trtllm_fp8_block_scale_moe function signature must exactly mirror the real op trtllm_fp8_block_scale_moe_op. The real op has num_fused_shared_experts: int = 0 at line 1661, but the fake op is missing this parameter. This will cause issues with torch.compile or other tracing scenarios.

Suggested fix
 `@register_fake_op`("flashinfer::trtllm_fp8_block_scale_moe")
 def _fake_trtllm_fp8_block_scale_moe(
     routing_logits: Optional[torch.Tensor],
     topk_ids: Optional[torch.Tensor],
     expert_weights: Optional[torch.Tensor],
     routing_bias: Optional[torch.Tensor],
     hidden_states: torch.Tensor,
     hidden_states_scale: torch.Tensor,
     gemm1_weights: torch.Tensor,
     gemm1_weights_scale: torch.Tensor,
     gemm2_weights: torch.Tensor,
     gemm2_weights_scale: torch.Tensor,
     output: torch.Tensor,
     num_experts: int,
     top_k: int,
     n_group: Optional[int],
     topk_group: Optional[int],
     intermediate_size: int,
     local_expert_offset: int,
     local_num_experts: int,
     routed_scaling_factor: Optional[float],
     routing_method_type: int = 0,
     use_shuffled_weight: bool = False,
     weight_layout: int = 0,
     do_finalize: bool = True,
     enable_pdl: Optional[bool] = None,
     tune_max_num_tokens: int = 8192,
     fp8_quantization_type: Fp8QuantizationType = Fp8QuantizationType.DeepSeekFp8,
+    num_fused_shared_experts: int = 0,
 ) -> List[torch.Tensor]:

Based on learnings: "When reviewing files that define fake ops decorated with register_fake_op (e.g., in flashinfer/fused_moe/*), ensure the function signatures exactly mirror the real op they stand in for."

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/fused_moe/core.py` around lines 1804 - 1836, The fake op
_fake_trtllm_fp8_block_scale_moe must exactly mirror the real op signature
trtllm_fp8_block_scale_moe_op: add the missing parameter
num_fused_shared_experts: int = 0 to the fake function signature (position it
where the real op declares it) so tracing/torch.compile sees identical
parameters; update any callers or tests if they rely on positional args to
ensure compatibility.
csrc/trtllm_fused_moe_kernel_launcher.cu (2)

1857-1874: ⚠️ Potential issue | 🟠 Major

Validate num_fused_shared_experts before using it in size math.

The new FFI parameter is folded directly into totalExpertsPerToken and totalLocalExperts. A negative value can drive those counts to zero or below and break tile selection/workspace sizing before any lower-layer routing checks run.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/trtllm_fused_moe_kernel_launcher.cu` around lines 1857 - 1874, The code
uses the FFI parameter num_fused_shared_experts directly in size math
(totalExpertsPerToken, totalLocalExperts) which can be negative; validate and
clamp it before use (e.g., ensure num_fused_shared_experts >= 0 and fits
expected bounds) and reject or adjust invalid values; specifically,
check/convert num_fused_shared_experts (and the optional
num_fused_shared_experts.value_or(0)) to a non-negative int64_t before computing
totalExpertsPerToken and totalLocalExperts, and add a defensive check that
aborts or logs an error if the provided FFI value is out of acceptable range so
computeSelectedTileN and downstream launchers
(Fp8BlockScaleLauncher::getSupportedTileNums, computeSelectedTileN,
MoERunnerArgs) never see negative counts.

919-929: ⚠️ Potential issue | 🟠 Major

Precomputed routing tensors still use the old top_k width.

Only the internally allocated expert_weights buffer is widened to top_k + num_fused_shared_experts. If the caller provides precomputed expert_indices / expert_weights, this path still follows the old top_k contract elsewhere, so fused-shared precomputed routing will either reject correctly sized tensors or consume too few columns during finalize.

♻️ Duplicate comments (2)
csrc/trtllm_fused_moe_routing_deepseek.cu (2)

620-623: ⚠️ Potential issue | 🟡 Minor

mNumFusedSharedExperts <= WarpSize check should be unconditional.

This validation is guarded by if (data.mNumExpertGroups > 1) (line 609), but the fused shared expert writes at lines 261-265 and 272-274 use laneIdx < mNumFusedSharedExperts regardless of expert groups. If mNumExpertGroups <= 1 and mNumFusedSharedExperts > WarpSize, the writes would silently skip some fused experts.

Suggested fix

Move the check out of the if (data.mNumExpertGroups > 1) block:

+  FLASHINFER_CHECK(data.mNumFusedSharedExperts <= WarpSize,
+                   "Number of fused shared experts (%d) must be less than warp size.",
+                   data.mNumFusedSharedExperts);
+
   if (data.mNumExpertGroups > 1) {
     FLASHINFER_CHECK(data.mNumExpertGroups <= MaxNumGroups,
                      ...);
     ...
-
-    FLASHINFER_CHECK(data.mNumFusedSharedExperts <= WarpSize,
-                     "Number of fused shared experts (%d) must be less than warp size.",
-                     data.mNumFusedSharedExperts);
   }
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/trtllm_fused_moe_routing_deepseek.cu` around lines 620 - 623, The check
ensuring data.mNumFusedSharedExperts <= WarpSize must be unconditional: move the
FLASHINFER_CHECK(data.mNumFusedSharedExperts <= WarpSize, ...) out of the if
(data.mNumExpertGroups > 1) block so it always runs; this prevents laneIdx <
mNumFusedSharedExperts conditions (used in the fused shared expert writes) from
silently skipping experts when data.mNumExpertGroups <= 1 and
mNumFusedSharedExperts > WarpSize. Ensure the unique symbols FLASHINFER_CHECK,
data.mNumFusedSharedExperts, WarpSize, data.mNumExpertGroups, and laneIdx are
referenced when relocating the check.

666-676: ⚠️ Potential issue | 🟠 Major

Expert count histogram not initialized for fused shared expert indices.

The routingInitExpertCounts kernel (line 666-669) initializes 2 * data.mNumExperts elements using the pre-mutation value. After the kernel completes, data.mNumExperts is incremented at lines 673-675 to include mNumFusedSharedExperts. Subsequent kernels (lines 678+) use the mutated value but access uninitialized histogram slots for indices [original_mNumExperts, original_mNumExperts + mNumFusedSharedExperts).

This causes atomicAdd operations to accumulate into uninitialized values for fused shared expert slots.

Suggested fix

Either move the mutation before the histogram initialization or expand the initialization range:

+  if (data.mNumFusedSharedExperts > 0) {
+    data.mNumExperts += data.mNumFusedSharedExperts;
+    data.mTopK += data.mNumFusedSharedExperts;
+    data.mNumLocalExperts += data.mNumFusedSharedExperts;
+  }
+
   if (data.mPtrTopKIds == nullptr) {
     ...
   } else {
     // Reset the global histograms.
     LAUNCH_ROUTING_DEEPSEEK(data, false, routingInitExpertCounts,
                             (2 * data.mNumExperts - 1) / numThreadsHist + 1, numThreadsHist,
                             /*smemSize=*/0,
                             stream, data.mNumExpertGroups > 1);
   }

-  if (data.mNumFusedSharedExperts > 0) {
-    data.mNumExperts += data.mNumFusedSharedExperts;
-    data.mTopK += data.mNumFusedSharedExperts;
-    data.mNumLocalExperts += data.mNumFusedSharedExperts;
-  }
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/trtllm_fused_moe_routing_deepseek.cu` around lines 666 - 676, The
histogram for expert counts is initialized by the routingInitExpertCounts kernel
using the pre-mutation value of data.mNumExperts, but data.mNumExperts is then
increased by mNumFusedSharedExperts, leaving the new fused-shared slots
uninitialized; fix by either moving the mutation of data.mNumExperts (and
data.mTopK/data.mNumLocalExperts) to occur before the
LAUNCH_ROUTING_DEEPSEEK(...) call that invokes routingInitExpertCounts so the
kernel initializes the full range, or modify the initialization invocation to
cover (2 * (data.mNumExperts + data.mNumFusedSharedExperts)) elements (or
equivalent) so routingInitExpertCounts explicitly zeroes the fused-shared
indices; update references to routingInitExpertCounts, data.mNumExperts,
data.mNumFusedSharedExperts, and the LAUNCH_ROUTING_DEEPSEEK call accordingly.
🧹 Nitpick comments (1)
flashinfer/fused_moe/core.py (1)

1757-1759: Redundant None check for num_fused_shared_experts.

Since num_fused_shared_experts is typed as int = 0 at line 1661 (not Optional[int]), the None check on line 1759 is unnecessary. The parameter can never be None at this point.

Suggested simplification
-        _nfse = num_fused_shared_experts if num_fused_shared_experts is not None else 0
+        _nfse = num_fused_shared_experts
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/fused_moe/core.py` around lines 1757 - 1759, The assignment uses
an unnecessary None check for num_fused_shared_experts (typed as int with
default 0); simplify by removing the conditional and directly assign _nfse =
num_fused_shared_experts in the scope where num_fused_shared_experts is passed
(refer to the variables num_fused_shared_experts and _nfse in this
function/class), ensuring no Optional handling remains.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@csrc/trtllm_fused_moe_runner.cu`:
- Around line 106-117: The partition math wrongly assumes uniform shards —
replace the division-based computation (numDevices, deviceIndex derived from
numExperts / localNumExperts and localExpertOffset / localNumExperts) with logic
that computes device boundaries from actual per-rank expert counts: build the
cumulative expert-count prefix (using the actual localNumExperts for each
device/rank) to find the device index and the exact token-offset/length for
routingData.mSharedExpertTokenOffset and routingData.mSharedExpertNumTokens;
ensure you use numTokens scaled by each device's expert count slice (not simple
baseTokensPerDevice/remainingTokens across a uniform numDevices), and reference
localExpertOffset, localNumExperts, numExperts when mapping into the cumulative
ranges so uneven sharding yields correct offsets and lengths.

---

Outside diff comments:
In `@csrc/trtllm_fused_moe_kernel_launcher.cu`:
- Around line 1857-1874: The code uses the FFI parameter
num_fused_shared_experts directly in size math (totalExpertsPerToken,
totalLocalExperts) which can be negative; validate and clamp it before use
(e.g., ensure num_fused_shared_experts >= 0 and fits expected bounds) and reject
or adjust invalid values; specifically, check/convert num_fused_shared_experts
(and the optional num_fused_shared_experts.value_or(0)) to a non-negative
int64_t before computing totalExpertsPerToken and totalLocalExperts, and add a
defensive check that aborts or logs an error if the provided FFI value is out of
acceptable range so computeSelectedTileN and downstream launchers
(Fp8BlockScaleLauncher::getSupportedTileNums, computeSelectedTileN,
MoERunnerArgs) never see negative counts.

In `@flashinfer/fused_moe/core.py`:
- Around line 1804-1836: The fake op _fake_trtllm_fp8_block_scale_moe must
exactly mirror the real op signature trtllm_fp8_block_scale_moe_op: add the
missing parameter num_fused_shared_experts: int = 0 to the fake function
signature (position it where the real op declares it) so tracing/torch.compile
sees identical parameters; update any callers or tests if they rely on
positional args to ensure compatibility.

---

Duplicate comments:
In `@csrc/trtllm_fused_moe_routing_deepseek.cu`:
- Around line 620-623: The check ensuring data.mNumFusedSharedExperts <=
WarpSize must be unconditional: move the
FLASHINFER_CHECK(data.mNumFusedSharedExperts <= WarpSize, ...) out of the if
(data.mNumExpertGroups > 1) block so it always runs; this prevents laneIdx <
mNumFusedSharedExperts conditions (used in the fused shared expert writes) from
silently skipping experts when data.mNumExpertGroups <= 1 and
mNumFusedSharedExperts > WarpSize. Ensure the unique symbols FLASHINFER_CHECK,
data.mNumFusedSharedExperts, WarpSize, data.mNumExpertGroups, and laneIdx are
referenced when relocating the check.
- Around line 666-676: The histogram for expert counts is initialized by the
routingInitExpertCounts kernel using the pre-mutation value of data.mNumExperts,
but data.mNumExperts is then increased by mNumFusedSharedExperts, leaving the
new fused-shared slots uninitialized; fix by either moving the mutation of
data.mNumExperts (and data.mTopK/data.mNumLocalExperts) to occur before the
LAUNCH_ROUTING_DEEPSEEK(...) call that invokes routingInitExpertCounts so the
kernel initializes the full range, or modify the initialization invocation to
cover (2 * (data.mNumExperts + data.mNumFusedSharedExperts)) elements (or
equivalent) so routingInitExpertCounts explicitly zeroes the fused-shared
indices; update references to routingInitExpertCounts, data.mNumExperts,
data.mNumFusedSharedExperts, and the LAUNCH_ROUTING_DEEPSEEK call accordingly.

---

Nitpick comments:
In `@flashinfer/fused_moe/core.py`:
- Around line 1757-1759: The assignment uses an unnecessary None check for
num_fused_shared_experts (typed as int with default 0); simplify by removing the
conditional and directly assign _nfse = num_fused_shared_experts in the scope
where num_fused_shared_experts is passed (refer to the variables
num_fused_shared_experts and _nfse in this function/class), ensuring no Optional
handling remains.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: e7cc95b5-315d-4900-b7a7-8eb9b5984c8d

📥 Commits

Reviewing files that changed from the base of the PR and between 259d279 and 2255bca.

📒 Files selected for processing (7)
  • csrc/trtllm_fused_moe_kernel_launcher.cu
  • csrc/trtllm_fused_moe_routing_deepseek.cu
  • csrc/trtllm_fused_moe_runner.cu
  • flashinfer/fused_moe/core.py
  • include/flashinfer/trtllm/fused_moe/RoutingKernel.h
  • include/flashinfer/trtllm/fused_moe/runner.h
  • tests/moe/test_trtllm_gen_fused_moe.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • include/flashinfer/trtllm/fused_moe/RoutingKernel.h

Comment on lines +106 to +117
int32_t const numDevices = (localNumExperts > 0) ? numExperts / localNumExperts : 1;
int32_t const deviceIndex = (localNumExperts > 0) ? localExpertOffset / localNumExperts : 0;
int32_t const baseTokensPerDevice = numTokens / numDevices;
int32_t const remainingTokens = numTokens % numDevices;

if (deviceIndex < remainingTokens) {
routingData.mSharedExpertTokenOffset = (baseTokensPerDevice + 1) * deviceIndex;
routingData.mSharedExpertNumTokens = baseTokensPerDevice + 1;
} else {
routingData.mSharedExpertTokenOffset = remainingTokens + deviceIndex * baseTokensPerDevice;
routingData.mSharedExpertNumTokens = baseTokensPerDevice;
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Shared-expert token partition assumes uniform expert shards.

numDevices = numExperts / localNumExperts and deviceIndex = localExpertOffset / localNumExperts are only correct when every rank owns the same routed-expert count. The visible checks here only require localExpertOffset + localNumExperts <= numExperts, so uneven sharding will compute the wrong mSharedExpertTokenOffset/mSharedExpertNumTokens range and route fused shared experts against the wrong token slice.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/trtllm_fused_moe_runner.cu` around lines 106 - 117, The partition math
wrongly assumes uniform shards — replace the division-based computation
(numDevices, deviceIndex derived from numExperts / localNumExperts and
localExpertOffset / localNumExperts) with logic that computes device boundaries
from actual per-rank expert counts: build the cumulative expert-count prefix
(using the actual localNumExperts for each device/rank) to find the device index
and the exact token-offset/length for routingData.mSharedExpertTokenOffset and
routingData.mSharedExpertNumTokens; ensure you use numTokens scaled by each
device's expert count slice (not simple baseTokensPerDevice/remainingTokens
across a uniform numDevices), and reference localExpertOffset, localNumExperts,
numExperts when mapping into the cumulative ranges so uneven sharding yields
correct offsets and lengths.

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[FAILED] Pipeline #45731067: 8/20 passed

aleozlx and others added 2 commits April 14, 2026 13:57
Merged origin/main into nv-yunzheq/DSR1_shared_expert_fusion.

Resolved conflicts:
- csrc/trtllm_fused_moe_kernel_launcher.cu: merged signature (num_fused_shared_experts + act_type/norm_topk_prob), combined expert_weights allocation logic
- csrc/trtllm_fused_moe_runner.cu: merged signature, adopted routingCustom framework with shared expert guard
- csrc/trtllm_fused_moe_routing_deepseek.cu: git rm (deleted on main, needs manual port)
- flashinfer/fused_moe/core.py: merged all new params in signatures and call sites
- tests/moe/test_trtllm_gen_fused_moe.py: merged test param additions

Co-Authored-By: Claude <noreply@anthropic.com>
The PR's shared expert fusion changes were made to the old
csrc/trtllm_fused_moe_routing_deepseek.cu which was deleted on main
and relocated to csrc/fused_moe/trtllm_backend/. This commit ports
those changes to the new location:

- routingMainKernel: use mTotalExpertsPerToken stride for TopK output,
  write shared expert indices (weight=1.0) after routed experts
- run(): compute adjusted topK/numExperts accounting for shared experts,
  add bounds check (mNumFusedSharedExperts <= WarpSize),
  bump data.mNumExperts/mTopK/mNumLocalExperts post-routing so the
  permutation pipeline sees the full expanded expert set
- Adjust single-cluster threshold, maxTokensCoop, and expandedIdxSize
  to use the shared-expert-adjusted topK

Co-Authored-By: Claude <noreply@anthropic.com>
@aleozlx
Copy link
Copy Markdown
Collaborator

aleozlx commented Apr 14, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !344 has been updated with latest changes, and the CI pipeline #48530875 is currently running. I'll report back once the pipeline job completes.

@aleozlx
Copy link
Copy Markdown
Collaborator

aleozlx commented Apr 14, 2026

@nv-yunzheq would you check merge conflict resolution if possible

how about nvfp4, do you have a pointer?

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
flashinfer/fused_moe/core.py (2)

2767-2796: ⚠️ Potential issue | 🔴 Critical

Critical: Missing num_fused_shared_experts argument causes parameter misalignment.

The trtllm_fp8_block_scale_routed_moe function call is missing the num_fused_shared_experts argument at position 27. This causes all subsequent arguments to be passed to the wrong parameters:

  • activation_type (e.g., 3) is passed to num_fused_shared_experts
  • True is passed to activation_type (coerced to 1)
  • norm_topk_prob is not passed (uses default)

This will cause incorrect MoE computation when activation_type value is interpreted as a shared expert count.

🐛 Proposed fix to add missing argument
     result = get_trtllm_moe_sm100_module().trtllm_fp8_block_scale_moe(
         None,  # routing_logits
         topk_ids,
         None,  # expert_weights
         routing_bias,
         hidden_states,
         hidden_states_scale,
         gemm1_weights,
         gemm1_weights_scale,
         gemm2_weights,
         gemm2_weights_scale,
         output,
         num_experts,
         top_k,
         n_group,
         topk_group,
         intermediate_size,
         local_expert_offset,
         local_num_experts,
         routed_scaling_factor,
         routing_method_type,
         use_shuffled_weight,
         weight_layout,
         do_finalize,
         enable_pdl,
         tune_max_num_tokens,
         fp8_quantization_type,
+        0,  # num_fused_shared_experts: not supported for pre-routed MoE
         activation_type,
         True,  # norm_topk_prob: not used for pre-computed routing
     )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/fused_moe/core.py` around lines 2767 - 2796, The call to
get_trtllm_moe_sm100_module().trtllm_fp8_block_scale_moe is missing the
num_fused_shared_experts argument, causing parameter misalignment
(activation_type and norm_topk_prob are shifted); update the call to include the
proper num_fused_shared_experts value (same type/variable used elsewhere for
fused shared experts) inserted immediately before the activation_type argument
so subsequent parameters (activation_type, norm_topk_prob) line up with the
function signature.

1823-1853: ⚠️ Potential issue | 🔴 Critical

Critical: Fake op signature missing num_fused_shared_experts parameter.

The _fake_trtllm_fp8_block_scale_moe function is missing the num_fused_shared_experts parameter that was added to trtllm_fp8_block_scale_moe_op at line 1658. Fake ops must exactly mirror the real op signatures for torch.compile/tracing to work correctly.

🐛 Proposed fix to add missing parameter
 `@register_fake_op`("flashinfer::trtllm_fp8_block_scale_moe")
 def _fake_trtllm_fp8_block_scale_moe(
     routing_logits: Optional[torch.Tensor],
     topk_ids: Optional[torch.Tensor],
     expert_weights: Optional[torch.Tensor],
     routing_bias: Optional[torch.Tensor],
     hidden_states: torch.Tensor,
     hidden_states_scale: torch.Tensor,
     gemm1_weights: torch.Tensor,
     gemm1_weights_scale: torch.Tensor,
     gemm2_weights: torch.Tensor,
     gemm2_weights_scale: torch.Tensor,
     output: torch.Tensor,
     num_experts: int,
     top_k: int,
     n_group: Optional[int],
     topk_group: Optional[int],
     intermediate_size: int,
     local_expert_offset: int,
     local_num_experts: int,
     routed_scaling_factor: Optional[float],
     routing_method_type: int = 0,
     use_shuffled_weight: bool = False,
     weight_layout: int = 0,
     do_finalize: bool = True,
     enable_pdl: Optional[bool] = None,
     tune_max_num_tokens: int = 8192,
     fp8_quantization_type: Fp8QuantizationType = Fp8QuantizationType.DeepSeekFp8,
+    num_fused_shared_experts: int = 0,
     activation_type: int = ActivationType.Swiglu.value,
     norm_topk_prob: bool = True,
 ) -> List[torch.Tensor]:
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/fused_moe/core.py` around lines 1823 - 1853, The fake op
_fake_trtllm_fp8_block_scale_moe must include the new parameter
num_fused_shared_experts to exactly mirror the real op signature; update the
function signature to add num_fused_shared_experts with the same type and
default as trtllm_fp8_block_scale_moe_op (e.g., int or Optional[int] matching
the real op) and propagate that new parameter in the fake-op declaration so
torch.compile/tracing sees an identical signature (no additional logic changes
needed inside the function).
🧹 Nitpick comments (2)
flashinfer/fused_moe/core.py (1)

1774-1792: Redundant None check for num_fused_shared_experts.

At line 1776, _nfse = num_fused_shared_experts if num_fused_shared_experts is not None else 0 is redundant because num_fused_shared_experts is typed as int = 0 at line 1658 and can never be None. You can simplify to use the parameter directly.

♻️ Suggested simplification
         num_fused_shared_experts=num_fused_shared_experts,
     )
-    _nfse = num_fused_shared_experts if num_fused_shared_experts is not None else 0
     # Call the C++ function for block scale MoE
     intermediate_output = moe_op.trtllm_fp8_block_scale_moe(
         routing_logits,
         topk_ids,
         expert_weights,
         routing_bias,
         hidden_states,
         hidden_states_scale,
         gemm1_weights,
         gemm1_weights_scale,
         gemm2_weights,
         gemm2_weights_scale,
         output,
         num_experts,
         top_k,
-        _nfse,
+        num_fused_shared_experts,
         n_group,
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/fused_moe/core.py` around lines 1774 - 1792, The local variable
_nfse and its None-check are redundant because num_fused_shared_experts is
declared as an int (default 0); remove the _nfse assignment and pass
num_fused_shared_experts directly into the call to
moe_op.trtllm_fp8_block_scale_moe (replace the _nfse argument with
num_fused_shared_experts), and delete the unused _nfse variable to simplify
core.py around the block that constructs intermediate_output.
csrc/fused_moe/trtllm_backend/trtllm_fused_moe_routing_deepseek.cu (1)

555-560: Explain why the shared-expert expansion mutates Data in place.

This hot-path mutation is easy to misread because launchMainKernel() consumes the pre-fusion counts and the remaining launches consume the expanded counts from the same object. A brief note on why this was chosen instead of staging a copied Data / separate post-pass would make the trade-off much clearer.

As per coding guidelines, "For performance-critical hot paths, leave comments with justification for special algorithmic choices and mention alternative approaches considered".

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/fused_moe/trtllm_backend/trtllm_fused_moe_routing_deepseek.cu` around
lines 555 - 560, The in-place mutation of Data (incrementing data.mNumExperts,
data.mTopK, data.mNumLocalExperts when data.mNumFusedSharedExperts > 0) is
confusing because launchMainKernel() uses the pre-fusion counts while subsequent
pipeline stages expect expanded counts; update the source by adding a concise
comment immediately above this block explaining why we mutate Data in place for
performance (avoid copying Data or making a separate post-pass), note that
launchMainKernel() intentionally consumes the original counts and the
permutation pipeline requires the expanded counts, and briefly mention the
considered alternatives (cloning Data or staging a post-pass) and why they were
rejected for this hot path.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@csrc/fused_moe/trtllm_backend/trtllm_fused_moe_routing_deepseek.cu`:
- Around line 540-546: Compute and validate the fused expert total before using
it: after calculating numExperts = data.mNumExperts +
data.mNumFusedSharedExperts, add a guard (via FLASHINFER_CHECK or equivalent)
that numExperts <= MaxSupportedExpertCount (or handle getMaxNumExperts returning
0) so getMaxNumExperts(numExperts) cannot return 0; update the block around
numExperts/topK/numThreadsHist to validate the fused total (referencing
data.mNumExperts, data.mNumFusedSharedExperts, numExperts, getMaxNumExperts, and
MaxSupportedExpertCount) and ensure downstream logic that uses useSingleCluster
/ maxTokensCoop only runs for supported expert counts.
- Around line 345-358: Before emitting appended shared-expert IDs, clear the
corresponding mPtrExpertCounts slots so stale device memory can't corrupt later
histogram/prefix-scan; in the same kernel code path that writes packed/shared
entries (where laneIdx, idxShared, params.mNumFusedSharedExperts,
params.mNumExperts, params.mPtrTopKPacked, params.mPtrTopKWeights are used) set
params.mPtrExpertCounts[params.mNumExperts + laneIdx] = 0 (guarded by laneIdx <
params.mNumFusedSharedExperts) so the extra expert-count/offset slots are zeroed
on-device prior to writing the shared-expert outputs (this complements the
earlier routingMainKernel initialization that only clears 2 * params.mNumExperts
entries).

---

Outside diff comments:
In `@flashinfer/fused_moe/core.py`:
- Around line 2767-2796: The call to
get_trtllm_moe_sm100_module().trtllm_fp8_block_scale_moe is missing the
num_fused_shared_experts argument, causing parameter misalignment
(activation_type and norm_topk_prob are shifted); update the call to include the
proper num_fused_shared_experts value (same type/variable used elsewhere for
fused shared experts) inserted immediately before the activation_type argument
so subsequent parameters (activation_type, norm_topk_prob) line up with the
function signature.
- Around line 1823-1853: The fake op _fake_trtllm_fp8_block_scale_moe must
include the new parameter num_fused_shared_experts to exactly mirror the real op
signature; update the function signature to add num_fused_shared_experts with
the same type and default as trtllm_fp8_block_scale_moe_op (e.g., int or
Optional[int] matching the real op) and propagate that new parameter in the
fake-op declaration so torch.compile/tracing sees an identical signature (no
additional logic changes needed inside the function).

---

Nitpick comments:
In `@csrc/fused_moe/trtllm_backend/trtllm_fused_moe_routing_deepseek.cu`:
- Around line 555-560: The in-place mutation of Data (incrementing
data.mNumExperts, data.mTopK, data.mNumLocalExperts when
data.mNumFusedSharedExperts > 0) is confusing because launchMainKernel() uses
the pre-fusion counts while subsequent pipeline stages expect expanded counts;
update the source by adding a concise comment immediately above this block
explaining why we mutate Data in place for performance (avoid copying Data or
making a separate post-pass), note that launchMainKernel() intentionally
consumes the original counts and the permutation pipeline requires the expanded
counts, and briefly mention the considered alternatives (cloning Data or staging
a post-pass) and why they were rejected for this hot path.

In `@flashinfer/fused_moe/core.py`:
- Around line 1774-1792: The local variable _nfse and its None-check are
redundant because num_fused_shared_experts is declared as an int (default 0);
remove the _nfse assignment and pass num_fused_shared_experts directly into the
call to moe_op.trtllm_fp8_block_scale_moe (replace the _nfse argument with
num_fused_shared_experts), and delete the unused _nfse variable to simplify
core.py around the block that constructs intermediate_output.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 740993a3-4d67-4790-8f5f-7226f67ad2dd

📥 Commits

Reviewing files that changed from the base of the PR and between 95a01d0 and b06cef5.

📒 Files selected for processing (5)
  • csrc/fused_moe/trtllm_backend/trtllm_fused_moe_routing_deepseek.cu
  • csrc/moe_utils_binding.cu
  • csrc/trtllm_fused_moe_kernel_launcher.cu
  • csrc/trtllm_fused_moe_runner.cu
  • flashinfer/fused_moe/core.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • csrc/trtllm_fused_moe_kernel_launcher.cu
  • csrc/trtllm_fused_moe_runner.cu

Comment on lines +345 to +358
if (laneIdx < params.mNumFusedSharedExperts && params.mPtrTopKPacked != nullptr) {
PackedScoreIdx<OutputT> packedScore{static_cast<OutputT>(1.0F),
static_cast<int16_t>(params.mNumExperts + laneIdx)};
params.mPtrTopKPacked[idxShared] = packedScore;
}

if (laneIdx < params.mTopK && params.mPtrTopKWeights != nullptr &&
params.mPtrTopKIds == nullptr) {
params.mPtrTopKWeights[idxTopK] = finalScore;
}

if (laneIdx < params.mNumFusedSharedExperts && params.mPtrTopKWeights != nullptr) {
params.mPtrTopKWeights[idxShared] = static_cast<OutputT>(1.0F);
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Zero the extra mPtrExpertCounts slots before emitting shared experts.

These appended expert IDs flow through the same non-single-cluster permutation path as routed experts, but routingMainKernel still clears only 2 * params.mNumExperts entries in mPtrExpertCounts (Lines 165-170). The shared-expert count/offset slots therefore start from stale device memory and can corrupt the histogram/prefix-scan for large-token runs.

🐛 Companion fix in the earlier expert-count initialization
   if (params.mPtrExpertCounts) {
     int32_t globalThreadIdx = blockIdx.x * blockDim.x + threadIdx.x;
     int32_t globalThreadStride = gridDim.x * blockDim.x;
-    int32_t expertCountsNum = 2 * params.mNumExperts;
+    int32_t expertCountsNum = 2 * (params.mNumExperts + params.mNumFusedSharedExperts);
     initArr(globalThreadIdx, expertCountsNum, globalThreadStride, params.mPtrExpertCounts, 0);
   }
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/fused_moe/trtllm_backend/trtllm_fused_moe_routing_deepseek.cu` around
lines 345 - 358, Before emitting appended shared-expert IDs, clear the
corresponding mPtrExpertCounts slots so stale device memory can't corrupt later
histogram/prefix-scan; in the same kernel code path that writes packed/shared
entries (where laneIdx, idxShared, params.mNumFusedSharedExperts,
params.mNumExperts, params.mPtrTopKPacked, params.mPtrTopKWeights are used) set
params.mPtrExpertCounts[params.mNumExperts + laneIdx] = 0 (guarded by laneIdx <
params.mNumFusedSharedExperts) so the extra expert-count/offset slots are zeroed
on-device prior to writing the shared-expert outputs (this complements the
earlier routingMainKernel initialization that only clears 2 * params.mNumExperts
entries).

Comment on lines +540 to +546
int const numExperts = data.mNumExperts + data.mNumFusedSharedExperts;
int const topK = data.mTopK + data.mNumFusedSharedExperts;
int const numThreadsHist = getMaxNumExperts(numExperts);

FLASHINFER_CHECK(topK <= MaxSupportedTopExperts,
"Routing kernel expects topK experts <= %d, got %d", MaxSupportedTopExperts,
topK);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Guard the fused total against the DeepSeek dispatch ceiling.

Only the pre-fusion data.mNumExperts is validated against MaxSupportedExpertCount. If data.mNumExperts + data.mNumFusedSharedExperts exceeds 512, getMaxNumExperts(numExperts) returns 0 here, which breaks the later useSingleCluster / maxTokensCoop math and leaves the post-topK launches on unsupported expert tiers.

🛡️ Minimal guard before using the fused-aware total
   int const numExperts = data.mNumExperts + data.mNumFusedSharedExperts;
   int const topK = data.mTopK + data.mNumFusedSharedExperts;
+  if (data.mPtrPermutedIdxSize != nullptr) {
+    FLASHINFER_CHECK(numExperts <= MaxSupportedExpertCount,
+                     "Permutation pipeline supports at most %d total experts, got %d",
+                     MaxSupportedExpertCount, numExperts);
+  }
   int const numThreadsHist = getMaxNumExperts(numExperts);
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/fused_moe/trtllm_backend/trtllm_fused_moe_routing_deepseek.cu` around
lines 540 - 546, Compute and validate the fused expert total before using it:
after calculating numExperts = data.mNumExperts + data.mNumFusedSharedExperts,
add a guard (via FLASHINFER_CHECK or equivalent) that numExperts <=
MaxSupportedExpertCount (or handle getMaxNumExperts returning 0) so
getMaxNumExperts(numExperts) cannot return 0; update the block around
numExperts/topK/numThreadsHist to validate the fused total (referencing
data.mNumExperts, data.mNumFusedSharedExperts, numExperts, getMaxNumExperts, and
MaxSupportedExpertCount) and ensure downstream logic that uses useSingleCluster
/ maxTokensCoop only runs for supported expert counts.

Fix function signature line wrapping in trtllm_fp8_block_scale_moe()
to conform to the project's clang-format configuration.

Co-Authored-By: Claude <noreply@anthropic.com>
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@csrc/trtllm_fused_moe_kernel_launcher.cu`:
- Around line 1970-1972: The tile/config discovery is still using unfused
dimensions; update resolveMoeTileAndConfig,
Fp8BlockScaleLauncher::getValidConfigs, and trtllm_get_valid_moe_configs to use
the fused totals (totalExpertsPerToken and totalLocalExperts) and the stored
args->num_fused_shared_experts instead of raw top_k and local_num_experts;
thread these fused totals through any calls that compute or cache tiles/configs,
update the fallback resolution path to bucket on the fused totals, and ensure
the exported trtllm_get_valid_moe_configs forwards the fused totals so autotuned
configs match prepare_moe_common() validation.
- Around line 1195-1198: The precomputed-routing path is unsafe with fused
shared experts because the runner expects top_k + num_fused_shared_experts
columns but precomputed validation only checks against top_k; update the call
site that uses use_precomputed, args->num_fused_shared_experts,
args->routing_logits and workspace.routing_expert_indexes to either (a) reject
precomputed routing when args->num_fused_shared_experts > 0 by returning an
error/setting use_precomputed=false, or (b) validate both
workspace.routing_expert_indexes and args->routing_logits widths against
(args->top_k + args->num_fused_shared_experts) before invoking the routing
runner and abort if they are smaller; implement one of these fixes where the
call is made so precomputed buffers cannot be indexed past their last column.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 90b0079d-8f7e-423e-b8c8-af4c05ac683b

📥 Commits

Reviewing files that changed from the base of the PR and between b06cef5 and 44ca3f5.

📒 Files selected for processing (1)
  • csrc/trtllm_fused_moe_kernel_launcher.cu

Comment on lines 1195 to +1198
use_precomputed ? nullptr : args->routing_logits, args->routing_bias, args->num_tokens,
args->num_experts, args->top_k, args->n_group, args->topk_group, args->local_expert_offset,
args->local_num_experts, args->routed_scaling_factor, workspace.routing_expert_indexes,
args->num_experts, args->top_k, args->num_fused_shared_experts, args->n_group,
args->topk_group, args->local_expert_offset, args->local_num_experts,
args->routed_scaling_factor, workspace.routing_expert_indexes,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Update the precomputed-routing contract for fused shared experts.

This call now tells the routing runner to consume top_k + num_fused_shared_experts, but the precomputed path still validates expert_indices against top_k and never validates precomputed expert_weights. With fused shared experts enabled, top-k-wide precomputed buffers will be misinterpreted and can be indexed past their last column. Either reject precomputed routing when num_fused_shared_experts > 0, or validate both tensors against the fused width before calling the runner.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/trtllm_fused_moe_kernel_launcher.cu` around lines 1195 - 1198, The
precomputed-routing path is unsafe with fused shared experts because the runner
expects top_k + num_fused_shared_experts columns but precomputed validation only
checks against top_k; update the call site that uses use_precomputed,
args->num_fused_shared_experts, args->routing_logits and
workspace.routing_expert_indexes to either (a) reject precomputed routing when
args->num_fused_shared_experts > 0 by returning an error/setting
use_precomputed=false, or (b) validate both workspace.routing_expert_indexes and
args->routing_logits widths against (args->top_k +
args->num_fused_shared_experts) before invoking the routing runner and abort if
they are smaller; implement one of these fixes where the call is made so
precomputed buffers cannot be indexed past their last column.

Comment on lines +1970 to +1972
int64_t const nFusedShared = num_fused_shared_experts.value_or(0);
int64_t const totalExpertsPerToken = top_k + nFusedShared;
int64_t const totalLocalExperts = local_num_experts + nFusedShared;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Make tile/config discovery fused-aware as well.

You compute fused totals here and store args->num_fused_shared_experts, but fallback tile resolution still buckets on raw top_k and local_num_experts. That leaves resolveMoeTileAndConfig(...), Fp8BlockScaleLauncher::getValidConfigs(...), and the exported trtllm_get_valid_moe_configs(...) surface describing a different problem shape than prepare_moe_common() validates, so cached tactics or the [-1, -1] fallback can become invalid for fused-shared runs.

Possible direction
-  auto const [tile_N, config] = resolveMoeTileAndConfig(config_index, supported_tile_nums,
-                                                        num_tokens, top_k, local_num_experts);
+  auto const [tile_N, config] = resolveMoeTileAndConfig(config_index, supported_tile_nums,
+                                                        num_tokens, totalExpertsPerToken,
+                                                        totalLocalExperts);

Fp8BlockScaleLauncher::getValidConfigs(...) and trtllm_get_valid_moe_configs(...) need the same fused totals threaded through, otherwise autotuned configs will still be generated against the unfused dimensions.

Also applies to: 1985-1985, 2010-2011

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/trtllm_fused_moe_kernel_launcher.cu` around lines 1970 - 1972, The
tile/config discovery is still using unfused dimensions; update
resolveMoeTileAndConfig, Fp8BlockScaleLauncher::getValidConfigs, and
trtllm_get_valid_moe_configs to use the fused totals (totalExpertsPerToken and
totalLocalExperts) and the stored args->num_fused_shared_experts instead of raw
top_k and local_num_experts; thread these fused totals through any calls that
compute or cache tiles/configs, update the fallback resolution path to bucket on
the fused totals, and ensure the exported trtllm_get_valid_moe_configs forwards
the fused totals so autotuned configs match prepare_moe_common() validation.

…eplay_out)

Merged origin/main into nv-yunzheq/DSR1_shared_expert_fusion.
New conflicts from PR flashinfer-ai#3024 (routing_replay_out support) resolved by
keeping both PR's shared-expert fields and main's routing replay fields.

Co-Authored-By: Claude <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants