Skip to content

[AMD] support two batch overlapping for mori ep#17953

Open
billishyahao wants to merge 17 commits intosgl-project:mainfrom
HaiShaw:mori_ep_tbo
Open

[AMD] support two batch overlapping for mori ep#17953
billishyahao wants to merge 17 commits intosgl-project:mainfrom
HaiShaw:mori_ep_tbo

Conversation

@billishyahao
Copy link
Contributor

@billishyahao billishyahao commented Jan 29, 2026

Motivation

co-author with @kkHuang-amd @ZhaiFeiyue @Duyi-Wang
cc @HaiShaw

This patch is to support TBO aka two batch overlapping feature for mori ep. It can be divided into the following changes:
(1) We introduce MORI async API to support CU-free method for low latency scenario.
(2) We introduce multi hip stream to enable communication-computation overlapping for high throughput scenario.
(3) We introduce the ENVs SGLANG_MORI_ASYNC_MODE for controlling the mori async behaviour and SGLANG_MORI_DUAL_STREAM for enabling dual stream.

Unittest is to be added.

Accuracy Tests

Accuracy check pass on gsm8k dataset:

DSR1 FP8 EP8 aiter backend + Mori normal mode + fp8 dispatch + eager

SGLANG_MORI_FP8_DISP=true \
MORI_SHMEM_MODE=ISOLATION                         \
SGLANG_MORI_ASYNC_MODE=false                       \
SGLANG_MORI_DUAL_STREAM=false                     \
SGLANG_MORI_NUM_MAX_DISPATCH_TOKENS_PER_RANK=8192 \
NCCL_IB_HCA=ionic_0,ionic_1,ionic_2,ionic_3,ionic_4,ionic_5,ionic_6,ionic_7 \
GLOO_SOCKET_IFNAME=enp81s0f1 \
NCCL_SOCKET_IFNAME=enp81s0f1 \
SGLANG_USE_AITER=1           \
python3 -m sglang.launch_server \
	--model-path /models/DSR1  \
	--tp-size 8 \
	--dp-size 8 \
	--ep-size 8 \
	--moe-a2a-backend mori \
	--deepep-mode normal \
	--enable-dp-attention \
	--decode-log-interval 1 \
	--host 0.0.0.0 \
	--port 8321 \
	--nnodes 1 \
	--node-rank 0 \
	--trust-remote-code \
	--moe-dense-tp-size 1 \
	--enable-dp-lm-head \
	--disable-radix-cache \
	--watchdog-timeout 1000000 \
	--mem-fraction-static 0.8 \
	--max-running-requests 128 \
	--chunked-prefill-size 65536 \
	--disable-cuda-graph \
	--kv-cache-dtype fp8_e4m3 \
	--log-requests \
	--log-requests-level 3 \
	--attention-backend aiter

Accuracy: 0.980
Invalid: 0.000
Latency: 80.601 s
Output throughput: 234.922 token/s

DSR1 FP8 EP8 aiter backend + Mori low_latency mode + fp8 dispatch + eager + non-persist mla

SGLANG_AITER_MLA_PERSIST=0 \
SGLANG_MORI_FP8_DISP=true \
MORI_SHMEM_MODE=ISOLATION                         \
SGLANG_MORI_ASYNC_MODE=true                       \
SGLANG_MORI_DUAL_STREAM=false                     \
SGLANG_MORI_NUM_MAX_DISPATCH_TOKENS_PER_RANK=8192 \
NCCL_IB_HCA=ionic_0,ionic_1,ionic_2,ionic_3,ionic_4,ionic_5,ionic_6,ionic_7 \
GLOO_SOCKET_IFNAME=enp81s0f1 \
NCCL_SOCKET_IFNAME=enp81s0f1 \
SGLANG_USE_AITER=1           \
python3 -m sglang.launch_server \
	--model-path /models/DSR1  \
	--tp-size 8 \
	--dp-size 8 \
	--ep-size 8 \
	--moe-a2a-backend mori \
	--deepep-mode low_latency \
	--enable-dp-attention \
	--decode-log-interval 1 \
	--host 0.0.0.0 \
	--port 8321 \
	--nnodes 1 \
	--node-rank 0 \
	--trust-remote-code \
	--moe-dense-tp-size 1 \
	--enable-dp-lm-head \
	--disable-radix-cache \
	--watchdog-timeout 1000000 \
	--mem-fraction-static 0.8 \
	--max-running-requests 256 \
	--chunked-prefill-size 65536 \
	--cuda-graph-max-bs 32 \
	--kv-cache-dtype fp8_e4m3 \
	--log-requests \
	--log-requests-level 3 \
	--attention-backend aiter

Accuracy: 0.970
Invalid: 0.000
Latency: 49.512 s
Output throughput: 376.373 token/s

