Skip to content

Commit d622e27

Browse files
[NVFP4] NVFP4 MOE emulation fallback for H100/MI300/MI350, standardize TritonExperts usage for OCP MX emulation (vllm-project#35737)
Signed-off-by: Felix Marty <Felix.Marty@amd.com> Signed-off-by: fxmarty-amd <felmarty@amd.com> Co-authored-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent 5f76b3f commit d622e27

12 files changed

Lines changed: 601 additions & 121 deletions

File tree

tests/evals/gsm8k/configs/models-mi3xx.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,5 @@ DeepSeek-R1-TP_MI325.yaml
22
DeepSeek-R1-DP_MI325.yaml
33
DeepSeek-V3.2-TP_MI325.yaml
44
DeepSeek-V3.2-DP_MI325.yaml
5+
Qwen3-30B-A3B-NVFP4.yaml
6+
Qwen3.5-35B-A3B-MXFP4-TP2.yaml

tests/models/quantization/test_nvfp4.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,3 +120,26 @@ def test_nvfp4(vllm_runner, model, eager, backend, monkeypatch):
120120
with vllm_runner(model, enforce_eager=eager) as llm:
121121
output = llm.generate_greedy(["1 2 3 4 5"], max_tokens=2)
122122
assert output[0][1] == "1 2 3 4 5 6"
123+
124+
125+
@pytest.mark.parametrize(
126+
"model",
127+
[
128+
"nvidia/Qwen3-30B-A3B-NVFP4",
129+
"RedHatAI/Qwen3-30B-A3B-NVFP4",
130+
],
131+
)
132+
@pytest.mark.parametrize("backend", ["emulation"])
133+
@pytest.mark.skipif(
134+
not current_platform.is_rocm(),
135+
reason="NVFP4 MOE emulation is only useful on AMD Instinct MI3xx",
136+
)
137+
def test_nvfp4_moe(vllm_runner, model, backend, monkeypatch):
138+
monkeypatch.setenv("VLLM_NVFP4_GEMM_BACKEND", backend)
139+
with vllm_runner(
140+
model,
141+
moe_backend=backend,
142+
load_format="dummy",
143+
hf_overrides={"num_hidden_layers": 2},
144+
) as llm:
145+
_ = llm.generate_greedy(["1 2 3 4 5"], max_tokens=2)

vllm/config/kernel.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ def with_default(
115115
"flashinfer_cutedsl",
116116
"marlin",
117117
"aiter",
118+
"emulation",
118119
]
119120

120121

@@ -142,7 +143,10 @@ class KernelConfig:
142143
- "flashinfer_cutlass": Use FlashInfer with CUTLASS kernels
143144
- "flashinfer_cutedsl": Use FlashInfer with CuteDSL kernels (FP4 only)
144145
- "marlin": Use Marlin kernels (weight-only quantization)
145-
- "aiter": Use AMD AITer kernels (ROCm only)"""
146+
- "aiter": Use AMD AITer kernels (ROCm only)
147+
- "emulation": use BF16/FP16 GEMM, dequantizing weights and
148+
running QDQ on activations.
149+
"""
146150

147151
@field_validator("moe_backend", mode="before")
148152
@classmethod
Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
"""
4+
NVFP4 quantization emulation for MoE.
5+
6+
This file implements NVFP4 emulation for NVFP4 MOE in case the hardware used does not
7+
natively support NVFP4 MOE.
8+
9+
Weights are dequantized on the fly during each forward, we fall back to calling
10+
`TritonExperts` using BF16, and fake NVFP4 quantize-dequantize
11+
is applied on `a13`, `a2`.
12+
"""
13+
14+
import torch
15+
16+
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
17+
from vllm.logger import init_logger
18+
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
19+
from vllm.model_executor.layers.fused_moe.config import (
20+
FusedMoEConfig,
21+
FusedMoEQuantConfig,
22+
)
23+
from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts
24+
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
25+
from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import (
26+
dequantize_to_dtype,
27+
)
28+
from vllm.model_executor.layers.quantization.utils.quant_utils import (
29+
QuantKey,
30+
kNvfp4Dynamic,
31+
kNvfp4Static,
32+
)
33+
34+
logger = init_logger(__name__)
35+
36+
37+
class Nvfp4QuantizationEmulationTritonExperts(TritonExperts):
38+
"""
39+
Extension of TritonExperts to support emulated NVFP4 MoE experts.
40+
41+
It may be used for NVFP4 models when the device does not have
42+
native support for this dtype.
43+
"""
44+
45+
def __init__(
46+
self,
47+
moe_config: FusedMoEConfig,
48+
quant_config: FusedMoEQuantConfig,
49+
):
50+
super().__init__(moe_config, quant_config)
51+
logger.warning_once(
52+
"Using Nvfp4QuantizationEmulationTritonExperts MOE backend. This will"
53+
" dequantize weights on the fly and may be slower than native"
54+
" quantized MOE. Consider using a device with native quantization"
55+
" support (e.g. Nvidia Blackwell) for better performance."
56+
)
57+
58+
# `TritonExperts.apply` expects pre-dequantized weights,
59+
# which we handle in `apply` below.
60+
self.w1_scale_val = self.quant_config.w1_scale
61+
self.w2_scale_val = self.quant_config.w2_scale
62+
63+
self.quant_config._w1.scale = None
64+
self.quant_config._w2.scale = None
65+
66+
self.quantization_emulation = True
67+
68+
@property
69+
def quant_dtype(self) -> torch.dtype | str | None:
70+
return "nvfp4"
71+
72+
@property
73+
def expects_unquantized_inputs(self) -> bool:
74+
return True
75+
76+
@staticmethod
77+
def _supports_quant_scheme(
78+
weight_key: QuantKey | None,
79+
activation_key: QuantKey | None,
80+
) -> bool:
81+
return (weight_key, activation_key) == (kNvfp4Static, kNvfp4Dynamic)
82+
83+
def apply(
84+
self,
85+
output: torch.Tensor,
86+
hidden_states: torch.Tensor,
87+
w1: torch.Tensor,
88+
w2: torch.Tensor,
89+
topk_weights: torch.Tensor,
90+
topk_ids: torch.Tensor,
91+
activation: MoEActivation,
92+
global_num_experts: int,
93+
expert_map: torch.Tensor | None,
94+
a1q_scale: torch.Tensor | None,
95+
a2_scale: torch.Tensor | None,
96+
workspace13: torch.Tensor,
97+
workspace2: torch.Tensor,
98+
expert_tokens_meta: mk.ExpertTokensMetadata | None,
99+
apply_router_weight_on_input: bool,
100+
):
101+
"""
102+
Apply emulated quantized MoE computation.
103+
104+
This dequantizes the weights on the fly and calls fused_experts_impl
105+
with activation quantization support.
106+
"""
107+
# Dequantize weights if they are quantized
108+
# For NVFP4, weights are packed in uint8 format
109+
# w1 shape: [num_experts, 2*intermediate_size, hidden_size//2]
110+
# w2 shape: [num_experts, hidden_size, intermediate_size//2]
111+
assert w1.dtype == torch.uint8
112+
assert w2.dtype == torch.uint8
113+
114+
# Dequantize w1 from packed NVFP4 to fp16/bf16
115+
w13_global_scale = self.quant_config.g1_alphas
116+
117+
w1_dequant = dequantize_to_dtype(
118+
tensor_fp4=w1,
119+
tensor_sf=self.w1_scale_val,
120+
global_scale=w13_global_scale,
121+
dtype=hidden_states.dtype,
122+
block_size=16,
123+
swizzle=False,
124+
)
125+
126+
# Dequantize w2 from packed NVFP4 to fp16/bf16
127+
w2_global_scale = self.quant_config.g2_alphas
128+
129+
w2_dequant = dequantize_to_dtype(
130+
tensor_fp4=w2,
131+
tensor_sf=self.w2_scale_val,
132+
global_scale=w2_global_scale,
133+
dtype=hidden_states.dtype,
134+
block_size=16,
135+
swizzle=False,
136+
)
137+
138+
hidden_states, _ = moe_kernel_quantize_input(
139+
A=hidden_states,
140+
A_scale=self.quant_config.a1_gscale,
141+
quant_dtype="nvfp4",
142+
per_act_token_quant=False,
143+
quantization_emulation=True,
144+
)
145+
146+
# Activation quantization/dequantization is deferred to
147+
# `moe_kernel_quantize_input` in TritonExperts.apply.
148+
super().apply(
149+
output=output,
150+
hidden_states=hidden_states,
151+
w1=w1_dequant,
152+
w2=w2_dequant,
153+
topk_weights=topk_weights,
154+
topk_ids=topk_ids,
155+
activation=activation,
156+
global_num_experts=global_num_experts,
157+
expert_map=expert_map,
158+
a1q_scale=None,
159+
a2_scale=self.quant_config.a2_gscale,
160+
workspace13=workspace13,
161+
workspace2=workspace2,
162+
expert_tokens_meta=expert_tokens_meta,
163+
apply_router_weight_on_input=apply_router_weight_on_input,
164+
)

0 commit comments

Comments
 (0)