Skip to content

Add support for Relu2 in BF16 fused MoE#2864

Merged
aleozlx merged 7 commits intoflashinfer-ai:mainfrom
amitz-nv:support-bf16-relu2-trtllm-gen-fused-moe-kernel
Apr 13, 2026
Merged

Add support for Relu2 in BF16 fused MoE#2864
aleozlx merged 7 commits intoflashinfer-ai:mainfrom
amitz-nv:support-bf16-relu2-trtllm-gen-fused-moe-kernel

Conversation

@amitz-nv
Copy link
Copy Markdown
Contributor

@amitz-nv amitz-nv commented Mar 23, 2026

📌 Description

  • Added support for Relu2 non-gated activation in BF16 Fused MoE by adding activation_type to external API:
    • trtllm_bf16_moe
    • trtllm_bf16_routed_moe
    • Bf16MoeLauncher::init
  • Updated trtllm-gen batched GEMM kernels
  • Updated tests/moe/test_trtllm_gen_fused_moe.py::test_deepseekv3_routing to include BF16 with Nemotron config, fixed nemotron config intermediate_size test param to match Nemotron 3 Super.
  • Fixed import issues found by pre-commit run --all-files
  • Required change from trtllm-gen batched GEMM update: Changed options.mNumStages == 4 to options.mNumStagesA == 4 && options.mNumStagesB == 4 in prioritizePredefinedConfigs function in csrc/trtllm_batched_gemm_runner.cu.

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • MoE APIs now accept a validated runtime activation_type, enabling selectable activation functions for BF16 and FP8 inference.
  • Tests

    • Expanded DeepSeekV3 routing tests and added BF16 to non-gated activation coverage.
    • Updated test parameters to reflect new compatibility.
  • Bug Fixes

    • Adjusted kernel configuration prioritization for a specific corner-case path.
  • Refactor

    • Internal enum imports reorganized to a shared enums module.
  • Chores

    • Updated batched GEMM artifact path and checksum.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Mar 23, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Activation type is now runtime-configurable across the BF16 MoE stack: Python APIs accept an integer activation_type which is validated and propagated through the C++ entrypoint into the Bf16MoeLauncher and kernel init; tests and weight-prep logic were adjusted for gated vs non-gated activations.

Changes

Cohort / File(s) Summary
CUDA Kernel Launcher
csrc/trtllm_fused_moe_kernel_launcher.cu
Bf16MoeLauncher::init now accepts ActivationType activation_type; exported trtllm_bf16_moe(...) gains trailing int64_t activation_type, validated via validateAndCastActivationType(...) and forwarded. FP8 per-tensor entrypoint also uses validation.
Python MoE Core & Module
flashinfer/fused_moe/core.py, flashinfer/fused_moe/__init__.py
trtllm_bf16_moe and trtllm_bf16_routed_moe gained activation_type: int (default ActivationType.Swiglu.value); docstrings and weight-shape notes updated; explicit enum imports moved to ..tllm_enums.
Tests / Test Utils
tests/moe/test_trtllm_gen_fused_moe.py, tests/moe/utils.py
Tests thread activation_type through call_moe to BF16 op; weight prep uses is_gated_activation(args.activation_type) for permute-index caching; added QuantMode.BF16 to supported modes; DeepSeekV3 params adjusted.
Batched GEMM Configs
csrc/trtllm_batched_gemm_runner.cu
Corner-case (n==0 && k==0) prioritization now requires mNumStagesA == 4 and mNumStagesB == 4 (was mNumStages == 4), changing prioritized config ordering for that branch.
Artifacts / Checksums
flashinfer/artifacts.py
Updated ArtifactPath.TRTLLM_GEN_BMM directory hash and corresponding CheckSumHash.TRTLLM_GEN_BMM SHA-256 string.

Sequence Diagram(s)

sequenceDiagram
    participant Py as Python API
    participant Core as flashinfer.fused_moe.core
    participant Cpp as trtllm entrypoint (csrc)
    participant Launcher as Bf16MoeLauncher
    participant Kernel as TRT-LLM kernel

    Py->>Core: trtllm_bf16_moe(..., activation_type=int)
    Core->>Core: validateAndCastActivationType(int) -> ActivationType
    Core->>Cpp: trtllm_bf16_moe(..., activation_type)
    Cpp->>Launcher: init(..., activation_type)
    Launcher->>Kernel: init_common(..., activation_type / isGatedActivation)
    Kernel-->>Launcher: configured
    Launcher-->>Cpp: ready to launch
    Cpp-->>Core: return results
    Core-->>Py: return Array<Tensor>
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~40 minutes

