Skip to content

Commit 9299fdb

Browse files
committed
test: align moe alltoall combine coverage with model specs
Cover the combine path with parameter sets that mirror representative MoE model configurations and add focused Qwen coverage for dtype and workspace staging. This keeps the regression matrix closer to real-world routing shapes while preserving targeted edge-case checks.
1 parent c21d2df commit 9299fdb

File tree

1 file changed

+16
-9
lines changed

1 file changed

+16
-9
lines changed

tests/comm/test_trtllm_moe_alltoall.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -52,15 +52,22 @@ def setup_test_environment():
5252
(8, 16), # 8 ranks
5353
]
5454

55+
# (world_size, num_tokens, vector_dim, top_k, dtype, payload_in_workspace)
5556
COMBINE_PARAMS = [
56-
(2, 64, 8, 2, torch.bfloat16, True), # Small input, 2 ranks
57-
(4, 32, 32768, 4, torch.bfloat16, True), # Large input, 4 ranks
58-
(8, 16, 2048, 8, torch.bfloat16, True), # Medium input, 8 ranks
59-
(8, 16, 2048, 8, torch.bfloat16, False), # Medium input, 8 ranks
60-
(2, 64, 8, 2, torch.float16, True), # Small input, 2 ranks
61-
(4, 32, 32768, 4, torch.float16, True), # Large input, 4 ranks
62-
(8, 16, 2048, 8, torch.float16, True), # Medium input, 8 ranks
63-
(8, 16, 2048, 8, torch.float16, False), # Medium input, 8 ranks
57+
# Coverage for popular model specifications
58+
(4, 16, 4096, 2, torch.bfloat16, True), # Mixtral-8x7B
59+
(4, 16, 2880, 4, torch.bfloat16, True), # GPT-OSS-120B
60+
(8, 16, 5120, 6, torch.bfloat16, True), # DeepSeek-V2
61+
(8, 16, 7168, 8, torch.bfloat16, True), # DeepSeek-V3
62+
(8, 16, 4096, 8, torch.bfloat16, True), # Qwen3-235B-A22B
63+
(8, 16, 4096, 10, torch.bfloat16, True), # Qwen3.5-397B-A17B
64+
(8, 16, 4096, 22, torch.bfloat16, True), # Nemotron-3-Super-120B-A12B
65+
# Coverage for num_tokens
66+
(8, 1, 4096, 8, torch.bfloat16, True),
67+
# Coverage for dtype
68+
(8, 16, 4096, 8, torch.float16, True),
69+
# Coverage for payload_in_workspace
70+
(8, 16, 4096, 8, torch.bfloat16, False),
6471
]
6572

6673

@@ -465,7 +472,7 @@ def test_moe_combine_multi_rank_single_gpu(
465472
):
466473
torch.cuda.set_device(0)
467474
check_sufficient_sm_count(num_tokens, world_size)
468-
max_world_size = 8
475+
max_world_size = 16
469476
assert world_size <= max_world_size, (
470477
f"should run with world_size at most {max_world_size}"
471478
)

0 commit comments

Comments
 (0)