DSR1 FP8 EP8 aiter backend + Mori low_latency mode + enable two batch overlap + fp8 dispatch + eager + non-persist mla

SGLANG_AITER_MLA_PERSIST=0 \
SGLANG_MORI_FP8_DISP=true \
MORI_SHMEM_MODE=ISOLATION                         \
SGLANG_MORI_ASYNC_MODE=true                       \
SGLANG_MORI_DUAL_STREAM=false                     \
SGLANG_MORI_NUM_MAX_DISPATCH_TOKENS_PER_RANK=8192 \
NCCL_IB_HCA=ionic_0,ionic_1,ionic_2,ionic_3,ionic_4,ionic_5,ionic_6,ionic_7 \
GLOO_SOCKET_IFNAME=enp81s0f1 \
NCCL_SOCKET_IFNAME=enp81s0f1 \
SGLANG_USE_AITER=1           \
python3 -m sglang.launch_server \
	--model-path /models/DSR1  \
	--tp-size 8 \
	--dp-size 8 \
	--ep-size 8 \
	--moe-a2a-backend mori \
	--deepep-mode low_latency \
	--enable-two-batch-overlap \
	--enable-dp-attention \
	--decode-log-interval 1 \
	--host 0.0.0.0 \
	--port 8321 \
	--nnodes 1 \
	--node-rank 0 \
	--trust-remote-code \
	--moe-dense-tp-size 1 \
	--enable-dp-lm-head \
	--disable-radix-cache \
	--watchdog-timeout 1000000 \
	--mem-fraction-static 0.8 \
	--max-running-requests 128 \
	--chunked-prefill-size 65536 \
	--disable-cuda-graph \
	--kv-cache-dtype fp8_e4m3 \
	--log-requests \
	--log-requests-level 3 \
	--attention-backend aiter

Accuracy: 0.980
Invalid: 0.000
Latency: 71.464 s
Output throughput: 258.213 token/s

DSR1 FP8 EP8 aiter backend + Mori low_latency mode + enable two batch overlap + fp8 dispatch + hip graph + non-persist mla

SGLANG_AITER_MLA_PERSIST=0 \
SGLANG_MORI_FP8_DISP=true \
MORI_SHMEM_MODE=ISOLATION                         \
SGLANG_MORI_ASYNC_MODE=true                       \
SGLANG_MORI_DUAL_STREAM=false                     \
SGLANG_MORI_NUM_MAX_DISPATCH_TOKENS_PER_RANK=8192 \
NCCL_IB_HCA=ionic_0,ionic_1,ionic_2,ionic_3,ionic_4,ionic_5,ionic_6,ionic_7 \
GLOO_SOCKET_IFNAME=enp81s0f1 \
NCCL_SOCKET_IFNAME=enp81s0f1 \
SGLANG_USE_AITER=1           \
python3 -m sglang.launch_server \
	--model-path /models/DSR1  \
	--tp-size 8 \
	--dp-size 8 \
	--ep-size 8 \
	--moe-a2a-backend mori \
	--deepep-mode low_latency \
	--enable-two-batch-overlap \
	--enable-dp-attention \
	--decode-log-interval 1 \
	--host 0.0.0.0 \
	--port 8321 \
	--nnodes 1 \
	--node-rank 0 \
	--trust-remote-code \
	--moe-dense-tp-size 1 \
	--enable-dp-lm-head \
	--disable-radix-cache \
	--watchdog-timeout 1000000 \
	--mem-fraction-static 0.8 \
	--max-running-requests 256 \
	--chunked-prefill-size 65536 \
	--cuda-graph-max-bs 32 \
	--kv-cache-dtype fp8_e4m3 \
	--log-requests \
	--log-requests-level 3 \
	--attention-backend aiter

Accuracy: 0.980
Invalid: 0.000
Latency: 67.590 s
Output throughput: 270.453 token/s

Benchmarking and Profiling

mbs async async + TBO TBO gain (%)
128 1548.43 1360.84 -12%
256 1958.43 2138.30 +9%
512 2518.49 2792.93 +11%
1024 2868.89 3342.79 +17%
2048 2985.92 3735.71 +25%

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @billishyahao, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly upgrades the Mori Expert Parallelism (EP) backend by integrating advanced features for performance optimization. It introduces an asynchronous API to reduce latency and leverages multi-HIP streams to enable efficient overlapping of communication and computation. These enhancements are crucial for improving the overall throughput and responsiveness of models utilizing Mori EP, particularly within a two-batch overlapping context, making the system more adaptable to diverse performance demands.

