[AMD] support two batch overlapping for mori ep#17953
[AMD] support two batch overlapping for mori ep#17953billishyahao wants to merge 17 commits intosgl-project:mainfrom
Conversation
Summary of ChangesHello @billishyahao, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly upgrades the Mori Expert Parallelism (EP) backend by integrating advanced features for performance optimization. It introduces an asynchronous API to reduce latency and leverages multi-HIP streams to enable efficient overlapping of communication and computation. These enhancements are crucial for improving the overall throughput and responsiveness of models utilizing Mori EP, particularly within a two-batch overlapping context, making the system more adaptable to diverse performance demands. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces support for two-batch overlapping for the mori expert parallelism backend, primarily targeting AMD GPUs. The changes are extensive, adding an async API for low latency scenarios and multi-hip stream support to overlap communication and computation. Overall, the implementation is solid and aligns with the PR's objectives. I've identified a critical bug that could cause crashes on non-CUDA platforms and have also included a few suggestions to improve code style and maintainability.
| total_num_sms = device_properties.multi_processor_count | ||
| deep_gemm_num_sms = total_num_sms - DeepEPConfig.get_instance().num_sms | ||
| deep_gemm_num_sms = None | ||
| if _is_cuda: |
There was a problem hiding this comment.
The calls to torch.cuda.get_device_properties and device_properties.multi_processor_count on lines 86-87 are CUDA-specific and should be moved inside this if _is_cuda: block. As it is, this code will raise an error on non-CUDA platforms.
Here is the corrected code block:
def _compute_moe_deepseek_blog_prefill(layer):
deep_gemm_num_sms = None
if _is_cuda:
device_properties = torch.cuda.get_device_properties(device="cuda")
total_num_sms = device_properties.multi_processor_count
deep_gemm_num_sms = total_num_sms - DeepEPConfig.get_instance().num_sms
return OperationsStrategy(
deep_gemm_num_sms=deep_gemm_num_sms,
# ...
)| total_num_sms = device_properties.multi_processor_count | ||
| deep_gemm_num_sms = total_num_sms - DeepEPConfig.get_instance().num_sms | ||
| deep_gemm_num_sms = None | ||
| if _is_cuda: |
There was a problem hiding this comment.
Similar to the issue in _compute_moe_deepseek_blog_prefill, the CUDA-specific calls on lines 165-166 should be moved inside this if _is_cuda: block to ensure compatibility with non-CUDA environments.
Here is the corrected code block:
def _compute_moe_qwen3_prefill(layer):
deep_gemm_num_sms = None
if _is_cuda:
device_properties = torch.cuda.get_device_properties(device="cuda")
total_num_sms = device_properties.multi_processor_count
deep_gemm_num_sms = total_num_sms - DeepEPConfig.get_instance().num_sms
return OperationsStrategy(
deep_gemm_num_sms=deep_gemm_num_sms,
# ...
)| topk_ids=topk_output.topk_ids, | ||
| topk_weights=topk_output.topk_weights, | ||
| ) | ||
| from sglang.srt.layers.moe.token_dispatcher import DispatchOutputChecker |
There was a problem hiding this comment.
| return (torch.cuda.current_device(), id(group)) | ||
|
|
||
| @classmethod | ||
| def getStreamFromPool(cls, group) -> torch.cuda.Stream: |
There was a problem hiding this comment.
The method name getStreamFromPool does not follow the PEP 8 style guide for function names, which recommends snake_case. Renaming it to get_stream_from_pool would improve consistency with the rest of the codebase.
| def getStreamFromPool(cls, group) -> torch.cuda.Stream: | |
| def get_stream_from_pool(cls, group) -> torch.cuda.Stream: |
2429529 to
37eb77f
Compare
|
@billishyahao conflicts? |
12c53b6 to
0a709fa
Compare
0a709fa to
d95089d
Compare
|
Thanks @HaiShaw for the comments. I have addressed all the conflicts. Feel free to review it. |
|
@kkHuang-amd Please review aiter backend, for performance implications |
| ): | ||
| # TODO(billishyahao): check aiter path | ||
| # billishyahao: for now, fused_moe only support torch.bfloat16 | ||
| output_dtype = torch.bfloat16 |
| x.add_(final_hidden_states, alpha=self.routed_scaling_factor) | ||
| final_hidden_states = x | ||
| else: | ||
| elif not _use_aiter: |
There was a problem hiding this comment.
What is the else case - that is _use_aiter but not covered in if, elif
Motivation
co-author with @kkHuang-amd @ZhaiFeiyue @Duyi-Wang
cc @HaiShaw
This patch is to support TBO aka two batch overlapping feature for mori ep. It can be divided into the following changes:
(1) We introduce MORI async API to support CU-free method for low latency scenario.
(2) We introduce multi hip stream to enable communication-computation overlapping for high throughput scenario.
(3) We introduce the ENVs
SGLANG_MORI_ASYNC_MODEfor controlling the mori async behaviour andSGLANG_MORI_DUAL_STREAMfor enabling dual stream.Unittest is to be added.
Accuracy Tests
Accuracy check pass on gsm8k dataset:
DSR1 FP8 EP8 aiter backend + Mori normal mode + fp8 dispatch + eager
Accuracy: 0.980
Invalid: 0.000
Latency: 80.601 s
Output throughput: 234.922 token/s
DSR1 FP8 EP8 aiter backend + Mori low_latency mode + fp8 dispatch + eager + non-persist mla
Accuracy: 0.970
Invalid: 0.000
Latency: 49.512 s
Output throughput: 376.373 token/s
DSR1 FP8 EP8 aiter backend + Mori low_latency mode + enable two batch overlap + fp8 dispatch + eager + non-persist mla
Accuracy: 0.980
Invalid: 0.000
Latency: 71.464 s
Output throughput: 258.213 token/s
DSR1 FP8 EP8 aiter backend + Mori low_latency mode + enable two batch overlap + fp8 dispatch + hip graph + non-persist mla
Accuracy: 0.980
Invalid: 0.000
Latency: 67.590 s
Output throughput: 270.453 token/s
Benchmarking and Profiling
+25%Checklist
Review Process
/tag-run-ci-label,/rerun-failed-ci,/tag-and-rerun-ci