Possibly related PRs

Suggested labels

run-ci, op: moe-routing

Suggested reviewers

  • yzh119
  • cyx-6
  • bkryu
  • jimmyzho
  • jiahanc
  • nv-yunzheq
  • aleozlx
  • yongwww
  • sricketts

Poem

🐰 Hop, hop — activation picked on the fly,
From Python to CUDA I nudge and pry,
Gates learn to dance and weights rearrange,
Runtime choice now sets the MoE range,
A rabbit applauds this nimble change!

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 40.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title 'Add support for Relu2 in BF16 fused MoE' is clear, concise, and directly describes the main change in the pull request.
Description check ✅ Passed The description covers key changes and includes pre-commit and test checklist items, though overall test pass status is not explicitly confirmed.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request enhances the BF16 Fused Mixture-of-Experts (MoE) functionality by integrating support for the Relu2 activation function. The changes involve extending the core C++ kernel and its Python bindings to allow specifying the activation type, moving beyond a fixed activation. This provides greater flexibility for model architectures utilizing BF16 MoE and is accompanied by updated test cases to confirm the new activation's behavior.

Highlights

  • Relu2 Activation Support: Added support for the Relu2 non-gated activation function within the BF16 Fused Mixture-of-Experts (MoE) implementation.
  • API and Kernel Updates: Modified the C++ kernel launcher and Python API to accept and propagate an activation_type parameter, allowing dynamic selection of activation functions.
  • Test Coverage Expansion: Updated existing tests to include validation for the new Relu2 activation, specifically with DeepSeek routing, ensuring correct functionality.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces dynamic activation function selection for BF16 Mixture-of-Experts (MoE) operations. Previously, the activation type was hardcoded to Swiglu. The changes involve modifying C++ kernel launcher signatures and implementations to accept an ActivationType parameter, propagating this parameter through the Python frontend functions, and updating test cases to reflect and validate this new configurability. Test configurations for specific models and intermediate sizes were also adjusted, and BF16 was added to the list of supported quantization modes in test utilities. I have no feedback to provide as there were no review comments.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (3)
csrc/trtllm_fused_moe_kernel_launcher.cu (1)

1681-1697: ⚠️ Potential issue | 🟠 Major

Validate activation_type before the BF16 cast.

Line 1697 bypasses the new validateAndCastActivationType() helper and feeds unchecked values into isGatedActivation() / Runner. For a public int64_t FFI parameter, bad inputs should fail here with a deterministic ICHECK, not later inside runner setup.

Suggested fix
-  auto const activation = static_cast<ActivationType>(activation_type);
+  auto const activation = validateAndCastActivationType(activation_type);
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/trtllm_fused_moe_kernel_launcher.cu` around lines 1681 - 1697, The
function currently casts the public int64_t activation_type directly via
static_cast<ActivationType> and proceeds, which can allow invalid values into
isGatedActivation() and Runner; replace that cast with a call to
validateAndCastActivationType(activation_type) before any use so the value is
deterministically checked (ICHECK) and converted; update all subsequent
references that use activation (and any branching like
isGatedActivation(activation)) to use the validated result; ensure
validateAndCastActivationType is called in this function before any Runner
construction or gated-activation checks.
tests/moe/test_trtllm_gen_fused_moe.py (1)

1439-1443: ⚠️ Potential issue | 🟠 Major

The new gated/non-gated flag is still aliased by the permute-index cache.

Line 1443 passes is_gated_act_gemm, but _maybe_get_cached_w3_w1_permute_indices() still memoizes only on ("w3_w1", dst_w3_w1_weight.shape) in flashinfer/fused_moe/core.py. Since cache_permute_indices is module-scoped, a gated BF16 case can poison a later Relu2 case with the same viewed shape, making this coverage order-dependent and permuting FC1 rows incorrectly.

Possible fix in flashinfer/fused_moe/core.py
-    cache_key = ("w3_w1", dst_w3_w1_weight.shape)
+    cache_key = (
+        "w3_w1",
+        dst_w3_w1_weight.shape,
+        epilogue_tile_m,
+        num_elts_per_sf,
+        is_gated_act_gemm,
+    )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/moe/test_trtllm_gen_fused_moe.py` around lines 1439 - 1443, The
permute-index cache (_maybe_get_cached_w3_w1_permute_indices) is currently keyed
only by ("w3_w1", dst_w3_w1_weight.shape) so a cached entry from a gated BF16
case can be reused for a non-gated case; update the cache key in
flashinfer/fused_moe/core.py to include the gated flag (is_gated_act_gemm) or
the activation type so the memoization distinguishes gated vs non-gated variants
(e.g., include is_gated_act_gemm in the tuple key when reading/writing
cache_permute_indices) to prevent cross-contamination.
flashinfer/fused_moe/core.py (1)

