Skip to content

Commit b3c8b1e

Browse files
committed
better profiling for trtllm-gen moe benchmark
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
1 parent 2b5e1f9 commit b3c8b1e

1 file changed

Lines changed: 97 additions & 75 deletions

File tree

benchmarks/bench_trtllm_gen_fused_moe_autotuner.py

Lines changed: 97 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import Optional, Literal
33
import torch
44
import numpy as np
5+
from functools import partial
56
from flashinfer import (
67
RoutingMethodType,
78
GatedActType,
@@ -143,27 +144,26 @@ def bench_trtllm_gen_fused_moe_autotuner_fp8(
143144
else tune_max_num_tokens,
144145
)
145146
else:
146-
fn = lambda: trtllm_fp8_per_tensor_scale_moe(
147-
routing_logits,
148-
None, # routing_bias
149-
hidden_states,
150-
w13,
151-
output1_scale_scalar,
152-
output1_scales_gate_scalar,
153-
w2,
154-
output2_scale_scalar,
155-
num_experts,
156-
top_k,
157-
None, # n_group
158-
None, # topk_group
159-
intermediate_size,
160-
0, # local_expert_offset
161-
num_experts,
162-
1.0, # routed_scaling_factor
163-
False, # use_routing_scales_on_input
164-
RoutingMethodType.TopK.value,
165-
enable_pdl,
166-
num_tokens if tune_max_num_tokens is None else tune_max_num_tokens,
147+
fn = partial(
148+
trtllm_fp8_per_tensor_scale_moe,
149+
routing_logits=routing_logits,
150+
routing_bias=None, # routing_bias
151+
output1_scale_scalar=output1_scale_scalar,
152+
output1_scales_gate_scalar=output1_scales_gate_scalar,
153+
output2_scale_scalar=output2_scale_scalar,
154+
num_experts=num_experts,
155+
top_k=top_k,
156+
n_group=None, # n_group
157+
topk_group=None, # topk_group
158+
intermediate_size=intermediate_size,
159+
local_expert_offset=0, # local_expert_offset
160+
routed_scaling_factor=1.0, # routed_scaling_factor
161+
use_routing_scales_on_input=False, # use_routing_scales_on_input
162+
routing_method_type=RoutingMethodType.TopK.value,
163+
enable_pdl=enable_pdl,
164+
tune_max_num_tokens=num_tokens
165+
if tune_max_num_tokens is None
166+
else tune_max_num_tokens,
167167
)
168168

169169
def bench(do_autotune):
@@ -173,6 +173,14 @@ def bench(do_autotune):
173173
fn,
174174
dry_run_iters=warmups,
175175
repeat_iters=iterations,
176+
enable_cupti=True,
177+
use_cuda_graph=True,
178+
input_kwargs={
179+
"hidden_states": hidden_states,
180+
"gemm1_weights": w13,
181+
"gemm2_weights": w2,
182+
},
183+
cold_l2_cache=True,
176184
)
177185
median_ms = np.median(ms_list)
178186
return median_ms
@@ -280,37 +288,31 @@ def bench_trtllm_gen_fused_moe_autotuner_fp4(
280288
output2_scale_scalar = torch.tensor(
281289
[hidden_states_global_scale * w2_global_scale] * num_experts, device=device
282290
)
283-
fn = lambda: trtllm_fp4_block_scale_moe(
284-
routing_logits,
285-
None, # routing_bias
286-
hidden_states,
287-
hidden_states_scale,
288-
w13,
289-
w13_scale,
290-
bias13,
291-
None, # gemm1_alpha
292-
None, # gemm1_beta
293-
None, # gemm1_clamp_limit
294-
w2,
295-
w2_scale,
296-
bias2,
297-
output1_scale_scalar,
298-
output1_scale_gate_scalar,
299-
output2_scale_scalar,
300-
num_experts,
301-
top_k,
302-
None, # n_group
303-
None, # topk_group
304-
intermediate_size,
305-
0, # local_expert_offset
306-
num_experts,
307-
None, # routed_scaling_factor
308-
RoutingMethodType.Renormalize.value,
309-
True,
310-
enable_pdl,
311-
GatedActType.SwiGlu.value, # gated_act_type
312-
None,
313-
num_tokens if tune_max_num_tokens is None else tune_max_num_tokens,
291+
fn = partial(
292+
trtllm_fp4_block_scale_moe,
293+
routing_logits=routing_logits,
294+
routing_bias=None, # routing_bias
295+
gemm1_alpha=None, # gemm1_alpha
296+
gemm1_beta=None, # gemm1_beta
297+
gemm1_clamp_limit=None, # gemm1_clamp_limit
298+
output1_scale_scalar=output1_scale_scalar,
299+
output1_scale_gate_scalar=output1_scale_gate_scalar,
300+
output2_scale_scalar=output2_scale_scalar,
301+
num_experts=num_experts,
302+
top_k=top_k,
303+
n_group=None, # n_group
304+
topk_group=None, # topk_group
305+
intermediate_size=intermediate_size,
306+
local_expert_offset=0, # local_expert_offset
307+
routed_scaling_factor=None, # routed_scaling_factor
308+
routing_method_type=RoutingMethodType.Renormalize.value,
309+
do_finalize=True,
310+
enable_pdl=enable_pdl,
311+
gated_act_type=GatedActType.SwiGlu.value, # gated_act_type
312+
output=None,
313+
tune_max_num_tokens=num_tokens
314+
if tune_max_num_tokens is None
315+
else tune_max_num_tokens,
314316
)
315317

316318
def bench(do_autotune):
@@ -320,6 +322,18 @@ def bench(do_autotune):
320322
fn,
321323
dry_run_iters=warmups,
322324
repeat_iters=iterations,
325+
enable_cupti=True,
326+
use_cuda_graph=True,
327+
input_kwargs={
328+
"hidden_states": hidden_states,
329+
"gemm1_weights": w13,
330+
"gemm1_weights_scale": w13_scale,
331+
"gemm2_weights": w2,
332+
"gemm2_weights_scale": w2_scale,
333+
"gemm1_bias": bias13,
334+
"gemm2_bias": bias2,
335+
},
336+
cold_l2_cache=True,
323337
)
324338
median_ms = np.median(ms_list)
325339
return median_ms
@@ -370,29 +384,27 @@ def bench_trtllm_gen_fused_moe_autotuner_mxint4(
370384
intermediate_size // 32,
371385
)
372386

373-
fn = lambda: trtllm_mxint4_block_scale_moe(
374-
routing_logits,
375-
routing_bias,
376-
hidden_states,
377-
w13,
378-
w13_scale,
379-
None, # gemm1_alpha
380-
None, # gemm1_beta
381-
None, # gemm1_clamp_limit
382-
w2,
383-
w2_scale,
384-
num_experts,
385-
top_k,
386-
1, # n_group
387-
1, # topk_group
388-
intermediate_size,
389-
0, # local_expert_offset
390-
num_experts,
391-
None, # routed_scaling_factor
392-
RoutingMethodType.DeepSeekV3.value,
393-
enable_pdl,
394-
None,
395-
num_tokens if tune_max_num_tokens is None else tune_max_num_tokens,
387+
fn = partial(
388+
trtllm_mxint4_block_scale_moe,
389+
routing_logits=routing_logits,
390+
routing_bias=routing_bias,
391+
hidden_states=hidden_states,
392+
gemm1_alpha=None, # gemm1_alpha
393+
gemm1_beta=None, # gemm1_beta
394+
gemm1_clamp_limit=None, # gemm1_clamp_limit
395+
num_experts=num_experts,
396+
top_k=top_k,
397+
n_group=1, # n_group
398+
topk_group=1, # topk_group
399+
intermediate_size=intermediate_size,
400+
local_expert_offset=0, # local_expert_offset
401+
routed_scaling_factor=None, # routed_scaling_factor
402+
routing_method_type=RoutingMethodType.DeepSeekV3.value,
403+
enable_pdl=enable_pdl,
404+
output=None,
405+
tune_max_num_tokens=num_tokens
406+
if tune_max_num_tokens is None
407+
else tune_max_num_tokens,
396408
)
397409

398410
def bench(do_autotune):
@@ -402,6 +414,16 @@ def bench(do_autotune):
402414
fn,
403415
dry_run_iters=warmups,
404416
repeat_iters=iterations,
417+
enable_cupti=True,
418+
use_cuda_graph=True,
419+
input_kwargs={
420+
"hidden_states": hidden_states,
421+
"gemm1_weights": w13,
422+
"gemm1_weights_scale": w13_scale,
423+
"gemm2_weights": w2,
424+
"gemm2_weights_scale": w2_scale,
425+
},
426+
cold_l2_cache=True,
405427
)
406428
median_ms = np.median(ms_list)
407429
return median_ms

0 commit comments

Comments
 (0)