Skip to content

WIP: Add B12x support for non-gated MoEs (Nemotron) #39920

Draft
askliar wants to merge 9 commits intovllm-project:mainfrom
askliar:sm120-b12x-nvfp4-backend
Draft

WIP: Add B12x support for non-gated MoEs (Nemotron) #39920
askliar wants to merge 9 commits intovllm-project:mainfrom
askliar:sm120-b12x-nvfp4-backend

Conversation

@askliar
Copy link
Copy Markdown
Contributor

@askliar askliar commented Apr 15, 2026

No description provided.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for the b12x external backend, targeting SM12x Blackwell architecture for both fused MoE and linear kernels. Key changes include the implementation of B12xNvFp4LinearKernel and B12xExperts, along with updates to the modular kernel framework to support direct output writing, which optimizes memory management by bypassing temporary workspace allocations. Review feedback correctly identified redundant backend registrations in the environment variable configuration and duplicate logic within the MoE backend mapping function.

Comment thread vllm/envs.py Outdated
Comment on lines +1473 to +1475
"b12x",
"cutlass",
"b12x",
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The "b12x" backend is redundantly added twice to the list of valid choices for the VLLM_NVFP4_GEMM_BACKEND environment variable.

Suggested change
"b12x",
"cutlass",
"b12x",
"b12x",
"cutlass",

Comment on lines +116 to +122
elif backend == NvFp4MoeBackend.B12X:
from vllm.model_executor.layers.fused_moe.experts.b12x_nvfp4_moe import (
B12xExperts,
)

return [B12xExperts]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The handling for NvFp4MoeBackend.B12X is duplicated in the backend_to_kernel_cls function. It was already implemented at the beginning of the function (lines 76-81). This redundant block should be removed.

@mergify mergify Bot added the nvidia label Apr 15, 2026
@askliar askliar force-pushed the sm120-b12x-nvfp4-backend branch 2 times, most recently from 864afd8 to 52ce4aa Compare April 16, 2026 11:41
meena-at-work and others added 5 commits April 16, 2026 13:44
  Adds FlashInferCuteDSLSM12xExperts targeting SM120/SM121 (RTX Pro
  6000 / DGX Spark) using cute_dsl_fused_moe_nvfp4 from FlashInfer
  PRs vllm-project#3051 and vllm-project#3066. The kernel fuses token dispatch, W1 GEMM, SwiGLU,
  and W2 GEMM into a single call; BF16 hidden states are passed directly
  as activation quantization is fused internally.

  - vllm/utils/flashinfer.py: lazy import wrappers for
    cute_dsl_fused_moe_nvfp4 and convert_sf_to_mma_layout; adds
    has_flashinfer_cutedsl_sm12x_moe() availability probe
  - experts/flashinfer_cutedsl_moe.py: FlashInferCuteDSLSM12xExperts
    with TODO to adopt plan/run() API from PR vllm-project#3066
  - oracle/nvfp4.py: FLASHINFER_CUTEDSL_SM12X backend enum and routing;
    falls back to FLASHINFER_CUTLASS on SM12x when PRs are absent
  - flashinfer_fp4_moe.py: SM12X added to FI weight-prep path and
    w1/w3 → w3/w1 reorder list
  - tests/kernels/moe/test_cutedsl_sm12x_moe.py: correctness tests vs
    BF16 torch reference; module-level skip when SM120 hw or FlashInfer
    PRs are absent

Signed-off-by: Meenakshi Venkataraman <meenakshiv@nvidia.com>
…_unquantized_inputs

The test was failing (all 24 cases, abs diff ~3904) because it used vLLM's
NVFP4 convention: scaled_fp4_quant(w, w_gs) bakes w_gs≈8960 into block
scales and sets g1_alphas=1/w_gs.  But launch_sm120_moe uses w1_alpha as
*both* the activation input_gs (arg 17) and the weight dequant factor (arg
18), conflating the two roles.  With input_gs=1/w_gs, activations are
scaled up by w_gs inside the kernel, producing outputs ~8960× too large.

Fix the test to use FlashInfer's convention: fp4_quantize(global_scale=1.0,
is_sf_swizzled_layout=True) so block_scale=max_abs/fp4_max and all alphas
are 1.0, satisfying both conflated roles simultaneously.

