Skip to content

Commit a7ca962

Browse files
author
menyu
committed
EPv2: address review (graph gating, fp8 fail-fast, helper scope, robustness)
Two rounds of code-review fixes: - server_args: disable the prefill CUDA graph under EPv2 -- only the direct-mode decode masked-GEMM path is capture-safe; the direct-mode prefill (extend) path uses a non-masked layout with a host readback and is not capture-validated. - fp8: DeepGEMM UE8M0 weight requant now fails fast for the EPv2 FusedMoE layer with a clear message, instead of asserting isinstance(layer, DeepEPMoE). - deepseek_v2: revert the 3 invented a2a helpers back to the original inline backend checks plus `or is_epv2()`, so EPv2 integration is purely additive. Fix two over-broad sites where a wide helper had replaced narrower checks: the AMD gfx95 allocator-size path (restore is_deepep_class_backend + epv2) and enable_a2a_moe (restore is_deepep/is_mooncake + epv2) -- unrelated backends (nixl / ascend / flashinfer / megamoe) keep their original behavior. - epv2: dispatch_b/combine_b raise a clear RuntimeError when called without a preceding dispatch_a/combine_a, aligned with the combine_a stage check. - kernels: document that the masked-slab overflow fast-fail is skipped during CUDA graph capture, and that safety then relies on the static max_m = cap * ep_group_size upper bound. - utils: drop the now-unused a2a helpers; clarify the capability-resolver comment (it reads runner flags to build the contract, like the DeepEP dispatcher does; the dispatcher itself only consumes the resolved contract). No functional or perf change for EPv2 or DeepEP: re-verified chat-completions 3-question correctness (direct + hybrid), unit tests (7 passed), and 4 throughput points (decode/prefill, EPv2 vs DeepEP) -- all within run-to-run noise of the pre-fix numbers.
1 parent 57bffbe commit a7ca962

6 files changed

Lines changed: 76 additions & 58 deletions

File tree

