77at batch sizes 1-4. Validates accuracy against torch.mm. Dynamically
88determines iteration count per shape based on IQR convergence.
99
10+ Per-kernel timestamps are recorded inside a captured CUDA graph via a small
11+ ctypes shim that calls hipEventRecordWithFlags(hipEventRecordExternal) —
12+ PyTorch's high-level torch.cuda.Event blocks this path on ROCm (AIESW-34641).
13+
1014Usage:
1115 python tests/kernels/quantization/bench_rocm_skinny_gemm.py
1216 python tests/kernels/quantization/bench_rocm_skinny_gemm.py --dtype bf16
1519"""
1620
1721import argparse
22+ import ctypes
1823import math
24+ import os
1925import time
2026
2127import torch
2733# Use a conservative estimate to ensure we bust L3.
2834CACHE_SIZE_BYTES = 64 * 1024 * 1024
2935
36+
37+ # ---------------------------------------------------------------------------
38+ # HIP ctypes shim — workaround for PyTorch's blanket disable of
39+ # cudaEventRecordExternal on ROCm (see AIESW-34641). Lets us record per-kernel
40+ # events inside a captured CUDA graph and read back queryable timestamps.
41+ # Remove once PyTorch upstream lifts the TORCH_CHECK in c10/cuda/CUDAEvent.h.
42+ # ---------------------------------------------------------------------------
43+ HIP_EVENT_RECORD_EXTERNAL = 0x01
44+
45+
46+ def _load_hip ():
47+ site = os .path .dirname (os .path .dirname (torch .__file__ ))
48+ for sub in ("_rocm_sdk_core/lib" , "_rocm_sdk_devel/lib" ):
49+ for name in ("libamdhip64.so.7" , "libamdhip64.so" ):
50+ p = os .path .join (site , sub , name )
51+ if os .path .exists (p ):
52+ lib = ctypes .CDLL (p )
53+ lib .hipEventRecordWithFlags .argtypes = [
54+ ctypes .c_void_p ,
55+ ctypes .c_void_p ,
56+ ctypes .c_uint ,
57+ ]
58+ lib .hipEventRecordWithFlags .restype = ctypes .c_int
59+ lib .hipEventElapsedTime .argtypes = [
60+ ctypes .POINTER (ctypes .c_float ),
61+ ctypes .c_void_p ,
62+ ctypes .c_void_p ,
63+ ]
64+ lib .hipEventElapsedTime .restype = ctypes .c_int
65+ return lib
66+ raise RuntimeError ("libamdhip64 not found under torch site-packages" )
67+
68+
69+ _HIP = _load_hip ()
70+
71+
72+ def _record_external (ev : torch .cuda .Event , stream ) -> None :
73+ """Record `ev` on `stream` with hipEventRecordExternal (graph-safe)."""
74+ err = _HIP .hipEventRecordWithFlags (
75+ int (ev .cuda_event ), int (stream .cuda_stream ), HIP_EVENT_RECORD_EXTERNAL
76+ )
77+ if err != 0 :
78+ raise RuntimeError (f"hipEventRecordWithFlags returned { err } " )
79+
80+
81+ def _elapsed_ms (start_ev : torch .cuda .Event , end_ev : torch .cuda .Event ) -> float :
82+ ms = ctypes .c_float (- 1.0 )
83+ err = _HIP .hipEventElapsedTime (
84+ ctypes .byref (ms ), int (start_ev .cuda_event ), int (end_ev .cuda_event )
85+ )
86+ if err != 0 :
87+ raise RuntimeError (f"hipEventElapsedTime returned { err } " )
88+ return ms .value
89+
90+
91+ def _make_event ():
92+ """Create a timing event and force lazy hipEventCreate by recording once."""
93+ e = torch .cuda .Event (enable_timing = True )
94+ e .record ()
95+ return e
96+
97+
3098SHAPES = [
3199 # Qwen3-4B / Qwen3-VL-4B (identical backbone)
32100 (6144 , 2560 , "Qwen3-4B qkv" ),
@@ -71,19 +139,20 @@ def _median_se(times_sorted):
71139def bench_dynamic (
72140 fn ,
73141 target_se_pct = 0.2 ,
74- min_replays = 8 ,
142+ min_replays = 4 ,
75143 max_replays = 40 ,
76144 max_time_s = 1.0 ,
77145 target_replay_ms = 20.0 ,
78146):
79- """Benchmark fn by capturing many launches into a CUDA graph.
147+ """Benchmark fn with per-kernel timing inside a captured CUDA graph.
80148
81149 Probes the kernel time, sizes one capture so a replay runs ~target_replay_ms
82150 (so the GPU stays continuously busy and DVFS doesn't drop the clock between
83- launches), captures `iters_per_replay` calls of fn(0..iters-1), and times
84- repeated replays. fn(i) lets callers rotate weight buffers.
151+ launches), captures `iters_per_replay` calls of fn(0..iters-1), each
152+ bracketed by hipEventRecord(EXTERNAL) so per-kernel timestamps are queryable
153+ on replay. fn(i) lets callers rotate weight buffers.
85154
86- Returns (median_ms_per_kernel, num_kernels_timed , se_pct).
155+ Returns (median_ms_per_kernel, num_samples , se_pct).
87156 """
88157 # 1) Probe one kernel to size the graph.
89158 fn (0 )
@@ -97,46 +166,44 @@ def bench_dynamic(
97166 probe_ms = max (probe_start .elapsed_time (probe_end ), 1e-3 )
98167 iters_per_replay = max (2 , min (2000 , int (target_replay_ms / probe_ms )))
99168
100- # 2) Warm + capture on a side stream. Run a few more launches inside the
101- # capture stream before recording so the caching allocator is settled.
102- s = torch .cuda .Stream ()
103- s .wait_stream (torch .cuda .current_stream ())
104- with torch .cuda .stream (s ):
105- for i in range (5 ):
106- fn (i )
107- torch .cuda .current_stream ().wait_stream (s )
169+ # 2) Allocate a chain of iters_per_replay+1 events. The i-th per-kernel
170+ # time is events[i].elapsed_time(events[i+1]). Force handle creation
171+ # on the default stream so the underlying hipEvent_t exists before
172+ # the stream-capture region.
173+ events = [_make_event () for _ in range (iters_per_replay + 1 )]
108174 torch .accelerator .synchronize ()
109175
176+ # 3) Capture on a side stream, recording the event chain with EXTERNAL.
177+ s = torch .cuda .Stream ()
110178 g = torch .cuda .CUDAGraph ()
111179 with torch .cuda .graph (g , stream = s ):
180+ _record_external (events [0 ], s )
112181 for i in range (iters_per_replay ):
113182 fn (i )
183+ _record_external (events [i + 1 ], s )
114184
115185 # Warm one replay to absorb first-launch cost.
116186 g .replay ()
117187 torch .accelerator .synchronize ()
118188
119- # 3) Time replays adaptively.
120- times = []
121- start_ev = torch .Event (enable_timing = True )
122- end_ev = torch .Event (enable_timing = True )
189+ # 4) Time replays adaptively. Each replay yields iters_per_replay samples.
190+ samples = []
123191 wall_start = time .monotonic ()
124192 for r in range (max_replays ):
125- start_ev .record ()
126193 g .replay ()
127- end_ev .record ()
128194 torch .accelerator .synchronize ()
129- times .append (start_ev .elapsed_time (end_ev ) / iters_per_replay )
195+ for i in range (iters_per_replay ):
196+ samples .append (_elapsed_ms (events [i ], events [i + 1 ]))
130197
131- if len ( times ) >= min_replays and len ( times ) % 5 == 0 :
132- med , se_pct = _median_se (sorted (times ))
198+ if r + 1 >= min_replays :
199+ med , se_pct = _median_se (sorted (samples ))
133200 if se_pct < target_se_pct :
134- return med , len (times ) * iters_per_replay , se_pct
201+ return med , len (samples ) , se_pct
135202 if time .monotonic () - wall_start > max_time_s :
136- return med , len (times ) * iters_per_replay , se_pct
203+ return med , len (samples ) , se_pct
137204
138- med , se_pct = _median_se (sorted (times ))
139- return med , len (times ) * iters_per_replay , se_pct
205+ med , se_pct = _median_se (sorted (samples ))
206+ return med , len (samples ) , se_pct
140207
141208
142209def parse_shape (s ):
@@ -156,11 +223,8 @@ def run_bench(shapes, batch_sizes, dtype, target_se_pct):
156223 print (f"Shapes: { len (shapes )} , Batch sizes: { batch_sizes } " )
157224 print ()
158225
159- print (
160- f"{ 'N' :>2} { 'M' :>6} x{ 'K' :<6} { 'Label' :<22} "
161- f"{ 'time_us' :>9} { 'BW GiB/s' :>9} { 'bufs' :>5} { 'iters' :>6} { 'SE%' :>5} "
162- )
163- print ("-" * 80 )
226+ print (f"{ 'N' :>2} { 'M' :>6} x{ 'K' :<6} { 'Label' :<22} { 'med_us' :>9} { 'med_GiB/s' :>10} " )
227+ print ("-" * 60 )
164228
165229 t0 = time .time ()
166230 for M , K , label in shapes :
@@ -184,17 +248,14 @@ def run_bench(shapes, batch_sizes, dtype, target_se_pct):
184248 fn = lambda i , ws = weights , a = activation : ops .wvSplitK (
185249 ws [i % len (ws )], a , cu_count
186250 )
187- med_ms , iters , se_pct = bench_dynamic (
251+ med_ms , _ , _ = bench_dynamic (
188252 fn ,
189253 target_se_pct = target_se_pct ,
190254 )
191- time_us = med_ms * 1000
192- bw_gibs = weight_bytes / (med_ms * 1e-3 ) / (1 << 30 )
255+ med_us = med_ms * 1000
256+ med_bw = weight_bytes / (med_ms * 1e-3 ) / (1 << 30 )
193257
194- print (
195- f"{ N :>2} { M :>6} x{ K :<6} { label :<22} "
196- f"{ time_us :>8.1f} { bw_gibs :>8.1f} { n_bufs :>5} { iters :>6} { se_pct :>5.2f} "
197- )
258+ print (f"{ N :>2} { M :>6} x{ K :<6} { label :<22} { med_us :>8.1f} { med_bw :>9.1f} " )
198259
199260 elapsed = time .time () - t0
200261 print ()
0 commit comments