1- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22# SPDX-License-Identifier: Apache-2.0
33#
44# Licensed under the Apache License, Version 2.0 (the "License");
@@ -273,6 +273,110 @@ def _run_moe_with_alltoall(
273273 return combined [:local_num_tokens ]
274274
275275
276+ def _run_trtllm_gen_nvfp4_moe_with_alltoall (
277+ x : torch .Tensor ,
278+ selected_experts : torch .Tensor ,
279+ routing_weights : torch .Tensor ,
280+ fc1_expert_weights_fp4 : torch .Tensor ,
281+ fc2_expert_weights_fp4 : torch .Tensor ,
282+ fc1_weight_blockscale_fp8 : torch .Tensor ,
283+ fc2_weight_blockscale_fp8 : torch .Tensor ,
284+ fc1_act_global_scale : torch .Tensor ,
285+ fc1_scale_c : torch .Tensor ,
286+ fc1_alpha : torch .Tensor ,
287+ fc2_alpha : torch .Tensor ,
288+ mapping : Mapping ,
289+ max_num_tokens : int ,
290+ act_type : int ,
291+ ) -> torch .Tensor :
292+ """Run TRTLLM-Gen NVFP4 MoE through the all-to-all dispatch/combine path."""
293+
294+ top_k = selected_experts .shape [1 ]
295+ hidden_size = x .shape [- 1 ]
296+ local_num_experts = int (fc1_expert_weights_fp4 .shape [0 ])
297+ global_num_experts = local_num_experts * mapping .moe_ep_size
298+ workspace_size = MoeAlltoAll .calculate_required_workspace_size (
299+ mapping .moe_ep_size , top_k , max_num_tokens , hidden_size , x .dtype
300+ )
301+ runtime_max_tokens_per_rank = max_num_tokens
302+
303+ moe_a2a = MoeAlltoAll (
304+ mapping = mapping ,
305+ max_num_tokens = max_num_tokens ,
306+ top_k = top_k ,
307+ num_slots = global_num_experts ,
308+ workspace_size_per_rank = workspace_size ,
309+ num_experts = None ,
310+ )
311+
312+ invalid_expert_id = global_num_experts
313+ local_num_tokens = x .shape [0 ]
314+ pad_expert_id = mapping .moe_ep_rank * local_num_experts
315+ pad_size = runtime_max_tokens_per_rank - local_num_tokens
316+ if pad_size > 0 :
317+ x = torch .nn .functional .pad (x , (0 , 0 , 0 , pad_size ))
318+ selected_experts = torch .nn .functional .pad (
319+ selected_experts , (0 , 0 , 0 , pad_size ), value = pad_expert_id
320+ )
321+ routing_weights = torch .nn .functional .pad (routing_weights , (0 , 0 , 0 , pad_size ))
322+
323+ recv_results = moe_a2a .dispatch (
324+ selected_experts ,
325+ [x .contiguous (), selected_experts .contiguous (), routing_weights .contiguous ()],
326+ runtime_max_tokens_per_rank ,
327+ invalid_token_expert_id = invalid_expert_id ,
328+ expert_id_payload_index = 1 ,
329+ )
330+
331+ dispatched_x = recv_results [0 ].reshape (- 1 , hidden_size )
332+ dispatched_selected = recv_results [1 ].reshape (- 1 , top_k ).to (torch .int32 ).contiguous ()
333+ dispatched_weights = recv_results [2 ].reshape (- 1 , top_k ).to (torch .bfloat16 ).contiguous ()
334+
335+ x_q_fp4 , x_sf = torch .ops .trtllm .fp4_quantize (
336+ dispatched_x , fc1_act_global_scale , TRTLLM_NVFP4_SCALING_VECTOR_SIZE , False , False
337+ )
338+ factor = 1 if act_type == 1 else 2
339+ intermediate_size = int (fc1_expert_weights_fp4 .shape [1 ] // factor )
340+ local_expert_offset = mapping .moe_ep_rank * local_num_experts
341+ routing_method_type = int (RoutingMethodType .DeepSeekV3 )
342+
343+ outputs = torch .ops .trtllm .fp4_block_scale_moe_runner (
344+ None ,
345+ None ,
346+ x_q_fp4 ,
347+ x_sf .view (torch .float8_e4m3fn ),
348+ fc1_expert_weights_fp4 ,
349+ fc1_weight_blockscale_fp8 .view (torch .float8_e4m3fn ),
350+ None ,
351+ None ,
352+ None ,
353+ None ,
354+ fc2_expert_weights_fp4 ,
355+ fc2_weight_blockscale_fp8 .view (torch .float8_e4m3fn ),
356+ None ,
357+ fc1_scale_c ,
358+ fc1_alpha ,
359+ fc2_alpha ,
360+ global_num_experts ,
361+ top_k ,
362+ 1 ,
363+ 1 ,
364+ intermediate_size ,
365+ local_expert_offset ,
366+ local_num_experts ,
367+ 1.0 ,
368+ routing_method_type ,
369+ do_finalize = True ,
370+ act_type = act_type ,
371+ topk_weights = dispatched_weights ,
372+ topk_ids = dispatched_selected ,
373+ )
374+
375+ moe_out = outputs [0 ].view (mapping .moe_ep_size , runtime_max_tokens_per_rank , hidden_size )
376+ combined = moe_a2a .combine (moe_out , runtime_max_tokens_per_rank )
377+ return combined [:local_num_tokens ]
378+
379+
276380@torch .library .custom_op ("auto_deploy::trtllm_moe_fused" , mutates_args = ())
277381def trtllm_moe_fused (
278382 x : torch .Tensor ,
@@ -825,7 +929,7 @@ def trtllm_nvfp4_trtllm_gen_moe_fused(
825929 fc2_alpha : torch .Tensor ,
826930 is_gated_mlp : bool = True ,
827931 act_fn : int = int (ActivationType .Silu ),
828- mapping_config : str = "" , # https://github.com/NVIDIA/TensorRT-LLM/issues/12008 Add mapping config support
932+ mapping_config : str = "" ,
829933 max_num_tokens : int = 0 ,
830934 apply_routing_on_input : bool = False ,
831935) -> torch .Tensor :
@@ -839,9 +943,6 @@ def trtllm_nvfp4_trtllm_gen_moe_fused(
839943 pad_size = expected_hidden - int (x2d .shape [- 1 ])
840944 if pad_size > 0 :
841945 x2d = torch .nn .functional .pad (x2d , (0 , pad_size ))
842- x_q_fp4 , x_sf = torch .ops .trtllm .fp4_quantize (
843- x2d , fc1_act_global_scale , TRTLLM_NVFP4_SCALING_VECTOR_SIZE , False , False
844- )
845946
846947 if act_fn in (ActivationType .Silu , ActivationType .Swiglu ):
847948 act_type = 0
@@ -855,6 +956,32 @@ def trtllm_nvfp4_trtllm_gen_moe_fused(
855956 factor = 1 if act_type == 1 else 2
856957 intermediate_size = int (fc1_expert_weights_fp4 .shape [1 ] // factor )
857958 routing_method_type = int (RoutingMethodType .DeepSeekV3 )
959+ mapping , enable_alltoall = _check_moe_alltoall (mapping_config , max_num_tokens )
960+
961+ if enable_alltoall :
962+ final_hidden_states = _run_trtllm_gen_nvfp4_moe_with_alltoall (
963+ x = x2d ,
964+ selected_experts = selected_experts .to (torch .int32 ),
965+ routing_weights = routing_weights .to (torch .float32 ),
966+ fc1_expert_weights_fp4 = fc1_expert_weights_fp4 ,
967+ fc2_expert_weights_fp4 = fc2_expert_weights_fp4 ,
968+ fc1_weight_blockscale_fp8 = fc1_weight_blockscale_fp8 ,
969+ fc2_weight_blockscale_fp8 = fc2_weight_blockscale_fp8 ,
970+ fc1_act_global_scale = fc1_act_global_scale ,
971+ fc1_scale_c = fc1_scale_c ,
972+ fc1_alpha = fc1_alpha ,
973+ fc2_alpha = fc2_alpha ,
974+ mapping = mapping ,
975+ max_num_tokens = max_num_tokens ,
976+ act_type = act_type ,
977+ )
978+ if final_hidden_states .shape [1 ] > x_shape [- 1 ]:
979+ final_hidden_states = final_hidden_states [:, : x_shape [- 1 ]].contiguous ()
980+ return final_hidden_states .view (x_shape )
981+
982+ x_q_fp4 , x_sf = torch .ops .trtllm .fp4_quantize (
983+ x2d , fc1_act_global_scale , TRTLLM_NVFP4_SCALING_VECTOR_SIZE , False , False
984+ )
858985
859986 outputs = torch .ops .trtllm .fp4_block_scale_moe_runner (
860987 None ,
0 commit comments