|
| 1 | +from typing import Callable, Tuple |
| 2 | + |
| 3 | +import fire |
| 4 | +import torch |
| 5 | +import triton |
| 6 | +from torch._inductor.utils import do_bench_using_profiling |
| 7 | + |
| 8 | +from torchao.prototype.mx_formats.custom_cast import ( |
| 9 | + triton_to_mxfp8_dim1, |
| 10 | +) |
| 11 | +from torchao.prototype.mx_formats.mx_tensor import to_mx |
| 12 | + |
| 13 | +torch.manual_seed(0) |
| 14 | + |
| 15 | +bytes_per_el_bf16 = 2 |
| 16 | +bytes_per_el_fp8 = 1 |
| 17 | + |
| 18 | + |
| 19 | +def scale_dim0_reference(x_hp, block_size) -> Tuple[torch.Tensor, torch.Tensor]: |
| 20 | + assert x_hp.is_contiguous() |
| 21 | + x_hp_d0_block = x_hp.reshape(-1, block_size) |
| 22 | + x_hp_d0_block_abs = x_hp_d0_block.abs() |
| 23 | + amax_dim0 = torch.amax(x_hp_d0_block_abs, dim=1).unsqueeze(1) |
| 24 | + x_hp_d0_block_normalized = x_hp_d0_block / amax_dim0 |
| 25 | + x_hp_d0_normalized = x_hp_d0_block_normalized.reshape(x_hp.shape) |
| 26 | + return x_hp_d0_normalized, amax_dim0 |
| 27 | + |
| 28 | + |
| 29 | +def scale_dim1_reference(x_hp, block_size) -> Tuple[torch.Tensor, torch.Tensor]: |
| 30 | + assert x_hp.is_contiguous() |
| 31 | + x_hp_d1 = x_hp.t().contiguous() |
| 32 | + x_hp_d1_block = x_hp_d1.reshape(-1, block_size) |
| 33 | + x_hp_d1_block_abs = x_hp_d1_block.abs() |
| 34 | + amax_dim1 = torch.amax(x_hp_d1_block_abs, dim=1).unsqueeze(1) |
| 35 | + x_hp_d1_block_normalized = x_hp_d1_block / amax_dim1 |
| 36 | + x_hp_d1_normalized = x_hp_d1_block_normalized.reshape(x_hp_d1.shape) |
| 37 | + return x_hp_d1_normalized, amax_dim1 |
| 38 | + |
| 39 | + |
| 40 | +def scale_dim0_dim1_reference( |
| 41 | + x_hp: torch.Tensor, block_size |
| 42 | +) -> Tuple[torch.Tensor, torch.Tensor]: |
| 43 | + # normalize across dim0 |
| 44 | + x_hp_d0_normalized, amax_dim0 = scale_dim0_reference(x_hp, block_size) |
| 45 | + # normalize across dim1 |
| 46 | + x_hp_d1_normalized, amax_dim1 = scale_dim1_reference(x_hp, block_size) |
| 47 | + return x_hp_d0_normalized, x_hp_d1_normalized.t(), amax_dim0, amax_dim1 |
| 48 | + |
| 49 | + |
| 50 | +def to_mx_dim0_reference(x_hp, block_size): |
| 51 | + scale_d0, data_d0 = to_mx(x_hp, torch.float8_e4m3fn, block_size) |
| 52 | + return data_d0, scale_d0 |
| 53 | + |
| 54 | + |
| 55 | +def to_mx_dim1_reference(x_hp, block_size): |
| 56 | + x_hp = x_hp.t().contiguous() |
| 57 | + scale_d1, data_d1 = to_mx(x_hp, torch.float8_e4m3fn, block_size) |
| 58 | + return data_d1.t(), scale_d1 |
| 59 | + |
| 60 | + |
| 61 | +def benchmark_cuda_function_in_microseconds(func: Callable, *args, **kwargs) -> float: |
| 62 | + """Thin wrapper around do_bench_using_profiling""" |
| 63 | + no_args = lambda: func(*args, **kwargs) |
| 64 | + time = do_bench_using_profiling(no_args) |
| 65 | + return time * 1e3 |
| 66 | + |
| 67 | + |
| 68 | +def run( |
| 69 | + M: int = 16384, |
| 70 | + K: int = 16384, |
| 71 | + BLOCK_SIZE: int = 32, |
| 72 | + mode: str = "dim0", |
| 73 | +): |
| 74 | + print(f"M {M} K {K} BLOCK_SIZE {BLOCK_SIZE}") |
| 75 | + print(f"GPU: {torch.cuda.get_device_name(0)}") |
| 76 | + print(f"torch version: {torch.__version__}") |
| 77 | + print(f"triton version: {triton.__version__}") |
| 78 | + print(f"mode: {mode}") |
| 79 | + assert mode in ("dim0", "dim1", "dim0_dim1", "dim0_mx", "dim1_mx", "dim1_mx_triton") |
| 80 | + |
| 81 | + x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") * 1000 |
| 82 | + |
| 83 | + if mode == "dim0": |
| 84 | + scale_dim0_reference_c = torch.compile(scale_dim0_reference) |
| 85 | + y_d0, s_d0 = scale_dim0_reference_c(x, BLOCK_SIZE) |
| 86 | + |
| 87 | + for _ in range(2): |
| 88 | + __ = scale_dim0_reference_c(x, BLOCK_SIZE) |
| 89 | + time_us = benchmark_cuda_function_in_microseconds( |
| 90 | + lambda x, b: scale_dim0_reference_c(x, BLOCK_SIZE), |
| 91 | + x, |
| 92 | + BLOCK_SIZE, |
| 93 | + ) |
| 94 | + |
| 95 | + assert y_d0.dtype == torch.bfloat16 |
| 96 | + assert s_d0.dtype == torch.bfloat16 |
| 97 | + bytes_rw = sum(t.numel() for t in [x, y_d0, s_d0]) * bytes_per_el_bf16 |
| 98 | + bps = bytes_rw / (time_us / 1e6) |
| 99 | + |
| 100 | + elif mode == "dim1": |
| 101 | + scale_dim1_reference_c = torch.compile(scale_dim1_reference) |
| 102 | + y_d1, s_d1 = scale_dim1_reference_c(x, BLOCK_SIZE) |
| 103 | + |
| 104 | + for _ in range(2): |
| 105 | + __ = scale_dim1_reference_c(x, BLOCK_SIZE) |
| 106 | + time_us = benchmark_cuda_function_in_microseconds( |
| 107 | + lambda x, b: scale_dim1_reference_c(x, BLOCK_SIZE), |
| 108 | + x, |
| 109 | + BLOCK_SIZE, |
| 110 | + ) |
| 111 | + |
| 112 | + assert y_d1.dtype == torch.bfloat16 |
| 113 | + assert s_d1.dtype == torch.bfloat16 |
| 114 | + bytes_rw = sum(t.numel() for t in [x, y_d1, s_d1]) * bytes_per_el_bf16 |
| 115 | + bps = bytes_rw / (time_us / 1e6) |
| 116 | + |
| 117 | + elif mode == "dim0_dim1": |
| 118 | + scale_dim0_dim1_reference_c = torch.compile(scale_dim0_dim1_reference) |
| 119 | + y_d0, y_d1, s_d0, s_d1 = scale_dim0_dim1_reference_c(x, BLOCK_SIZE) |
| 120 | + |
| 121 | + for _ in range(2): |
| 122 | + __ = scale_dim0_dim1_reference_c(x, BLOCK_SIZE) |
| 123 | + time_us = benchmark_cuda_function_in_microseconds( |
| 124 | + lambda x, b: scale_dim0_dim1_reference_c(x, BLOCK_SIZE), |
| 125 | + x, |
| 126 | + BLOCK_SIZE, |
| 127 | + ) |
| 128 | + |
| 129 | + assert y_d0.dtype == torch.bfloat16 |
| 130 | + assert s_d0.dtype == torch.bfloat16 |
| 131 | + assert y_d1.dtype == torch.bfloat16 |
| 132 | + assert s_d1.dtype == torch.bfloat16 |
| 133 | + bytes_rw = ( |
| 134 | + sum(t.numel() for t in [x, y_d0, y_d1, s_d0, s_d1]) * bytes_per_el_bf16 |
| 135 | + ) |
| 136 | + bps = bytes_rw / (time_us / 1e6) |
| 137 | + |
| 138 | + elif mode == "dim0_mx": |
| 139 | + to_mx_dim0_reference_c = torch.compile(to_mx_dim0_reference) |
| 140 | + y_d0, s_d0 = to_mx_dim0_reference_c(x, BLOCK_SIZE) |
| 141 | + |
| 142 | + for _ in range(2): |
| 143 | + __ = to_mx_dim0_reference_c(x, BLOCK_SIZE) |
| 144 | + time_us = benchmark_cuda_function_in_microseconds( |
| 145 | + lambda x, b: to_mx_dim0_reference_c(x, BLOCK_SIZE), |
| 146 | + x, |
| 147 | + BLOCK_SIZE, |
| 148 | + ) |
| 149 | + |
| 150 | + assert y_d0.dtype == torch.float8_e4m3fn |
| 151 | + assert s_d0.dtype == torch.uint8 |
| 152 | + bytes_r = x.numel() * bytes_per_el_bf16 |
| 153 | + bytes_w = (y_d0.numel() + s_d0.numel()) * bytes_per_el_fp8 |
| 154 | + bps = (bytes_r + bytes_w) / (time_us / 1e6) |
| 155 | + |
| 156 | + elif mode == "dim1_mx": |
| 157 | + to_mx_dim1_reference_c = torch.compile(to_mx_dim1_reference) |
| 158 | + y_d1, s_d1 = to_mx_dim1_reference_c(x, BLOCK_SIZE) |
| 159 | + |
| 160 | + for _ in range(2): |
| 161 | + __ = to_mx_dim1_reference_c(x, BLOCK_SIZE) |
| 162 | + time_us = benchmark_cuda_function_in_microseconds( |
| 163 | + lambda x, b: to_mx_dim1_reference_c(x, BLOCK_SIZE), |
| 164 | + x, |
| 165 | + BLOCK_SIZE, |
| 166 | + ) |
| 167 | + |
| 168 | + assert y_d1.dtype == torch.float8_e4m3fn |
| 169 | + assert s_d1.dtype == torch.uint8 |
| 170 | + bytes_r = x.numel() * bytes_per_el_bf16 |
| 171 | + bytes_w = (y_d1.numel() + s_d1.numel()) * bytes_per_el_fp8 |
| 172 | + bps = (bytes_r + bytes_w) / (time_us / 1e6) |
| 173 | + |
| 174 | + elif mode == "dim1_mx_triton": |
| 175 | + y_d1, s_d1 = triton_to_mxfp8_dim1(x, inner_block_size=BLOCK_SIZE) |
| 176 | + |
| 177 | + for _ in range(2): |
| 178 | + __ = triton_to_mxfp8_dim1(x, inner_block_size=BLOCK_SIZE) |
| 179 | + time_us = benchmark_cuda_function_in_microseconds( |
| 180 | + lambda x, b: triton_to_mxfp8_dim1(x, inner_block_size=BLOCK_SIZE), |
| 181 | + x, |
| 182 | + BLOCK_SIZE, |
| 183 | + ) |
| 184 | + |
| 185 | + assert y_d1.dtype == torch.float8_e4m3fn |
| 186 | + assert s_d1.dtype == torch.float8_e8m0fnu |
| 187 | + bytes_r = x.numel() * bytes_per_el_bf16 |
| 188 | + bytes_w = (y_d1.numel() + s_d1.numel()) * bytes_per_el_fp8 |
| 189 | + bps = (bytes_r + bytes_w) / (time_us / 1e6) |
| 190 | + |
| 191 | + else: |
| 192 | + raise AssertionError(f"unknown mode {mode}") |
| 193 | + |
| 194 | + print("time_us", time_us) |
| 195 | + print("mem_bw_gbps", bps / 1e9) |
| 196 | + |
| 197 | + |
| 198 | +if __name__ == "__main__": |
| 199 | + fire.Fire(run) |
0 commit comments