1323-1350: ⚠️ Potential issue | 🟡 Minor

Pre-existing signature mismatch in fake op.

The activation_type addition (line 1345) is correct. However, the fake op signature is missing routed_scaling_factor: Optional[float] between local_num_experts and routing_method_type compared to the real op at lines 1190-1213.

This pre-existing mismatch should be addressed to ensure the fake op mirrors the real op exactly.

🔧 Proposed fix to add missing parameter
     local_expert_offset: int,
     local_num_experts: int,
+    routed_scaling_factor: Optional[float],
     routing_method_type: int,
     use_shuffled_weight: bool,

Based on learnings: "When reviewing files that define fake ops decorated with register_fake_op, ensure the function signatures exactly mirror the real op they stand in for."

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/fused_moe/core.py` around lines 1323 - 1350, The fake op
_fake_trtllm_bf16_moe has a signature mismatch: add the missing parameter
routed_scaling_factor: Optional[float] (default None) between local_num_experts
and routing_method_type so the fake op exactly mirrors the real op signature;
include the parameter in the function signature (but it can remain unused) and
keep the activation_type and other params unchanged to ensure parity with the
real operator.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Outside diff comments:
In `@csrc/trtllm_fused_moe_kernel_launcher.cu`:
- Around line 1681-1697: The function currently casts the public int64_t
activation_type directly via static_cast<ActivationType> and proceeds, which can
allow invalid values into isGatedActivation() and Runner; replace that cast with
a call to validateAndCastActivationType(activation_type) before any use so the
value is deterministically checked (ICHECK) and converted; update all subsequent
references that use activation (and any branching like
isGatedActivation(activation)) to use the validated result; ensure
validateAndCastActivationType is called in this function before any Runner
construction or gated-activation checks.

In `@flashinfer/fused_moe/core.py`:
- Around line 1323-1350: The fake op _fake_trtllm_bf16_moe has a signature
mismatch: add the missing parameter routed_scaling_factor: Optional[float]
(default None) between local_num_experts and routing_method_type so the fake op
exactly mirrors the real op signature; include the parameter in the function
signature (but it can remain unused) and keep the activation_type and other
params unchanged to ensure parity with the real operator.

In `@tests/moe/test_trtllm_gen_fused_moe.py`:
- Around line 1439-1443: The permute-index cache
(_maybe_get_cached_w3_w1_permute_indices) is currently keyed only by ("w3_w1",
dst_w3_w1_weight.shape) so a cached entry from a gated BF16 case can be reused
for a non-gated case; update the cache key in flashinfer/fused_moe/core.py to
include the gated flag (is_gated_act_gemm) or the activation type so the
memoization distinguishes gated vs non-gated variants (e.g., include
is_gated_act_gemm in the tuple key when reading/writing cache_permute_indices)
to prevent cross-contamination.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 5a59e148-d201-4efe-bf55-ea14b1ac3535

📥 Commits

Reviewing files that changed from the base of the PR and between 27cae50 and f3dae20.

📒 Files selected for processing (5)
  • csrc/trtllm_fused_moe_kernel_launcher.cu
  • flashinfer/fused_moe/__init__.py
  • flashinfer/fused_moe/core.py
  • tests/moe/test_trtllm_gen_fused_moe.py
  • tests/moe/utils.py

@amitz-nv amitz-nv force-pushed the support-bf16-relu2-trtllm-gen-fused-moe-kernel branch from f3dae20 to 62e38fd Compare March 24, 2026 16:53
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tests/moe/test_trtllm_gen_fused_moe.py (1)

1439-1443: ⚠️ Potential issue | 🟠 Major

Include is_gated_act_gemm in the permute-cache key.

Passing the flag here still reuses whatever _maybe_get_cached_w3_w1_permute_indices() cached first, because the helper currently keys only on ("w3_w1", shape). With the module-scoped cache_permute_indices fixture, gated and non-gated cases that collapse to the same view(torch.uint8) shape can therefore reuse the wrong row order, so the BF16 shuffle becomes test-order dependent.

Please fix this in flashinfer/fused_moe/core.py by keying the cache on the activation mode as well, instead of only passing the flag at the call site.

Suggested helper-side fix
-    cache_key = ("w3_w1", dst_w3_w1_weight.shape)
+    cache_key = (
+        "w3_w1",
+        dst_w3_w1_weight.shape,
+        epilogue_tile_m,
+        num_elts_per_sf,
+        is_gated_act_gemm,
+    )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/moe/test_trtllm_gen_fused_moe.py` around lines 1439 - 1443, The cached
permute indices helper _maybe_get_cached_w3_w1_permute_indices currently keys
only on ("w3_w1", shape) which allows gated and non-gated tensors with identical
uint8 views to collide; change the helper to include the is_gated_act_gemm
boolean in the cache key (e.g., ("w3_w1", shape, is_gated_act_gemm)) and update
any cache lookups/insertions that use cache_permute_indices so gated and
non-gated cases store and retrieve distinct entries while leaving the call sites
(which already pass is_gated_act_gemm) unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@csrc/trtllm_fused_moe_kernel_launcher.cu`:
- Around line 1758-1761: trtllm_fp8_per_tensor_scale_moe currently accepts any
activation enum while trtllm_get_valid_moe_configs and
Fp8PerTensorLauncher::prepare_moe assume the gated FC1/gate-scale layout; make
them consistent by enforcing the gated-only contract at the entry point: after
calling validateAndCastActivationType(activation_type) in
trtllm_fp8_per_tensor_scale_moe, check that the returned activation is one of
the gated activation variants used by Fp8PerTensorLauncher::prepare_moe (reject
non-gated enums) and return an error (or throw) if not; alternatively, if you
prefer to permit non-gated activations, update trtllm_get_valid_moe_configs and
Fp8PerTensorLauncher::prepare_moe to accept the non-gated layout—but pick one
approach and apply it consistently across trtllm_fp8_per_tensor_scale_moe,
trtllm_get_valid_moe_configs, and Fp8PerTensorLauncher::prepare_moe so both
autotune and direct execution advertise the same activation contract.

---

Outside diff comments:
In `@tests/moe/test_trtllm_gen_fused_moe.py`:
- Around line 1439-1443: The cached permute indices helper
_maybe_get_cached_w3_w1_permute_indices currently keys only on ("w3_w1", shape)
which allows gated and non-gated tensors with identical uint8 views to collide;
change the helper to include the is_gated_act_gemm boolean in the cache key
(e.g., ("w3_w1", shape, is_gated_act_gemm)) and update any cache
lookups/insertions that use cache_permute_indices so gated and non-gated cases
store and retrieve distinct entries while leaving the call sites (which already
pass is_gated_act_gemm) unchanged.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: d7102610-1265-42b5-add1-237337333f3c

📥 Commits

Reviewing files that changed from the base of the PR and between f3dae20 and 62e38fd.

📒 Files selected for processing (5)
  • csrc/trtllm_fused_moe_kernel_launcher.cu
  • flashinfer/fused_moe/__init__.py
  • flashinfer/fused_moe/core.py
  • tests/moe/test_trtllm_gen_fused_moe.py
  • tests/moe/utils.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • tests/moe/utils.py
  • flashinfer/fused_moe/core.py

Comment thread csrc/trtllm_fused_moe_kernel_launcher.cu Outdated
@amitz-nv amitz-nv force-pushed the support-bf16-relu2-trtllm-gen-fused-moe-kernel branch from 62e38fd to c397011 Compare March 30, 2026 09:29
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tests/moe/test_trtllm_gen_fused_moe.py (1)

1439-1443: ⚠️ Potential issue | 🔴 Critical

Fix cache key collision: include is_gated_act_gemm in the permutation cache key.

The cache at _maybe_get_cached_w3_w1_permute_indices() (flashinfer/fused_moe/core.py line 108) uses only ("w3_w1", weight_shape) as the key and does not include is_gated_act_gemm. When the BF16 test harness now passes is_gated_act_gemm=is_gated_activation(args.activation_type), test cases with different activation types (e.g., Relu2 vs. Swiglu) but identical weight shapes will collide in the module-scoped _cache_permute_indices fixture, causing the wrong permutation to be reused. The DeepSeekV3 test matrix includes both nemotron_3_super (Relu2, non-gated) and kimi_k2 (Swiglu/Geglu, gated) variants that can trigger this collision. Update the cache key to include is_gated_act_gemm: cache_key = ("w3_w1", dst_w3_w1_weight.shape, is_gated_act_gemm).

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/moe/test_trtllm_gen_fused_moe.py` around lines 1439 - 1443, The
permutation cache key in _maybe_get_cached_w3_w1_permute_indices currently only
uses ("w3_w1", weight_shape) which causes collisions when the same weight shape
appears with different activation gating; update the cache key to include the
is_gated_act_gemm flag so different activation types don't share the same cached
permutation (use a key like ("w3_w1", dst_w3_w1_weight.shape,
is_gated_act_gemm)). Locate the function _maybe_get_cached_w3_w1_permute_indices
in flashinfer/fused_moe/core.py and modify both the cache lookup and cache store
to include is_gated_act_gemm in the tuple key, ensuring callers (e.g., where
_maybe_get_cached_w3_w1_permute_indices is invoked) pass the is_gated_act_gemm
argument through.
♻️ Duplicate comments (1)
csrc/trtllm_fused_moe_kernel_launcher.cu (1)