python/sglang/srt/layers/moe/ep_moe/kernels.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1731,7 +1731,7 @@ def fp8_per_token_to_per_tensor_quant_triton(
17311731

17321732

17331733
# ---------------------------------------------------------------------------
1734-
# EPv2 decode masked-GEMM bridge (Claude): repack the expanded expert-packed
1734+
# EPv2 decode masked-GEMM bridge: repack the expanded expert-packed
17351735
# dispatch buffer into a regular [E_local, max_m, hidden] slab so DeepGEMM's
17361736
# *masked* grouped GEMM can bound compute by per-expert real counts (masked_m)
17371737
# instead of the dispatch capacity. All-GPU, static shapes -> cuda-graph safe.
@@ -1875,6 +1875,11 @@ def expand_to_masked_slab(
18751875
# Outside cuda graph capture, fail fast on slab overflow rather than return a
18761876
# silently truncated result. During capture we skip the host read to keep the
18771877
# path graph-safe; the eager warmup forward validates representative shapes.
1878+
# Safety under graph replay therefore relies on the static upper bound
1879+
# max_m = cap * ep_group_size holding: each rank sends at most `cap` tokens
1880+
# (enforced by the dispatch-entry assert) and a token contributes at most once
1881+
# per local expert, so no expert can exceed max_m. If those invariants change,
1882+
# graph replay would NOT fail-fast on overflow — re-validate before relying on it.
18781883
if not torch.cuda.is_current_stream_capturing() and int(overflow.item()) != 0:
18791884
raise RuntimeError(
18801885
f"EPv2 masked slab overflow: an expert received more than max_m="

python/sglang/srt/layers/moe/token_dispatcher/epv2.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -527,6 +527,10 @@ def dispatch_a(self, hidden_states: torch.Tensor, topk_output: TopKOutput) -> No
527527
self._dispatch_state = self._impl.dispatch_a(hidden_states, topk_output)
528528

529529
def dispatch_b(self) -> DispatchOutput:
530+
if self._dispatch_state is None:
531+
raise RuntimeError(
532+
"DeepEP v2 dispatch_b() called without a preceding dispatch_a()"
533+
)
530534
out = self._impl.dispatch_b(*self._dispatch_state)
531535
self._dispatch_state = None
532536
self._stage = _Stage.AFTER_DISPATCH
@@ -553,6 +557,10 @@ def combine_a(self, combine_input: CombineInput) -> None:
553557
self._combine_state = self._impl.combine_a(combine_input)
554558

555559
def combine_b(self) -> torch.Tensor:
560+
if self._combine_state is None:
561+
raise RuntimeError(
562+
"DeepEP v2 combine_b() called without a preceding combine_a()"
563+
)
556564
try:
557565
return self._impl.combine_b(*self._combine_state)
558566
finally:

python/sglang/srt/layers/moe/utils.py

Lines changed: 4 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -164,8 +164,10 @@ class EpV2RunnerCapability(NamedTuple):
164164
"""
165165
Describes the EPv2 dispatcher contract required by the active MoE runner.
166166
167-
The dispatcher should depend on this explicit contract instead of peeking at
168-
runner implementation details such as DeepGEMM JIT flags.
167+
This capability is resolved once (in get_epv2_runner_capability, which reads
168+
runner-side flags such as DeepGEMM JIT TMA/UE8M0 settings) and then consumed
169+
by the dispatcher. The dispatcher depends only on this resolved contract and
170+
does not peek at runner implementation details itself.
169171
"""
170172

171173
output_dtype: EpV2OutputDtype
@@ -468,48 +470,6 @@ def is_deepep_class_backend() -> bool:
468470
return b.is_deepep() or b.is_mooncake() or b.is_mori()
469471

470472

471-
def uses_a2a_moe_forward() -> bool:
472-
"""Return whether the active backend uses the A2A MoE forward path."""
473-
b = get_moe_a2a_backend()
474-
return (
475-
b.is_deepep()
476-
or b.is_mooncake()
477-
or b.is_nixl()
478-
or b.is_mori()
479-
or b.is_ascend_fuseep()
480-
or b.is_flashinfer()
481-
or b.is_epv2()
482-
)
483-
484-
485-
def uses_a2a_expert_parallel_metadata() -> bool:
486-
"""Return whether the backend needs EP metadata on DeepSeek MoE layers."""
487-
b = get_moe_a2a_backend()
488-
return (
489-
b.is_deepep()
490-
or b.is_mooncake()
491-
or b.is_nixl()
492-
or b.is_mori()
493-
or b.is_ascend_fuseep()
494-
or b.is_epv2()
495-
)
496-
497-
498-
def requires_shared_expert_tp1() -> bool:
499-
"""Return whether shared experts should be materialized with TP=1."""
500-
b = get_moe_a2a_backend()
501-
return (
502-
b.is_deepep()
503-
or b.is_mooncake()
504-
or b.is_nixl()
505-
or b.is_mori()
506-
or b.is_ascend_fuseep()
507-
or b.is_flashinfer()
508-
or b.is_megamoe()
509-
or b.is_epv2()
510-
)
511-
512-
513473
def is_flashinfer_cutedsl_v1_path() -> bool:
514474
"""CuteDSL v1 + DeepEP low-latency path (no MoeRunner, no autotune)."""
515475
return (

python/sglang/srt/layers/quantization/fp8.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1367,9 +1367,21 @@ def process_weights_after_loading_block_quant(self, layer: Module) -> None:
13671367
weight_block_size,
13681368
use_deepgemm_runner=will_use_deepgemm,
13691369
):
1370-
assert isinstance(
1371-
layer, DeepEPMoE
1372-
), "DeepGemm MoE is only supported with DeepEPMoE"
1370+
if not isinstance(layer, DeepEPMoE):
1371+
# UE8M0 in-place weight requant is only wired for the
1372+
# DeepEPMoE layer (legacy deepep backend). The EPv2
1373+
# backend uses FusedMoE, so fail fast with a clear
1374+
# message instead of asserting; use a pre-requantized
1375+
# FP8 checkpoint or --moe-a2a-backend deepep for
1376+
# checkpoints that need UE8M0 requant at load time.
1377+
raise NotImplementedError(
1378+
"DeepGEMM UE8M0 weight requant requires the "
1379+
f"DeepEPMoE layer, got {type(layer).__name__}. The "
1380+
"EPv2 backend does not support FP8 checkpoints that "
1381+
"need load-time UE8M0 requant yet; use a "
1382+
"pre-requantized checkpoint or --moe-a2a-backend "
1383+
"deepep."
1384+
)
13731385
requant_block_scale_ue8m0_for_deepgemm(
13741386
layer.w2_weight,
13751387
layer.w2_weight_scale_inv,

python/sglang/srt/models/deepseek_v2.py

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -101,9 +101,6 @@
101101
is_deepep_class_backend,
102102
is_sbo_enabled,
103103
is_tbo_enabled,
104-
requires_shared_expert_tp1,
105-
uses_a2a_expert_parallel_metadata,
106-
uses_a2a_moe_forward,
107104
)
108105
from sglang.srt.layers.quantization.base_config import QuantizationConfig
109106
from sglang.srt.layers.quantization.fp8 import Fp8Config
@@ -703,7 +700,14 @@ def __init__(
703700
# explicitly requested for DSV4 checkpoints whose shared scales are
704701
# not divisible by the global TP size.
705702
_shared_expert_use_tp1 = (
706-
requires_shared_expert_tp1()
703+
get_moe_a2a_backend().is_deepep()
704+
or get_moe_a2a_backend().is_mooncake()
705+
or get_moe_a2a_backend().is_nixl()
706+
or get_moe_a2a_backend().is_mori()
707+
or get_moe_a2a_backend().is_ascend_fuseep()
708+
or get_moe_a2a_backend().is_flashinfer()
709+
or get_moe_a2a_backend().is_megamoe()
710+
or get_moe_a2a_backend().is_epv2()
707711
or should_use_flashinfer_cutlass_moe_fp4_allgather()
708712
or envs.SGLANG_SHARED_EXPERT_TP1.get()
709713
)
@@ -778,7 +782,14 @@ def __init__(
778782

779783
self.top_k = config.num_experts_per_tok
780784

781-
if uses_a2a_expert_parallel_metadata():
785+
if (
786+
get_moe_a2a_backend().is_deepep()
787+
or get_moe_a2a_backend().is_mooncake()
788+
or get_moe_a2a_backend().is_nixl()
789+
or get_moe_a2a_backend().is_mori()
790+
or get_moe_a2a_backend().is_ascend_fuseep()
791+
or get_moe_a2a_backend().is_epv2()
792+
):
782793
# TODO: we will support tp < ep in the future
783794
self.ep_size = get_parallel().moe_ep_size
784795
self.num_experts = (
@@ -794,7 +805,15 @@ def __init__(
794805
else None
795806
)
796807

797-
self._enable_a2a_moe = uses_a2a_moe_forward()
808+
self._enable_a2a_moe = (
809+
get_moe_a2a_backend().is_deepep()
810+
or get_moe_a2a_backend().is_mooncake()
811+
or get_moe_a2a_backend().is_nixl()
812+
or get_moe_a2a_backend().is_mori()
813+
or get_moe_a2a_backend().is_ascend_fuseep()
814+
or get_moe_a2a_backend().is_flashinfer()
815+
or get_moe_a2a_backend().is_epv2()
816+
)
798817
self._fuse_shared_experts_inside_sbo = SboFlags.fuse_shared_experts_inside_sbo()
799818

800819
def get_moe_weights(self):
@@ -2413,9 +2432,12 @@ def __init__(
24132432
for i in range(len(self.layers)):
24142433
if isinstance(self.layers[i].mlp, DeepseekV2MoE):
24152434
# tp_size = get_parallel().tp_size
2416-
# requires_shared_expert_tp1() is epv2-aware (is_deepep_class_backend
2417-
# is not); keep it so EPv2 also materializes shared experts at TP=1.
2418-
is_a2a_moe = requires_shared_expert_tp1()
2435+
# Keep the original deepep-class scope here and only add EPv2,
2436+
# so unrelated backends' allocator sizing is unchanged.
2437+
is_a2a_moe = (
2438+
is_deepep_class_backend()
2439+
or get_moe_a2a_backend().is_epv2()
2440+
)
24192441
tp_size = 1 if is_a2a_moe else get_parallel().tp_size
24202442
intermediate_size = (
24212443
config.moe_intermediate_size * config.n_shared_experts
@@ -2435,7 +2457,11 @@ def __init__(
24352457
)
24362458
)
24372459
self.layers_to_capture = []
2438-
self.enable_a2a_moe = uses_a2a_moe_forward()
2460+
self.enable_a2a_moe = (
2461+
get_moe_a2a_backend().is_deepep()
2462+
or get_moe_a2a_backend().is_mooncake()
2463+
or get_moe_a2a_backend().is_epv2()
2464+
)
24392465

24402466
# llama_4_scaling: for supporting Mistral-Large-3 model
24412467
self.llama_4_scaling_config = getattr(config, "llama_4_scaling", None)

python/sglang/srt/server_args.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5547,6 +5547,13 @@ def _handle_a2a_moe(self):
55475547
if not epv2_graph_ok:
55485548
self.cuda_graph_config.decode.backend = Backend.DISABLED
55495549
self.cuda_graph_config.prefill.backend = Backend.DISABLED
5550+
else:
5551+
# Only the direct-mode decode masked-GEMM path is capture-safe
5552+
# (static shapes, no host readback). The direct-mode prefill
5553+
# (extend) path goes through the non-masked contiguous layout with
5554+
# a host readback and is not capture-validated, so keep the decode
5555+
# graph but always disable the prefill graph under EPv2.
5556+
self.cuda_graph_config.prefill.backend = Backend.DISABLED
55505557
logger.warning(
55515558
f"DeepEP v2 MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
55525559
)

0 commit comments

Comments
 (0)