|
| 1 | +import torch |
| 2 | +from flashinfer import ( |
| 3 | + SfLayout, |
| 4 | + autotune, |
| 5 | + mm_fp4, |
| 6 | + nvfp4_quantize, |
| 7 | + mxfp4_quantize, |
| 8 | +) |
| 9 | +from flashinfer.testing.utils import bench_gpu_time |
| 10 | +from flashinfer.utils import get_compute_capability |
| 11 | + |
| 12 | +import logging |
| 13 | +import numpy as np |
| 14 | +from typing import Literal |
| 15 | + |
| 16 | +from functools import partial |
| 17 | + |
| 18 | + |
| 19 | +def _bench_mm_fp4( |
| 20 | + m: int, |
| 21 | + n: int, |
| 22 | + k: int, |
| 23 | + res_dtype: torch.dtype, |
| 24 | + backend: Literal["cudnn", "trtllm", "cutlass", "auto"], |
| 25 | + use_128x4_sf_layout: bool, |
| 26 | + fp4_type: str, |
| 27 | + do_autotune: bool = False, |
| 28 | + warmups: int = 100, |
| 29 | + iterations: int = 100, |
| 30 | +) -> tuple[float, float]: |
| 31 | + use_nvfp4 = fp4_type == "nvfp4" |
| 32 | + |
| 33 | + compute_capability = get_compute_capability(torch.device(device="cuda")) |
| 34 | + compute_capability_number = compute_capability[0] * 10 + compute_capability[1] |
| 35 | + if not mm_fp4.is_backend_supported(backend, compute_capability_number): |
| 36 | + print( |
| 37 | + f"Skipping test for {backend} because it is not supported on compute capability {compute_capability_number}." |
| 38 | + ) |
| 39 | + return |
| 40 | + |
| 41 | + if backend == "trtllm": |
| 42 | + if res_dtype == torch.float16: |
| 43 | + print("Skipping test for trtllm fp4 with float16") |
| 44 | + return |
| 45 | + if compute_capability[0] in [11, 12]: |
| 46 | + print("trtllm gemm does not support SM110/SM120/SM121 GPUs.") |
| 47 | + return |
| 48 | + if not use_128x4_sf_layout and backend != "trtllm": |
| 49 | + print("Skipping test for non-trtllm fp4 with use_128x4_sf_layout=False") |
| 50 | + return |
| 51 | + if not use_nvfp4 and backend not in ["cudnn", "auto"]: |
| 52 | + print("mx_fp4 is only supported for cudnn and auto backends") |
| 53 | + return |
| 54 | + |
| 55 | + input = torch.randn([m, k], device="cuda", dtype=torch.bfloat16) |
| 56 | + mat2 = torch.randn([n, k], device="cuda", dtype=torch.bfloat16) |
| 57 | + a_sf_layout = SfLayout.layout_128x4 if use_128x4_sf_layout else SfLayout.layout_8x4 |
| 58 | + |
| 59 | + global_sf_input = (448 * 6) / input.float().abs().nan_to_num().max() |
| 60 | + global_sf_mat2 = (448 * 6) / mat2.float().abs().nan_to_num().max() |
| 61 | + |
| 62 | + # for trtllm, we need to shuffle mat2 because we swap A, B. |
| 63 | + do_shuffle_b = backend == "trtllm" |
| 64 | + |
| 65 | + block_size = 16 if use_nvfp4 else 32 |
| 66 | + has_alpha = fp4_type == "mxfp4_alpha" or fp4_type == "nvfp4" |
| 67 | + |
| 68 | + if use_nvfp4: |
| 69 | + input_fp4, input_inv_s = nvfp4_quantize( |
| 70 | + input, global_sf_input, sfLayout=a_sf_layout, do_shuffle=False |
| 71 | + ) |
| 72 | + mat2_fp4, mat2_inv_s = nvfp4_quantize( |
| 73 | + mat2, |
| 74 | + global_sf_mat2, |
| 75 | + sfLayout=SfLayout.layout_128x4, |
| 76 | + do_shuffle=do_shuffle_b, |
| 77 | + ) |
| 78 | + else: |
| 79 | + input_fp4, input_inv_s = mxfp4_quantize(input) |
| 80 | + mat2_fp4, mat2_inv_s = mxfp4_quantize(mat2) |
| 81 | + |
| 82 | + alpha = 1.0 / (global_sf_input * global_sf_mat2) if has_alpha else None |
| 83 | + |
| 84 | + res = torch.empty([m, n], device="cuda", dtype=res_dtype) |
| 85 | + |
| 86 | + fn = partial( |
| 87 | + mm_fp4, |
| 88 | + alpha=alpha, |
| 89 | + out_dtype=res_dtype, |
| 90 | + out=res, |
| 91 | + block_size=block_size, |
| 92 | + use_8x4_sf_layout=not use_128x4_sf_layout, |
| 93 | + backend=backend, |
| 94 | + use_nvfp4=use_nvfp4, |
| 95 | + ) |
| 96 | + |
| 97 | + def bench(do_autotune: bool) -> float: |
| 98 | + with autotune(do_autotune): |
| 99 | + fn( |
| 100 | + a=input_fp4, |
| 101 | + b=mat2_fp4.T, |
| 102 | + a_descale=input_inv_s, |
| 103 | + b_descale=mat2_inv_s.T, |
| 104 | + ) |
| 105 | + ms_list = bench_gpu_time( |
| 106 | + fn, |
| 107 | + dry_run_iters=warmups, |
| 108 | + repeat_iters=iterations, |
| 109 | + use_cuda_graph=True, |
| 110 | + input_kwargs={ |
| 111 | + "a": input_fp4, |
| 112 | + "b": mat2_fp4.T, |
| 113 | + "a_descale": input_inv_s, |
| 114 | + "b_descale": mat2_inv_s.T, |
| 115 | + }, |
| 116 | + cold_l2_cache=True, |
| 117 | + ) |
| 118 | + median_ms = np.median(ms_list) |
| 119 | + return median_ms |
| 120 | + |
| 121 | + ms = bench(do_autotune=do_autotune) |
| 122 | + tflops = 2 * m * n * k * 1e-9 / ms |
| 123 | + return ms, tflops |
| 124 | + |
| 125 | + |
| 126 | +logging.basicConfig(level="WARNING") # suppress autotuner's logs |
| 127 | + |
| 128 | +if __name__ == "__main__": |
| 129 | + for m in [1, 2, 4, 8, 16, 32, 64]: |
| 130 | + for n in [2560, 5120, 8192]: |
| 131 | + for k in [16384, 32768]: |
| 132 | + print(f"m={m}, n={n}, k={k}".center(100, "-")) |
| 133 | + for backend in ["cudnn", "trtllm", "cutlass"]: |
| 134 | + print(f" {backend}:") |
| 135 | + ms, tflops = _bench_mm_fp4( |
| 136 | + m, n, k, torch.bfloat16, backend, True, "nvfp4", False |
| 137 | + ) |
| 138 | + print(f" w/o autotune: {ms:.3f} ms, {tflops:.3f} TFLOPs/s") |
| 139 | + ms, tflops = _bench_mm_fp4( |
| 140 | + m, n, k, torch.bfloat16, backend, True, "nvfp4", True |
| 141 | + ) |
| 142 | + print(f" with autotune: {ms:.3f} ms, {tflops:.3f} TFLOPs/s") |
0 commit comments