Highlights

  • Mori EP Integration: The mori backend has been added as a supported option for the --moe-a2a-backend argument, expanding the available expert parallelism communication backends.
  • Asynchronous Mori API for Low Latency: An asynchronous API for Mori EP is introduced to support low-latency scenarios, controlled by the new SGLANG_MORI_ASYNC_MODE environment variable. This enables non-blocking operations for improved responsiveness.
  • Multi-HIP Stream for Communication-Computation Overlapping: Multi-HIP stream functionality has been implemented to allow for communication-computation overlapping, enhancing throughput. This feature is enabled via the SGLANG_MORI_DUAL_STREAM environment variable.
  • Two Batch Overlapping (TBO) for Mori EP: The existing two-batch overlapping mechanism has been extended to fully support the Mori EP dispatcher, allowing for more efficient resource utilization by overlapping processing of different batches.
  • Refactored Mori Dispatcher Logic: The Mori EP dispatcher (MoriEPDispatcher) has been refactored to accommodate both normal and low-latency modes, introducing new dispatch and combine input/output structures (MoriEPLLDispatchOutput, MoriEPLLCombineInput) for better organization and functionality.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for two-batch overlapping for the mori expert parallelism backend, primarily targeting AMD GPUs. The changes are extensive, adding an async API for low latency scenarios and multi-hip stream support to overlap communication and computation. Overall, the implementation is solid and aligns with the PR's objectives. I've identified a critical bug that could cause crashes on non-CUDA platforms and have also included a few suggestions to improve code style and maintainability.

total_num_sms = device_properties.multi_processor_count
deep_gemm_num_sms = total_num_sms - DeepEPConfig.get_instance().num_sms
deep_gemm_num_sms = None
if _is_cuda:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The calls to torch.cuda.get_device_properties and device_properties.multi_processor_count on lines 86-87 are CUDA-specific and should be moved inside this if _is_cuda: block. As it is, this code will raise an error on non-CUDA platforms.

Here is the corrected code block:

def _compute_moe_deepseek_blog_prefill(layer):
    deep_gemm_num_sms = None
    if _is_cuda:
        device_properties = torch.cuda.get_device_properties(device="cuda")
        total_num_sms = device_properties.multi_processor_count
        deep_gemm_num_sms = total_num_sms - DeepEPConfig.get_instance().num_sms

    return OperationsStrategy(
        deep_gemm_num_sms=deep_gemm_num_sms,
        # ...
    )

total_num_sms = device_properties.multi_processor_count
deep_gemm_num_sms = total_num_sms - DeepEPConfig.get_instance().num_sms
deep_gemm_num_sms = None
if _is_cuda:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to the issue in _compute_moe_deepseek_blog_prefill, the CUDA-specific calls on lines 165-166 should be moved inside this if _is_cuda: block to ensure compatibility with non-CUDA environments.

Here is the corrected code block:

def _compute_moe_qwen3_prefill(layer):
    deep_gemm_num_sms = None
    if _is_cuda:
        device_properties = torch.cuda.get_device_properties(device="cuda")
        total_num_sms = device_properties.multi_processor_count
        deep_gemm_num_sms = total_num_sms - DeepEPConfig.get_instance().num_sms

    return OperationsStrategy(
        deep_gemm_num_sms=deep_gemm_num_sms,
        # ...
    )

topk_ids=topk_output.topk_ids,
topk_weights=topk_output.topk_weights,
)
from sglang.srt.layers.moe.token_dispatcher import DispatchOutputChecker
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This local import should be moved to the top of the file. This follows standard Python style (PEP 8), improves readability, and avoids potential issues with circular dependencies, although none seem to exist here. Keeping imports at the top level makes dependencies clearer.

return (torch.cuda.current_device(), id(group))

@classmethod
def getStreamFromPool(cls, group) -> torch.cuda.Stream:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The method name getStreamFromPool does not follow the PEP 8 style guide for function names, which recommends snake_case. Renaming it to get_stream_from_pool would improve consistency with the rest of the codebase.

Suggested change
def getStreamFromPool(cls, group) -> torch.cuda.Stream:
def get_stream_from_pool(cls, group) -> torch.cuda.Stream:

@HaiShaw
Copy link
Collaborator

HaiShaw commented Feb 15, 2026

@billishyahao conflicts?

@billishyahao
Copy link
Contributor Author

Thanks @HaiShaw for the comments. I have addressed all the conflicts. Feel free to review it.

@HaiShaw HaiShaw self-assigned this Feb 17, 2026
@HaiShaw
Copy link
Collaborator

HaiShaw commented Feb 19, 2026

@kkHuang-amd Please review aiter backend, for performance implications

):
# TODO(billishyahao): check aiter path
# billishyahao: for now, fused_moe only support torch.bfloat16
output_dtype = torch.bfloat16
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not to hard code

x.add_(final_hidden_states, alpha=self.routed_scaling_factor)
final_hidden_states = x
else:
elif not _use_aiter:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the else case - that is _use_aiter but not covered in if, elif

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

deepseek documentation Improvements or additions to documentation run-ci

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants

Comments