Skip to content

Commit fc83799

Browse files
authored
[None][fix] Fix moe_chunking_tokens during MoE A2A (NVIDIA#12929)
Signed-off-by: Wanli Jiang <35160485+Wanli-Jiang@users.noreply.github.com>
1 parent 4825da7 commit fc83799

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -347,10 +347,17 @@ def _get_quant_config_dict(self, model_config: ModelConfig) -> Optional[Dict]:
347347

348348
def calculate_num_chunks(self, all_rank_num_tokens: List[int]) -> int:
349349
"""
350-
Calculate how many chunks are needed
350+
Calculate how many chunks are needed.
351351
352+
Uses ep_size * max(all_rank_num_tokens) when A2A communication is active,
353+
because the A2A recv buffer is shaped [ep_size, max_tokens_per_rank, hidden]
354+
regardless of how tokens are distributed across ranks. This matches the
355+
actual memory footprint of the MoE GEMM workspace.
352356
"""
353-
num_rows = sum(all_rank_num_tokens)
357+
if self.use_dp and self.comm is not None:
358+
num_rows = self.mapping.moe_ep_size * max(all_rank_num_tokens)
359+
else:
360+
num_rows = sum(all_rank_num_tokens)
354361
return (num_rows + self.moe_max_num_tokens - 1) // self.moe_max_num_tokens
355362

356363
def split_chunk(self, split_token_num: int, split_num_chunks: int) -> List[int]:

0 commit comments

Comments
 (0)