[Feature] Add DeepEPv2 (ElasticBuffer) MoE A2A backend#29525
[Feature] Add DeepEPv2 (ElasticBuffer) MoE A2A backend#29525MengYu10151 wants to merge 5 commits into
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces support for the DeepEP v2 (ElasticBuffer) MoE A2A backend (epv2), integrating it with both the deep_gemm and triton runners. It adds the EpV2Dispatcher, implements Triton kernels for repacking expanded expert-packed buffers into masked-GEMM slabs to optimize DeepGEMM execution, and updates model configurations, server arguments, and CI scripts accordingly. The feedback highlights several critical and high-severity issues: a missing boundary mask in the _fwd_kernel_ep_scatter_psum_init Triton kernel that can cause out-of-bounds writes, the potential propagation of NaN values due to using torch.empty instead of torch.zeros for gather_out, a robustness improvement for retrieving LOCAL_RANK during distributed initialization, and a missing import check before calling the FP8 quantization kernel in the EPv2 dispatcher.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| off_expert = tl.arange(0, BLOCK_E) | ||
| for start_m in tl.range(0, cur_token_num, BLOCK_E, num_stages=4): | ||
| tl.store(m_indices + cur_start + start_m + off_expert, cur_expert) |
There was a problem hiding this comment.
The Triton kernel _fwd_kernel_ep_scatter_psum_init writes to m_indices without any boundary mask. Since different experts have arbitrary token counts that are not multiples of BLOCK_E, and different program instances run concurrently, this unmasked store will write out-of-bounds of the current expert's region and corrupt the indices of subsequent experts. A mask idx < cur_end must be applied to prevent non-deterministic corruption and out-of-bounds writes.
| off_expert = tl.arange(0, BLOCK_E) | |
| for start_m in tl.range(0, cur_token_num, BLOCK_E, num_stages=4): | |
| tl.store(m_indices + cur_start + start_m + off_expert, cur_expert) | |
| off_expert = tl.arange(0, BLOCK_E) | |
| for start_m in tl.range(0, cur_token_num, BLOCK_E, num_stages=4): | |
| idx = cur_start + start_m + off_expert | |
| tl.store(m_indices + idx, cur_expert, mask=idx < cur_end) |
| gather_out = torch.empty( | ||
| running_state["hidden_states_shape"], | ||
| device=running_state["hidden_states_device"], | ||
| dtype=torch.bfloat16, | ||
| ) |
There was a problem hiding this comment.
Using torch.empty to allocate gather_out can leave unrouted or padding tokens with uninitialized garbage/NaN values. Since gather_out is the final output of the MoE layer and is added directly to the residual connection, any garbage values will propagate through the model and potentially cause NaNs. Allocating with torch.zeros is a safer and necessary defensive practice here.
| gather_out = torch.empty( | |
| running_state["hidden_states_shape"], | |
| device=running_state["hidden_states_device"], | |
| dtype=torch.bfloat16, | |
| ) | |
| gather_out = torch.zeros( | |
| running_state["hidden_states_shape"], | |
| device=running_state["hidden_states_device"], | |
| dtype=torch.bfloat16, | |
| ) |
| init_local_rank = local_rank | ||
| if init_local_rank == -1: | ||
| if distributed_init_method == "env://": | ||
| init_local_rank = int(os.environ.get("LOCAL_RANK", "0")) | ||
| else: | ||
| init_local_rank = rank |
There was a problem hiding this comment.
When local_rank is -1 and the init method is not "env://", the code currently falls back to using rank directly. However, launchers like torchrun still set the LOCAL_RANK environment variable even when using other initialization methods (e.g., TCP or shared file). Checking LOCAL_RANK first regardless of the initialization method is more robust and standard.
| init_local_rank = local_rank | |
| if init_local_rank == -1: | |
| if distributed_init_method == "env://": | |
| init_local_rank = int(os.environ.get("LOCAL_RANK", "0")) | |
| else: | |
| init_local_rank = rank | |
| init_local_rank = local_rank | |
| if init_local_rank == -1: | |
| env_local_rank = os.environ.get("LOCAL_RANK") | |
| if env_local_rank is not None: | |
| init_local_rank = int(env_local_rank) | |
| else: | |
| init_local_rank = rank |
| if self._uses_fp8_dispatch_output(): | ||
| if use_masked: |
There was a problem hiding this comment.
In the use_masked branch, sglang_per_token_group_quant_fp8 is called directly without calling _ensure_fp8_quant_available(). If the quantization kernel is not available (e.g., due to import errors), this will raise a confusing TypeError: 'NoneType' object is not callable instead of a clean ImportError. Calling _ensure_fp8_quant_available() at the start of the FP8 block ensures a clean fail-fast behavior.
| if self._uses_fp8_dispatch_output(): | |
| if use_masked: | |
| if self._uses_fp8_dispatch_output(): | |
| _ensure_fp8_quant_available() | |
| if use_masked: |
Introduce epv2 as a new MoE all-to-all backend backed by DeepEP v2 ElasticBuffer, kept semantically separate from the legacy deepep backend -- it only replaces the expert-parallel dispatch/combine. Adds EpV2Dispatcher with two-phase dispatch_a/b + combine_a/b mirroring the DeepEP dispatcher, runner-capability resolution (MoeA2ABackend.EPV2 / EpV2OutputDtype / EpV2RunnerCapability with fail-fast on unsupported runner/dtype), server args (--moe-a2a-backend epv2, --epv2-mode direct|hybrid, --epv2-dispatcher-output-dtype auto|fp8|bf16) with TBO/SBO fail-fast, env knobs, dispatcher registration, and DeepSeek-V2/V3 wiring. DeepEP import is guarded so importing the module without DeepEP v2 installed does not fail.
Bridge EPv2 dispatch output to the MoE runners. The deep_gemm adapter handles the FP8 + scale contract: direct/decode expanded layout is repacked to a masked [E_local, max_m, hidden] slab for the masked grouped GEMM (cuda-graph safe via static shapes), and hybrid/prefill non-expanded layout is scattered to the contiguous grouped GEMM. The triton adapter consumes BF16 with valid-row compaction. New repack kernels (expand_to_masked_slab / masked_slab_to_expand, ep_scatter_from_psum, ep_expand_init_m_indices_from_psum) are block-row vectorized and apply top-k weights in-kernel. Adds a CUDA unit test for the masked-slab repack roundtrip (bf16/fp8, empty experts, boundary, overflow) that depends only on triton kernels, so it runs in CI without DeepEP installed.
DeepEP v2 calculate_buffer_size -> ncclTeamWorld can segfault when the world process group is created without a device_id. Pass device_id on the EPv2 group init path so this is deterministic, scoped to the EPv2 path without affecting non-EPv2 startup.
Bump the pinned DeepEP commit (ci_install_deepep.sh and docker/Dockerfile) from 9af0e0d0 to d4f41e4 (v1.2.1-32), the first pin that ships EPv2 ElasticBuffer (EPv2 was introduced in DeepEP sgl-project#605 / b306af0). d4f41e4 is a descendant of the old pin and still exports the legacy v1 Buffer, so the existing deepep backend keeps working while the epv2 backend can import ElasticBuffer. EPv2 links NCCL symmetric memory (nccl::NCCLSymmetricMemoryContext), so the build needs an NCCL shipping the symmetric-memory headers (verified with nvidia-nccl-cu13>=2.30.7).
…stness) Two rounds of code-review fixes: - server_args: disable the prefill CUDA graph under EPv2 -- only the direct-mode decode masked-GEMM path is capture-safe; the direct-mode prefill (extend) path uses a non-masked layout with a host readback and is not capture-validated. - fp8: DeepGEMM UE8M0 weight requant now fails fast for the EPv2 FusedMoE layer with a clear message, instead of asserting isinstance(layer, DeepEPMoE). - deepseek_v2: revert the 3 invented a2a helpers back to the original inline backend checks plus `or is_epv2()`, so EPv2 integration is purely additive. Fix two over-broad sites where a wide helper had replaced narrower checks: the AMD gfx95 allocator-size path (restore is_deepep_class_backend + epv2) and enable_a2a_moe (restore is_deepep/is_mooncake + epv2) -- unrelated backends (nixl / ascend / flashinfer / megamoe) keep their original behavior. - epv2: dispatch_b/combine_b raise a clear RuntimeError when called without a preceding dispatch_a/combine_a, aligned with the combine_a stage check. - kernels: document that the masked-slab overflow fast-fail is skipped during CUDA graph capture, and that safety then relies on the static max_m = cap * ep_group_size upper bound. - utils: drop the now-unused a2a helpers; clarify the capability-resolver comment (it reads runner flags to build the contract, like the DeepEP dispatcher does; the dispatcher itself only consumes the resolved contract). No functional or perf change for EPv2 or DeepEP: re-verified chat-completions 3-question correctness (direct + hybrid), unit tests (7 passed), and 4 throughput points (decode/prefill, EPv2 vs DeepEP) -- all within run-to-run noise of the pre-fix numbers.
9058031 to
a7ca962
Compare
Motivation
DeepEP v2 introduces
ElasticBuffer, a faster expert-parallel all-to-all withboth a hierarchical (hybrid) and a direct communication mode. This PR integrates
it as a new standalone MoE A2A backend
epv2, kept semantically separate fromthe legacy
deepepbackend: it only replaces the expert-parallel dispatch/combineand leaves the existing DeepEP path untouched. Target workload is DeepGEMM FP8 MoE
(e.g. DeepSeek-V3 / V4-Flash) on EP.
Key semantics:
--moe-a2a-backend epv2--epv2-mode direct|hybridmaps to ElasticBuffer's communication mode (decodeuses
direct→ native expanded layout → masked GEMM + CUDA graph; prefill useshybrid→ non-expanded layout → contiguous GEMM). This is not equivalent tolegacy DeepEP
normal/low_latency; the mode is fixed at server init.--epv2-dispatcher-output-dtype auto|fp8|bf16: the runner↔dtype contract isdeep_gemm→ FP8 activation+scale (main path),triton→ BF16 (functional);any other runner/dtype fails fast.
Modifications
token_dispatcher/epv2.py,utils.py,server_args.py,environ.py, dispatcher registration, DeepSeek-V2/V3 wiring):EpV2Dispatcherwith two-phase
dispatch_a/b+combine_a/b(mirrors the DeepEP dispatcher),runner-capability resolution (
MoeA2ABackend.EPV2/EpV2OutputDtype/EpV2RunnerCapability, fail-fast on unsupported runner/dtype) and--epv2-*server args with TBO/SBO fail-fast. DeepEP import is guarded, so importing the
module without DeepEP v2 installed does not fail (CI-safe).
moe_runner/deep_gemm.py,moe_runner/triton.py,ep_moe/kernels.py,fused_moe_triton/layer.py):deep_gemm adapter bridges the FP8+scale contract — direct/decode expanded layout
is repacked into a masked
[E_local, max_m, hidden]slab for the masked groupedGEMM (static shapes → CUDA-graph safe); hybrid/prefill non-expanded layout is
scattered to the contiguous grouped GEMM. triton adapter consumes BF16. New
repack kernels (
expand_to_masked_slab/masked_slab_to_expand,ep_scatter_from_psum,ep_expand_init_m_indices_from_psum) are block-rowvectorized and apply top-k weights in-kernel.
distributed/parallel_state.py): passdevice_idon the EPv2 group init path (DeepEP v2
calculate_buffer_size → ncclTeamWorldcan segfault without it), scoped to the EPv2 path.
scripts/ci/cuda/ci_install_deepep.sh,docker/Dockerfile):bump the pinned DeepEP commit to
d4f41e4(the first pin shipping EPv2ElasticBuffer; a descendant of the previous pin that still exports the legacyv1
Buffer, so the existing deepep backend keeps working). EPv2 links NCCLsymmetric memory, so the build needs an NCCL shipping the symmetric-memory
headers (verified with
nvidia-nccl-cu13>=2.30.7).server_args.py): EPv2 enables the CUDA graph only for thedirect-mode decode masked-GEMM path, which has static shapes. The direct-mode
prefill (extend) path uses a non-masked layout with a host readback and is not
capture-safe, so the prefill graph is disabled under EPv2.
fp8.py): DeepGEMM UE8M0 load-time weight requant isonly wired for the legacy
DeepEPMoElayer. Under the EPv2FusedMoElayer itfails fast with a clear message (use a pre-requantized FP8 checkpoint or
--moe-a2a-backend deepep) rather than asserting on the layer type.Accuracy Tests
Setup: DeepSeek-V4-Flash-FP8, H20×8,
--tp 8 --dp 8 --ep 8 --enable-dp-attention,--moe-runner-backend deep_gemm,--kv-cache-dtype fp8_e4m3./v1/chat/completions3-question set (capitals CN/JP,17*23+19, foxtranslation) PASS on both
--epv2-mode directandhybrid.test/registered/unit/layers/moe/test_epv2_masked_slab.py(masked-slabrepack roundtrip: bf16/fp8, empty experts, single-hot, boundary, overflow):
7 passed. Depends only on Triton kernels — runs in CI without DeepEP.
Speed Tests and Profiling
Setup as above; normal (non-disaggregated) server, TBO/SBO off, single run.
End-to-end throughput
directvslow_latency, ISL=1 OSL=1024 CC=1024 cap=128, graph onhybridvsnormal, ISL=1024 OSL=1 CC=128 cap=1024, np=512(Single run; decode gap floats −0.8%…−2.1% and prefill +0.6%…+1.3% across runs —
direction stable, magnitude needs multi-run averaging.)
Per-module GPU-kernel time (end-to-end torch trace, total = per-call × calls, single TP rank)
dispatch/combine totals include spin-wait (the dispatch kernel busy-waits for all
ranks to sync); since both backends run the same batch and sync points, the total
still reflects which backend's communication is faster overall. Comparing per-call
alone is misleading — DeepEP LL's decode dispatch is 2 kernels per call (x430 vs
EPv2 x215).
Decode (ms), DeepEP LL vs EPv2 direct — ISL=1 OSL=256 CC=128:
Prefill (ms), DeepEP normal vs EPv2 hybrid — ISL=1024 OSL=1 CC=128:
Read-out:
kernels, same shapes.
dispatch −3.8%, combine −17%). EPv2 issues a single elastic dispatch/combine;
DeepEP normal runs the multi-kernel legacy intranode path
(
notify_dispatch+dispatch+cached_notify_combine+combine).layout natively, zero repack).
is no CUDA graph, so the communication win lands directly on the critical path
(+2.4% throughput). In decode, the repack is pure compute on the critical path
(fully charged), while the dispatch total is spin-wait-dominated and its saving
does not fully land on the critical path — so the repack cost dominates the net.
Checklist
registered-test checks all green on the changed files).
server-arg behavior (prefill-graph-disable, runner/dtype reject) needs a real
GPU/model to materialize
cuda_graph_config, so it is covered by e2e smoke (bootlog shows
cuda_graph_config(decode=full, prefill=disabled)).Notes for reviewers
DEEPEP_COMMITaffects all DeepEP users (not just EPv2).d4f41e4is back-compatible (legacyBufferstill present); a maintainer shouldrun the legacy DeepEP nightly to confirm no regression. On this single-node H20
dev box, legacy
deepep --deepep-mode normalond4f41e4passes the 3-questionset, and
low_latencyruns correctly (a standalone single-run repro landed at~14.9k tok/s, in line with the prior baseline) — so
d4f41e4does not breaklegacy. This is a compatibility check run separately; the decode gap table above
uses its own co-measured LL baseline (15010), so the two LL numbers come from
different runs.
epv2_modecouples ElasticBuffer hybrid/direct with the SGLang expanded/non-expanded layout policy (documented above); a future PR could split them into
separate knobs. Capacity-overflow precheck and keyed EPv2Buffer cache are noted
as follow-ups.