Skip to content

Commit 9058031

Browse files
author
menyu
committed
EPv2: address review (graph gating, fp8 fail-fast, helper scope, robustness)
Two rounds of code-review fixes: - server_args: disable the prefill CUDA graph under EPv2 -- only the direct-mode decode masked-GEMM path is capture-safe; the direct-mode prefill (extend) path uses a non-masked layout with a host readback and is not capture-validated. - fp8: DeepGEMM UE8M0 weight requant now fails fast for the EPv2 FusedMoE layer with a clear message, instead of asserting isinstance(layer, DeepEPMoE). - deepseek_v2: revert the 3 invented a2a helpers back to the original inline backend checks plus `or is_epv2()`, so EPv2 integration is purely additive. Fix two over-broad sites where a wide helper had replaced narrower checks: the AMD gfx95 allocator-size path (restore is_deepep_class_backend + epv2) and enable_a2a_moe (restore is_deepep/is_mooncake + epv2) -- unrelated backends (nixl / ascend / flashinfer / megamoe) keep their original behavior. - epv2: dispatch_b/combine_b raise a clear RuntimeError when called without a preceding dispatch_a/combine_a, aligned with the combine_a stage check. - kernels: document that the masked-slab overflow fast-fail is skipped during CUDA graph capture, and that safety then relies on the static max_m = cap * ep_group_size upper bound. - utils: drop the now-unused a2a helpers; clarify the capability-resolver comment (it reads runner flags to build the contract, like the DeepEP dispatcher does; the dispatcher itself only consumes the resolved contract). No functional or perf change for EPv2 or DeepEP: re-verified chat-completions 3-question correctness (direct + hybrid), unit tests (7 passed), and 4 throughput points (decode/prefill, EPv2 vs DeepEP) -- all within run-to-run noise of the pre-fix numbers.
1 parent 395b55c commit 9058031

6 files changed

Lines changed: 75 additions & 58 deletions

File tree

