Skip to content

Commit c52ff9e

Browse files
committed
[bench] wvSplitK: per-kernel timing inside captured CUDA graph
Replaces the whole-replay event timing with per-kernel timing recorded inside the captured graph. PyTorch's torch.cuda.Event blocks the path that records queryable timestamps inside graph capture on ROCm (TORCH_CHECK(!external_) in c10/cuda/CUDAEvent.h, AIESW-34641), so the bench drops to a small ctypes shim that calls hipEventRecordWithFlags(hipEventRecordExternal) and hipEventElapsedTime directly on the raw hipEvent_t handle pulled out of torch.cuda.Event.cuda_event. Capture layout uses a single event chain of iters_per_replay+1 events (one per kernel boundary) rather than 2*iters_per_replay start/end pairs, halving the event count in the graph for the same number of per-kernel samples. Reduces run-to-run noise from 0.88% to 0.51% median (6.27% -> 2.54% max) across the full 76-cell sweep, and ~45 s wall per run. Numbers match the in-model profile better than the previous whole-replay median. The ctypes shim can be removed once PyTorch upstream lifts the ROCm external-event guard. Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
1 parent 0a14f01 commit c52ff9e

1 file changed

Lines changed: 99 additions & 38 deletions

File tree

tests/kernels/quantization/bench_rocm_skinny_gemm.py

