Skip to content
Open
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/advanced_features/server_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s
| Argument | Description | Defaults | Options |
| --- | --- | --- | --- |
| `--expert-parallel-size`<br>`--ep-size`<br>`--ep` | The expert parallelism size. | `1` | Type: int |
| `--moe-a2a-backend` | Select the backend for all-to-all communication for expert parallelism. | `none` | `none`, `deepep`, `mooncake`, `ascend_fuseep`|
| `--moe-a2a-backend` | Select the backend for all-to-all communication for expert parallelism. | `none` | `none`, `deepep`, `mooncake`, `mori`, `ascend_fuseep`|
| `--moe-runner-backend` | Choose the runner backend for MoE. | `auto` | `auto`, `deep_gemm`, `triton`, `triton_kernel`, `flashinfer_trtllm`, `flashinfer_cutlass`, `flashinfer_mxfp4`, `flashinfer_cutedsl`, `cutlass` |
| `--flashinfer-mxfp4-moe-precision` | Choose the computation precision of flashinfer mxfp4 moe | `default` | `default`, `bf16` |
| `--enable-flashinfer-allreduce-fusion` | Enable FlashInfer allreduce fusion with Residual RMSNorm. | `False` | bool flag (set to enable) |
Expand Down
11 changes: 9 additions & 2 deletions python/sglang/srt/batch_overlap/operations_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
from sglang.srt.batch_overlap.operations import Operation
from sglang.srt.layers.moe.token_dispatcher import DeepEPConfig
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.utils import is_hip

_is_hip = is_hip()


@dataclass
Expand Down Expand Up @@ -91,7 +94,9 @@ def _compute_moe_deepseek_layer_operations_strategy_tbo(
def _compute_moe_deepseek_blog_prefill(layer):
device_properties = torch.cuda.get_device_properties(device="cuda")
total_num_sms = device_properties.multi_processor_count
deep_gemm_num_sms = total_num_sms - DeepEPConfig.get_instance().num_sms
deep_gemm_num_sms = None
if not _is_hip:
deep_gemm_num_sms = total_num_sms - DeepEPConfig.get_instance().num_sms

return OperationsStrategy(
deep_gemm_num_sms=deep_gemm_num_sms,
Expand Down Expand Up @@ -168,7 +173,9 @@ def _compute_moe_qwen3_layer_operations_strategy_tbo(
def _compute_moe_qwen3_prefill(layer):
device_properties = torch.cuda.get_device_properties(device="cuda")
total_num_sms = device_properties.multi_processor_count
deep_gemm_num_sms = total_num_sms - DeepEPConfig.get_instance().num_sms
deep_gemm_num_sms = None
if not _is_hip:
deep_gemm_num_sms = total_num_sms - DeepEPConfig.get_instance().num_sms

return OperationsStrategy(
deep_gemm_num_sms=deep_gemm_num_sms,
Expand Down
5 changes: 5 additions & 0 deletions python/sglang/srt/batch_overlap/two_batch_overlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from sglang.srt.layers.moe.token_dispatcher import (
DeepEPDispatcher,
MooncakeEPDispatcher,
MoriEPDispatcher,
)
from sglang.srt.layers.moe.token_dispatcher.base import BaseDispatcher
from sglang.srt.managers.schedule_batch import ScheduleBatch
Expand Down Expand Up @@ -1027,6 +1028,10 @@ def __init__(self, **kwargs):
self._inners = [
MooncakeEPDispatcher(**kwargs) for _ in range(num_inner_dispatchers)
]
elif get_moe_a2a_backend().is_mori():
self._inners = [
MoriEPDispatcher(**kwargs) for _ in range(num_inner_dispatchers)
]

def _execute(self, name, tbo_subbatch_index: Optional[int] = None, **kwargs):
return getattr(self._inners[tbo_subbatch_index or 0], name)(**kwargs)
Expand Down
214 changes: 162 additions & 52 deletions python/sglang/srt/layers/attention/aiter_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,7 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
# num_kv_splits_indptr = None

if forward_batch.forward_mode.is_decode_or_idle():
if spec_info is None:
if spec_info is None or forward_batch.forward_mode.is_idle():
kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
kv_indptr = kv_indptr[: bs + 1]
kv_indices = torch.empty(
Expand Down Expand Up @@ -1074,6 +1074,17 @@ def init_forward_metadata_replay_cuda_graph(
seq_lens_cpu: Optional[torch.Tensor],
):

num_kv_splits = None
# num_kv_splits_indptr = None

work_metadata = None
work_info_set = None
work_indptr = None

reduce_indptr = None
reduce_final_map = None
reduce_partial_map = None

if forward_mode.is_decode_or_idle():
kv_indptr = self.kv_indptr
kv_indices = self.cuda_graph_kv_indices
Expand All @@ -1093,6 +1104,58 @@ def init_forward_metadata_replay_cuda_graph(
kv_indptr[: spec_info.kv_indptr.shape[0]] = spec_info.kv_indptr
kv_indices[: spec_info.kv_indices.shape[0]] = spec_info.kv_indices

if self.use_mla:
qo_indptr = self.qo_indptr_[: bs + 1]
qo_indptr[1 : bs + 1] = torch.cumsum(
self.cuda_graph_kv_last_page_len[:bs], dim=0
)
kv_last_page_len = self.cuda_graph_kv_last_page_len[:bs]
max_q_len = 1

if _use_mla_ps_kernel:
num_kv_splits = self.max_split_per_batch

self.make_mla_meta_data(
qo_indptr,
kv_indptr,
kv_last_page_len,
self.work_metadata,
self.work_info_set,
self.work_indptr,
self.reduce_indptr,
self.reduce_final_map,
self.reduce_partial_map,
max_q_len,
fast_mode=fast_mode,
max_split_per_batch=num_kv_splits,
intra_batch_mode=intra_batch_mode,
)

work_metadata = self.work_metadata
work_info_set = self.work_info_set
work_indptr = self.work_indptr

reduce_indptr = self.reduce_indptr
reduce_final_map = self.reduce_final_map
reduce_partial_map = self.reduce_partial_map

self.forward_metadata = ForwardMetadata(
kv_indptr,
kv_indices,
qo_indptr,
kv_last_page_len,
max_q_len,
kv_indptr[-1].item(),
work_metadata=work_metadata,
work_info_set=work_info_set,
work_indptr=work_indptr,
reduce_indptr=reduce_indptr,
reduce_final_map=reduce_final_map,
reduce_partial_map=reduce_partial_map,
num_kv_splits=num_kv_splits,
# num_kv_splits_indptr=num_kv_splits_indptr,
)

elif forward_mode.is_target_verify():
bs = len(req_pool_indices)
qo_indptr = self.qo_indptr[: bs + 1]
Expand All @@ -1117,7 +1180,57 @@ def init_forward_metadata_replay_cuda_graph(
self.req_to_token.stride(0),
)

kv_last_page_len = self.cuda_graph_kv_last_page_len[:bs]
max_q_len = self.num_draft_tokens

# if self.kv_cache_dtype == fp8_dtype:
if _use_mla_ps_kernel:

num_kv_splits = self.max_split_per_batch

self.make_mla_meta_data(
qo_indptr,
kv_indptr,
kv_last_page_len,
self.work_metadata,
self.work_info_set,
self.work_indptr,
self.reduce_indptr,
self.reduce_final_map,
self.reduce_partial_map,
max_q_len,
fast_mode=fast_mode,
max_split_per_batch=num_kv_splits,
intra_batch_mode=intra_batch_mode,
)

work_metadata = self.work_metadata
work_info_set = self.work_info_set
work_indptr = self.work_indptr

reduce_indptr = self.reduce_indptr
reduce_final_map = self.reduce_final_map
reduce_partial_map = self.reduce_partial_map

self.forward_metadata = ForwardMetadata(
kv_indptr,
kv_indices,
qo_indptr,
kv_last_page_len,
max_q_len,
kv_indptr[-1].item(),
work_metadata=work_metadata,
work_info_set=work_info_set,
work_indptr=work_indptr,
reduce_indptr=reduce_indptr,
reduce_final_map=reduce_final_map,
reduce_partial_map=reduce_partial_map,
num_kv_splits=num_kv_splits,
# num_kv_splits_indptr=num_kv_splits_indptr,
)

elif forward_mode.is_draft_extend():
num_tokens_per_bs = self.speculative_num_steps + 1
seq_lens = seq_lens[:bs]
accept_lens = spec_info.accept_length[:bs]
qo_indptr = self.qo_indptr[: bs + 1]
Expand All @@ -1135,6 +1248,54 @@ def init_forward_metadata_replay_cuda_graph(
self.req_to_token.stride(0),
)

kv_last_page_len = self.cuda_graph_kv_last_page_len[:bs]
max_q_len = num_tokens_per_bs

if _use_mla_ps_kernel:

num_kv_splits = self.max_split_per_batch

self.make_mla_meta_data(
qo_indptr,
kv_indptr,
kv_last_page_len,
self.work_metadata,
self.work_info_set,
self.work_indptr,
self.reduce_indptr,
self.reduce_final_map,
self.reduce_partial_map,
max_q_len,
fast_mode=fast_mode,
max_split_per_batch=num_kv_splits,
intra_batch_mode=intra_batch_mode,
)

work_metadata = self.work_metadata
work_info_set = self.work_info_set
work_indptr = self.work_indptr

reduce_indptr = self.reduce_indptr
reduce_final_map = self.reduce_final_map
reduce_partial_map = self.reduce_partial_map

self.forward_metadata = ForwardMetadata(
kv_indptr,
kv_indices,
qo_indptr,
kv_last_page_len,
max_q_len,
kv_indptr[-1].item(),
work_metadata=work_metadata,
work_info_set=work_info_set,
work_indptr=work_indptr,
reduce_indptr=reduce_indptr,
reduce_final_map=reduce_final_map,
reduce_partial_map=reduce_partial_map,
num_kv_splits=num_kv_splits,
# num_kv_splits_indptr=num_kv_splits_indptr,
)

else:
raise ValueError("Invalid forward mode")

Expand Down Expand Up @@ -1366,23 +1527,6 @@ def forward_extend(

num_kv_splits = self.forward_metadata.num_kv_splits

if layer.layer_id == 0 and _use_mla_ps_kernel:
self.make_mla_meta_data(
self.forward_metadata.qo_indptr,
self.forward_metadata.kv_indptr,
self.forward_metadata.kv_last_page_len,
work_metadata,
work_info_set,
work_indptr,
reduce_indptr,
reduce_final_map,
reduce_partial_map,
self.forward_metadata.max_q_len,
fast_mode=fast_mode,
max_split_per_batch=num_kv_splits,
intra_batch_mode=intra_batch_mode,
)

mla_decode_fwd(
q,
K_Buffer.view(-1, 1, 1, layer.qk_head_dim),
Expand Down Expand Up @@ -1418,23 +1562,6 @@ def forward_extend(

num_kv_splits = self.forward_metadata.num_kv_splits

if layer.layer_id == 0 and _use_mla_ps_kernel:
self.make_mla_meta_data(
self.forward_metadata.qo_indptr,
self.forward_metadata.kv_indptr,
self.forward_metadata.kv_last_page_len,
work_metadata,
work_info_set,
work_indptr,
reduce_indptr,
reduce_final_map,
reduce_partial_map,
self.forward_metadata.max_q_len,
fast_mode=fast_mode,
max_split_per_batch=num_kv_splits,
intra_batch_mode=intra_batch_mode,
)

if self.forward_metadata.run_graph is not True:

bs, q_pad, q_mask = pad_sequence_with_mask(
Expand Down Expand Up @@ -1577,23 +1704,6 @@ def forward_decode(

num_kv_splits = self.forward_metadata.num_kv_splits

if layer.layer_id == 0 and _use_mla_ps_kernel:
self.make_mla_meta_data(
self.forward_metadata.qo_indptr,
self.forward_metadata.kv_indptr,
self.forward_metadata.kv_last_page_len,
work_metadata,
work_info_set,
work_indptr,
reduce_indptr,
reduce_final_map,
reduce_partial_map,
self.forward_metadata.max_q_len,
fast_mode=fast_mode,
max_split_per_batch=num_kv_splits,
intra_batch_mode=intra_batch_mode,
)

mla_decode_fwd(
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
k_buffer.view(-1, 1, 1, layer.qk_head_dim),
Expand Down
Loading
Loading