|
| 1 | +import argparse |
| 2 | +import csv |
| 3 | +import os |
| 4 | +import sys |
| 5 | +from pathlib import Path |
| 6 | +from statistics import median |
| 7 | + |
| 8 | +import torch |
| 9 | + |
| 10 | + |
| 11 | +REPO_ROOT = Path(__file__).resolve().parents[1] |
| 12 | +BUILD_LIBS = list((REPO_ROOT / "build").glob("lib.*")) |
| 13 | +if BUILD_LIBS: |
| 14 | + sys.path.insert(0, os.path.realpath(BUILD_LIBS[0])) |
| 15 | +sys.path.insert(0, os.path.realpath(REPO_ROOT)) |
| 16 | + |
| 17 | +import hpc # noqa: E402 |
| 18 | + |
| 19 | + |
| 20 | +DEFAULT_M_LIST = [1, 16, 48, 96, 208, 512, 1024, 2048, 4096] |
| 21 | +PROVIDERS = ["hpc-ops-bf16xfp32", "FP32(cublas)", "TF32(cublas)"] |
| 22 | + |
| 23 | + |
| 24 | +def parse_int_list(text): |
| 25 | + return [int(x.strip()) for x in text.split(",") if x.strip()] |
| 26 | + |
| 27 | + |
| 28 | +def percentile(values, pct): |
| 29 | + values = sorted(values) |
| 30 | + idx = int(round((len(values) - 1) * pct / 100.0)) |
| 31 | + return values[idx] |
| 32 | + |
| 33 | + |
| 34 | +def tflops(m, n, k, us): |
| 35 | + return (2.0 * m * n * k) * 1e-12 / (us * 1e-6) |
| 36 | + |
| 37 | + |
| 38 | +def bench_cuda_events(fn, flush, warmup, iters): |
| 39 | + for _ in range(warmup): |
| 40 | + fn() |
| 41 | + torch.cuda.synchronize() |
| 42 | + |
| 43 | + times = [] |
| 44 | + out = None |
| 45 | + for _ in range(iters): |
| 46 | + flush.zero_() |
| 47 | + start = torch.cuda.Event(enable_timing=True) |
| 48 | + stop = torch.cuda.Event(enable_timing=True) |
| 49 | + start.record() |
| 50 | + out = fn() |
| 51 | + stop.record() |
| 52 | + torch.cuda.synchronize() |
| 53 | + times.append(start.elapsed_time(stop) * 1000.0) |
| 54 | + return median(times), percentile(times, 90), out |
| 55 | + |
| 56 | + |
| 57 | +def error_metrics(out, ref): |
| 58 | + out = out.float() |
| 59 | + ref = ref.float() |
| 60 | + diff = (out - ref).abs() |
| 61 | + rel = diff / ref.abs().clamp_min(1e-6) |
| 62 | + cosine = torch.nn.functional.cosine_similarity(out.flatten(), ref.flatten(), dim=0) |
| 63 | + return { |
| 64 | + "max_abs": diff.max().item(), |
| 65 | + "mean_abs": diff.mean().item(), |
| 66 | + "max_rel": rel.max().item(), |
| 67 | + "mean_rel": rel.mean().item(), |
| 68 | + "cosine": cosine.item(), |
| 69 | + } |
| 70 | + |
| 71 | + |
| 72 | +def make_inputs(m, n, k, scale, device): |
| 73 | + x = torch.randn((m, k), dtype=torch.float32, device=device).to(torch.bfloat16) |
| 74 | + w = torch.randn((n, k), dtype=torch.float32, device=device) |
| 75 | + w_high = w.to(torch.bfloat16) |
| 76 | + w_low = ((w - w_high.float()) / scale).to(torch.bfloat16) |
| 77 | + return x, w, w_high, w_low |
| 78 | + |
| 79 | + |
| 80 | +def build_runner(provider, x, w, w_high, w_low, scale, split_flag): |
| 81 | + if provider == "hpc-ops-bf16xfp32": |
| 82 | + return lambda: hpc.gemm_bf16xfp32(x, w_high, w_low, scale, True, True, split_flag) |
| 83 | + if provider == "FP32(cublas)": |
| 84 | + return lambda: torch.matmul(x.float(), w.t()) |
| 85 | + if provider == "TF32(cublas)": |
| 86 | + return lambda: torch.matmul(x.float(), w.t()) |
| 87 | + raise ValueError(f"unknown provider: {provider}") |
| 88 | + |
| 89 | + |
| 90 | +def benchmark_shape(m, n, k, providers, args, flush): |
| 91 | + scale = 1.0 / 256.0 |
| 92 | + x, w, w_high, w_low = make_inputs(m, n, k, scale, "cuda") |
| 93 | + split_flag = hpc.get_gemm_bf16xfp32_workspace(n) |
| 94 | + |
| 95 | + torch.backends.cuda.matmul.allow_tf32 = False |
| 96 | + ref = torch.matmul(x.float(), w.t()) |
| 97 | + torch.cuda.synchronize() |
| 98 | + |
| 99 | + results = {} |
| 100 | + outputs = {} |
| 101 | + for provider in providers: |
| 102 | + torch.backends.cuda.matmul.allow_tf32 = provider == "TF32(cublas)" |
| 103 | + run = build_runner(provider, x, w, w_high, w_low, scale, split_flag) |
| 104 | + us, p90_us, out = bench_cuda_events(run, flush, args.warmup, args.iters) |
| 105 | + results[provider] = { |
| 106 | + "us": us, |
| 107 | + "p90_us": p90_us, |
| 108 | + "tflops": tflops(m, n, k, us), |
| 109 | + } |
| 110 | + outputs[provider] = out |
| 111 | + |
| 112 | + errors = {} |
| 113 | + for provider, out in outputs.items(): |
| 114 | + errors[provider] = error_metrics(out, ref) |
| 115 | + |
| 116 | + if args.check: |
| 117 | + fp32_err = errors.get("FP32(cublas)") |
| 118 | + if fp32_err is not None and fp32_err["max_abs"] != 0.0: |
| 119 | + raise AssertionError(f"FP32(cublas) should match reference exactly, got {fp32_err['max_abs']}") |
| 120 | + hpc_err = errors.get("hpc-ops-bf16xfp32") |
| 121 | + if hpc_err is not None: |
| 122 | + if hpc_err["max_abs"] > args.max_abs_tol or hpc_err["mean_abs"] > args.mean_abs_tol: |
| 123 | + raise AssertionError( |
| 124 | + "hpc-ops-bf16xfp32 accuracy check failed: " |
| 125 | + f"max_abs={hpc_err['max_abs']:.6f}, mean_abs={hpc_err['mean_abs']:.6f}" |
| 126 | + ) |
| 127 | + |
| 128 | + row = {"m": m, "n": n, "k": k} |
| 129 | + for provider in PROVIDERS: |
| 130 | + if provider not in results: |
| 131 | + continue |
| 132 | + prefix = provider_key(provider) |
| 133 | + row[f"{prefix}_us"] = results[provider]["us"] |
| 134 | + row[f"{prefix}_p90_us"] = results[provider]["p90_us"] |
| 135 | + row[f"{prefix}_tflops"] = results[provider]["tflops"] |
| 136 | + row[f"{prefix}_max_abs"] = errors[provider]["max_abs"] |
| 137 | + row[f"{prefix}_mean_abs"] = errors[provider]["mean_abs"] |
| 138 | + row[f"{prefix}_cosine"] = errors[provider]["cosine"] |
| 139 | + |
| 140 | + if "hpc-ops-bf16xfp32" in results and "FP32(cublas)" in results: |
| 141 | + row["hpc_vs_fp32_speedup"] = results["FP32(cublas)"]["us"] / results["hpc-ops-bf16xfp32"]["us"] |
| 142 | + if "hpc-ops-bf16xfp32" in results and "TF32(cublas)" in results: |
| 143 | + row["hpc_vs_tf32_speedup"] = results["TF32(cublas)"]["us"] / results["hpc-ops-bf16xfp32"]["us"] |
| 144 | + return row |
| 145 | + |
| 146 | + |
| 147 | +def provider_key(provider): |
| 148 | + return { |
| 149 | + "hpc-ops-bf16xfp32": "hpc", |
| 150 | + "FP32(cublas)": "torch_fp32", |
| 151 | + "TF32(cublas)": "torch_tf32", |
| 152 | + }[provider] |
| 153 | + |
| 154 | + |
| 155 | +def print_tflops_table(rows, providers): |
| 156 | + headers = ["M"] + [f"{p} TFLOP/s" for p in providers] |
| 157 | + widths = [max(len(h), 8) for h in headers] |
| 158 | + print("\n" + " ".join(h.rjust(w) for h, w in zip(headers, widths))) |
| 159 | + print(" ".join("-" * w for w in widths)) |
| 160 | + for row in rows: |
| 161 | + values = [str(row["m"])] |
| 162 | + for provider in providers: |
| 163 | + values.append(f"{row[provider_key(provider) + '_tflops']:.2f}") |
| 164 | + print(" ".join(v.rjust(w) for v, w in zip(values, widths))) |
| 165 | + |
| 166 | + |
| 167 | +def print_csv(rows): |
| 168 | + print( |
| 169 | + "\n" |
| 170 | + "m,n,k,hpc_us,hpc_p90_us,torch_fp32_us,torch_fp32_p90_us," |
| 171 | + "torch_tf32_us,torch_tf32_p90_us,hpc_vs_fp32,hpc_vs_tf32," |
| 172 | + "hpc_tflops,torch_fp32_tflops,torch_tf32_tflops," |
| 173 | + "hpc_max_abs,hpc_mean_abs,tf32_max_abs,tf32_mean_abs" |
| 174 | + ) |
| 175 | + for row in rows: |
| 176 | + print( |
| 177 | + f"{row['m']},{row['n']},{row['k']}," |
| 178 | + f"{row.get('hpc_us', float('nan')):.2f},{row.get('hpc_p90_us', float('nan')):.2f}," |
| 179 | + f"{row.get('torch_fp32_us', float('nan')):.2f},{row.get('torch_fp32_p90_us', float('nan')):.2f}," |
| 180 | + f"{row.get('torch_tf32_us', float('nan')):.2f},{row.get('torch_tf32_p90_us', float('nan')):.2f}," |
| 181 | + f"{row.get('hpc_vs_fp32_speedup', float('nan')):.2f}," |
| 182 | + f"{row.get('hpc_vs_tf32_speedup', float('nan')):.2f}," |
| 183 | + f"{row.get('hpc_tflops', float('nan')):.2f}," |
| 184 | + f"{row.get('torch_fp32_tflops', float('nan')):.2f}," |
| 185 | + f"{row.get('torch_tf32_tflops', float('nan')):.2f}," |
| 186 | + f"{row.get('hpc_max_abs', float('nan')):.6f}," |
| 187 | + f"{row.get('hpc_mean_abs', float('nan')):.6f}," |
| 188 | + f"{row.get('torch_tf32_max_abs', float('nan')):.6f}," |
| 189 | + f"{row.get('torch_tf32_mean_abs', float('nan')):.6f}" |
| 190 | + ) |
| 191 | + |
| 192 | + |
| 193 | +def write_csv(path, rows): |
| 194 | + if not rows: |
| 195 | + return |
| 196 | + fieldnames = [ |
| 197 | + "m", |
| 198 | + "n", |
| 199 | + "k", |
| 200 | + "hpc_us", |
| 201 | + "hpc_p90_us", |
| 202 | + "torch_fp32_us", |
| 203 | + "torch_fp32_p90_us", |
| 204 | + "torch_tf32_us", |
| 205 | + "torch_tf32_p90_us", |
| 206 | + "hpc_vs_fp32_speedup", |
| 207 | + "hpc_vs_tf32_speedup", |
| 208 | + "hpc_tflops", |
| 209 | + "torch_fp32_tflops", |
| 210 | + "torch_tf32_tflops", |
| 211 | + "hpc_max_abs", |
| 212 | + "hpc_mean_abs", |
| 213 | + "torch_tf32_max_abs", |
| 214 | + "torch_tf32_mean_abs", |
| 215 | + "hpc_cosine", |
| 216 | + "torch_fp32_cosine", |
| 217 | + "torch_tf32_cosine", |
| 218 | + ] |
| 219 | + with Path(path).open("w", newline="") as f: |
| 220 | + writer = csv.DictWriter(f, fieldnames=fieldnames, extrasaction="ignore") |
| 221 | + writer.writeheader() |
| 222 | + for row in rows: |
| 223 | + writer.writerow(row) |
| 224 | + |
| 225 | + |
| 226 | +def parse_args(): |
| 227 | + parser = argparse.ArgumentParser(description="Benchmark hpc-ops BF16xFP32 GEMM vs cuBLAS.") |
| 228 | + parser.add_argument("--m-list", default=",".join(str(x) for x in DEFAULT_M_LIST)) |
| 229 | + parser.add_argument("--n", type=int, default=192) |
| 230 | + parser.add_argument("--k", type=int, default=4096) |
| 231 | + parser.add_argument("--warmup", type=int, default=8) |
| 232 | + parser.add_argument("--iters", type=int, default=30) |
| 233 | + parser.add_argument("--seed", type=int, default=10086) |
| 234 | + parser.add_argument("--flush-mb", type=int, default=128) |
| 235 | + parser.add_argument("--providers", nargs="+", default=PROVIDERS, choices=PROVIDERS) |
| 236 | + parser.add_argument("--csv", type=str, default="", help="Optional output CSV path.") |
| 237 | + parser.add_argument("--check", action=argparse.BooleanOptionalAction, default=True) |
| 238 | + parser.add_argument("--max-abs-tol", type=float, default=0.01) |
| 239 | + parser.add_argument("--mean-abs-tol", type=float, default=0.001) |
| 240 | + parser.add_argument("--print-csv", action="store_true", help="Print machine-readable CSV rows.") |
| 241 | + return parser.parse_args() |
| 242 | + |
| 243 | + |
| 244 | +def main(): |
| 245 | + args = parse_args() |
| 246 | + if not torch.cuda.is_available(): |
| 247 | + raise RuntimeError("CUDA is required") |
| 248 | + if torch.cuda.get_device_capability()[0] != 9: |
| 249 | + raise RuntimeError("This benchmark is tuned for SM90 GPUs") |
| 250 | + |
| 251 | + torch.manual_seed(args.seed) |
| 252 | + torch.cuda.manual_seed(args.seed) |
| 253 | + |
| 254 | + m_values = parse_int_list(args.m_list) |
| 255 | + flush = torch.empty(args.flush_mb * 1024 * 1024 // 4, dtype=torch.int32, device="cuda") |
| 256 | + old_tf32 = torch.backends.cuda.matmul.allow_tf32 |
| 257 | + rows = [] |
| 258 | + try: |
| 259 | + print(f"Device: {torch.cuda.get_device_name()} N={args.n} K={args.k}") |
| 260 | + print(f"Providers: {', '.join(args.providers)}") |
| 261 | + for m in m_values: |
| 262 | + rows.append(benchmark_shape(m, args.n, args.k, args.providers, args, flush)) |
| 263 | + finally: |
| 264 | + torch.backends.cuda.matmul.allow_tf32 = old_tf32 |
| 265 | + |
| 266 | + print_tflops_table(rows, args.providers) |
| 267 | + if args.print_csv: |
| 268 | + print_csv(rows) |
| 269 | + if args.csv: |
| 270 | + write_csv(args.csv, rows) |
| 271 | + print(f"\nWrote CSV: {args.csv}") |
| 272 | + print("\nBenchmark finished!") |
| 273 | + |
| 274 | + |
| 275 | +if __name__ == "__main__": |
| 276 | + main() |
0 commit comments