Skip to content

Commit 3555a31

Browse files
author
aidenren
committed
refactor: refactor launch cofig and benchmark
1 parent 49fcbd6 commit 3555a31

5 files changed

Lines changed: 322 additions & 217 deletions

File tree

Lines changed: 276 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,276 @@
1+
import argparse
2+
import csv
3+
import os
4+
import sys
5+
from pathlib import Path
6+
from statistics import median
7+
8+
import torch
9+
10+
11+
REPO_ROOT = Path(__file__).resolve().parents[1]
12+
BUILD_LIBS = list((REPO_ROOT / "build").glob("lib.*"))
13+
if BUILD_LIBS:
14+
sys.path.insert(0, os.path.realpath(BUILD_LIBS[0]))
15+
sys.path.insert(0, os.path.realpath(REPO_ROOT))
16+
17+
import hpc # noqa: E402
18+
19+
20+
DEFAULT_M_LIST = [1, 16, 48, 96, 208, 512, 1024, 2048, 4096]
21+
PROVIDERS = ["hpc-ops-bf16xfp32", "FP32(cublas)", "TF32(cublas)"]
22+
23+
24+
def parse_int_list(text):
25+
return [int(x.strip()) for x in text.split(",") if x.strip()]
26+
27+
28+
def percentile(values, pct):
29+
values = sorted(values)
30+
idx = int(round((len(values) - 1) * pct / 100.0))
31+
return values[idx]
32+
33+
34+
def tflops(m, n, k, us):
35+
return (2.0 * m * n * k) * 1e-12 / (us * 1e-6)
36+
37+
38+
def bench_cuda_events(fn, flush, warmup, iters):
39+
for _ in range(warmup):
40+
fn()
41+
torch.cuda.synchronize()
42+
43+
times = []
44+
out = None
45+
for _ in range(iters):
46+
flush.zero_()
47+
start = torch.cuda.Event(enable_timing=True)
48+
stop = torch.cuda.Event(enable_timing=True)
49+
start.record()
50+
out = fn()
51+
stop.record()
52+
torch.cuda.synchronize()
53+
times.append(start.elapsed_time(stop) * 1000.0)
54+
return median(times), percentile(times, 90), out
55+
56+
57+
def error_metrics(out, ref):
58+
out = out.float()
59+
ref = ref.float()
60+
diff = (out - ref).abs()
61+
rel = diff / ref.abs().clamp_min(1e-6)
62+
cosine = torch.nn.functional.cosine_similarity(out.flatten(), ref.flatten(), dim=0)
63+
return {
64+
"max_abs": diff.max().item(),
65+
"mean_abs": diff.mean().item(),
66+
"max_rel": rel.max().item(),
67+
"mean_rel": rel.mean().item(),
68+
"cosine": cosine.item(),
69+
}
70+
71+
72+
def make_inputs(m, n, k, scale, device):
73+
x = torch.randn((m, k), dtype=torch.float32, device=device).to(torch.bfloat16)
74+
w = torch.randn((n, k), dtype=torch.float32, device=device)
75+
w_high = w.to(torch.bfloat16)
76+
w_low = ((w - w_high.float()) / scale).to(torch.bfloat16)
77+
return x, w, w_high, w_low
78+
79+
80+
def build_runner(provider, x, w, w_high, w_low, scale, split_flag):
81+
if provider == "hpc-ops-bf16xfp32":
82+
return lambda: hpc.gemm_bf16xfp32(x, w_high, w_low, scale, True, True, split_flag)
83+
if provider == "FP32(cublas)":
84+
return lambda: torch.matmul(x.float(), w.t())
85+
if provider == "TF32(cublas)":
86+
return lambda: torch.matmul(x.float(), w.t())
87+
raise ValueError(f"unknown provider: {provider}")
88+
89+
90+
def benchmark_shape(m, n, k, providers, args, flush):
91+
scale = 1.0 / 256.0
92+
x, w, w_high, w_low = make_inputs(m, n, k, scale, "cuda")
93+
split_flag = hpc.get_gemm_bf16xfp32_workspace(n)
94+
95+
torch.backends.cuda.matmul.allow_tf32 = False
96+
ref = torch.matmul(x.float(), w.t())
97+
torch.cuda.synchronize()
98+
99+
results = {}
100+
outputs = {}
101+
for provider in providers:
102+
torch.backends.cuda.matmul.allow_tf32 = provider == "TF32(cublas)"
103+
run = build_runner(provider, x, w, w_high, w_low, scale, split_flag)
104+
us, p90_us, out = bench_cuda_events(run, flush, args.warmup, args.iters)
105+
results[provider] = {
106+
"us": us,
107+
"p90_us": p90_us,
108+
"tflops": tflops(m, n, k, us),
109+
}
110+
outputs[provider] = out
111+
112+
errors = {}
113+
for provider, out in outputs.items():
114+
errors[provider] = error_metrics(out, ref)
115+
116+
if args.check:
117+
fp32_err = errors.get("FP32(cublas)")
118+
if fp32_err is not None and fp32_err["max_abs"] != 0.0:
119+
raise AssertionError(f"FP32(cublas) should match reference exactly, got {fp32_err['max_abs']}")
120+
hpc_err = errors.get("hpc-ops-bf16xfp32")
121+
if hpc_err is not None:
122+
if hpc_err["max_abs"] > args.max_abs_tol or hpc_err["mean_abs"] > args.mean_abs_tol:
123+
raise AssertionError(
124+
"hpc-ops-bf16xfp32 accuracy check failed: "
125+
f"max_abs={hpc_err['max_abs']:.6f}, mean_abs={hpc_err['mean_abs']:.6f}"
126+
)
127+
128+
row = {"m": m, "n": n, "k": k}
129+
for provider in PROVIDERS:
130+
if provider not in results:
131+
continue
132+
prefix = provider_key(provider)
133+
row[f"{prefix}_us"] = results[provider]["us"]
134+
row[f"{prefix}_p90_us"] = results[provider]["p90_us"]
135+
row[f"{prefix}_tflops"] = results[provider]["tflops"]
136+
row[f"{prefix}_max_abs"] = errors[provider]["max_abs"]
137+
row[f"{prefix}_mean_abs"] = errors[provider]["mean_abs"]
138+
row[f"{prefix}_cosine"] = errors[provider]["cosine"]
139+
140+
if "hpc-ops-bf16xfp32" in results and "FP32(cublas)" in results:
141+
row["hpc_vs_fp32_speedup"] = results["FP32(cublas)"]["us"] / results["hpc-ops-bf16xfp32"]["us"]
142+
if "hpc-ops-bf16xfp32" in results and "TF32(cublas)" in results:
143+
row["hpc_vs_tf32_speedup"] = results["TF32(cublas)"]["us"] / results["hpc-ops-bf16xfp32"]["us"]
144+
return row
145+
146+
147+
def provider_key(provider):
148+
return {
149+
"hpc-ops-bf16xfp32": "hpc",
150+
"FP32(cublas)": "torch_fp32",
151+
"TF32(cublas)": "torch_tf32",
152+
}[provider]
153+
154+
155+
def print_tflops_table(rows, providers):
156+
headers = ["M"] + [f"{p} TFLOP/s" for p in providers]
157+
widths = [max(len(h), 8) for h in headers]
158+
print("\n" + " ".join(h.rjust(w) for h, w in zip(headers, widths)))
159+
print(" ".join("-" * w for w in widths))
160+
for row in rows:
161+
values = [str(row["m"])]
162+
for provider in providers:
163+
values.append(f"{row[provider_key(provider) + '_tflops']:.2f}")
164+
print(" ".join(v.rjust(w) for v, w in zip(values, widths)))
165+
166+
167+
def print_csv(rows):
168+
print(
169+
"\n"
170+
"m,n,k,hpc_us,hpc_p90_us,torch_fp32_us,torch_fp32_p90_us,"
171+
"torch_tf32_us,torch_tf32_p90_us,hpc_vs_fp32,hpc_vs_tf32,"
172+
"hpc_tflops,torch_fp32_tflops,torch_tf32_tflops,"
173+
"hpc_max_abs,hpc_mean_abs,tf32_max_abs,tf32_mean_abs"
174+
)
175+
for row in rows:
176+
print(
177+
f"{row['m']},{row['n']},{row['k']},"
178+
f"{row.get('hpc_us', float('nan')):.2f},{row.get('hpc_p90_us', float('nan')):.2f},"
179+
f"{row.get('torch_fp32_us', float('nan')):.2f},{row.get('torch_fp32_p90_us', float('nan')):.2f},"
180+
f"{row.get('torch_tf32_us', float('nan')):.2f},{row.get('torch_tf32_p90_us', float('nan')):.2f},"
181+
f"{row.get('hpc_vs_fp32_speedup', float('nan')):.2f},"
182+
f"{row.get('hpc_vs_tf32_speedup', float('nan')):.2f},"
183+
f"{row.get('hpc_tflops', float('nan')):.2f},"
184+
f"{row.get('torch_fp32_tflops', float('nan')):.2f},"
185+
f"{row.get('torch_tf32_tflops', float('nan')):.2f},"
186+
f"{row.get('hpc_max_abs', float('nan')):.6f},"
187+
f"{row.get('hpc_mean_abs', float('nan')):.6f},"
188+
f"{row.get('torch_tf32_max_abs', float('nan')):.6f},"
189+
f"{row.get('torch_tf32_mean_abs', float('nan')):.6f}"
190+
)
191+
192+
193+
def write_csv(path, rows):
194+
if not rows:
195+
return
196+
fieldnames = [
197+
"m",
198+
"n",
199+
"k",
200+
"hpc_us",
201+
"hpc_p90_us",
202+
"torch_fp32_us",
203+
"torch_fp32_p90_us",
204+
"torch_tf32_us",
205+
"torch_tf32_p90_us",
206+
"hpc_vs_fp32_speedup",
207+
"hpc_vs_tf32_speedup",
208+
"hpc_tflops",
209+
"torch_fp32_tflops",
210+
"torch_tf32_tflops",
211+
"hpc_max_abs",
212+
"hpc_mean_abs",
213+
"torch_tf32_max_abs",
214+
"torch_tf32_mean_abs",
215+
"hpc_cosine",
216+
"torch_fp32_cosine",
217+
"torch_tf32_cosine",
218+
]
219+
with Path(path).open("w", newline="") as f:
220+
writer = csv.DictWriter(f, fieldnames=fieldnames, extrasaction="ignore")
221+
writer.writeheader()
222+
for row in rows:
223+
writer.writerow(row)
224+
225+
226+
def parse_args():
227+
parser = argparse.ArgumentParser(description="Benchmark hpc-ops BF16xFP32 GEMM vs cuBLAS.")
228+
parser.add_argument("--m-list", default=",".join(str(x) for x in DEFAULT_M_LIST))
229+
parser.add_argument("--n", type=int, default=192)
230+
parser.add_argument("--k", type=int, default=4096)
231+
parser.add_argument("--warmup", type=int, default=8)
232+
parser.add_argument("--iters", type=int, default=30)
233+
parser.add_argument("--seed", type=int, default=10086)
234+
parser.add_argument("--flush-mb", type=int, default=128)
235+
parser.add_argument("--providers", nargs="+", default=PROVIDERS, choices=PROVIDERS)
236+
parser.add_argument("--csv", type=str, default="", help="Optional output CSV path.")
237+
parser.add_argument("--check", action=argparse.BooleanOptionalAction, default=True)
238+
parser.add_argument("--max-abs-tol", type=float, default=0.01)
239+
parser.add_argument("--mean-abs-tol", type=float, default=0.001)
240+
parser.add_argument("--print-csv", action="store_true", help="Print machine-readable CSV rows.")
241+
return parser.parse_args()
242+
243+
244+
def main():
245+
args = parse_args()
246+
if not torch.cuda.is_available():
247+
raise RuntimeError("CUDA is required")
248+
if torch.cuda.get_device_capability()[0] != 9:
249+
raise RuntimeError("This benchmark is tuned for SM90 GPUs")
250+
251+
torch.manual_seed(args.seed)
252+
torch.cuda.manual_seed(args.seed)
253+
254+
m_values = parse_int_list(args.m_list)
255+
flush = torch.empty(args.flush_mb * 1024 * 1024 // 4, dtype=torch.int32, device="cuda")
256+
old_tf32 = torch.backends.cuda.matmul.allow_tf32
257+
rows = []
258+
try:
259+
print(f"Device: {torch.cuda.get_device_name()} N={args.n} K={args.k}")
260+
print(f"Providers: {', '.join(args.providers)}")
261+
for m in m_values:
262+
rows.append(benchmark_shape(m, args.n, args.k, args.providers, args, flush))
263+
finally:
264+
torch.backends.cuda.matmul.allow_tf32 = old_tf32
265+
266+
print_tflops_table(rows, args.providers)
267+
if args.print_csv:
268+
print_csv(rows)
269+
if args.csv:
270+
write_csv(args.csv, rows)
271+
print(f"\nWrote CSV: {args.csv}")
272+
print("\nBenchmark finished!")
273+
274+
275+
if __name__ == "__main__":
276+
main()

src/gemm/gemm.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@ namespace gemm {
1111

1212
bool gemm_bf16xfp32_async(void *y_ptr, void *splitk_y_ptr, void *split_flag_ptr, const void *x_ptr,
1313
const void *w_high_ptr, const void *w_low_ptr, int m, int n, int k,
14-
float scale, bool use_fp32_output, int splitk, cudaStream_t stream);
14+
float scale, bool use_fp32_output, int splitk, int kTileM, int wgn,
15+
cudaStream_t stream);
1516
} // namespace gemm
1617
} // namespace hpc
1718

src/gemm/sm90/entry.cc

Lines changed: 29 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,9 @@ static inline int ceil_div_int(int a, int b) { return (a + b - 1) / b; }
2121
static inline int normalized_m(int m, int n, int k) {
2222
constexpr int kRefN = 192;
2323
constexpr int kRefK = 4096;
24-
return static_cast<int>((static_cast<long long>(m) * n * kRefK + static_cast<long long>(kRefN) * k - 1) /
25-
(static_cast<long long>(kRefN) * k));
24+
return static_cast<int>(
25+
(static_cast<long long>(m) * n * kRefK + static_cast<long long>(kRefN) * k - 1) /
26+
(static_cast<long long>(kRefN) * k));
2627
}
2728

2829
static inline int select_split_k_by_work(int norm_m) {
@@ -48,51 +49,33 @@ static inline int select_tile16_wgn(int m, int n, int split_k) {
4849
}
4950

5051
static inline KernelConfig select_config(int m, int n, int k, bool use_splitk) {
51-
constexpr int kN192 = 192;
52-
constexpr int kN512 = 512;
53-
constexpr int kN1024 = 1024;
54-
constexpr int kN2048 = 2048;
55-
constexpr int kMThreshold128 = 128;
56-
constexpr int kDefaultKtm64 = 64;
57-
constexpr int kDefaultKtm16 = 16;
58-
constexpr int kDefaultSk8 = 8;
59-
constexpr int kDefaultSk4 = 4;
60-
constexpr int kDefaultSk2 = 2;
61-
constexpr int kDefaultWgn = 2;
62-
constexpr int kDefaultKtmForLargeM = 64;
63-
64-
if (n == kN192) {
65-
const int norm_m = normalized_m(m, n, k);
66-
if (norm_m > 624 && norm_m <= 832) {
67-
return {kDefaultSk2, 1, kDefaultKtm64};
68-
}
69-
if (norm_m > 1024 && norm_m <= 2048) {
70-
return {kDefaultSk4, 1, kDefaultKtm64};
71-
}
72-
if (norm_m > 2048) {
73-
return {1, 1, kDefaultKtm64};
74-
}
52+
const int norm_m = normalized_m(m, n, k);
7553

76-
const int split_k = select_split_k_by_work(norm_m);
77-
return {split_k, select_tile16_wgn(m, n, split_k), kDefaultKtm16};
54+
if (norm_m > 624 && norm_m <= 832) {
55+
return {2, 1, 64};
7856
}
79-
80-
// Fallback for n != 192: preserve the original heuristic.
81-
int sk = 1;
82-
if (use_splitk && m <= kMThreshold128) {
83-
if (n == kN512) {
84-
sk = kDefaultSk8;
85-
} else if (n == kN1024) {
86-
sk = kDefaultSk4;
87-
} else if (n == kN2048) {
88-
sk = kDefaultSk2;
89-
}
57+
if (norm_m > 832 && norm_m <= 896) {
58+
return {2, 2, 16};
59+
}
60+
if (norm_m > 1024 && norm_m <= 1088) {
61+
return {1, 2, 16};
62+
}
63+
if (norm_m > 1088 && norm_m <= 1152) {
64+
return {4, 1, 64};
65+
}
66+
if (norm_m > 1152 && norm_m <= 1536) {
67+
return {1, 1, 64};
68+
}
69+
if (norm_m > 1536 && norm_m <= 2048) {
70+
return {4, 1, 64};
71+
}
72+
if (norm_m > 2048) {
73+
return {1, 1, 64};
9074
}
9175

92-
int wgn = kDefaultWgn;
93-
int ktm = (m > kMThreshold128) ? kDefaultKtmForLargeM : kDefaultKtm16;
94-
95-
return {sk, wgn, ktm};
76+
// kTileM=16 path: select split_k by workload, then wgn by occupancy.
77+
const int split_k = select_split_k_by_work(norm_m);
78+
return {split_k, select_tile16_wgn(m, n, split_k), 16};
9679
}
9780

9881
torch::Tensor gemm_bf16xfp32_entry(const torch::Tensor &x, const torch::Tensor &w_high,
@@ -148,9 +131,9 @@ torch::Tensor gemm_bf16xfp32_entry(const torch::Tensor &x, const torch::Tensor &
148131
const auto *w_low_ptr = w_low.const_data_ptr();
149132
auto *y_ptr = y.mutable_data_ptr();
150133

151-
bool running =
152-
gemm_bf16xfp32_async(y_ptr, split_y_ptr, split_flag_ptr, x_ptr, w_high_ptr, w_low_ptr, m, n,
153-
k, scale, use_fp32_output, cfg.split_k, stream);
134+
bool running = gemm_bf16xfp32_async(y_ptr, split_y_ptr, split_flag_ptr, x_ptr, w_high_ptr,
135+
w_low_ptr, m, n, k, scale, use_fp32_output, cfg.split_k,
136+
cfg.kTileM, cfg.k_warpgroup_n, stream);
154137

155138
TORCH_CHECK(running, "gemm_bf16xfp32 launch failed!");
156139

0 commit comments

Comments
 (0)