-
Notifications
You must be signed in to change notification settings - Fork 23
Description
Problem
The current GEMV kernel in tileops/kernels/gemv/gemv.py has a critical memory access inefficiency that severely limits performance on Hopper GPUs (SM90), including H200.
Root Cause: Uncoalesced Matrix B Access
The thread block is configured as threads=(block_n, reduce_threads), which makes tn = threadIdx.x (the fast-varying dimension) and tk = threadIdx.y.
Within a warp (32 consecutive linear thread IDs), tn varies while tk is fixed. The B matrix access pattern becomes:
Thread 0 (tn=0, tk=0) → B[row+0, col] memory offset: 0
Thread 1 (tn=1, tk=0) → B[row+1, col] memory offset: +K × sizeof(dtype)
Thread 2 (tn=2, tk=0) → B[row+2, col] memory offset: +2K × sizeof(dtype)
...
For typical shapes like (n=7168, k=16384), the stride between consecutive thread accesses is 16384 × 2 = 32768 bytes. This means the warp issues 32 independent cache-line requests rather than a single coalesced transaction — effectively wasting ~97% of the available memory bandwidth.
Secondary Issue: Redundant Global Loads of Vector a
Vector a (shape (k,)) is loaded independently by every thread for every output row it contributes to. With no shared memory caching, the same a tile is read block_n times from global memory.
Current Default Config on SM90
{"block_n": 32, "reduce_threads": 8}With reduce_threads=8, each warp spans 4 different output rows — further fragmenting coalescing and leaving hardware shuffle-reduce underutilized.
Proposed Optimizations
O1: Fix Coalescing by Swapping Thread Dimensions (Highest Priority)
Change thread layout from threads=(block_n, reduce_threads) to threads=(reduce_threads, block_n), so that tk = threadIdx.x becomes the fast-varying dimension.
With reduce_threads=32 (full warp), a warp now consists of 32 threads all belonging to the same output row, accessing consecutive columns:
Thread 0 (tk=0, tn=r) → B[row_r, col+ 0: 8] ← 128-bit vector load
Thread 1 (tk=1, tn=r) → B[row_r, col+ 8:16]
...
Thread 31 (tk=31, tn=r) → B[row_r, col+248:256] ← 32×128-bit = 512 bytes coalesced
This single change recovers near-peak HBM3e bandwidth utilization.
O2: Cache Vector a in Shared Memory
Load each a tile once into shared memory per thread block, shared across all block_n rows:
a_shared = T.alloc_shared((block_k,), dtype)
# Only tn==0 threads write (predicated load)
for _k in T.vectorized(tile_k):
a_shared[tk * tile_k + _k] = a[bk * block_k + tk * tile_k + _k]
T.syncthreads()
# All rows reuse a_shared in the FMA loopReduces a global memory traffic by a factor of block_n (e.g., 16× for block_n=16).
O3: Use reduce_threads=32 as Default (Full Warp per Row)
Aligns one warp per output row, enabling:
- Hardware-accelerated warp shuffle reduction (replaces the current allreduce)
- Elimination of partial-warp idle cycles
block_k = 32 × 8 = 256fp16 elements per tile (2 consecutive cache lines)
O4: Expand Autotune Search Space for SM90/H200
@property
def autotune_configs(self) -> list[dict]:
block_n_list = [1, 2, 4, 8, 16, 32]
reduce_threads_list = [32]
extra = [
{"block_n": 64, "reduce_threads": 16},
{"block_n": 128, "reduce_threads": 16},
{"block_n": 256, "reduce_threads": 32},
]
configs = [{"block_n": bn, "reduce_threads": rt}
for bn, rt in itertools.product(block_n_list, reduce_threads_list)]
return configs + extraO5: Update SM90 Default Config
if sm_version in {90}:
return {"block_n": 16, "reduce_threads": 32}512 threads/block, 256-element a tile in shared memory (512 bytes), sufficient blocks to saturate 132 SMs on H200.
Expected Impact
| Optimization | Change | Expected Gain |
|---|---|---|
| O1 Fix coalescing | Swap thread dims | Up to ~32× B bandwidth recovery |
O2 Shared memory for a |
T.alloc_shared + syncthreads |
~block_n× reduction in a traffic |
| O3 Full warp reduce | reduce_threads=32 |
Eliminates partial-warp waste |
| O4 Autotune expansion | Wider search space | Finds true optimum for each shape |
| O5 New default config | block_n=16, reduce_threads=32 |
Better out-of-box performance on H200 |
Benchmark Shapes Affected
From tests/ops/test_gemv.py:
(n=7168, k=16384)— typical LLM FFN weight shape(n=18432, k=7168)— larger FFN gate/up projection(n=1024, k=1024)— small test case
For (n=7168, k=16384, fp16) on H200:
- Memory footprint: ~235 MB
- Theoretical min time at 4.8 TB/s bandwidth: ~49 μs
- Current estimate: severely bandwidth-limited due to uncoalesced access
Next Steps
A fix PR addressing O1–O5 will follow this issue.