Skip to content

Commit 5cc0ccd

Browse files
[https://nvbugs/5973199][fix] support attn-dp TRTLLM-Gen NVFP4 MoE fu… (NVIDIA#12156)
Signed-off-by: Tal Cherckez <127761168+tcherckez-nvidia@users.noreply.github.com>
1 parent 7c64f4b commit 5cc0ccd

File tree

2 files changed

+132
-13
lines changed

2 files changed

+132
-13
lines changed

tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py

Lines changed: 132 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: Apache-2.0
33
#
44
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -273,6 +273,110 @@ def _run_moe_with_alltoall(
273273
return combined[:local_num_tokens]
274274

275275

276+
def _run_trtllm_gen_nvfp4_moe_with_alltoall(
277+
x: torch.Tensor,
278+
selected_experts: torch.Tensor,
279+
routing_weights: torch.Tensor,
280+
fc1_expert_weights_fp4: torch.Tensor,
281+
fc2_expert_weights_fp4: torch.Tensor,
282+
fc1_weight_blockscale_fp8: torch.Tensor,
283+
fc2_weight_blockscale_fp8: torch.Tensor,
284+
fc1_act_global_scale: torch.Tensor,
285+
fc1_scale_c: torch.Tensor,
286+
fc1_alpha: torch.Tensor,
287+
fc2_alpha: torch.Tensor,
288+
mapping: Mapping,
289+
max_num_tokens: int,
290+
act_type: int,
291+
) -> torch.Tensor:
292+
"""Run TRTLLM-Gen NVFP4 MoE through the all-to-all dispatch/combine path."""
293+
294+
top_k = selected_experts.shape[1]
295+
hidden_size = x.shape[-1]
296+
local_num_experts = int(fc1_expert_weights_fp4.shape[0])
297+
global_num_experts = local_num_experts * mapping.moe_ep_size
298+
workspace_size = MoeAlltoAll.calculate_required_workspace_size(
299+
mapping.moe_ep_size, top_k, max_num_tokens, hidden_size, x.dtype
300+
)
301+
runtime_max_tokens_per_rank = max_num_tokens
302+
303+
moe_a2a = MoeAlltoAll(
304+
mapping=mapping,
305+
max_num_tokens=max_num_tokens,
306+
top_k=top_k,
307+
num_slots=global_num_experts,
308+
workspace_size_per_rank=workspace_size,
309+
num_experts=None,
310+
)
311+
312+
invalid_expert_id = global_num_experts
313+
local_num_tokens = x.shape[0]
314+
pad_expert_id = mapping.moe_ep_rank * local_num_experts
315+
pad_size = runtime_max_tokens_per_rank - local_num_tokens
316+
if pad_size > 0:
317+
x = torch.nn.functional.pad(x, (0, 0, 0, pad_size))
318+
selected_experts = torch.nn.functional.pad(
319+
selected_experts, (0, 0, 0, pad_size), value=pad_expert_id
320+
)
321+
routing_weights = torch.nn.functional.pad(routing_weights, (0, 0, 0, pad_size))
322+
323+
recv_results = moe_a2a.dispatch(
324+
selected_experts,
325+
[x.contiguous(), selected_experts.contiguous(), routing_weights.contiguous()],
326+
runtime_max_tokens_per_rank,
327+
invalid_token_expert_id=invalid_expert_id,
328+
expert_id_payload_index=1,
329+
)
330+
331+
dispatched_x = recv_results[0].reshape(-1, hidden_size)
332+
dispatched_selected = recv_results[1].reshape(-1, top_k).to(torch.int32).contiguous()
333+
dispatched_weights = recv_results[2].reshape(-1, top_k).to(torch.bfloat16).contiguous()
334+
335+
x_q_fp4, x_sf = torch.ops.trtllm.fp4_quantize(
336+
dispatched_x, fc1_act_global_scale, TRTLLM_NVFP4_SCALING_VECTOR_SIZE, False, False
337+
)
338+
factor = 1 if act_type == 1 else 2
339+
intermediate_size = int(fc1_expert_weights_fp4.shape[1] // factor)
340+
local_expert_offset = mapping.moe_ep_rank * local_num_experts
341+
routing_method_type = int(RoutingMethodType.DeepSeekV3)
342+
343+
outputs = torch.ops.trtllm.fp4_block_scale_moe_runner(
344+
None,
345+
None,
346+
x_q_fp4,
347+
x_sf.view(torch.float8_e4m3fn),
348+
fc1_expert_weights_fp4,
349+
fc1_weight_blockscale_fp8.view(torch.float8_e4m3fn),
350+
None,
351+
None,
352+
None,
353+
None,
354+
fc2_expert_weights_fp4,
355+
fc2_weight_blockscale_fp8.view(torch.float8_e4m3fn),
356+
None,
357+
fc1_scale_c,
358+
fc1_alpha,
359+
fc2_alpha,
360+
global_num_experts,
361+
top_k,
362+
1,
363+
1,
364+
intermediate_size,
365+
local_expert_offset,
366+
local_num_experts,
367+
1.0,
368+
routing_method_type,
369+
do_finalize=True,
370+
act_type=act_type,
371+
topk_weights=dispatched_weights,
372+
topk_ids=dispatched_selected,
373+
)
374+
375+
moe_out = outputs[0].view(mapping.moe_ep_size, runtime_max_tokens_per_rank, hidden_size)
376+
combined = moe_a2a.combine(moe_out, runtime_max_tokens_per_rank)
377+
return combined[:local_num_tokens]
378+
379+
276380
@torch.library.custom_op("auto_deploy::trtllm_moe_fused", mutates_args=())
277381
def trtllm_moe_fused(
278382
x: torch.Tensor,
@@ -825,7 +929,7 @@ def trtllm_nvfp4_trtllm_gen_moe_fused(
825929
fc2_alpha: torch.Tensor,
826930
is_gated_mlp: bool = True,
827931
act_fn: int = int(ActivationType.Silu),
828-
mapping_config: str = "", # https://github.com/NVIDIA/TensorRT-LLM/issues/12008 Add mapping config support
932+
mapping_config: str = "",
829933
max_num_tokens: int = 0,
830934
apply_routing_on_input: bool = False,
831935
) -> torch.Tensor:
@@ -839,9 +943,6 @@ def trtllm_nvfp4_trtllm_gen_moe_fused(
839943
pad_size = expected_hidden - int(x2d.shape[-1])
840944
if pad_size > 0:
841945
x2d = torch.nn.functional.pad(x2d, (0, pad_size))
842-
x_q_fp4, x_sf = torch.ops.trtllm.fp4_quantize(
843-
x2d, fc1_act_global_scale, TRTLLM_NVFP4_SCALING_VECTOR_SIZE, False, False
844-
)
845946

846947
if act_fn in (ActivationType.Silu, ActivationType.Swiglu):
847948
act_type = 0
@@ -855,6 +956,32 @@ def trtllm_nvfp4_trtllm_gen_moe_fused(
855956
factor = 1 if act_type == 1 else 2
856957
intermediate_size = int(fc1_expert_weights_fp4.shape[1] // factor)
857958
routing_method_type = int(RoutingMethodType.DeepSeekV3)
959+
mapping, enable_alltoall = _check_moe_alltoall(mapping_config, max_num_tokens)
960+
961+
if enable_alltoall:
962+
final_hidden_states = _run_trtllm_gen_nvfp4_moe_with_alltoall(
963+
x=x2d,
964+
selected_experts=selected_experts.to(torch.int32),
965+
routing_weights=routing_weights.to(torch.float32),
966+
fc1_expert_weights_fp4=fc1_expert_weights_fp4,
967+
fc2_expert_weights_fp4=fc2_expert_weights_fp4,
968+
fc1_weight_blockscale_fp8=fc1_weight_blockscale_fp8,
969+
fc2_weight_blockscale_fp8=fc2_weight_blockscale_fp8,
970+
fc1_act_global_scale=fc1_act_global_scale,
971+
fc1_scale_c=fc1_scale_c,
972+
fc1_alpha=fc1_alpha,
973+
fc2_alpha=fc2_alpha,
974+
mapping=mapping,
975+
max_num_tokens=max_num_tokens,
976+
act_type=act_type,
977+
)
978+
if final_hidden_states.shape[1] > x_shape[-1]:
979+
final_hidden_states = final_hidden_states[:, : x_shape[-1]].contiguous()
980+
return final_hidden_states.view(x_shape)
981+
982+
x_q_fp4, x_sf = torch.ops.trtllm.fp4_quantize(
983+
x2d, fc1_act_global_scale, TRTLLM_NVFP4_SCALING_VECTOR_SIZE, False, False
984+
)
858985

859986
outputs = torch.ops.trtllm.fp4_block_scale_moe_runner(
860987
None,

tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2393,14 +2393,6 @@ def _shuffle_scale_stack(scale_3d_u8: torch.Tensor, is_gated: bool) -> torch.Ten
23932393

23942394
matched_nodes = [node for node in graph.nodes if is_op(node, replaced_op)]
23952395
for node in matched_nodes:
2396-
mapping_config = node.kwargs.get("mapping_config", "") if node.kwargs else ""
2397-
if mapping_config:
2398-
ad_logger.debug_once(
2399-
"Skip TRTLLM-Gen NVFP4 fusion: mapping_config is not supported.",
2400-
key="trtllm_gen_nvfp4_skip_mapping_config",
2401-
)
2402-
continue
2403-
24042396
(
24052397
hidden_states,
24062398
selected_experts,

0 commit comments

Comments
 (0)