Lines changed: 99 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@
77
at batch sizes 1-4. Validates accuracy against torch.mm. Dynamically
88
determines iteration count per shape based on IQR convergence.
99
10+
Per-kernel timestamps are recorded inside a captured CUDA graph via a small
11+
ctypes shim that calls hipEventRecordWithFlags(hipEventRecordExternal) —
12+
PyTorch's high-level torch.cuda.Event blocks this path on ROCm (AIESW-34641).
13+
1014
Usage:
1115
python tests/kernels/quantization/bench_rocm_skinny_gemm.py
1216
python tests/kernels/quantization/bench_rocm_skinny_gemm.py --dtype bf16
@@ -15,7 +19,9 @@
1519
"""
1620

1721
import argparse
22+
import ctypes
1823
import math
24+
import os
1925
import time
2026

2127
import torch
@@ -27,6 +33,68 @@
2733
# Use a conservative estimate to ensure we bust L3.
2834
CACHE_SIZE_BYTES = 64 * 1024 * 1024
2935

36+
37+
# ---------------------------------------------------------------------------
38+
# HIP ctypes shim — workaround for PyTorch's blanket disable of
39+
# cudaEventRecordExternal on ROCm (see AIESW-34641). Lets us record per-kernel
40+
# events inside a captured CUDA graph and read back queryable timestamps.
41+
# Remove once PyTorch upstream lifts the TORCH_CHECK in c10/cuda/CUDAEvent.h.
42+
# ---------------------------------------------------------------------------
43+
HIP_EVENT_RECORD_EXTERNAL = 0x01
44+
45+
46+
def _load_hip():
47+
site = os.path.dirname(os.path.dirname(torch.__file__))
48+
for sub in ("_rocm_sdk_core/lib", "_rocm_sdk_devel/lib"):
49+
for name in ("libamdhip64.so.7", "libamdhip64.so"):
50+
p = os.path.join(site, sub, name)
51+
if os.path.exists(p):
52+
lib = ctypes.CDLL(p)
53+
lib.hipEventRecordWithFlags.argtypes = [
54+
ctypes.c_void_p,
55+
ctypes.c_void_p,
56+
ctypes.c_uint,
57+
]
58+
lib.hipEventRecordWithFlags.restype = ctypes.c_int
59+
lib.hipEventElapsedTime.argtypes = [
60+
ctypes.POINTER(ctypes.c_float),
61+
ctypes.c_void_p,
62+
ctypes.c_void_p,
63+
]
64+
lib.hipEventElapsedTime.restype = ctypes.c_int
65+
return lib
66+
raise RuntimeError("libamdhip64 not found under torch site-packages")
67+
68+
69+
_HIP = _load_hip()
70+
71+
72+
def _record_external(ev: torch.cuda.Event, stream) -> None:
73+
"""Record `ev` on `stream` with hipEventRecordExternal (graph-safe)."""
74+
err = _HIP.hipEventRecordWithFlags(
75+
int(ev.cuda_event), int(stream.cuda_stream), HIP_EVENT_RECORD_EXTERNAL
76+
)
77+
if err != 0:
78+
raise RuntimeError(f"hipEventRecordWithFlags returned {err}")
79+
80+
81+
def _elapsed_ms(start_ev: torch.cuda.Event, end_ev: torch.cuda.Event) -> float:
82+
ms = ctypes.c_float(-1.0)
83+
err = _HIP.hipEventElapsedTime(
84+
ctypes.byref(ms), int(start_ev.cuda_event), int(end_ev.cuda_event)
85+
)
86+
if err != 0:
87+
raise RuntimeError(f"hipEventElapsedTime returned {err}")
88+
return ms.value
89+
90+
91+
def _make_event():
92+
"""Create a timing event and force lazy hipEventCreate by recording once."""
93+
e = torch.cuda.Event(enable_timing=True)
94+
e.record()
95+
return e
96+
97+
3098
SHAPES = [
3199
# Qwen3-4B / Qwen3-VL-4B (identical backbone)
32100
(6144, 2560, "Qwen3-4B qkv"),
@@ -71,19 +139,20 @@ def _median_se(times_sorted):
71139
def bench_dynamic(
72140
fn,
73141
target_se_pct=0.2,
74-
min_replays=8,
142+
min_replays=4,
75143
max_replays=40,
76144
max_time_s=1.0,
77145
target_replay_ms=20.0,
78146
):
79-
"""Benchmark fn by capturing many launches into a CUDA graph.
147+
"""Benchmark fn with per-kernel timing inside a captured CUDA graph.
80148
81149
Probes the kernel time, sizes one capture so a replay runs ~target_replay_ms
82150
(so the GPU stays continuously busy and DVFS doesn't drop the clock between
83-
launches), captures `iters_per_replay` calls of fn(0..iters-1), and times
84-
repeated replays. fn(i) lets callers rotate weight buffers.
151+
launches), captures `iters_per_replay` calls of fn(0..iters-1), each
152+
bracketed by hipEventRecord(EXTERNAL) so per-kernel timestamps are queryable
153+
on replay. fn(i) lets callers rotate weight buffers.
85154
86-
Returns (median_ms_per_kernel, num_kernels_timed, se_pct).
155+
Returns (median_ms_per_kernel, num_samples, se_pct).
87156
"""
88157
# 1) Probe one kernel to size the graph.
89158
fn(0)
@@ -97,46 +166,44 @@ def bench_dynamic(
97166
probe_ms = max(probe_start.elapsed_time(probe_end), 1e-3)
98167
iters_per_replay = max(2, min(2000, int(target_replay_ms / probe_ms)))
99168

100-
# 2) Warm + capture on a side stream. Run a few more launches inside the
101-
# capture stream before recording so the caching allocator is settled.
102-
s = torch.cuda.Stream()
103-
s.wait_stream(torch.cuda.current_stream())
104-
with torch.cuda.stream(s):
105-
for i in range(5):
106-
fn(i)
107-
torch.cuda.current_stream().wait_stream(s)
169+
# 2) Allocate a chain of iters_per_replay+1 events. The i-th per-kernel
170+
# time is events[i].elapsed_time(events[i+1]). Force handle creation
171+
# on the default stream so the underlying hipEvent_t exists before
172+
# the stream-capture region.
173+
events = [_make_event() for _ in range(iters_per_replay + 1)]
108174
torch.accelerator.synchronize()
109175

176+
# 3) Capture on a side stream, recording the event chain with EXTERNAL.
177+
s = torch.cuda.Stream()
110178
g = torch.cuda.CUDAGraph()
111179
with torch.cuda.graph(g, stream=s):
180+
_record_external(events[0], s)
112181
for i in range(iters_per_replay):
113182
fn(i)
183+
_record_external(events[i + 1], s)
114184

115185
# Warm one replay to absorb first-launch cost.
116186
g.replay()
117187
torch.accelerator.synchronize()
118188

119-
# 3) Time replays adaptively.
120-
times = []
121-
start_ev = torch.Event(enable_timing=True)
122-
end_ev = torch.Event(enable_timing=True)
189+
# 4) Time replays adaptively. Each replay yields iters_per_replay samples.
190+
samples = []
123191
wall_start = time.monotonic()
124192
for r in range(max_replays):
125-
start_ev.record()
126193
g.replay()
127-
end_ev.record()
128194
torch.accelerator.synchronize()
129-
times.append(start_ev.elapsed_time(end_ev) / iters_per_replay)
195+
for i in range(iters_per_replay):
196+
samples.append(_elapsed_ms(events[i], events[i + 1]))
130197

131-
if len(times) >= min_replays and len(times) % 5 == 0:
132-
med, se_pct = _median_se(sorted(times))
198+
if r + 1 >= min_replays:
199+
med, se_pct = _median_se(sorted(samples))
133200
if se_pct < target_se_pct:
134-
return med, len(times) * iters_per_replay, se_pct
201+
return med, len(samples), se_pct
135202
if time.monotonic() - wall_start > max_time_s:
136-
return med, len(times) * iters_per_replay, se_pct
203+
return med, len(samples), se_pct
137204

138-
med, se_pct = _median_se(sorted(times))
139-
return med, len(times) * iters_per_replay, se_pct
205+
med, se_pct = _median_se(sorted(samples))
206+
return med, len(samples), se_pct
140207

141208

142209
def parse_shape(s):
@@ -156,11 +223,8 @@ def run_bench(shapes, batch_sizes, dtype, target_se_pct):
156223
print(f"Shapes: {len(shapes)}, Batch sizes: {batch_sizes}")
157224
print()
158225

159-
print(
160-
f"{'N':>2} {'M':>6}x{'K':<6} {'Label':<22} "
161-
f"{'time_us':>9} {'BW GiB/s':>9} {'bufs':>5} {'iters':>6} {'SE%':>5}"
162-
)
163-
print("-" * 80)
226+
print(f"{'N':>2} {'M':>6}x{'K':<6} {'Label':<22} {'med_us':>9} {'med_GiB/s':>10}")
227+
print("-" * 60)
164228

165229
t0 = time.time()
166230
for M, K, label in shapes:
@@ -184,17 +248,14 @@ def run_bench(shapes, batch_sizes, dtype, target_se_pct):
184248
fn = lambda i, ws=weights, a=activation: ops.wvSplitK(
185249
ws[i % len(ws)], a, cu_count
186250
)
187-
med_ms, iters, se_pct = bench_dynamic(
251+
med_ms, _, _ = bench_dynamic(
188252
fn,
189253
target_se_pct=target_se_pct,
190254
)
191-
time_us = med_ms * 1000
192-
bw_gibs = weight_bytes / (med_ms * 1e-3) / (1 << 30)
255+
med_us = med_ms * 1000
256+
med_bw = weight_bytes / (med_ms * 1e-3) / (1 << 30)
193257

194-
print(
195-
f"{N:>2} {M:>6}x{K:<6} {label:<22} "
196-
f"{time_us:>8.1f} {bw_gibs:>8.1f} {n_bufs:>5} {iters:>6} {se_pct:>5.2f}"
197-
)
258+
print(f"{N:>2} {M:>6}x{K:<6} {label:<22} {med_us:>8.1f} {med_bw:>9.1f}")
198259

199260
elapsed = time.time() - t0
200261
print()

0 commit comments

Comments
 (0)