diff --git a/docs/advanced_features/server_arguments.md b/docs/advanced_features/server_arguments.md
index af02f824a52a..e9833a8a5020 100644
--- a/docs/advanced_features/server_arguments.md
+++ b/docs/advanced_features/server_arguments.md
@@ -311,7 +311,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s
| Argument | Description | Defaults | Options |
| --- | --- | --- | --- |
| `--expert-parallel-size`
`--ep-size`
`--ep` | The expert parallelism size. | `1` | Type: int |
-| `--moe-a2a-backend` | Select the backend for all-to-all communication for expert parallelism. | `none` | `none`, `deepep`, `mooncake`, `ascend_fuseep`|
+| `--moe-a2a-backend` | Select the backend for all-to-all communication for expert parallelism. | `none` | `none`, `deepep`, `mooncake`, `mori`, `ascend_fuseep`|
| `--moe-runner-backend` | Choose the runner backend for MoE. | `auto` | `auto`, `deep_gemm`, `triton`, `triton_kernel`, `flashinfer_trtllm`, `flashinfer_cutlass`, `flashinfer_mxfp4`, `flashinfer_cutedsl`, `cutlass` |
| `--flashinfer-mxfp4-moe-precision` | Choose the computation precision of flashinfer mxfp4 moe | `default` | `default`, `bf16` |
| `--enable-flashinfer-allreduce-fusion` | Enable FlashInfer allreduce fusion with Residual RMSNorm. | `False` | bool flag (set to enable) |
diff --git a/python/sglang/srt/batch_overlap/operations_strategy.py b/python/sglang/srt/batch_overlap/operations_strategy.py
index 41f40275eb73..d39ad838577d 100644
--- a/python/sglang/srt/batch_overlap/operations_strategy.py
+++ b/python/sglang/srt/batch_overlap/operations_strategy.py
@@ -7,6 +7,9 @@
from sglang.srt.batch_overlap.operations import Operation
from sglang.srt.layers.moe.token_dispatcher import DeepEPConfig
from sglang.srt.model_executor.forward_batch_info import ForwardMode
+from sglang.srt.utils import is_hip
+
+_is_hip = is_hip()
@dataclass
@@ -91,7 +94,9 @@ def _compute_moe_deepseek_layer_operations_strategy_tbo(
def _compute_moe_deepseek_blog_prefill(layer):
device_properties = torch.cuda.get_device_properties(device="cuda")
total_num_sms = device_properties.multi_processor_count
- deep_gemm_num_sms = total_num_sms - DeepEPConfig.get_instance().num_sms
+ deep_gemm_num_sms = None
+ if not _is_hip:
+ deep_gemm_num_sms = total_num_sms - DeepEPConfig.get_instance().num_sms
return OperationsStrategy(
deep_gemm_num_sms=deep_gemm_num_sms,
@@ -168,7 +173,9 @@ def _compute_moe_qwen3_layer_operations_strategy_tbo(
def _compute_moe_qwen3_prefill(layer):
device_properties = torch.cuda.get_device_properties(device="cuda")
total_num_sms = device_properties.multi_processor_count
- deep_gemm_num_sms = total_num_sms - DeepEPConfig.get_instance().num_sms
+ deep_gemm_num_sms = None
+ if not _is_hip:
+ deep_gemm_num_sms = total_num_sms - DeepEPConfig.get_instance().num_sms
return OperationsStrategy(
deep_gemm_num_sms=deep_gemm_num_sms,
diff --git a/python/sglang/srt/batch_overlap/two_batch_overlap.py b/python/sglang/srt/batch_overlap/two_batch_overlap.py
index cfd2a54ed132..e2840fee0dde 100644
--- a/python/sglang/srt/batch_overlap/two_batch_overlap.py
+++ b/python/sglang/srt/batch_overlap/two_batch_overlap.py
@@ -30,6 +30,7 @@
from sglang.srt.layers.moe.token_dispatcher import (
DeepEPDispatcher,
MooncakeEPDispatcher,
+ MoriEPDispatcher,
)
from sglang.srt.layers.moe.token_dispatcher.base import BaseDispatcher
from sglang.srt.managers.schedule_batch import ScheduleBatch
@@ -1027,6 +1028,10 @@ def __init__(self, **kwargs):
self._inners = [
MooncakeEPDispatcher(**kwargs) for _ in range(num_inner_dispatchers)
]
+ elif get_moe_a2a_backend().is_mori():
+ self._inners = [
+ MoriEPDispatcher(**kwargs) for _ in range(num_inner_dispatchers)
+ ]
def _execute(self, name, tbo_subbatch_index: Optional[int] = None, **kwargs):
return getattr(self._inners[tbo_subbatch_index or 0], name)(**kwargs)
diff --git a/python/sglang/srt/layers/attention/aiter_backend.py b/python/sglang/srt/layers/attention/aiter_backend.py
index d851040cfc37..0a2e579645b0 100644
--- a/python/sglang/srt/layers/attention/aiter_backend.py
+++ b/python/sglang/srt/layers/attention/aiter_backend.py
@@ -431,7 +431,7 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
# num_kv_splits_indptr = None
if forward_batch.forward_mode.is_decode_or_idle():
- if spec_info is None:
+ if spec_info is None or forward_batch.forward_mode.is_idle():
kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
kv_indptr = kv_indptr[: bs + 1]
kv_indices = torch.empty(
@@ -1074,6 +1074,17 @@ def init_forward_metadata_replay_cuda_graph(
seq_lens_cpu: Optional[torch.Tensor],
):
+ num_kv_splits = None
+ # num_kv_splits_indptr = None
+
+ work_metadata = None
+ work_info_set = None
+ work_indptr = None
+
+ reduce_indptr = None
+ reduce_final_map = None
+ reduce_partial_map = None
+
if forward_mode.is_decode_or_idle():
kv_indptr = self.kv_indptr
kv_indices = self.cuda_graph_kv_indices
@@ -1093,6 +1104,58 @@ def init_forward_metadata_replay_cuda_graph(
kv_indptr[: spec_info.kv_indptr.shape[0]] = spec_info.kv_indptr
kv_indices[: spec_info.kv_indices.shape[0]] = spec_info.kv_indices
+ if self.use_mla:
+ qo_indptr = self.qo_indptr_[: bs + 1]
+ qo_indptr[1 : bs + 1] = torch.cumsum(
+ self.cuda_graph_kv_last_page_len[:bs], dim=0
+ )
+ kv_last_page_len = self.cuda_graph_kv_last_page_len[:bs]
+ max_q_len = 1
+
+ if _use_mla_ps_kernel:
+ num_kv_splits = self.max_split_per_batch
+
+ self.make_mla_meta_data(
+ qo_indptr,
+ kv_indptr,
+ kv_last_page_len,
+ self.work_metadata,
+ self.work_info_set,
+ self.work_indptr,
+ self.reduce_indptr,
+ self.reduce_final_map,
+ self.reduce_partial_map,
+ max_q_len,
+ fast_mode=fast_mode,
+ max_split_per_batch=num_kv_splits,
+ intra_batch_mode=intra_batch_mode,
+ )
+
+ work_metadata = self.work_metadata
+ work_info_set = self.work_info_set
+ work_indptr = self.work_indptr
+
+ reduce_indptr = self.reduce_indptr
+ reduce_final_map = self.reduce_final_map
+ reduce_partial_map = self.reduce_partial_map
+
+ self.forward_metadata = ForwardMetadata(
+ kv_indptr,
+ kv_indices,
+ qo_indptr,
+ kv_last_page_len,
+ max_q_len,
+ kv_indptr[-1].item(),
+ work_metadata=work_metadata,
+ work_info_set=work_info_set,
+ work_indptr=work_indptr,
+ reduce_indptr=reduce_indptr,
+ reduce_final_map=reduce_final_map,
+ reduce_partial_map=reduce_partial_map,
+ num_kv_splits=num_kv_splits,
+ # num_kv_splits_indptr=num_kv_splits_indptr,
+ )
+
elif forward_mode.is_target_verify():
bs = len(req_pool_indices)
qo_indptr = self.qo_indptr[: bs + 1]
@@ -1117,7 +1180,57 @@ def init_forward_metadata_replay_cuda_graph(
self.req_to_token.stride(0),
)
+ kv_last_page_len = self.cuda_graph_kv_last_page_len[:bs]
+ max_q_len = self.num_draft_tokens
+
+ # if self.kv_cache_dtype == fp8_dtype:
+ if _use_mla_ps_kernel:
+
+ num_kv_splits = self.max_split_per_batch
+
+ self.make_mla_meta_data(
+ qo_indptr,
+ kv_indptr,
+ kv_last_page_len,
+ self.work_metadata,
+ self.work_info_set,
+ self.work_indptr,
+ self.reduce_indptr,
+ self.reduce_final_map,
+ self.reduce_partial_map,
+ max_q_len,
+ fast_mode=fast_mode,
+ max_split_per_batch=num_kv_splits,
+ intra_batch_mode=intra_batch_mode,
+ )
+
+ work_metadata = self.work_metadata
+ work_info_set = self.work_info_set
+ work_indptr = self.work_indptr
+
+ reduce_indptr = self.reduce_indptr
+ reduce_final_map = self.reduce_final_map
+ reduce_partial_map = self.reduce_partial_map
+
+ self.forward_metadata = ForwardMetadata(
+ kv_indptr,
+ kv_indices,
+ qo_indptr,
+ kv_last_page_len,
+ max_q_len,
+ kv_indptr[-1].item(),
+ work_metadata=work_metadata,
+ work_info_set=work_info_set,
+ work_indptr=work_indptr,
+ reduce_indptr=reduce_indptr,
+ reduce_final_map=reduce_final_map,
+ reduce_partial_map=reduce_partial_map,
+ num_kv_splits=num_kv_splits,
+ # num_kv_splits_indptr=num_kv_splits_indptr,
+ )
+
elif forward_mode.is_draft_extend():
+ num_tokens_per_bs = self.speculative_num_steps + 1
seq_lens = seq_lens[:bs]
accept_lens = spec_info.accept_length[:bs]
qo_indptr = self.qo_indptr[: bs + 1]
@@ -1135,6 +1248,54 @@ def init_forward_metadata_replay_cuda_graph(
self.req_to_token.stride(0),
)
+ kv_last_page_len = self.cuda_graph_kv_last_page_len[:bs]
+ max_q_len = num_tokens_per_bs
+
+ if _use_mla_ps_kernel:
+
+ num_kv_splits = self.max_split_per_batch
+
+ self.make_mla_meta_data(
+ qo_indptr,
+ kv_indptr,
+ kv_last_page_len,
+ self.work_metadata,
+ self.work_info_set,
+ self.work_indptr,
+ self.reduce_indptr,
+ self.reduce_final_map,
+ self.reduce_partial_map,
+ max_q_len,
+ fast_mode=fast_mode,
+ max_split_per_batch=num_kv_splits,
+ intra_batch_mode=intra_batch_mode,
+ )
+
+ work_metadata = self.work_metadata
+ work_info_set = self.work_info_set
+ work_indptr = self.work_indptr
+
+ reduce_indptr = self.reduce_indptr
+ reduce_final_map = self.reduce_final_map
+ reduce_partial_map = self.reduce_partial_map
+
+ self.forward_metadata = ForwardMetadata(
+ kv_indptr,
+ kv_indices,
+ qo_indptr,
+ kv_last_page_len,
+ max_q_len,
+ kv_indptr[-1].item(),
+ work_metadata=work_metadata,
+ work_info_set=work_info_set,
+ work_indptr=work_indptr,
+ reduce_indptr=reduce_indptr,
+ reduce_final_map=reduce_final_map,
+ reduce_partial_map=reduce_partial_map,
+ num_kv_splits=num_kv_splits,
+ # num_kv_splits_indptr=num_kv_splits_indptr,
+ )
+
else:
raise ValueError("Invalid forward mode")
@@ -1366,23 +1527,6 @@ def forward_extend(
num_kv_splits = self.forward_metadata.num_kv_splits
- if layer.layer_id == 0 and _use_mla_ps_kernel:
- self.make_mla_meta_data(
- self.forward_metadata.qo_indptr,
- self.forward_metadata.kv_indptr,
- self.forward_metadata.kv_last_page_len,
- work_metadata,
- work_info_set,
- work_indptr,
- reduce_indptr,
- reduce_final_map,
- reduce_partial_map,
- self.forward_metadata.max_q_len,
- fast_mode=fast_mode,
- max_split_per_batch=num_kv_splits,
- intra_batch_mode=intra_batch_mode,
- )
-
mla_decode_fwd(
q,
K_Buffer.view(-1, 1, 1, layer.qk_head_dim),
@@ -1418,23 +1562,6 @@ def forward_extend(
num_kv_splits = self.forward_metadata.num_kv_splits
- if layer.layer_id == 0 and _use_mla_ps_kernel:
- self.make_mla_meta_data(
- self.forward_metadata.qo_indptr,
- self.forward_metadata.kv_indptr,
- self.forward_metadata.kv_last_page_len,
- work_metadata,
- work_info_set,
- work_indptr,
- reduce_indptr,
- reduce_final_map,
- reduce_partial_map,
- self.forward_metadata.max_q_len,
- fast_mode=fast_mode,
- max_split_per_batch=num_kv_splits,
- intra_batch_mode=intra_batch_mode,
- )
-
if self.forward_metadata.run_graph is not True:
bs, q_pad, q_mask = pad_sequence_with_mask(
@@ -1577,23 +1704,6 @@ def forward_decode(
num_kv_splits = self.forward_metadata.num_kv_splits
- if layer.layer_id == 0 and _use_mla_ps_kernel:
- self.make_mla_meta_data(
- self.forward_metadata.qo_indptr,
- self.forward_metadata.kv_indptr,
- self.forward_metadata.kv_last_page_len,
- work_metadata,
- work_info_set,
- work_indptr,
- reduce_indptr,
- reduce_final_map,
- reduce_partial_map,
- self.forward_metadata.max_q_len,
- fast_mode=fast_mode,
- max_split_per_batch=num_kv_splits,
- intra_batch_mode=intra_batch_mode,
- )
-
mla_decode_fwd(
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
k_buffer.view(-1, 1, 1, layer.qk_head_dim),
diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py
index ebcc696ecf0d..da41c667017e 100644
--- a/python/sglang/srt/layers/moe/ep_moe/layer.py
+++ b/python/sglang/srt/layers/moe/ep_moe/layer.py
@@ -24,7 +24,10 @@
DeepEPLLCombineInput,
DeepEPNormalCombineInput,
)
-from sglang.srt.layers.moe.token_dispatcher.moriep import MoriEPNormalCombineInput
+from sglang.srt.layers.moe.token_dispatcher.moriep import (
+ MoriEPLLCombineInput,
+ MoriEPNormalCombineInput,
+)
from sglang.srt.layers.moe.topk import TopKOutput, TopKOutputChecker
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.compressed_tensors.schemes import (
@@ -130,13 +133,14 @@ def __init__(
if (
self.deepep_mode.enable_low_latency()
and not _is_npu
+ and not _is_hip
and not (
get_moe_runner_backend().is_flashinfer_cutedsl()
and self.quant_config.get_name() == "modelopt_fp4"
)
):
- # NPU supports low_latency deepep without deepgemm
- # FP4 quantization with flashinfer_cutedsl also supports low_latency deepep without deepgemm
+ # AMD HIP, NPU supports low_latency deepep without deepgemm
+ # NV FP4 quantization with flashinfer_cutedsl also supports low_latency deepep without deepgemm
assert (
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
), f"DeepEP {self.deepep_mode} mode requires deep_gemm"
@@ -246,6 +250,7 @@ def run_moe_core(
if DispatchOutputChecker.format_is_deepep_normal(dispatch_output)
else DeepEPLLCombineInput
)
+
return combine_input_wrapper(
hidden_states=output,
topk_ids=dispatch_output.topk_ids,
@@ -275,8 +280,10 @@ def forward_aiter(
dispatch_output.topk_ids,
dispatch_output.topk_weights,
)
+
if hidden_states.shape[0] == 0:
return hidden_states
+
# in original deepep, idx == -1 meaning invalid and will not be processed.
# aiter does not accept -1, we use a expert mask to make these idx invalid
# (idx == num_local_experts) meaning not used in aiter fused_moe
@@ -592,28 +599,46 @@ def forward(
self,
hidden_states: torch.Tensor,
topk_output: TopKOutput,
- forward_shared_experts=None,
- alt_stream=None,
- disable_sbo=False,
):
num_token = hidden_states.shape[0]
- output_dtype = hidden_states.dtype
+ dispatch_output = self.dispatcher.dispatch(
+ hidden_states=hidden_states, topk_output=topk_output
+ )
+ combine_input = self.run_moe_core(dispatch_output)
+ hidden_states = self.dispatcher.combine(
+ combine_input=combine_input,
+ )
+
+ return hidden_states[:num_token]
+
+ def run_moe_core(
+ self,
+ dispatch_output: DispatchOutput,
+ ):
+ # TODO(billishyahao): check aiter path
+ # billishyahao: for now, fused_moe only support torch.bfloat16
+ output_dtype = torch.bfloat16
scale = None
is_fp8_quant = isinstance(self.quant_method, Fp8MoEMethod)
is_quark_w4a4 = isinstance(self.scheme, QuarkW4A4MXFp4MoE)
- # dispatch
- dispatch_output = self.dispatcher.dispatch(
- hidden_states, topk_output
- ) # , scale=scale)
-
(
dispatch_a1,
dispatch_scale,
dispatch_ids,
dispatch_weights,
dispatch_recv_token_num,
- ) = dispatch_output
+ origin_topk_ids,
+ origin_topk_weights,
+ ) = (
+ dispatch_output.hidden_states,
+ dispatch_output.hidden_states_scale,
+ dispatch_output.topk_ids,
+ dispatch_output.topk_weights,
+ dispatch_output.num_recv_tokens_per_expert,
+ dispatch_output.origin_topk_ids,
+ dispatch_output.origin_topk_weights,
+ )
w13_weight = self.w13_weight
w2_weight = self.w2_weight
@@ -670,17 +695,19 @@ def forward(
dtype=output_dtype,
)
- combine_input_wrapper = MoriEPNormalCombineInput
- combine_input = combine_input_wrapper(
- hidden_states=hidden_states,
- topk_ids=topk_output.topk_ids,
- topk_weights=topk_output.topk_weights,
- )
+ from sglang.srt.layers.moe.token_dispatcher import DispatchOutputChecker
- # combine
- result = self.dispatcher.combine(combine_input)
+ combine_input_wrapper = (
+ MoriEPNormalCombineInput
+ if DispatchOutputChecker.format_is_deepep_normal(dispatch_output)
+ else MoriEPLLCombineInput
+ )
- return result[:num_token]
+ return combine_input_wrapper(
+ hidden_states=hidden_states,
+ topk_ids=dispatch_output.origin_topk_ids,
+ topk_weights=dispatch_output.origin_topk_weights,
+ )
def get_moe_impl_class(quant_config: Optional[QuantizationConfig]):
diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py
index de8a07ab3000..85b52d4caf32 100644
--- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py
+++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py
@@ -96,9 +96,13 @@ def create_moe_dispatcher(moe_runner_config: MoeRunnerConfig) -> BaseDispatcher:
a2a_backend = get_moe_a2a_backend()
if a2a_backend.is_none():
return StandardDispatcher(moe_runner_config)
- elif a2a_backend.is_deepep() or a2a_backend.is_mooncake():
+ elif a2a_backend.is_deepep() or a2a_backend.is_mooncake() or a2a_backend.is_mori():
return MaybeTboDeepEPDispatcher(
- group=get_tp_group().device_group,
+ group=(
+ get_tp_group().device_group
+ if not a2a_backend.is_mori()
+ else get_tp_group()
+ ),
router_topk=moe_runner_config.top_k,
permute_fusion=True,
num_experts=moe_runner_config.num_experts,
@@ -121,19 +125,7 @@ def create_moe_dispatcher(moe_runner_config: MoeRunnerConfig) -> BaseDispatcher:
hidden_size=moe_runner_config.hidden_size,
params_dtype=moe_runner_config.params_dtype,
)
- elif a2a_backend.is_mori():
- from sglang.srt.layers.moe.token_dispatcher import MoriEPDispatcher
- return MoriEPDispatcher(
- group=get_tp_group(),
- router_topk=moe_runner_config.top_k,
- permute_fusion=True,
- num_experts=moe_runner_config.num_experts,
- num_local_experts=moe_runner_config.num_local_experts,
- hidden_size=moe_runner_config.hidden_size,
- params_dtype=moe_runner_config.params_dtype,
- deepep_mode=get_deepep_mode(),
- )
elif a2a_backend.is_flashinfer():
return FlashinferDispatcher(
group=get_tp_group().device_group,
diff --git a/python/sglang/srt/layers/moe/token_dispatcher/__init__.py b/python/sglang/srt/layers/moe/token_dispatcher/__init__.py
index 209570073b04..dd40a8d98548 100644
--- a/python/sglang/srt/layers/moe/token_dispatcher/__init__.py
+++ b/python/sglang/srt/layers/moe/token_dispatcher/__init__.py
@@ -28,6 +28,8 @@
)
from sglang.srt.layers.moe.token_dispatcher.moriep import (
MoriEPDispatcher,
+ MoriEPLLCombineInput,
+ MoriEPLLDispatchOutput,
MoriEPNormalCombineInput,
MoriEPNormalDispatchOutput,
)
@@ -53,6 +55,8 @@
"MooncakeEPDispatcher",
"MoriEPNormalDispatchOutput",
"MoriEPNormalCombineInput",
+ "MoriEPLLDispatchOutput",
+ "MoriEPLLCombineInput",
"MoriEPDispatcher",
"StandardDispatcher",
"StandardDispatchOutput",
diff --git a/python/sglang/srt/layers/moe/token_dispatcher/moriep.py b/python/sglang/srt/layers/moe/token_dispatcher/moriep.py
index 6ee2443b4058..dfb6e38368a5 100644
--- a/python/sglang/srt/layers/moe/token_dispatcher/moriep.py
+++ b/python/sglang/srt/layers/moe/token_dispatcher/moriep.py
@@ -12,6 +12,7 @@
DispatchOutput,
DispatchOutputFormat,
)
+from sglang.srt.layers.moe.token_dispatcher.deepep import DeepEPPDispatchHooks
from sglang.srt.layers.moe.topk import TopKOutput
from sglang.srt.layers.moe.utils import DeepEPMode
from sglang.srt.utils import get_bool_env_var, get_int_env_var, is_hip
@@ -40,21 +41,47 @@
logger = logging.getLogger(__name__)
+class MoriEPPDispatchHooks(DeepEPPDispatchHooks):
+
+ def __call__(self, dispatcher: BaseDispatcher):
+ for hook_fun in self.hook_dict.values():
+ hook_fun(dispatcher)
+
+
class MoriEPNormalDispatchOutput(NamedTuple):
- """Mori EP dispatch output."""
+ """Mori EP normal dispatch output."""
hidden_states: torch.Tensor
hidden_states_scale: Optional[torch.Tensor]
topk_ids: torch.Tensor
topk_weights: torch.Tensor
num_recv_tokens_per_expert: List[int]
+ origin_topk_ids: torch.Tensor
+ origin_topk_weights: torch.Tensor
@property
def format(self) -> DispatchOutputFormat:
return DispatchOutputFormat.DEEPEP_NORMAL
+class MoriEPLLDispatchOutput(NamedTuple):
+ """Mori EP low latency dispatch output."""
+
+ hidden_states: torch.Tensor
+ hidden_states_scale: Optional[torch.Tensor]
+ topk_ids: torch.Tensor
+ topk_weights: torch.Tensor
+ num_recv_tokens_per_expert: List[int]
+ origin_topk_ids: torch.Tensor
+ origin_topk_weights: torch.Tensor
+
+ @property
+ def format(self) -> DispatchOutputFormat:
+ return DispatchOutputFormat.DEEPEP_LL
+
+
assert isinstance(MoriEPNormalDispatchOutput, DispatchOutput)
+assert isinstance(MoriEPLLDispatchOutput, DispatchOutput)
class MoriEPNormalCombineInput(NamedTuple):
@@ -69,12 +96,26 @@ def format(self) -> CombineInputFormat:
return CombineInputFormat.DEEPEP_NORMAL
+class MoriEPLLCombineInput(NamedTuple):
+ """Mori EP combine input."""
+
+ hidden_states: torch.Tensor
+ topk_ids: torch.Tensor
+ topk_weights: torch.Tensor
+
+ @property
+ def format(self) -> CombineInputFormat:
+ return CombineInputFormat.DEEPEP_LL
+
+
assert isinstance(MoriEPNormalCombineInput, CombineInput)
+assert isinstance(MoriEPLLCombineInput, CombineInput)
class EpMode(Enum):
INTRA_NODE = "intra_node"
INTER_NODE = "inter_node"
+ LOW_LATENCY = "low_latency"
@dataclass(frozen=True)
@@ -113,12 +154,19 @@ def get_ep_dispatch_configs(num_max_dispatch_tokens_per_rank: int = 4096):
block_num=64,
rdma_block_num=32,
),
+ # TODO(billishyahao): may need to set different configs for intra node async
+ EpMode.LOW_LATENCY: EpDispatchConfig(
+ kernel_type=mori.ops.EpDispatchCombineKernelType.AsyncLL,
+ warp_num_per_block=8,
+ block_num=64,
+ rdma_block_num=32,
+ ),
}
# init_mori_op only needs do once in model initial stage
# use lru_cache to reuse the same mori_op instance to avoid the init overhead for mori
-@lru_cache(maxsize=1)
+@lru_cache(maxsize=2)
def init_mori_op(
group,
router_topk,
@@ -137,11 +185,16 @@ def init_mori_op(
cpu_group = group.cpu_group
torch._C._distributed_c10d._register_process_group("mori", cpu_group)
mori.shmem.shmem_torch_process_group_init("mori")
+
+ mode = EpMode.INTRA_NODE if world_size <= 8 else EpMode.INTER_NODE
+ async_mode = get_bool_env_var("SGLANG_MORI_ASYNC_MODE", "false")
+ if async_mode:
+ mode = EpMode.LOW_LATENCY
+
logger.info(
- f"[MORI init] {world_size=} {rank=} {hidden_size=} {params_dtype=} {num_max_dispatch_tokens_per_rank=} {num_local_experts=} {router_topk=}"
+ f"[MORI init] {world_size=} {rank=} {hidden_size=} {params_dtype=} {num_max_dispatch_tokens_per_rank=} {num_local_experts=} {router_topk=} {mode=}"
)
- mode = EpMode.INTRA_NODE if world_size <= 8 else EpMode.INTER_NODE
cfg = get_ep_dispatch_configs(num_max_dispatch_tokens_per_rank)[mode]
kernel_type = cfg.kernel_type
@@ -174,6 +227,28 @@ def init_mori_op(
return mori_op
+class CommStreamPool:
+ _streams = {} # key -> torch.cuda.Stream
+
+ @classmethod
+ def _make_key(cls, group):
+ return (torch.cuda.current_device(), id(group))
+
+ @classmethod
+ def getStreamFromPool(cls, group) -> torch.cuda.Stream:
+ key = cls._make_key(group)
+ stream = cls._streams.get(key)
+ if stream is None:
+ stream = torch.cuda.Stream(priority=0)
+ cls._streams[key] = stream
+ return stream
+
+ @classmethod
+ def clear_group(cls, group):
+ key = (torch.cuda.current_device(), id(group))
+ cls._streams.pop(key, None)
+
+
class _MoriEPDispatcherImplBase:
def __init__(
self,
@@ -184,7 +259,6 @@ def __init__(
num_local_experts: int,
hidden_size: int,
params_dtype: torch.dtype,
- return_recv_hook: bool,
deepep_mode: DeepEPMode,
):
try:
@@ -198,7 +272,6 @@ def __init__(
self.num_local_experts = num_local_experts
self.hidden_size = hidden_size
self.params_dtype = params_dtype
- self.return_recv_hook = return_recv_hook
self.deepep_mode = deepep_mode
self.num_max_dispatch_tokens_per_rank = get_int_env_var(
@@ -215,6 +288,11 @@ def __init__(
num_max_dispatch_tokens_per_rank=self.num_max_dispatch_tokens_per_rank,
)
+ self.quant_config: Optional[dict] = None
+
+ self.overlap_args: Optional[CombineOverlapArgs] = None
+ self.meta_overlap_args: Optional[dict] = None
+
def dispatch_a(
self,
hidden_states: torch.Tensor,
@@ -230,23 +308,46 @@ def combine_a(
hidden_states: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
- overlap_args: Optional[CombineOverlapArgs] = None,
):
raise NotImplementedError
def combine_b(self, *args, **kwargs):
raise NotImplementedError
- def _get_buffer(self):
- raise NotImplementedError
+ def set_quant_config(self, quant_config: dict) -> None:
+ self.quant_config = quant_config
+
+ def set_overlap_args(
+ self, combine_overlap_args: CombineOverlapArgs, meta_overlap_args: dict
+ ) -> None:
+ self.overlap_args = combine_overlap_args
+ self.meta_overlap_args = meta_overlap_args
+
+ def clear_overlap_args(self) -> None:
+ self.overlap_args = None
+ self.meta_overlap_args = None
class _MoriEPDispatcherImplNormal(_MoriEPDispatcherImplBase):
- def __init__(self, **kwargs):
+ def __init__(self, async_finish: bool, **kwargs):
super().__init__(**kwargs)
+
+ self.async_finish = async_finish
self.quant_config = {}
# [kk TODO] need to support mxfp4 type
self.quant_func = get_hip_quant(QuantType.per_1x128)
+ self.enable_dual_stream = get_bool_env_var("SGLANG_MORI_DUAL_STREAM", "false")
+ self._comm_stream = None
+ if self.enable_dual_stream:
+ self._comm_stream = CommStreamPool.getStreamFromPool(self.group)
+
+ def _capture_event_if_async(self) -> Optional[torch.cuda.Event]:
+ assert self.enable_dual_stream, "dual stream must be enabled"
+ if not self.async_finish:
+ return None
+ ev = torch.cuda.Event(blocking=False, interprocess=False)
+ ev.record(torch.cuda.current_stream())
+ return ev
def dispatch_a(
self,
@@ -255,17 +356,16 @@ def dispatch_a(
):
topk_weights, topk_ids = topk_output.topk_weights, topk_output.topk_ids
- return (
- hidden_states,
- topk_weights,
- topk_ids,
- )
+ previous_event = self._capture_event_if_async() if self._comm_stream else None
+
+ return (hidden_states, topk_weights, topk_ids, previous_event)
def dispatch_b(
self,
hidden_states,
topk_weights,
topk_ids,
+ previous_event,
):
num_token = hidden_states.shape[0]
scale = None
@@ -295,14 +395,254 @@ def dispatch_b(
recv_scales,
recv_topk_ids,
packed_recv_count,
- ) = self._dispatch_core(hidden_states, topk_weights, topk_ids, scale)
+ done_event,
+ ) = self._dispatch_core(
+ hidden_states,
+ topk_weights,
+ topk_ids,
+ scale=scale,
+ previous_event=previous_event,
+ )
+
+ if self._comm_stream and self.async_finish and done_event is not None:
+ torch.cuda.current_stream().wait_event(done_event)
return MoriEPNormalDispatchOutput(
+ hidden_states=packed_recv_hidden,
+ hidden_states_scale=recv_scales,
+ topk_ids=recv_topk_ids,
+ topk_weights=recv_topk_weights,
+ num_recv_tokens_per_expert=packed_recv_count,
+ origin_topk_ids=topk_ids,
+ origin_topk_weights=topk_weights,
+ )
+
+ def _dispatch_core(
+ self,
+ hidden_states: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ scale: Optional[torch.Tensor] = None,
+ previous_event: Optional[torch.cuda.Event] = None,
+ ):
+ done_event: Optional[torch.cuda.Event] = None
+
+ if self._comm_stream:
+ compute_stream = torch.cuda.current_stream()
+ comm_stream = self._comm_stream # comm stream
+
+ for t in (hidden_states, topk_weights, topk_ids):
+ t.record_stream(comm_stream)
+ if scale is not None:
+ scale.record_stream(comm_stream)
+
+ with torch.cuda.stream(comm_stream):
+ # if (previous_event) stream_wait(comm_stream, previous_event)
+ # else stream_wait(comm_stream, compute_stream)
+
+ if previous_event is not None:
+ comm_stream.wait_event(previous_event)
+ else:
+ comm_stream.wait_stream(compute_stream)
+
+ (
+ packed_recv_hidden,
+ recv_topk_weights,
+ recv_scales,
+ recv_topk_ids,
+ packed_recv_count,
+ ) = self.mori_op.dispatch(hidden_states, topk_weights, scale, topk_ids)
+
+ if self.async_finish:
+ done_event = torch.cuda.Event(blocking=False, interprocess=False)
+ done_event.record(comm_stream)
+ else:
+ compute_stream.wait_stream(comm_stream)
+
+ for t in (
+ packed_recv_hidden,
+ recv_topk_weights,
+ recv_scales,
+ recv_topk_ids,
+ ):
+ if t is not None:
+ t.record_stream(comm_stream)
+ else:
+
+ (
+ packed_recv_hidden,
+ recv_topk_weights,
+ recv_scales,
+ recv_topk_ids,
+ packed_recv_count,
+ ) = self.mori_op.dispatch(hidden_states, topk_weights, scale, topk_ids)
+
+ # TODO(billishyahao): EPLB
+ # get_global_expert_distribution_recorder().on_deepep_dispatch_normal(
+
+ return (
packed_recv_hidden,
+ recv_topk_weights,
recv_scales,
recv_topk_ids,
+ packed_recv_count,
+ done_event,
+ )
+
+ def combine_a(
+ self,
+ hidden_states: torch.Tensor,
+ topk_ids: torch.Tensor,
+ topk_weights: torch.Tensor,
+ ):
+ previous_event = self._capture_event_if_async() if self._comm_stream else None
+ return hidden_states, topk_ids, topk_weights, previous_event
+
+ def combine_b(self, hidden_states, topk_ids, topk_weights, previous_event):
+
+ hidden_states, done_event = self._combine_core(
+ hidden_states, topk_ids, topk_weights, previous_event
+ )
+
+ if self._comm_stream and self.async_finish and done_event is not None:
+ torch.cuda.current_stream().wait_event(done_event)
+
+ return hidden_states
+
+ def _combine_core(
+ self,
+ hidden_states: torch.Tensor,
+ topk_ids: torch.Tensor,
+ topk_weights: torch.Tensor,
+ previous_event: Optional[torch.cuda.Event],
+ ):
+ done_event: Optional[torch.cuda.Event] = None
+
+ if self._comm_stream:
+ compute_stream = torch.cuda.current_stream()
+ comm_stream = self._comm_stream
+
+ for t in (hidden_states, topk_ids, topk_weights):
+ t.record_stream(comm_stream)
+
+ with torch.cuda.stream(comm_stream):
+ if previous_event is not None:
+ comm_stream.wait_event(previous_event)
+ else:
+ comm_stream.wait_stream(compute_stream)
+
+ combined_hidden_states = self.mori_op.combine(
+ hidden_states, None, topk_ids
+ )[0]
+
+ if self.async_finish:
+ done_event = torch.cuda.Event(blocking=False, interprocess=False)
+ done_event.record(comm_stream)
+ else:
+ compute_stream.wait_stream(comm_stream)
+
+ combined_hidden_states.record_stream(comm_stream)
+
+ else:
+ combined_hidden_states = self.mori_op.combine(
+ hidden_states, None, topk_ids
+ )[0]
+
+ return combined_hidden_states, done_event
+
+ def set_quant_config(self, quant_config: dict):
+ self.quant_config = quant_config
+
+
+class _MoriEPDispatcherImplLowLatency(_MoriEPDispatcherImplBase):
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ self.quant_config = {}
+ self.quant_func = get_hip_quant(QuantType.per_1x128)
+
+ def dispatch_a(
+ self,
+ hidden_states: torch.Tensor,
+ topk_output: TopKOutput,
+ ):
+ import mori
+
+ assert (
+ self.mori_op.config.kernel_type
+ is mori.ops.EpDispatchCombineKernelType.AsyncLL
+ ), "mori asyncll mismatch"
+
+ num_tokens = hidden_states.shape[0]
+ scale = None
+
+ fp8_dispatch = get_bool_env_var("SGLANG_MORI_FP8_DISP", "False")
+
+ if fp8_dispatch:
+ # FP8 quant
+ if num_tokens > 0:
+ # NOTE: aiter is able to handle token=0 case in UT. But for some reason it failed at e2e case. Root cause TBD.
+ hidden_states, scale = self.quant_func(
+ hidden_states, quant_dtype=fp8_dtype
+ )
+ else:
+ hidden_states = torch.empty(
+ hidden_states.shape, dtype=fp8_dtype, device=hidden_states.device
+ )
+ scale = torch.empty(
+ (0, self.hidden_size // 128),
+ dtype=torch.float32,
+ device=hidden_states.device,
+ )
+
+ topk_weights, topk_ids = topk_output.topk_weights, topk_output.topk_ids
+
+ (
+ packed_recv_hidden,
recv_topk_weights,
+ recv_scales,
+ recv_topk_ids,
packed_recv_count,
+ ) = self._dispatch_core(hidden_states, topk_weights, topk_ids, scale=scale)
+
+ return (
+ packed_recv_hidden,
+ recv_topk_weights,
+ recv_topk_ids,
+ recv_scales,
+ packed_recv_count,
+ topk_weights,
+ topk_ids,
+ )
+
+ def dispatch_b(
+ self,
+ hidden_states,
+ recv_topk_weights,
+ recv_topk_ids,
+ recv_scales,
+ packed_recv_count,
+ topk_weights,
+ topk_ids,
+ ):
+
+ ##TODO(billishyahao): add assertion here to check async
+ import mori
+
+ assert (
+ self.mori_op.config.kernel_type
+ is mori.ops.EpDispatchCombineKernelType.AsyncLL
+ ), "mori asyncll mismatch"
+
+ self.mori_op.dispatch_recv()
+
+ return MoriEPLLDispatchOutput(
+ hidden_states=hidden_states,
+ hidden_states_scale=recv_scales,
+ topk_ids=recv_topk_ids,
+ topk_weights=recv_topk_weights,
+ num_recv_tokens_per_expert=packed_recv_count,
+ origin_topk_ids=topk_ids,
+ origin_topk_weights=topk_weights,
)
def _dispatch_core(
@@ -312,16 +652,15 @@ def _dispatch_core(
topk_ids: torch.Tensor,
scale: Optional[torch.Tensor] = None,
):
+ ##TODO(billishyahao): add assertion here to check async
+
(
packed_recv_hidden,
recv_topk_weights,
recv_scales,
recv_topk_ids,
packed_recv_count,
- ) = self.mori_op.dispatch(hidden_states, topk_weights, scale, topk_ids)
-
- # TODO(billishyahao): EPLB
- # get_global_expert_distribution_recorder().on_deepep_dispatch_normal(
+ ) = self.mori_op.dispatch_send(hidden_states, topk_weights, scale, topk_ids)
return (
packed_recv_hidden,
@@ -338,21 +677,32 @@ def combine_a(
topk_weights: torch.Tensor,
overlap_args: Optional[CombineOverlapArgs] = None,
):
- previous_event = None
- return hidden_states, topk_ids, topk_weights, previous_event
+ hidden_states = self._combine_core(
+ hidden_states,
+ topk_ids,
+ topk_weights,
+ overlap_args=overlap_args,
+ )
+ return hidden_states, topk_ids, topk_weights, overlap_args
def combine_b(self, hidden_states, topk_ids, topk_weights, previous_event):
- hidden_states = self._combine_core(hidden_states, topk_ids, topk_weights)
- return hidden_states
+
+ self.mori_op.combine_recv()
+
+ return hidden_states[0]
def _combine_core(
self,
hidden_states: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
+ overlap_args: Optional[CombineOverlapArgs] = None,
):
- combined_hidden_states = self.mori_op.combine(hidden_states, None, topk_ids)
- return combined_hidden_states[0]
+ combined_hidden_states = self.mori_op.combine_send(
+ hidden_states, None, topk_ids
+ )
+
+ return combined_hidden_states
def set_quant_config(self, quant_config: dict):
self.quant_config = quant_config
@@ -380,27 +730,43 @@ def __init__(
async_finish: bool = False,
return_recv_hook: bool = False,
):
+ super().__init__()
+
self.deepep_mode = deepep_mode
+ common_kwargs = dict(
+ group=group,
+ router_topk=router_topk,
+ permute_fusion=permute_fusion,
+ num_experts=num_experts,
+ num_local_experts=num_local_experts,
+ hidden_size=hidden_size,
+ params_dtype=params_dtype,
+ deepep_mode=deepep_mode,
+ )
+
+ if self.deepep_mode.enable_low_latency():
+ self._low_latency_dispatcher = _MoriEPDispatcherImplLowLatency(
+ **common_kwargs,
+ )
+
if self.deepep_mode.enable_normal():
self._normal_dispatcher = _MoriEPDispatcherImplNormal(
- group=group,
- router_topk=router_topk,
- permute_fusion=permute_fusion,
- num_experts=num_experts,
- num_local_experts=num_local_experts,
- hidden_size=hidden_size,
- params_dtype=params_dtype,
- return_recv_hook=return_recv_hook,
- deepep_mode=deepep_mode,
+ async_finish=async_finish,
+ **common_kwargs,
)
- if self.deepep_mode.enable_low_latency():
- raise NotImplementedError
self._stage = _Stage.INITIAL
+ self._deepep_dispatch_hooks = MoriEPPDispatchHooks()
- def dispatch(self, *args, **kwargs) -> DispatchOutput:
- self.dispatch_a(*args, **kwargs)
+ def dispatch(
+ self,
+ hidden_states: torch.Tensor,
+ topk_output: TopKOutput,
+ ) -> DispatchOutput:
+ self.dispatch_a(hidden_states, topk_output)
+ if self._deepep_dispatch_hooks is not None:
+ self._deepep_dispatch_hooks(self)
ret = self.dispatch_b()
return ret
@@ -425,16 +791,14 @@ def dispatch_b(self):
def combine(
self,
combine_input: CombineInput,
- overlap_args: Optional[CombineOverlapArgs] = None,
) -> Tuple:
- self.combine_a(combine_input, overlap_args)
+ self.combine_a(combine_input)
ret = self.combine_b()
return ret
def combine_a(
self,
combine_input: CombineInput,
- overlap_args: Optional[CombineOverlapArgs] = None,
):
hidden_states, topk_ids, topk_weights = combine_input
self._update_stage(_Stage.AFTER_DISPATCH_B, _Stage.AFTER_COMBINE_A)
@@ -442,7 +806,6 @@ def combine_a(
hidden_states=hidden_states,
topk_ids=topk_ids,
topk_weights=topk_weights,
- overlap_args=overlap_args,
)
self._combine_intermediate_state = inner_state
@@ -458,7 +821,7 @@ def _get_impl(self) -> _MoriEPDispatcherImplBase:
if resolved_deepep_mode == DeepEPMode.NORMAL:
return self._normal_dispatcher
elif resolved_deepep_mode == DeepEPMode.LOW_LATENCY:
- raise NotImplementedError
+ return self._low_latency_dispatcher
else:
raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}")
@@ -467,7 +830,31 @@ def _update_stage(self, old_stage, new_stage):
self._stage = new_stage
def set_quant_config(self, quant_config: dict):
+ super().set_quant_config(quant_config)
if self.deepep_mode.enable_low_latency():
- raise NotImplementedError
+ self._low_latency_dispatcher.set_quant_config(quant_config)
if self.deepep_mode.enable_normal():
self._normal_dispatcher.set_quant_config(quant_config)
+
+ def set_overlap_args(
+ self, combine_overlap_args: CombineOverlapArgs, meta_overlap_args: dict
+ ):
+ super().set_overlap_args(combine_overlap_args, meta_overlap_args)
+ if self.deepep_mode.enable_low_latency():
+ self._low_latency_dispatcher.set_overlap_args(
+ combine_overlap_args, meta_overlap_args
+ )
+ if self.deepep_mode.enable_normal():
+ self._normal_dispatcher.set_overlap_args(
+ combine_overlap_args, meta_overlap_args
+ )
+
+ def clear_overlap_args(self):
+ super().clear_overlap_args()
+ if self.deepep_mode.enable_low_latency():
+ self._low_latency_dispatcher.clear_overlap_args()
+ if self.deepep_mode.enable_normal():
+ self._normal_dispatcher.clear_overlap_args()
+
+ def register_deepep_dispatch_hook(self, hook):
+ return self._deepep_dispatch_hooks.register_hook(hook)
diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py
index 1583dd78804f..d88c6ecb5207 100644
--- a/python/sglang/srt/models/deepseek_v2.py
+++ b/python/sglang/srt/models/deepseek_v2.py
@@ -951,6 +951,7 @@ def _post_combine_hook(
and self.alt_stream is not None
):
torch.cuda.current_stream().wait_event(shared_event)
+
if shared_output is not None:
x = shared_output
# aiter moe call will handle routed_scaling_factor in the function
@@ -1054,11 +1055,18 @@ def op_combine_b(self, state):
def op_output(self, state):
final_hidden_states = state.pop("hidden_states_after_combine")
+ if get_moe_a2a_backend().is_mori():
+ num_tokens = state.pop("num_tokens")
+ final_hidden_states = final_hidden_states[:num_tokens]
+
if (shared_output := state.pop("shared_output")) is not None:
x = shared_output
- x.add_(final_hidden_states, alpha=self.routed_scaling_factor)
+ if _use_aiter:
+ x.add_(final_hidden_states)
+ else:
+ x.add_(final_hidden_states, alpha=self.routed_scaling_factor)
final_hidden_states = x
- else:
+ elif not _use_aiter:
final_hidden_states *= self.routed_scaling_factor
state.hidden_states_mlp_output = final_hidden_states
@@ -2449,6 +2457,7 @@ def op_comm_prepare_attn(
state.hidden_states_after_comm_pre_attn, state.residual_after_input_ln = (
self.layer_communicator.prepare_attn(hidden_states, residual, forward_batch)
)
+ state.num_tokens = hidden_states.shape[0]
state.update(
dict(
forward_batch=forward_batch,
diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py
index 88cbcf1d5ab0..be89a0fe70de 100644
--- a/python/sglang/srt/server_args.py
+++ b/python/sglang/srt/server_args.py
@@ -2216,15 +2216,17 @@ def _handle_a2a_moe(self):
if self.moe_a2a_backend == "mori":
self.ep_size = self.tp_size
- self.deepep_mode = "normal"
- logger.warning("auto set deepep_mode=`normal` for MORI EP")
logger.warning(
f"MoRI MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
)
- assert (self.chunked_prefill_size) <= get_int_env_var(
- "SGLANG_MORI_NUM_MAX_DISPATCH_TOKENS_PER_RANK", 4096
- ), "SGLANG_MORI_NUM_MAX_DISPATCH_TOKENS_PER_RANK (default 4096) must be larger or equal to chunked_prefill_size"
+ # Check chunked prefill for mori
+ # Skip validation if chunked prefill is disabled (i.e., size <= 0).
+ # Skip validation if disaggregation mode is decode.
+ if self.chunked_prefill_size > 0 and self.disaggregation_mode != "decode":
+ assert (self.chunked_prefill_size) <= get_int_env_var(
+ "SGLANG_MORI_NUM_MAX_DISPATCH_TOKENS_PER_RANK", 4096
+ ), "SGLANG_MORI_NUM_MAX_DISPATCH_TOKENS_PER_RANK (default 4096) must be larger or equal to chunked_prefill_size"
def _handle_eplb_and_dispatch(self):
if self.enable_eplb and (self.expert_distribution_recorder_mode is None):
diff --git a/python/sglang/test/bench_one_batch_server_internal.py b/python/sglang/test/bench_one_batch_server_internal.py
index 7f2f62065532..e2f3ed09b48a 100644
--- a/python/sglang/test/bench_one_batch_server_internal.py
+++ b/python/sglang/test/bench_one_batch_server_internal.py
@@ -749,22 +749,34 @@ def run_benchmark_internal(
else:
tokenizer = get_tokenizer(tokenizer_path)
- # Get token capacity
internal_state = server_info.get("internal_states", [{}])
- skip_token_capacity_threshold = (
- internal_state[0].get("memory_usage", {}).get("token_capacity", 1000000000)
- )
+ dp_size = internal_state[0].get("dp_size", None) or 1
# Get effective max running requests
max_running_requests_per_dp = internal_state[0].get(
"effective_max_running_requests_per_dp", -1
)
- dp_size = server_info.get("dp_size", None) or 1
+
+ # Get token capacity
+ skip_token_capacity_threshold = 0
+
+ for i in range(dp_size):
+ skip_token_capacity_threshold += (
+ internal_state[i]
+ .get("memory_usage", {})
+ .get("token_capacity", 1000000000)
+ )
+
assert (
max_running_requests_per_dp > 0
), f"effective_max_running_requests_per_dp is not set, {max_running_requests_per_dp=}"
skip_max_running_requests_threshold = max_running_requests_per_dp * dp_size
+ print(f"{max_running_requests_per_dp=}")
+ print(f"{dp_size=}")
+ print(f"{skip_max_running_requests_threshold=}")
+ print(f"{skip_token_capacity_threshold=}")
+
# Warmup
if not bench_args.skip_warmup:
print("=" * 8 + " Warmup Begin " + "=" * 8)