Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions python/sglang/srt/layers/communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
get_attention_dp_size,
get_attention_tp_rank,
get_attention_tp_size,
get_dp_global_num_tokens,
get_global_dp_buffer,
get_local_dp_buffer,
is_allocation_symmetric,
Expand All @@ -55,6 +56,7 @@
from sglang.srt.layers.flashinfer_comm_fusion import is_flashinfer_allreduce_unavailable
from sglang.srt.layers.moe import (
get_moe_a2a_backend,
should_use_dp_reduce_scatterv,
should_use_flashinfer_cutlass_moe_fp4_allgather,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
Expand Down Expand Up @@ -1007,8 +1009,13 @@ def _scatter_hidden_states(
get_local_dp_buffer(),
hidden_states,
)
if allow_reduce_scatter and forward_batch.dp_padding_mode.is_max_len():
# When using padding, all_reduce is skipped after MLP and MOE and reduce scatter is used here instead.
if should_use_dp_reduce_scatterv():
get_tp_group().reduce_scatterv(
global_hidden_states,
output=hidden_states,
sizes=get_dp_global_num_tokens(),
)
elif allow_reduce_scatter and forward_batch.dp_padding_mode.is_max_len():
dp_reduce_scatter_tensor(hidden_states, global_hidden_states)
else:
dp_scatter(hidden_states, global_hidden_states, forward_batch)
Expand Down
2 changes: 2 additions & 0 deletions python/sglang/srt/layers/moe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
get_tbo_token_distribution_threshold,
initialize_moe_config,
is_tbo_enabled,
should_use_dp_reduce_scatterv,
should_use_flashinfer_cutlass_moe_fp4_allgather,
)

Expand All @@ -23,6 +24,7 @@
"get_moe_a2a_backend",
"get_moe_runner_backend",
"get_deepep_mode",
"should_use_dp_reduce_scatterv",
"should_use_flashinfer_cutlass_moe_fp4_allgather",
"is_tbo_enabled",
"get_tbo_token_distribution_threshold",
Expand Down
15 changes: 15 additions & 0 deletions python/sglang/srt/layers/moe/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,21 @@ def should_use_flashinfer_cutlass_moe_fp4_allgather():
)


def should_use_dp_reduce_scatterv():
"""
Use reduce_scatterv in the standard dispatcher's combine() for DP attention
with EP, replacing the default all-reduce + dp_scatter path.
Only changes the combine (post-kernel) communication; dispatch is unchanged.
"""
return (
not should_use_flashinfer_cutlass_moe_fp4_allgather()
and get_moe_a2a_backend().is_none()
and is_dp_attention_enabled()
and get_attention_dp_size() > 1
and get_moe_expert_parallel_world_size() == get_attention_dp_size()
)


@contextmanager
def speculative_moe_backend_context():
"""
Expand Down
2 changes: 2 additions & 0 deletions python/sglang/srt/models/qwen2_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe import (
get_moe_a2a_backend,
should_use_dp_reduce_scatterv,
should_use_flashinfer_cutlass_moe_fp4_allgather,
)
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
Expand Down Expand Up @@ -352,6 +353,7 @@ def forward(
and not should_allreduce_fusion
and not use_reduce_scatter
and not should_use_flashinfer_cutlass_moe_fp4_allgather()
and not should_use_dp_reduce_scatterv()
):
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)

Expand Down
Loading