Skip to content

Commit 14550ae

Browse files
committed
update trtllm-gen moe benchmark scripts; add cutlass fp4 mm benchmark scripts
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
1 parent 48966b6 commit 14550ae

2 files changed

Lines changed: 367 additions & 98 deletions

File tree

benchmarks/bench_mm_fp4.py

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
import torch
2+
from flashinfer import (
3+
SfLayout,
4+
autotune,
5+
mm_fp4,
6+
nvfp4_quantize,
7+
mxfp4_quantize,
8+
)
9+
from flashinfer.testing.utils import bench_gpu_time
10+
from flashinfer.utils import get_compute_capability
11+
12+
import logging
13+
import numpy as np
14+
from typing import Literal
15+
16+
from functools import partial
17+
18+
19+
def _bench_mm_fp4(
20+
m: int,
21+
n: int,
22+
k: int,
23+
res_dtype: torch.dtype,
24+
backend: Literal["cudnn", "trtllm", "cutlass", "auto"],
25+
use_128x4_sf_layout: bool,
26+
fp4_type: str,
27+
do_autotune: bool = False,
28+
warmups: int = 100,
29+
iterations: int = 100,
30+
) -> tuple[float, float]:
31+
use_nvfp4 = fp4_type == "nvfp4"
32+
33+
compute_capability = get_compute_capability(torch.device(device="cuda"))
34+
compute_capability_number = compute_capability[0] * 10 + compute_capability[1]
35+
if not mm_fp4.is_backend_supported(backend, compute_capability_number):
36+
print(
37+
f"Skipping test for {backend} because it is not supported on compute capability {compute_capability_number}."
38+
)
39+
return
40+
41+
if backend == "trtllm":
42+
if res_dtype == torch.float16:
43+
print("Skipping test for trtllm fp4 with float16")
44+
return
45+
if compute_capability[0] in [11, 12]:
46+
print("trtllm gemm does not support SM110/SM120/SM121 GPUs.")
47+
return
48+
if not use_128x4_sf_layout and backend != "trtllm":
49+
print("Skipping test for non-trtllm fp4 with use_128x4_sf_layout=False")
50+
return
51+
if not use_nvfp4 and backend not in ["cudnn", "auto"]:
52+
print("mx_fp4 is only supported for cudnn and auto backends")
53+
return
54+
55+
input = torch.randn([m, k], device="cuda", dtype=torch.bfloat16)
56+
mat2 = torch.randn([n, k], device="cuda", dtype=torch.bfloat16)
57+
a_sf_layout = SfLayout.layout_128x4 if use_128x4_sf_layout else SfLayout.layout_8x4
58+
59+
global_sf_input = (448 * 6) / input.float().abs().nan_to_num().max()
60+
global_sf_mat2 = (448 * 6) / mat2.float().abs().nan_to_num().max()
61+
62+
# for trtllm, we need to shuffle mat2 because we swap A, B.
63+
do_shuffle_b = backend == "trtllm"
64+
65+
block_size = 16 if use_nvfp4 else 32
66+
has_alpha = fp4_type == "mxfp4_alpha" or fp4_type == "nvfp4"
67+
68+
if use_nvfp4:
69+
input_fp4, input_inv_s = nvfp4_quantize(
70+
input, global_sf_input, sfLayout=a_sf_layout, do_shuffle=False
71+
)
72+
mat2_fp4, mat2_inv_s = nvfp4_quantize(
73+
mat2,
74+
global_sf_mat2,
75+
sfLayout=SfLayout.layout_128x4,
76+
do_shuffle=do_shuffle_b,
77+
)
78+
else:
79+
input_fp4, input_inv_s = mxfp4_quantize(input)
80+
mat2_fp4, mat2_inv_s = mxfp4_quantize(mat2)
81+
82+
alpha = 1.0 / (global_sf_input * global_sf_mat2) if has_alpha else None
83+
84+
res = torch.empty([m, n], device="cuda", dtype=res_dtype)
85+
86+
fn = partial(
87+
mm_fp4,
88+
alpha=alpha,
89+
out_dtype=res_dtype,
90+
out=res,
91+
block_size=block_size,
92+
use_8x4_sf_layout=not use_128x4_sf_layout,
93+
backend=backend,
94+
use_nvfp4=use_nvfp4,
95+
)
96+
97+
def bench(do_autotune: bool) -> float:
98+
with autotune(do_autotune):
99+
fn(
100+
a=input_fp4,
101+
b=mat2_fp4.T,
102+
a_descale=input_inv_s,
103+
b_descale=mat2_inv_s.T,
104+
)
105+
ms_list = bench_gpu_time(
106+
fn,
107+
dry_run_iters=warmups,
108+
repeat_iters=iterations,
109+
use_cuda_graph=True,
110+
input_kwargs={
111+
"a": input_fp4,
112+
"b": mat2_fp4.T,
113+
"a_descale": input_inv_s,
114+
"b_descale": mat2_inv_s.T,
115+
},
116+
cold_l2_cache=True,
117+
)
118+
median_ms = np.median(ms_list)
119+
return median_ms
120+
121+
ms = bench(do_autotune=do_autotune)
122+
tflops = 2 * m * n * k * 1e-9 / ms
123+
return ms, tflops
124+
125+
126+
logging.basicConfig(level="WARNING") # suppress autotuner's logs
127+
128+
if __name__ == "__main__":
129+
for m in [1, 2, 4, 8, 16, 32, 64]:
130+
for n in [2560, 5120, 8192]:
131+
for k in [16384, 32768]:
132+
print(f"m={m}, n={n}, k={k}".center(100, "-"))
133+
for backend in ["cudnn", "trtllm", "cutlass"]:
134+
print(f" {backend}:")
135+
ms, tflops = _bench_mm_fp4(
136+
m, n, k, torch.bfloat16, backend, True, "nvfp4", False
137+
)
138+
print(f" w/o autotune: {ms:.3f} ms, {tflops:.3f} TFLOPs/s")
139+
ms, tflops = _bench_mm_fp4(
140+
m, n, k, torch.bfloat16, backend, True, "nvfp4", True
141+
)
142+
print(f" with autotune: {ms:.3f} ms, {tflops:.3f} TFLOPs/s")

0 commit comments

Comments
 (0)