1750-1761: ⚠️ Potential issue | 🟠 Major

FP8 per-tensor activation contract is still inconsistent.

trtllm_fp8_per_tensor_scale_moe() now accepts any valid activation_type, but trtllm_get_valid_moe_configs() later in this file still rejects non-gated per-tensor activations. The new FP8PerTensorMoe + Relu2 DeepSeekV3 matrix will hit that mismatch as soon as autotune asks for valid configs. Either reject non-gated activations here too, or lift the gated-only restriction in the valid-config/workspace path as well.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/trtllm_fused_moe_kernel_launcher.cu` around lines 1750 - 1761, The
function trtllm_fp8_per_tensor_scale_moe currently accepts any activation_type
but trtllm_get_valid_moe_configs enforces a gated-only restriction, causing a
mismatch when autotune requests configs (e.g., FP8PerTensorMoe + Relu2); fix by
making the validation consistent: either update trtllm_fp8_per_tensor_scale_moe
to reject non-gated activations (mirror trtllm_get_valid_moe_configs) or loosen
trtllm_get_valid_moe_configs to allow non-gated per-tensor activations (and
adjust any workspace/compatibility checks accordingly); pick one approach and
apply the change to both the activation validation code path and the
valid-config/workspace generation logic so both trtllm_fp8_per_tensor_scale_moe
and trtllm_get_valid_moe_configs accept the same activation set.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Outside diff comments:
In `@tests/moe/test_trtllm_gen_fused_moe.py`:
- Around line 1439-1443: The permutation cache key in
_maybe_get_cached_w3_w1_permute_indices currently only uses ("w3_w1",
weight_shape) which causes collisions when the same weight shape appears with
different activation gating; update the cache key to include the
is_gated_act_gemm flag so different activation types don't share the same cached
permutation (use a key like ("w3_w1", dst_w3_w1_weight.shape,
is_gated_act_gemm)). Locate the function _maybe_get_cached_w3_w1_permute_indices
in flashinfer/fused_moe/core.py and modify both the cache lookup and cache store
to include is_gated_act_gemm in the tuple key, ensuring callers (e.g., where
_maybe_get_cached_w3_w1_permute_indices is invoked) pass the is_gated_act_gemm
argument through.

---

Duplicate comments:
In `@csrc/trtllm_fused_moe_kernel_launcher.cu`:
- Around line 1750-1761: The function trtllm_fp8_per_tensor_scale_moe currently
accepts any activation_type but trtllm_get_valid_moe_configs enforces a
gated-only restriction, causing a mismatch when autotune requests configs (e.g.,
FP8PerTensorMoe + Relu2); fix by making the validation consistent: either update
trtllm_fp8_per_tensor_scale_moe to reject non-gated activations (mirror
trtllm_get_valid_moe_configs) or loosen trtllm_get_valid_moe_configs to allow
non-gated per-tensor activations (and adjust any workspace/compatibility checks
accordingly); pick one approach and apply the change to both the activation
validation code path and the valid-config/workspace generation logic so both
trtllm_fp8_per_tensor_scale_moe and trtllm_get_valid_moe_configs accept the same
activation set.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 808d3e9c-611b-4dcc-9b34-8d33bbad6460

📥 Commits

Reviewing files that changed from the base of the PR and between 62e38fd and c397011.

📒 Files selected for processing (7)
  • csrc/trtllm_batched_gemm_runner.cu
  • csrc/trtllm_fused_moe_kernel_launcher.cu
  • flashinfer/artifacts.py
  • flashinfer/fused_moe/__init__.py
  • flashinfer/fused_moe/core.py
  • tests/moe/test_trtllm_gen_fused_moe.py
  • tests/moe/utils.py
✅ Files skipped from review due to trivial changes (2)
  • tests/moe/utils.py
  • flashinfer/fused_moe/init.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • flashinfer/fused_moe/core.py

TomerBN-Nvidia pushed a commit to TomerBN-Nvidia/flashinfer that referenced this pull request Mar 30, 2026
…nfer-ai#2864)

Adds runtime-configurable activation type to BF16 fused MoE:
- Bf16MoeLauncher::init accepts ActivationType parameter (was hardcoded Swiglu)
- trtllm_bf16_moe() and trtllm_bf16_routed_moe() gain activation_type param
- Updated batched GEMM artifacts and checksums
- Uses validateAndCastActivationType for safety

Supports Swiglu (3) and Relu2 (6) for Nemotron models.
TomerBN-Nvidia pushed a commit to TomerBN-Nvidia/flashinfer that referenced this pull request Mar 30, 2026
…MM cubins

Cherry-pick of flashinfer-ai#2864 (squashed) plus:
- activation_type param for trtllm_bf16_moe/trtllm_bf16_routed_moe (Swiglu=3, Relu2=6)
- routing_replay_out param for BF16 kernel (same pattern as FP8)
- Updated batched GEMM artifacts and checksums
- validateAndCastActivationType for safety
- Bf16MoeLauncher::init accepts ActivationType + routing_replay_out
@amitz-nv amitz-nv force-pushed the support-bf16-relu2-trtllm-gen-fused-moe-kernel branch from c397011 to e9df6f8 Compare March 31, 2026 16:13
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

♻️ Duplicate comments (1)
csrc/trtllm_fused_moe_kernel_launcher.cu (1)

1784-1842: ⚠️ Potential issue | 🟠 Major

Reject non-gated activations in the FP8 per-tensor entrypoint.

The range check is good, but this path still accepts values like Relu2 while Fp8PerTensorLauncher::check_moe() and Fp8PerTensorLauncher::prepare_moe() are still hard-wired to the gated FC1 + gate-scale layout, and trtllm_get_valid_moe_configs() already rejects non-gated activations. Direct execution can still advertise a contract that the launcher does not implement.

💡 Suggested fix
   // Basic type validation
   auto dtype = hidden_states.dtype();
   auto activation = validateAndCastActivationType(activation_type);
+  if (!isGatedActivation(activation)) {
+    TVM_FFI_LOG_AND_THROW(NotImplementedError)
+        << "FP8 per-tensor currently supports gated activations only, "
+        << "got activation_type=" << activation_type << ".";
+  }

   if (static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::DeepSeekV3) {
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/trtllm_fused_moe_kernel_launcher.cu` around lines 1784 - 1842, The FP8
per-tensor path currently accepts non-gated activations (activation from
validateAndCastActivationType) even though Fp8PerTensorLauncher::check_moe() and
Fp8PerTensorLauncher::prepare_moe() assume a gated FC1+gate-scale layout and
trtllm_get_valid_moe_configs() already rejects non-gated activations; add an
explicit runtime check right after the activation variable is set to verify the
activation is a gated variant (the same gated enum(s) used by
trtllm_get_valid_moe_configs()) and abort with a clear error if not, so the code
path that constructs launchers (the loop creating MoERunnerArgs and calling
Fp8PerTensorLauncher::init) only proceeds for gated activations.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Duplicate comments:
In `@csrc/trtllm_fused_moe_kernel_launcher.cu`:
- Around line 1784-1842: The FP8 per-tensor path currently accepts non-gated
activations (activation from validateAndCastActivationType) even though
Fp8PerTensorLauncher::check_moe() and Fp8PerTensorLauncher::prepare_moe() assume
a gated FC1+gate-scale layout and trtllm_get_valid_moe_configs() already rejects
non-gated activations; add an explicit runtime check right after the activation
variable is set to verify the activation is a gated variant (the same gated
enum(s) used by trtllm_get_valid_moe_configs()) and abort with a clear error if
not, so the code path that constructs launchers (the loop creating MoERunnerArgs
and calling Fp8PerTensorLauncher::init) only proceeds for gated activations.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 83355232-6031-4470-8370-9b3af9478586

📥 Commits

Reviewing files that changed from the base of the PR and between c397011 and e9df6f8.

📒 Files selected for processing (7)
  • csrc/trtllm_batched_gemm_runner.cu
  • csrc/trtllm_fused_moe_kernel_launcher.cu
  • flashinfer/artifacts.py
  • flashinfer/fused_moe/__init__.py
  • flashinfer/fused_moe/core.py
  • tests/moe/test_trtllm_gen_fused_moe.py
  • tests/moe/utils.py
✅ Files skipped from review due to trivial changes (1)
  • flashinfer/fused_moe/core.py
🚧 Files skipped from review as they are similar to previous changes (5)
  • csrc/trtllm_batched_gemm_runner.cu
  • tests/moe/utils.py
  • flashinfer/fused_moe/init.py
  • flashinfer/artifacts.py
  • tests/moe/test_trtllm_gen_fused_moe.py

@amirkl94
Copy link
Copy Markdown
Contributor

amirkl94 commented Apr 1, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

@amirkl94 is not authorized to trigger this CI job. cc: @yzh119, @sricketts, @yongwww

@amitz-nv amitz-nv force-pushed the support-bf16-relu2-trtllm-gen-fused-moe-kernel branch from e9df6f8 to 6048e0a Compare April 5, 2026 11:29
@amitz-nv
Copy link
Copy Markdown
Contributor Author

