22from typing import Optional , Literal
33import torch
44import numpy as np
5+ from functools import partial
56from flashinfer import (
67 RoutingMethodType ,
78 GatedActType ,
@@ -143,27 +144,26 @@ def bench_trtllm_gen_fused_moe_autotuner_fp8(
143144 else tune_max_num_tokens ,
144145 )
145146 else :
146- fn = lambda : trtllm_fp8_per_tensor_scale_moe (
147- routing_logits ,
148- None , # routing_bias
149- hidden_states ,
150- w13 ,
151- output1_scale_scalar ,
152- output1_scales_gate_scalar ,
153- w2 ,
154- output2_scale_scalar ,
155- num_experts ,
156- top_k ,
157- None , # n_group
158- None , # topk_group
159- intermediate_size ,
160- 0 , # local_expert_offset
161- num_experts ,
162- 1.0 , # routed_scaling_factor
163- False , # use_routing_scales_on_input
164- RoutingMethodType .TopK .value ,
165- enable_pdl ,
166- num_tokens if tune_max_num_tokens is None else tune_max_num_tokens ,
147+ fn = partial (
148+ trtllm_fp8_per_tensor_scale_moe ,
149+ routing_logits = routing_logits ,
150+ routing_bias = None , # routing_bias
151+ output1_scale_scalar = output1_scale_scalar ,
152+ output1_scales_gate_scalar = output1_scales_gate_scalar ,
153+ output2_scale_scalar = output2_scale_scalar ,
154+ num_experts = num_experts ,
155+ top_k = top_k ,
156+ n_group = None , # n_group
157+ topk_group = None , # topk_group
158+ intermediate_size = intermediate_size ,
159+ local_expert_offset = 0 , # local_expert_offset
160+ routed_scaling_factor = 1.0 , # routed_scaling_factor
161+ use_routing_scales_on_input = False , # use_routing_scales_on_input
162+ routing_method_type = RoutingMethodType .TopK .value ,
163+ enable_pdl = enable_pdl ,
164+ tune_max_num_tokens = num_tokens
165+ if tune_max_num_tokens is None
166+ else tune_max_num_tokens ,
167167 )
168168
169169 def bench (do_autotune ):
@@ -173,6 +173,14 @@ def bench(do_autotune):
173173 fn ,
174174 dry_run_iters = warmups ,
175175 repeat_iters = iterations ,
176+ enable_cupti = True ,
177+ use_cuda_graph = True ,
178+ input_kwargs = {
179+ "hidden_states" : hidden_states ,
180+ "gemm1_weights" : w13 ,
181+ "gemm2_weights" : w2 ,
182+ },
183+ cold_l2_cache = True ,
176184 )
177185 median_ms = np .median (ms_list )
178186 return median_ms
@@ -280,37 +288,31 @@ def bench_trtllm_gen_fused_moe_autotuner_fp4(
280288 output2_scale_scalar = torch .tensor (
281289 [hidden_states_global_scale * w2_global_scale ] * num_experts , device = device
282290 )
283- fn = lambda : trtllm_fp4_block_scale_moe (
284- routing_logits ,
285- None , # routing_bias
286- hidden_states ,
287- hidden_states_scale ,
288- w13 ,
289- w13_scale ,
290- bias13 ,
291- None , # gemm1_alpha
292- None , # gemm1_beta
293- None , # gemm1_clamp_limit
294- w2 ,
295- w2_scale ,
296- bias2 ,
297- output1_scale_scalar ,
298- output1_scale_gate_scalar ,
299- output2_scale_scalar ,
300- num_experts ,
301- top_k ,
302- None , # n_group
303- None , # topk_group
304- intermediate_size ,
305- 0 , # local_expert_offset
306- num_experts ,
307- None , # routed_scaling_factor
308- RoutingMethodType .Renormalize .value ,
309- True ,
310- enable_pdl ,
311- GatedActType .SwiGlu .value , # gated_act_type
312- None ,
313- num_tokens if tune_max_num_tokens is None else tune_max_num_tokens ,
291+ fn = partial (
292+ trtllm_fp4_block_scale_moe ,
293+ routing_logits = routing_logits ,
294+ routing_bias = None , # routing_bias
295+ gemm1_alpha = None , # gemm1_alpha
296+ gemm1_beta = None , # gemm1_beta
297+ gemm1_clamp_limit = None , # gemm1_clamp_limit
298+ output1_scale_scalar = output1_scale_scalar ,
299+ output1_scale_gate_scalar = output1_scale_gate_scalar ,
300+ output2_scale_scalar = output2_scale_scalar ,
301+ num_experts = num_experts ,
302+ top_k = top_k ,
303+ n_group = None , # n_group
304+ topk_group = None , # topk_group
305+ intermediate_size = intermediate_size ,
306+ local_expert_offset = 0 , # local_expert_offset
307+ routed_scaling_factor = None , # routed_scaling_factor
308+ routing_method_type = RoutingMethodType .Renormalize .value ,
309+ do_finalize = True ,
310+ enable_pdl = enable_pdl ,
311+ gated_act_type = GatedActType .SwiGlu .value , # gated_act_type
312+ output = None ,
313+ tune_max_num_tokens = num_tokens
314+ if tune_max_num_tokens is None
315+ else tune_max_num_tokens ,
314316 )
315317
316318 def bench (do_autotune ):
@@ -320,6 +322,18 @@ def bench(do_autotune):
320322 fn ,
321323 dry_run_iters = warmups ,
322324 repeat_iters = iterations ,
325+ enable_cupti = True ,
326+ use_cuda_graph = True ,
327+ input_kwargs = {
328+ "hidden_states" : hidden_states ,
329+ "gemm1_weights" : w13 ,
330+ "gemm1_weights_scale" : w13_scale ,
331+ "gemm2_weights" : w2 ,
332+ "gemm2_weights_scale" : w2_scale ,
333+ "gemm1_bias" : bias13 ,
334+ "gemm2_bias" : bias2 ,
335+ },
336+ cold_l2_cache = True ,
323337 )
324338 median_ms = np .median (ms_list )
325339 return median_ms
@@ -370,29 +384,27 @@ def bench_trtllm_gen_fused_moe_autotuner_mxint4(
370384 intermediate_size // 32 ,
371385 )
372386
373- fn = lambda : trtllm_mxint4_block_scale_moe (
374- routing_logits ,
375- routing_bias ,
376- hidden_states ,
377- w13 ,
378- w13_scale ,
379- None , # gemm1_alpha
380- None , # gemm1_beta
381- None , # gemm1_clamp_limit
382- w2 ,
383- w2_scale ,
384- num_experts ,
385- top_k ,
386- 1 , # n_group
387- 1 , # topk_group
388- intermediate_size ,
389- 0 , # local_expert_offset
390- num_experts ,
391- None , # routed_scaling_factor
392- RoutingMethodType .DeepSeekV3 .value ,
393- enable_pdl ,
394- None ,
395- num_tokens if tune_max_num_tokens is None else tune_max_num_tokens ,
387+ fn = partial (
388+ trtllm_mxint4_block_scale_moe ,
389+ routing_logits = routing_logits ,
390+ routing_bias = routing_bias ,
391+ hidden_states = hidden_states ,
392+ gemm1_alpha = None , # gemm1_alpha
393+ gemm1_beta = None , # gemm1_beta
394+ gemm1_clamp_limit = None , # gemm1_clamp_limit
395+ num_experts = num_experts ,
396+ top_k = top_k ,
397+ n_group = 1 , # n_group
398+ topk_group = 1 , # topk_group
399+ intermediate_size = intermediate_size ,
400+ local_expert_offset = 0 , # local_expert_offset
401+ routed_scaling_factor = None , # routed_scaling_factor
402+ routing_method_type = RoutingMethodType .DeepSeekV3 .value ,
403+ enable_pdl = enable_pdl ,
404+ output = None ,
405+ tune_max_num_tokens = num_tokens
406+ if tune_max_num_tokens is None
407+ else tune_max_num_tokens ,
396408 )
397409
398410 def bench (do_autotune ):
@@ -402,6 +414,16 @@ def bench(do_autotune):
402414 fn ,
403415 dry_run_iters = warmups ,
404416 repeat_iters = iterations ,
417+ enable_cupti = True ,
418+ use_cuda_graph = True ,
419+ input_kwargs = {
420+ "hidden_states" : hidden_states ,
421+ "gemm1_weights" : w13 ,
422+ "gemm1_weights_scale" : w13_scale ,
423+ "gemm2_weights" : w2 ,
424+ "gemm2_weights_scale" : w2_scale ,
425+ },
426+ cold_l2_cache = True ,
405427 )
406428 median_ms = np .median (ms_list )
407429 return median_ms
0 commit comments