python/sglang/srt/layers/moe/ep_moe/kernels.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1731,7 +1731,7 @@ def fp8_per_token_to_per_tensor_quant_triton(
17311731

17321732

17331733
# ---------------------------------------------------------------------------
1734-
# EPv2 decode masked-GEMM bridge (Claude): repack the expanded expert-packed
1734+
# EPv2 decode masked-GEMM bridge: repack the expanded expert-packed
17351735
# dispatch buffer into a regular [E_local, max_m, hidden] slab so DeepGEMM's
17361736
# *masked* grouped GEMM can bound compute by per-expert real counts (masked_m)
17371737
# instead of the dispatch capacity. All-GPU, static shapes -> cuda-graph safe.
@@ -1875,6 +1875,11 @@ def expand_to_masked_slab(
18751875
# Outside cuda graph capture, fail fast on slab overflow rather than return a
18761876
# silently truncated result. During capture we skip the host read to keep the
18771877
# path graph-safe; the eager warmup forward validates representative shapes.
1878+
# Safety under graph replay therefore relies on the static upper bound
1879+
# max_m = cap * ep_group_size holding: each rank sends at most `cap` tokens
1880+
# (enforced by the dispatch-entry assert) and a token contributes at most once
1881+
# per local expert, so no expert can exceed max_m. If those invariants change,
1882+
# graph replay would NOT fail-fast on overflow — re-validate before relying on it.
18781883
if not torch.cuda.is_current_stream_capturing() and int(overflow.item()) != 0:
18791884
raise RuntimeError(
18801885
f"EPv2 masked slab overflow: an expert received more than max_m="

python/sglang/srt/layers/moe/token_dispatcher/epv2.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -527,6 +527,10 @@ def dispatch_a(self, hidden_states: torch.Tensor, topk_output: TopKOutput) -> No
527527
self._dispatch_state = self._impl.dispatch_a(hidden_states, topk_output)
528528

529529
def dispatch_b(self) -> DispatchOutput:
530+
if self._dispatch_state is None:
531+
raise RuntimeError(
532+
"DeepEP v2 dispatch_b() called without a preceding dispatch_a()"
533+
)
530534
out = self._impl.dispatch_b(*self._dispatch_state)
531535
self._dispatch_state = None
532536
self._stage = _Stage.AFTER_DISPATCH
@@ -553,6 +557,10 @@ def combine_a(self, combine_input: CombineInput) -> None:
553557
self._combine_state = self._impl.combine_a(combine_input)
554558

555559
def combine_b(self) -> torch.Tensor:
560+
if self._combine_state is None:
561+
raise RuntimeError(
562+
"DeepEP v2 combine_b() called without a preceding combine_a()"
563+
)
556564
try:
557565
return self._impl.combine_b(*self._combine_state)
558566
finally:

python/sglang/srt/layers/moe/utils.py

Lines changed: 4 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -164,8 +164,10 @@ class EpV2RunnerCapability(NamedTuple):
164164
"""
165165
Describes the EPv2 dispatcher contract required by the active MoE runner.
166166
167-
The dispatcher should depend on this explicit contract instead of peeking at
168-
runner implementation details such as DeepGEMM JIT flags.
167+
This capability is resolved once (in get_epv2_runner_capability, which reads
168+
runner-side flags such as DeepGEMM JIT TMA/UE8M0 settings) and then consumed
169+
by the dispatcher. The dispatcher depends only on this resolved contract and
170+
does not peek at runner implementation details itself.
169171
"""
170172

171173
output_dtype: EpV2OutputDtype
@@ -468,48 +470,6 @@ def is_deepep_class_backend() -> bool:
468470
return b.is_deepep() or b.is_mooncake() or b.is_mori()
469471

470472

471-
def uses_a2a_moe_forward() -> bool:
472-
"""Return whether the active backend uses the A2A MoE forward path."""
473-
b = get_moe_a2a_backend()
474-
return (
475-
b.is_deepep()
476-
or b.is_mooncake()
477-
or b.is_nixl()
478-
or b.is_mori()
479-
or b.is_ascend_fuseep()
480-
or b.is_flashinfer()
481-
or b.is_epv2()
482-
)
483-
484-
485-
def uses_a2a_expert_parallel_metadata() -> bool:
486-
"""Return whether the backend needs EP metadata on DeepSeek MoE layers."""
487-
b = get_moe_a2a_backend()
488-
return (
489-
b.is_deepep()
490-
or b.is_mooncake()
491-
or b.is_nixl()
492-
or b.is_mori()
493-
or b.is_ascend_fuseep()
494-
or b.is_epv2()
495-
)
496-
497-
498-
def requires_shared_expert_tp1() -> bool:
499-
"""Return whether shared experts should be materialized with TP=1."""
500-
b = get_moe_a2a_backend()
501-
return (
502-
b.is_deepep()
503-
or b.is_mooncake()
504-
or b.is_nixl()
505-
or b.is_mori()
506-
or b.is_ascend_fuseep()
507-
or b.is_flashinfer()
508-
or b.is_megamoe()
509-
or b.is_epv2()
510-
)
511-
512-
513473
def is_flashinfer_cutedsl_v1_path() -> bool:
514474
"""CuteDSL v1 + DeepEP low-latency path (no MoeRunner, no autotune)."""
515475
return (

python/sglang/srt/layers/quantization/fp8.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1389,9 +1389,20 @@ def process_weights_after_loading_block_quant(self, layer: Module) -> None:
13891389
and will_use_deepgemm
13901390
and not layer.w13_weight_scale_inv.format_ue8m0
13911391
):
1392-
assert isinstance(
1393-
layer, DeepEPMoE
1394-
), "DeepGemm MoE is only supported with DeepEPMoE"
1392+
if not isinstance(layer, DeepEPMoE):
1393+
# UE8M0 in-place weight requant is only wired for the
1394+
# DeepEPMoE layer (legacy deepep backend). The EPv2 backend
1395+
# uses FusedMoE, so fail fast with a clear message instead of
1396+
# asserting; use a pre-requantized FP8 checkpoint or
1397+
# --moe-a2a-backend deepep for checkpoints that need UE8M0
1398+
# requant at load time.
1399+
raise NotImplementedError(
1400+
"DeepGEMM UE8M0 weight requant requires the DeepEPMoE "
1401+
f"layer, got {type(layer).__name__}. The EPv2 backend "
1402+
"does not support FP8 checkpoints that need load-time "
1403+
"UE8M0 requant yet; use a pre-requantized checkpoint or "
1404+
"--moe-a2a-backend deepep."
1405+
)
13951406
weight_block_size = self.quant_config.weight_block_size
13961407
requant_weight_ue8m0_inplace(
13971408
layer.w13_weight, layer.w13_weight_scale_inv, weight_block_size

python/sglang/srt/models/deepseek_v2.py

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -101,9 +101,6 @@
101101
is_deepep_class_backend,
102102
is_sbo_enabled,
103103
is_tbo_enabled,
104-
requires_shared_expert_tp1,
105-
uses_a2a_expert_parallel_metadata,
106-
uses_a2a_moe_forward,
107104
)
108105
from sglang.srt.layers.quantization.base_config import QuantizationConfig
109106
from sglang.srt.layers.quantization.fp8 import Fp8Config
@@ -712,7 +709,14 @@ def __init__(
712709
# explicitly requested for DSV4 checkpoints whose shared scales are
713710
# not divisible by the global TP size.
714711
_shared_expert_use_tp1 = (
715-
requires_shared_expert_tp1()
712+
get_moe_a2a_backend().is_deepep()
713+
or get_moe_a2a_backend().is_mooncake()
714+
or get_moe_a2a_backend().is_nixl()
715+
or get_moe_a2a_backend().is_mori()
716+
or get_moe_a2a_backend().is_ascend_fuseep()
717+
or get_moe_a2a_backend().is_flashinfer()
718+
or get_moe_a2a_backend().is_megamoe()
719+
or get_moe_a2a_backend().is_epv2()
716720
or should_use_flashinfer_cutlass_moe_fp4_allgather()
717721
or envs.SGLANG_SHARED_EXPERT_TP1.get()
718722
)
@@ -787,7 +791,14 @@ def __init__(
787791

788792
self.top_k = config.num_experts_per_tok
789793

790-
if uses_a2a_expert_parallel_metadata():
794+
if (
795+
get_moe_a2a_backend().is_deepep()
796+
or get_moe_a2a_backend().is_mooncake()
797+
or get_moe_a2a_backend().is_nixl()
798+
or get_moe_a2a_backend().is_mori()
799+
or get_moe_a2a_backend().is_ascend_fuseep()
800+
or get_moe_a2a_backend().is_epv2()
801+
):
791802
# TODO: we will support tp < ep in the future
792803
self.ep_size = get_parallel().moe_ep_size
793804
self.num_experts = (
@@ -803,7 +814,15 @@ def __init__(
803814
else None
804815
)
805816

806-
self._enable_a2a_moe = uses_a2a_moe_forward()
817+
self._enable_a2a_moe = (
818+
get_moe_a2a_backend().is_deepep()
819+
or get_moe_a2a_backend().is_mooncake()
820+
or get_moe_a2a_backend().is_nixl()
821+
or get_moe_a2a_backend().is_mori()
822+
or get_moe_a2a_backend().is_ascend_fuseep()
823+
or get_moe_a2a_backend().is_flashinfer()
824+
or get_moe_a2a_backend().is_epv2()
825+
)
807826
self._fuse_shared_experts_inside_sbo = SboFlags.fuse_shared_experts_inside_sbo()
808827

809828
def get_moe_weights(self):
@@ -2424,9 +2443,12 @@ def __init__(
24242443
for i in range(len(self.layers)):
24252444
if isinstance(self.layers[i].mlp, DeepseekV2MoE):
24262445
# tp_size = get_parallel().tp_size
2427-
# requires_shared_expert_tp1() is epv2-aware (is_deepep_class_backend
2428-
# is not); keep it so EPv2 also materializes shared experts at TP=1.
2429-
is_a2a_moe = requires_shared_expert_tp1()
2446+
# Keep the original deepep-class scope here and only add EPv2,
2447+
# so unrelated backends' allocator sizing is unchanged.
2448+
is_a2a_moe = (
2449+
is_deepep_class_backend()
2450+
or get_moe_a2a_backend().is_epv2()
2451+
)
24302452
tp_size = 1 if is_a2a_moe else get_parallel().tp_size
24312453
intermediate_size = (
24322454
config.moe_intermediate_size * config.n_shared_experts
@@ -2446,7 +2468,11 @@ def __init__(
24462468
)
24472469
)
24482470
self.layers_to_capture = []
2449-
self.enable_a2a_moe = uses_a2a_moe_forward()
2471+
self.enable_a2a_moe = (
2472+
get_moe_a2a_backend().is_deepep()
2473+
or get_moe_a2a_backend().is_mooncake()
2474+
or get_moe_a2a_backend().is_epv2()
2475+
)
24502476

24512477
# llama_4_scaling: for supporting Mistral-Large-3 model
24522478
self.llama_4_scaling_config = getattr(config, "llama_4_scaling", None)

python/sglang/srt/server_args.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5547,6 +5547,13 @@ def _handle_a2a_moe(self):
55475547
if not epv2_graph_ok:
55485548
self.cuda_graph_config.decode.backend = Backend.DISABLED
55495549
self.cuda_graph_config.prefill.backend = Backend.DISABLED
5550+
else:
5551+
# Only the direct-mode decode masked-GEMM path is capture-safe
5552+
# (static shapes, no host readback). The direct-mode prefill
5553+
# (extend) path goes through the non-masked contiguous layout with
5554+
# a host readback and is not capture-validated, so keep the decode
5555+
# graph but always disable the prefill graph under EPv2.
5556+
self.cuda_graph_config.prefill.backend = Backend.DISABLED
55505557
logger.warning(
55515558
f"DeepEP v2 MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
55525559
)

0 commit comments

Comments
 (0)