WIP: Add B12x support for non-gated MoEs (Nemotron) #39920
WIP: Add B12x support for non-gated MoEs (Nemotron) #39920askliar wants to merge 9 commits intovllm-project:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces support for the b12x external backend, targeting SM12x Blackwell architecture for both fused MoE and linear kernels. Key changes include the implementation of B12xNvFp4LinearKernel and B12xExperts, along with updates to the modular kernel framework to support direct output writing, which optimizes memory management by bypassing temporary workspace allocations. Review feedback correctly identified redundant backend registrations in the environment variable configuration and duplicate logic within the MoE backend mapping function.
| "b12x", | ||
| "cutlass", | ||
| "b12x", |
| elif backend == NvFp4MoeBackend.B12X: | ||
| from vllm.model_executor.layers.fused_moe.experts.b12x_nvfp4_moe import ( | ||
| B12xExperts, | ||
| ) | ||
|
|
||
| return [B12xExperts] | ||
|
|
864afd8 to
52ce4aa
Compare
Adds FlashInferCuteDSLSM12xExperts targeting SM120/SM121 (RTX Pro 6000 / DGX Spark) using cute_dsl_fused_moe_nvfp4 from FlashInfer PRs vllm-project#3051 and vllm-project#3066. The kernel fuses token dispatch, W1 GEMM, SwiGLU, and W2 GEMM into a single call; BF16 hidden states are passed directly as activation quantization is fused internally. - vllm/utils/flashinfer.py: lazy import wrappers for cute_dsl_fused_moe_nvfp4 and convert_sf_to_mma_layout; adds has_flashinfer_cutedsl_sm12x_moe() availability probe - experts/flashinfer_cutedsl_moe.py: FlashInferCuteDSLSM12xExperts with TODO to adopt plan/run() API from PR vllm-project#3066 - oracle/nvfp4.py: FLASHINFER_CUTEDSL_SM12X backend enum and routing; falls back to FLASHINFER_CUTLASS on SM12x when PRs are absent - flashinfer_fp4_moe.py: SM12X added to FI weight-prep path and w1/w3 → w3/w1 reorder list - tests/kernels/moe/test_cutedsl_sm12x_moe.py: correctness tests vs BF16 torch reference; module-level skip when SM120 hw or FlashInfer PRs are absent Signed-off-by: Meenakshi Venkataraman <meenakshiv@nvidia.com>
…_unquantized_inputs The test was failing (all 24 cases, abs diff ~3904) because it used vLLM's NVFP4 convention: scaled_fp4_quant(w, w_gs) bakes w_gs≈8960 into block scales and sets g1_alphas=1/w_gs. But launch_sm120_moe uses w1_alpha as *both* the activation input_gs (arg 17) and the weight dequant factor (arg 18), conflating the two roles. With input_gs=1/w_gs, activations are scaled up by w_gs inside the kernel, producing outputs ~8960× too large. Fix the test to use FlashInfer's convention: fp4_quantize(global_scale=1.0, is_sf_swizzled_layout=True) so block_scale=max_abs/fp4_max and all alphas are 1.0, satisfying both conflated roles simultaneously. Also add expects_unquantized_inputs=True to FlashInferCuteDSLSM12xExperts: cute_dsl_fused_moe_nvfp4 quantizes activations internally and must receive BF16 hidden states. Without this override the modular kernel pre-quantizes to FP4 (size k//2) before apply(), breaking convert_sf_to_mma_layout which expects the full k dimension. Verified: 24/24 passed on SM121 (DGX SparkX2, p4242-0064). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- FlashInferCuteDSLSM12xExperts.process_weights_after_loading: normalize float8 block scales by w13_weight_scale_2 (=1/w_gs) at load time, then set weight_scale_2=1.0. This converts vLLM's NVFP4 convention (block_scale = max_abs * w_gs / fp4_max) to the SM12x kernel's required convention (block_scale = max_abs / fp4_max, g1_alphas = 1.0), without re-quantising packed FP4 values which are identical in both conventions. Unlike other backends, activation scale is NOT baked in (would break the conflated activation-gs role in launch_sm120_moe). - kernel.py: add "flashinfer_cutedsl_sm12x" to MoEBackend Literal so the --moe-backend CLI arg accepts it without "invalid choice" error. - test skip message: "RTX Pro 6000 / DGX Spark" (not "Blackwell GeForce"). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
52ce4aa to
3d7663e
Compare
- Updated references from FlashInferCuteDSLSM12xExperts to FlashInferB12xExperts across the codebase. - Modified test cases to reflect the new B12x naming convention and adjusted skip conditions accordingly. - Enhanced the `make_dummy_moe_config` function to accept activation type and is_act_and_mul parameters. - Introduced a new test for FlashInferB12x with ReLU2 activation. - Updated backend handling in various modules to support the B12x configuration. This change aligns with the new backend structure and improves clarity in the codebase.
Signed-off-by: Andrii Skliar <askliar@nvidia.com>
|
@askliar - have you got it to work on Spark? |
|
@eugr yes it's pretty flaky just yet - hold off for a little bit, I will update you as soon as it's in a good shape (should take a day or two) |
…t conversion - Introduced dynamic per-block quantization for FC2 input activations to prevent saturation during scaling. - Added conversion of swizzled 3D scale factors to 6D MMA layout for compatibility with the SM12x kernel. - Updated the run method to utilize the new MMA layout for weight scales. These changes improve the performance and accuracy of the FlashInferB12xExperts implementation. Signed-off-by: Andrii Skliar <askliar@nvidia.com>
|
@askliar - will this help? flashinfer-ai/flashinfer#3080 |
|
@eugr yes, for sure! It does add support.
Also, this PR should go in soon-ish: #39921 - it makes Nemotron BF16 small-M GEMMs close to SoL. |
|
@askliar - sm120/121 confusion strikes again... Not sure if that can be addressed in flashinfer, or we need to wait for upstream cutlass fix. |
|
@eugr Ah yes, my bad - I did not share a few extra bits that will, actually, take a bit of time to be resolved most likely. Also, nvidia-cutlass-dsl must be 4.4.2 — version 4.5.0 generates bad PTX on SM121. |
|
@askliar - if I build with this PR and 39921 (tinygemm), I can't get b12x to work, getting: and I guess I could try your branch, but I wanted to apply PRs to current main instead as it could be easier incorporated into my build pipeline (that I don't want to lag behind vLLM main). I also tried the other PR you mentioned, but there are some issues with that one too. |
No description provided.