Also add expects_unquantized_inputs=True to FlashInferCuteDSLSM12xExperts:
cute_dsl_fused_moe_nvfp4 quantizes activations internally and must receive
BF16 hidden states.  Without this override the modular kernel pre-quantizes
to FP4 (size k//2) before apply(), breaking convert_sf_to_mma_layout which
expects the full k dimension.

Verified: 24/24 passed on SM121 (DGX SparkX2, p4242-0064).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- FlashInferCuteDSLSM12xExperts.process_weights_after_loading: normalize
  float8 block scales by w13_weight_scale_2 (=1/w_gs) at load time, then
  set weight_scale_2=1.0.  This converts vLLM's NVFP4 convention
  (block_scale = max_abs * w_gs / fp4_max) to the SM12x kernel's required
  convention (block_scale = max_abs / fp4_max, g1_alphas = 1.0), without
  re-quantising packed FP4 values which are identical in both conventions.
  Unlike other backends, activation scale is NOT baked in (would break the
  conflated activation-gs role in launch_sm120_moe).
- kernel.py: add "flashinfer_cutedsl_sm12x" to MoEBackend Literal so the
  --moe-backend CLI arg accepts it without "invalid choice" error.
- test skip message: "RTX Pro 6000 / DGX Spark" (not "Blackwell GeForce").

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Andrii Skliar <askliar@nvidia.com>
@askliar askliar force-pushed the sm120-b12x-nvfp4-backend branch from 52ce4aa to 3d7663e Compare April 16, 2026 11:44
Andrii Skliar added 3 commits April 16, 2026 13:46
- Updated references from FlashInferCuteDSLSM12xExperts to FlashInferB12xExperts across the codebase.
- Modified test cases to reflect the new B12x naming convention and adjusted skip conditions accordingly.
- Enhanced the `make_dummy_moe_config` function to accept activation type and is_act_and_mul parameters.
- Introduced a new test for FlashInferB12x with ReLU2 activation.
- Updated backend handling in various modules to support the B12x configuration.

This change aligns with the new backend structure and improves clarity in the codebase.
Signed-off-by: Andrii Skliar <askliar@nvidia.com>
@eugr
Copy link
Copy Markdown

eugr commented Apr 17, 2026

@askliar - have you got it to work on Spark?
I've build the latest vLLM with this PR applied (and another one that re-enables flashinfer_cutlass) and getting this error. I have a flashinfer build with b12x support (also built from main).

(EngineCore pid=114)   File "/usr/local/lib/python3.12/dist-packages/vllm/utils/flashinfer.py", line 505, in flashinfer_mm_fp4
(EngineCore pid=114)     return flashinfer_mm_fp4_(
(EngineCore pid=114)            ^^^^^^^^^^^^^^^^^^^
(EngineCore pid=114)   File "/usr/local/lib/python3.12/dist-packages/flashinfer/utils.py", line 1223, in wrapper
(EngineCore pid=114)     raise BackendSupportedError(
(EngineCore pid=114) flashinfer.utils.BackendSupportedError: mm_fp4 does not support backend 'b12x' with capability 121

@askliar
Copy link
Copy Markdown
Contributor Author

askliar commented Apr 17, 2026

@eugr yes it's pretty flaky just yet - hold off for a little bit, I will update you as soon as it's in a good shape (should take a day or two)
Also, FI-main is not supporting Nemotron yet

…t conversion

- Introduced dynamic per-block quantization for FC2 input activations to prevent saturation during scaling.
- Added conversion of swizzled 3D scale factors to 6D MMA layout for compatibility with the SM12x kernel.
- Updated the run method to utilize the new MMA layout for weight scales.

These changes improve the performance and accuracy of the FlashInferB12xExperts implementation.

Signed-off-by: Andrii Skliar <askliar@nvidia.com>
@eugr
Copy link
Copy Markdown

eugr commented Apr 20, 2026

@askliar - will this help? flashinfer-ai/flashinfer#3080

@askliar
Copy link
Copy Markdown
Contributor Author

askliar commented Apr 21, 2026

@eugr yes, for sure! It does add support.
Generally speaking, there are two branches:

  1. https://github.com/askliar/vllm/tree/askliar/b12x-with-tinygemm - it's my branch that I am using to push all the latest improvements for Nemotron. It might be less stable but it is (likely) the best you can get in terms of speed.
  2. Integrate flashinfer b12x MoE and FP4 GEMM kernels for SM120/121 #40082 - more stable branch that is the one that's slightly behind mine but with basically identical performance

Also, this PR should go in soon-ish: #39921 - it makes Nemotron BF16 small-M GEMMs close to SoL.

@eugr
Copy link
Copy Markdown

eugr commented Apr 21, 2026

@askliar - sm120/121 confusion strikes again...

File "/usr/local/lib/python3.12/dist-packages/flashinfer/gemm/kernels/dense_blockscaled_gemm_sm120.py", line 144, in _setup_attributes
(EngineCore pid=114)     mma_op = cute.nvgpu.warp.MmaMXF4NVF4Op(
(EngineCore pid=114)              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore pid=114)   File "/usr/local/lib/python3.12/dist-packages/nvidia_cutlass_dsl/python_packages/cutlass/cute/nvgpu/warp/mma.py", line 325, in __init__
(EngineCore pid=114)     super().__init__(
(EngineCore pid=114)   File "<string>", line 9, in __init__
(EngineCore pid=114)   File "/usr/local/lib/python3.12/dist-packages/nvidia_cutlass_dsl/python_packages/cutlass/cute/nvgpu/warp/mma.py", line 139, in __post_init__
(EngineCore pid=114)     raise OpError(
(EngineCore pid=114) cutlass.cute.nvgpu.common.OpError: OpError: expects arch to be one of ['sm_120a'], but got Arch.sm_121a
(EngineCore pid=114) Error Code: MmaMXF4NVF4Op error

Not sure if that can be addressed in flashinfer, or we need to wait for upstream cutlass fix.

@askliar
Copy link
Copy Markdown
Contributor Author

askliar commented Apr 21, 2026

@eugr Ah yes, my bad - I did not share a few extra bits that will, actually, take a bit of time to be resolved most likely.
First of all, new command is this (most important additions are CUTE_DSL_ARCH + moe_backend in spec-dec):

FLASHINFER_DISABLE_VERSION_CHECK=1 \
  VLLM_USE_FLASHINFER_MOE_FP16=1 \
  VLLM_NVFP4_GEMM_BACKEND=flashinfer-b12x \
  VLLM_ALLOW_LONG_MAX_MODEL_LEN=1 \
  VLLM_FLASHINFER_ALLREDUCE_BACKEND=trtllm \
  VLLM_USE_FLASHINFER_MOE_FP4=1 \
  CUTE_DSL_ARCH=sm_121a \
  OMP_NUM_THREADS=16 \
  SAFETENSORS_FAST_GPU=1 \
  python3 -m vllm.entrypoints.openai.api_server \
    --model /workspace/models/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 \
    --served-model-name nemotron-3-super --host 0.0.0.0 --port 9000 \
    --async-scheduling --dtype auto --kv-cache-dtype fp8 \
    --tensor-parallel-size 1 --pipeline-parallel-size 1 --data-parallel-size 1 \
    --trust-remote-code --gpu-memory-utilization 0.9 \
    --enable-chunked-prefill --max-num-seqs 4 --max-model-len 53000 \
    --attention-backend FLASHINFER --moe-backend flashinfer_b12x \
    --mamba_ssm_cache_dtype float32 --quantization fp4 \
    --mamba-cache-mode align --enable-prefix-caching \
    --max_num_batched_tokens 16484 \
    --speculative-config '{"method":"nemotron_h_mtp","num_speculative_tokens":3,"moe_backend":"flashinfer_cutlass"}' &

Also, nvidia-cutlass-dsl must be 4.4.2 — version 4.5.0 generates bad PTX on SM121.
Finally, even with that, cutlass-dsl has to be adjusted with these commands:

  # warp/mma.py: Add sm_121a alongside sm_120a
  sed -i "s/if not arch == Arch.sm_120a:/if arch not in (Arch.sm_120a, Arch.sm_121a):/"

  # tcgen05/mma.py: Add sm_120a and sm_121a to arch list
  sed -i "/Arch.sm_103a,/a\\        Arch.sm_120a,\n        Arch.sm_121a,"

  # tcgen05/copy.py: Add sm_120f family
  sed -i "s/arch.is_family_of(Arch.sm_110f)/arch.is_family_of(Arch.sm_110f) or arch.is_family_of(Arch.sm_120f)/"

@eugr
Copy link
Copy Markdown

eugr commented Apr 22, 2026

@askliar - if I build with this PR and 39921 (tinygemm), I can't get b12x to work, getting:

ValueError: Invalid value 'flashinfer-b12x' for VLLM_NVFP4_GEMM_BACKEND. Valid options: ['flashinfer-cudnn', 'flashinfer-trtllm', 'flashinfer-cutlass', 'cutlass', 'marlin', 'emulation']

and

vllm serve: error: argument --moe-backend: invalid choice: 'flashinfer_b12x' (choose from 'aiter', 'auto', 'cutlass', 'deep_gemm', 'emulation', 'flashinfer_cutedsl', 'flashinfer_cutlass', 'flashinfer_trtllm', 'marlin', 'triton')

I guess I could try your branch, but I wanted to apply PRs to current main instead as it could be easier incorporated into my build pipeline (that I don't want to lag behind vLLM main).

I also tried the other PR you mentioned, but there are some issues with that one too.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

Status: No status

Development

Successfully merging this pull request may close these issues.

3 participants