amitz-nv commented Apr 5, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !506 has been created, and the CI pipeline #47767264 is currently running. I'll report back once the pipeline job completes.

@jiahanc jiahanc added the run-ci label Apr 5, 2026
Copy link
Copy Markdown
Collaborator

@jiahanc jiahanc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for contribution!

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[SUCCESS] Pipeline #47767264: 10/20 passed

@amitz-nv amitz-nv force-pushed the support-bf16-relu2-trtllm-gen-fused-moe-kernel branch from 5116227 to ce9b42a Compare April 7, 2026 11:10
@amitz-nv
Copy link
Copy Markdown
Contributor Author

amitz-nv commented Apr 7, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !506 has been updated with latest changes, and the CI pipeline #47917420 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[SUCCESS] Pipeline #47917420: 10/20 passed

@amitz-nv amitz-nv force-pushed the support-bf16-relu2-trtllm-gen-fused-moe-kernel branch from ce9b42a to 8e90827 Compare April 13, 2026 08:53
@aleozlx
Copy link
Copy Markdown
Collaborator

aleozlx commented Apr 13, 2026

seems prior artifact merge conflict has been resolved.

enabling auto-merge now

@aleozlx aleozlx enabled auto-merge (squash) April 13, 2026 08:55
@amitz-nv
Copy link
Copy Markdown
Contributor Author

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !506 has been updated with latest changes, and the CI pipeline #48388053 is currently running. I'll report back once the pipeline job completes.

@amitz-nv
Copy link
Copy Markdown
Contributor Author

I think #2982 pre-commit was run before #2966 was merged, which changed the parameters of trtllm_moe_finalize_allreduce_fusion, so now when both are merged, the pre-commit fails. @aleozlx

@aleozlx
Copy link
Copy Markdown
Collaborator

aleozlx commented Apr 13, 2026

waiting on the precommit fix that's blocking CI
#3040

Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com>
…it run --all-files'

Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com>
Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com>
…trtllm_batched_gemm_runner.cu access to BatchedGemmOptions.mNumStages as it was split to A and B

Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com>
…e of fp32

Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com>
…shape

Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com>
Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com>
auto-merge was automatically disabled April 13, 2026 14:37

Head branch was pushed to by a user without write access

@amitz-nv amitz-nv force-pushed the support-bf16-relu2-trtllm-gen-fused-moe-kernel branch from 8e90827 to d055258 Compare April 13, 2026 14:37
@aleozlx aleozlx enabled auto-merge (squash) April 13, 2026 17:42
@aleozlx aleozlx merged commit ede2225 into flashinfer-ai:main Apr 13, 2026
29 of 30 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants