Replace all-reduce + dp_scatter with reduce_scatterv for DP attention#22642
Open
YAMY1234 wants to merge 2 commits intosgl-project:mainfrom
Open
Replace all-reduce + dp_scatter with reduce_scatterv for DP attention#22642YAMY1234 wants to merge 2 commits intosgl-project:mainfrom
YAMY1234 wants to merge 2 commits intosgl-project:mainfrom
Conversation
… MoE For DP attention with EP, the default MoE combine path performs an all-reduce followed by dp_scatter, which is equivalent to two separate communication steps. This replaces them with a single reduce_scatterv call that combines reduce and scatter in one operation, improving throughput by ~7.7% (53k -> 57k tok/s on Qwen3.5-397B-A17B-FP8 DEP4). Only the post-kernel communication (combine phase) is changed; the dispatch phase and kernel inputs remain untouched. Made-with: Cursor
Contributor
There was a problem hiding this comment.
Code Review
This pull request implements a reduce_scatterv optimization for MoE layers when using Data Parallel attention with Expert Parallelism. The review feedback suggests refactoring the communication logic in LayerCommunicator to use pre-allocated buffers for better memory efficiency and compatibility with symmetric memory. Additionally, a potential synchronization issue was identified in the qwen2_moe model where the shared expert might perform an inconsistent all-reduce, leading to tensor mismatches.
Ensures symmetric memory compatibility by using the standard DP buffer allocation path, and avoids an extra torch.empty inside reduce_scatterv. Made-with: Cursor
Contributor
Author
|
/tag-and-rerun-ci |
Fridge003
approved these changes
Apr 13, 2026
Contributor
Author
|
/rerun-failed-ci |
1 similar comment
Contributor
Author
|
/rerun-failed-ci |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
For DP attention with Expert Parallelism (EP), the default MoE communication path performs two separate operations after the MoE kernel:
tensor_model_parallel_all_reduce— reduces expert outputs across all DP workersdp_scatter— extracts each worker's local token slice from the global resultThis is functionally equivalent to a single
reduce_scatterv, which fuses the reduce and scatter into one NCCL collective, cutting the number of communication rounds in half.Modifications
4 files changed, ~35 lines added
python/sglang/srt/layers/moe/utils.py: Addedshould_use_dp_reduce_scatterv()— activates when DP attention + EP is enabled, no DeepEP or FP4 allgather path is active, andep_size == dp_size.python/sglang/srt/models/qwen2_moe.py: Skiptensor_model_parallel_all_reducewhenshould_use_dp_reduce_scatterv()is true (the reduction is deferred to the communicator).python/sglang/srt/layers/communicator.py: InCommunicateSummableTensorPairFn._scatter_hidden_states, usereduce_scatterv(viaget_tp_group().reduce_scatterv) instead ofdp_scatterwhen the flag is active. Output is allocated fromget_local_dp_buffer()for symmetric memory compatibility.python/sglang/srt/layers/moe/__init__.py: Export the new utility function.The dispatch phase and kernel inputs are completely untouched — only the post-kernel communication (combine/scatter) is changed.
Accuracy Tests
GSM8K 8-shot on Qwen3.5-397B-A17B-FP8, DP4 EP4, 1319 examples, max_tokens=16384:
Accuracy is identical — the optimization is mathematically equivalent (reduce + scatter = reduce_scatter).
Speed Tests and Profiling
Throughput
Max-throughput benchmark on Qwen3.5-397B-A17B-FP8, 1×GB200 node (4 GPUs), DP4 EP4 TP4, ISL=1000 OSL=1, concurrency=4096:
Profiling (100 decode steps, Torch Profiler)
NCCL communication summary (DP0, TP0):
Per-kernel latency (single decode step, steady state):
ncclDevKernel_Reduce_Sum_bf16(reduce_scatterv)ncclDevKernel_AllReduce_Sum_bf16(baseline)Why reduce_scatterv is faster:
AllReduce=ReduceScatter+AllGather: it reduces data across all ranks and broadcasts the full result back to every rank. In DP attention, each rank only needs its own token subset for the next attention layer — the AllGather half is wasted work.reduce_scattervperforms only the reduce-scatter phase, delivering each rank exactly the tokens it owns. This cuts the communication volume roughly in half (~37% per-kernel latency reduction, ~13.6% total NCCL time reduction), directly translating to the +7.7% end-to-end throughput gain.Checklist
Review and Merge Process
/tag-and-rerun-ci,/tag-run-ci-label,/rerun-failed-ci