Skip to content

Replace all-reduce + dp_scatter with reduce_scatterv for DP attention#22642

Open
YAMY1234 wants to merge 2 commits intosgl-project:mainfrom
YAMY1234:feat/dp-gatherv-scatterv
Open

Replace all-reduce + dp_scatter with reduce_scatterv for DP attention#22642
YAMY1234 wants to merge 2 commits intosgl-project:mainfrom
YAMY1234:feat/dp-gatherv-scatterv

Conversation

@YAMY1234
Copy link
Copy Markdown
Contributor

@YAMY1234 YAMY1234 commented Apr 12, 2026

Motivation

For DP attention with Expert Parallelism (EP), the default MoE communication path performs two separate operations after the MoE kernel:

  1. tensor_model_parallel_all_reduce — reduces expert outputs across all DP workers
  2. dp_scatter — extracts each worker's local token slice from the global result

This is functionally equivalent to a single reduce_scatterv, which fuses the reduce and scatter into one NCCL collective, cutting the number of communication rounds in half.

Modifications

4 files changed, ~35 lines added

  • python/sglang/srt/layers/moe/utils.py: Added should_use_dp_reduce_scatterv() — activates when DP attention + EP is enabled, no DeepEP or FP4 allgather path is active, and ep_size == dp_size.
  • python/sglang/srt/models/qwen2_moe.py: Skip tensor_model_parallel_all_reduce when should_use_dp_reduce_scatterv() is true (the reduction is deferred to the communicator).
  • python/sglang/srt/layers/communicator.py: In CommunicateSummableTensorPairFn._scatter_hidden_states, use reduce_scatterv (via get_tp_group().reduce_scatterv) instead of dp_scatter when the flag is active. Output is allocated from get_local_dp_buffer() for symmetric memory compatibility.
  • python/sglang/srt/layers/moe/__init__.py: Export the new utility function.

The dispatch phase and kernel inputs are completely untouched — only the post-kernel communication (combine/scatter) is changed.

Accuracy Tests

GSM8K 8-shot on Qwen3.5-397B-A17B-FP8, DP4 EP4, 1319 examples, max_tokens=16384:

Configuration GSM8K Score std
reduce_scatterv (this PR) 97.86% 0.1446
baseline (all-reduce + dp_scatter) 97.86% 0.1446

Accuracy is identical — the optimization is mathematically equivalent (reduce + scatter = reduce_scatter).

Speed Tests and Profiling

Throughput

Max-throughput benchmark on Qwen3.5-397B-A17B-FP8, 1×GB200 node (4 GPUs), DP4 EP4 TP4, ISL=1000 OSL=1, concurrency=4096:

Configuration Throughput (tok/s) Change
baseline (all-reduce + dp_scatter) 53,006
reduce_scatterv (this PR) 57,115 +7.7%
image

Profiling (100 decode steps, Torch Profiler)

NCCL communication summary (DP0, TP0):

Metric Baseline reduce_scatterv Change
Total NCCL time 8,326ms 7,193ms −13.6%
AllReduce count 4,920 (~49/step) 2,460 (~25/step) −50%
AllReduce total time 8,324ms 4,385ms −47.3%
Reduce (scatterv) count 0 2,460 (~25/step) new
Reduce (scatterv) total time 0ms 2,806ms new
Comm / Compute ratio 105.5% 97.3% comm no longer bottleneck

Per-kernel latency (single decode step, steady state):

Kernel Duration Note
ncclDevKernel_Reduce_Sum_bf16 (reduce_scatterv) ~161µs replaces AllReduce for MoE layers
ncclDevKernel_AllReduce_Sum_bf16 (baseline) ~257µs original MoE post-kernel comm
image image

Why reduce_scatterv is faster:

AllReduce = ReduceScatter + AllGather: it reduces data across all ranks and broadcasts the full result back to every rank. In DP attention, each rank only needs its own token subset for the next attention layer — the AllGather half is wasted work.

reduce_scatterv performs only the reduce-scatter phase, delivering each rank exactly the tokens it owns. This cuts the communication volume roughly in half (~37% per-kernel latency reduction, ~13.6% total NCCL time reduction), directly translating to the +7.7% end-to-end throughput gain.

Checklist

Review and Merge Process

  1. Ping Merge Oncalls to start the process. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • Common commands include /tag-and-rerun-ci, /tag-run-ci-label, /rerun-failed-ci
  4. After green CI and required approvals, ask Merge Oncalls or people with Write permission to merge the PR.

… MoE

For DP attention with EP, the default MoE combine path performs an
all-reduce followed by dp_scatter, which is equivalent to two separate
communication steps. This replaces them with a single reduce_scatterv
call that combines reduce and scatter in one operation, improving
throughput by ~7.7% (53k -> 57k tok/s on Qwen3.5-397B-A17B-FP8 DEP4).

Only the post-kernel communication (combine phase) is changed; the
dispatch phase and kernel inputs remain untouched.

Made-with: Cursor
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements a reduce_scatterv optimization for MoE layers when using Data Parallel attention with Expert Parallelism. The review feedback suggests refactoring the communication logic in LayerCommunicator to use pre-allocated buffers for better memory efficiency and compatibility with symmetric memory. Additionally, a potential synchronization issue was identified in the qwen2_moe model where the shared expert might perform an inconsistent all-reduce, leading to tensor mismatches.

Ensures symmetric memory compatibility by using the standard DP buffer
allocation path, and avoids an extra torch.empty inside reduce_scatterv.

Made-with: Cursor
@YAMY1234 YAMY1234 marked this pull request as ready for review April 13, 2026 03:35
@YAMY1234
Copy link
Copy Markdown
Contributor Author

/tag-and-rerun-ci

@YAMY1234
Copy link
Copy Markdown
Contributor Author

/rerun-failed-ci

1 similar comment
@YAMY1234
Copy link
Copy Markdown
Contributor Author

/rerun-failed-ci

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants