-
Notifications
You must be signed in to change notification settings - Fork 5.3k
feat: Support flashinfer_cutedsl MoE runner with flashinfer alltoall backend #22669
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
samuellees
wants to merge
1
commit into
sgl-project:main
Choose a base branch
from
samuellees:feat/enable-fp4cutedslmoe+a2a
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -100,7 +100,7 @@ def __init__( | |
|
|
||
| # TODO: Can this be a server arg and shared with deepep/mooncakeep? | ||
| self.max_num_tokens = ( | ||
| get_int_env_var("SGLANG_FLASHINFER_NUM_MAX_DISPATCH_TOKENS_PER_RANK", 1024) | ||
| get_int_env_var("SGLANG_FLASHINFER_NUM_MAX_DISPATCH_TOKENS_PER_RANK", 8192) | ||
| * self.ep_size | ||
| ) | ||
|
|
||
|
|
@@ -178,15 +178,11 @@ def dispatch( | |
| topk_ids = topk_output.topk_ids | ||
| topk_weights = topk_output.topk_weights | ||
|
|
||
| # Handle case where there are no tokens on this DP worker | ||
| # moe_a2a.dispatch requires at least one token | ||
| self.has_dummy_token = False | ||
| if x.shape[0] == 0: | ||
| logger.warning("No tokens on this DP worker, using dummy token") | ||
| self.has_dummy_token = True | ||
| x = self.dummy_x | ||
| topk_ids = self.dummy_topk_ids | ||
| topk_weights = self.dummy_topk_weights | ||
| # Track if this DP worker has no tokens (idle rank). | ||
| # Unlike the old dummy-token approach, we pass 0-size tensors directly | ||
| # to the alltoall kernel, which handles local_num_tokens=0 natively | ||
| # (same as TRT-LLM). The kernel keeps 1 thread alive for sync. | ||
| self.has_dummy_token = x.shape[0] == 0 | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe we could rename it, something like |
||
|
|
||
| global_scale = self.quant_config.get("input_global_scale", None) | ||
| if global_scale is not None: | ||
|
|
@@ -216,9 +212,10 @@ def dispatch( | |
| else x.shape[0] | ||
| ) | ||
| recv_tensors = self.moe_a2a.dispatch( | ||
| self.dummy_topk_ids_current_rank if self.has_dummy_token else topk_ids, | ||
| topk_ids, | ||
| payloads, | ||
| self.runtime_max_tokens_per_rank, | ||
| invalid_token_expert_id=self.num_experts, | ||
| expert_id_payload_index=expert_id_payload_index, | ||
| ) | ||
| if x_sf is not None: | ||
|
|
@@ -257,10 +254,6 @@ def combine(self, combine_input: FlashinferCombineInput) -> torch.Tensor: | |
| payload_in_workspace=self.payload_in_workspace, | ||
| ) | ||
|
|
||
| # Remove dummy token if it was added in dispatch | ||
| if self.has_dummy_token: | ||
| hidden_states = hidden_states[1:, :] | ||
|
|
||
| del self.runtime_max_tokens_per_rank | ||
| del self.has_dummy_token | ||
| return hidden_states | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,91 @@ | ||
| """Test CuteDSL FP4 MoE + FlashInfer alltoall on B200 with DP attention. | ||
|
|
||
| Config: Qwen3.5-397B-A17B-NVFP4, B200x4, EP=4 DP=4, cutedsl + flashinfer a2a. | ||
| """ | ||
|
|
||
| import unittest | ||
| from types import SimpleNamespace | ||
|
|
||
| import torch | ||
|
|
||
| from sglang.srt.utils import kill_process_tree | ||
| from sglang.test.ci.ci_register import register_cuda_ci | ||
| from sglang.test.run_eval import run_eval | ||
| from sglang.test.test_utils import ( | ||
| DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, | ||
| DEFAULT_URL_FOR_TEST, | ||
| CustomTestCase, | ||
| popen_launch_server, | ||
| ) | ||
|
|
||
| register_cuda_ci(est_time=600, suite="stage-c-test-4-gpu-b200") | ||
|
|
||
| MODEL = "nvidia/Qwen3.5-397B-A17B-NVFP4" | ||
|
|
||
| SKIP_TEST = torch.cuda.get_device_capability() < (10, 0) | ||
| SKIP_REASON = "Requires Blackwell (B200, sm_100a) or above." | ||
|
|
||
|
|
||
| @unittest.skipIf(SKIP_TEST, SKIP_REASON) | ||
| class TestCuteDslFlashinferA2A(CustomTestCase): | ||
| """CuteDSL FP4 MoE + FlashInfer one-sided alltoall + DP4 EP4 on B200.""" | ||
|
|
||
| @classmethod | ||
| def setUpClass(cls): | ||
| cls.model = MODEL | ||
| cls.base_url = DEFAULT_URL_FOR_TEST | ||
| cls.process = popen_launch_server( | ||
| cls.model, | ||
| cls.base_url, | ||
| timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH * 3, | ||
| other_args=[ | ||
| "--trust-remote-code", | ||
| "--quantization", | ||
| "modelopt_fp4", | ||
| "--tp", | ||
| "4", | ||
| "--ep-size", | ||
| "4", | ||
| "--dp", | ||
| "4", | ||
| "--enable-dp-attention", | ||
| "--enable-dp-lm-head", | ||
| "--moe-runner-backend", | ||
| "flashinfer_cutedsl", | ||
| "--moe-a2a-backend", | ||
| "flashinfer", | ||
| "--disable-radix-cache", | ||
| "--disable-flashinfer-autotune", | ||
| "--watchdog-timeout", | ||
| "900", | ||
| ], | ||
| env={ | ||
| "FLASHINFER_DISABLE_VERSION_CHECK": "1", | ||
| "SGLANG_MOE_NVFP4_DISPATCH": "0", | ||
| }, | ||
| ) | ||
|
|
||
| @classmethod | ||
| def tearDownClass(cls): | ||
| kill_process_tree(cls.process.pid) | ||
|
|
||
| def test_gsm8k(self): | ||
| args = SimpleNamespace( | ||
| base_url=self.base_url, | ||
| eval_name="gsm8k", | ||
| num_examples=1319, | ||
| max_tokens=10240, | ||
| repeat=1, | ||
| num_threads=1319, | ||
| num_shots=8, | ||
| temperature=0.6, | ||
| top_p=0.95, | ||
| top_k=20, | ||
| ) | ||
| metrics = run_eval(args) | ||
| print(metrics) | ||
| self.assertGreater(metrics["score"], 0.90) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| unittest.main() |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Optional: Are we able to verify this path? If it is not supported, could we consider forcibly disabling NVFP4_DISPATCH in server_args when cutedsl + flashinfer a2a is detected, and emit a warning?