Skip to content

[PERF][GEMV] fix uncoalesced memory access and add H200-specific optimizations #232

@superAngGao

Description

@superAngGao

Problem

The current GEMV kernel in tileops/kernels/gemv/gemv.py has a critical memory access inefficiency that severely limits performance on Hopper GPUs (SM90), including H200.

Root Cause: Uncoalesced Matrix B Access

The thread block is configured as threads=(block_n, reduce_threads), which makes tn = threadIdx.x (the fast-varying dimension) and tk = threadIdx.y.

Within a warp (32 consecutive linear thread IDs), tn varies while tk is fixed. The B matrix access pattern becomes:

Thread 0 (tn=0, tk=0) → B[row+0, col]   memory offset: 0
Thread 1 (tn=1, tk=0) → B[row+1, col]   memory offset: +K × sizeof(dtype)
Thread 2 (tn=2, tk=0) → B[row+2, col]   memory offset: +2K × sizeof(dtype)
...

For typical shapes like (n=7168, k=16384), the stride between consecutive thread accesses is 16384 × 2 = 32768 bytes. This means the warp issues 32 independent cache-line requests rather than a single coalesced transaction — effectively wasting ~97% of the available memory bandwidth.

Secondary Issue: Redundant Global Loads of Vector a

Vector a (shape (k,)) is loaded independently by every thread for every output row it contributes to. With no shared memory caching, the same a tile is read block_n times from global memory.

Current Default Config on SM90

{"block_n": 32, "reduce_threads": 8}

With reduce_threads=8, each warp spans 4 different output rows — further fragmenting coalescing and leaving hardware shuffle-reduce underutilized.


Proposed Optimizations

O1: Fix Coalescing by Swapping Thread Dimensions (Highest Priority)

Change thread layout from threads=(block_n, reduce_threads) to threads=(reduce_threads, block_n), so that tk = threadIdx.x becomes the fast-varying dimension.

With reduce_threads=32 (full warp), a warp now consists of 32 threads all belonging to the same output row, accessing consecutive columns:

Thread 0 (tk=0, tn=r) → B[row_r, col+ 0: 8]   ← 128-bit vector load
Thread 1 (tk=1, tn=r) → B[row_r, col+ 8:16]
...
Thread 31 (tk=31, tn=r) → B[row_r, col+248:256] ← 32×128-bit = 512 bytes coalesced

This single change recovers near-peak HBM3e bandwidth utilization.

O2: Cache Vector a in Shared Memory

Load each a tile once into shared memory per thread block, shared across all block_n rows:

a_shared = T.alloc_shared((block_k,), dtype)
# Only tn==0 threads write (predicated load)
for _k in T.vectorized(tile_k):
    a_shared[tk * tile_k + _k] = a[bk * block_k + tk * tile_k + _k]
T.syncthreads()
# All rows reuse a_shared in the FMA loop

Reduces a global memory traffic by a factor of block_n (e.g., 16× for block_n=16).

O3: Use reduce_threads=32 as Default (Full Warp per Row)

Aligns one warp per output row, enabling:

  • Hardware-accelerated warp shuffle reduction (replaces the current allreduce)
  • Elimination of partial-warp idle cycles
  • block_k = 32 × 8 = 256 fp16 elements per tile (2 consecutive cache lines)

O4: Expand Autotune Search Space for SM90/H200

@property
def autotune_configs(self) -> list[dict]:
    block_n_list = [1, 2, 4, 8, 16, 32]
    reduce_threads_list = [32]
    extra = [
        {"block_n": 64,  "reduce_threads": 16},
        {"block_n": 128, "reduce_threads": 16},
        {"block_n": 256, "reduce_threads": 32},
    ]
    configs = [{"block_n": bn, "reduce_threads": rt}
               for bn, rt in itertools.product(block_n_list, reduce_threads_list)]
    return configs + extra

O5: Update SM90 Default Config

if sm_version in {90}:
    return {"block_n": 16, "reduce_threads": 32}

512 threads/block, 256-element a tile in shared memory (512 bytes), sufficient blocks to saturate 132 SMs on H200.


Expected Impact

Optimization Change Expected Gain
O1 Fix coalescing Swap thread dims Up to ~32× B bandwidth recovery
O2 Shared memory for a T.alloc_shared + syncthreads ~block_n× reduction in a traffic
O3 Full warp reduce reduce_threads=32 Eliminates partial-warp waste
O4 Autotune expansion Wider search space Finds true optimum for each shape
O5 New default config block_n=16, reduce_threads=32 Better out-of-box performance on H200

Benchmark Shapes Affected

From tests/ops/test_gemv.py:

  • (n=7168, k=16384) — typical LLM FFN weight shape
  • (n=18432, k=7168) — larger FFN gate/up projection
  • (n=1024, k=1024) — small test case

For (n=7168, k=16384, fp16) on H200:

  • Memory footprint: ~235 MB
  • Theoretical min time at 4.8 TB/s bandwidth: ~49 μs
  • Current estimate: severely bandwidth-limited due to uncoalesced access

Next Steps

A fix PR addressing O1–O5 will follow this issue.

Metadata

Metadata

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions