Skip to content

[kernels] wvSplitK: gated v_permlanex16_b32 for half-wave reduction#982

Draft
mgehre-amd wants to merge 1 commit into
gfx11from
matthias.wvsplitk-permlanex16
Draft

[kernels] wvSplitK: gated v_permlanex16_b32 for half-wave reduction#982
mgehre-amd wants to merge 1 commit into
gfx11from
matthias.wvsplitk-permlanex16

Conversation

@mgehre-amd
Copy link
Copy Markdown

@mgehre-amd mgehre-amd commented May 29, 2026

Summary

The wave32 cross-row reduction __shfl_xor(sum, 16) in the three wvSplitK
bf16/fp16 kernels (wvSplitK_hf_{sml,_,big}) lowers to ds_bpermute_b32
an LDS round-trip plus s_waitcnt lgkmcnt(0) before the dependent FMA can
issue. v_permlanex16_b32 does the same lane swap in 1 VALU op with no LDS
traffic and no waitcnt.

Switching unconditionally is net-neutral on bench_rocm_skinny_gemm_bf16.py:
gains on the AC=8/16 paths (Qwen3-4B / Qwen2.5-VL families, up to -5.8 %)
are roughly cancelled by ~+1–3 % regressions on the K=4096 AC=32 fast paths
recently tuned in 51b596e533 / d862a50107. Those AC=32 kernels are
VALU-dense (16 pk_dot per thread per k1 iter); the LDS bpermute used
to overlap with VALU on the DS unit, so replacing it with an extra VALU op
adds to the critical schedule.

Gating the swap on A_CHUNK != 32 keeps ds_bpermute for the AC=32
family and uses permlanex16 elsewhere — best of both.

Changes

  • csrc/rocm/skinny_gemms.cu: add shfl_xor16_f32 helper (gfx10+/RDNA
    wave32), route the three wvSplitK_hf_{sml,_,big} reduction tails
    through it via if constexpr (A_CHUNK == 32) ? __shfl_xor : shfl_xor16_f32.
  • csrc/rocm/skinny_gemms.cu: extend wvSplitK_sweep to support
    achunk=32 (used to confirm the dispatcher's AC=32 entries remain
    optimal — they do; best AC=16 candidate trails AC=32 by 0–5 %).
  • benchmarks/kernels/sweep_bf16_kernel.py: add 32 to the ACHUNKS
    sweep grid.

ISA

Per-A_CHUNK tally from the rebuilt _rocm_C.abi3.so:

A_CHUNK kernels v_permlanex16_b32 ds_bpermute_b32
8 304 1912 0
16 80 280 0
32 96 0 240

Test plan

tests/kernels/quantization/bench_rocm_skinny_gemm_bf16.py on Strix Halo
(gfx1151, 20 CU): 19 shapes × N=1..4 = 76 cases, target SE 0.1 %,

TODO: remeasure

End-to-end (Qwen3-4B bf16, vLLM serving)

Two paired reps each side, max-concurrency=1, input 128 / output 256, 1 prompts:

TODO: remeasure

Reproduce

# baseline:    git checkout d862a50107 && build-vllm && bench
# AC-gated:    git checkout 9f18849da4 && build-vllm && bench

SITE_PACKAGES=$(.venv/bin/python -c "import site; print(site.getsitepackages()[0])")
export PYTHONPATH=$SITE_PACKAGES/_rocm_sdk_core/share/amd_smi

gpu-lock .venv/bin/python tests/kernels/quantization/bench_rocm_skinny_gemm_bf16.py \
    --batch-sizes 1 2 3 4 --dtype bf16

AI assistance was used to author this change.

The wave32 cross-row reduction __shfl_xor(sum, 16) in the three wvSplitK
bf16/fp16 kernels lowers to ds_bpermute_b32 — an LDS round-trip plus
an s_waitcnt lgkmcnt(0) before the dependent FMA can issue.
v_permlanex16_b32 does the same lane swap in one VALU op with no LDS
traffic and no waitcnt.

Switching unconditionally is net-neutral on bench_rocm_skinny_gemm_bf16:
gains on the AC=8/16 paths (Qwen3-4B/Qwen2.5-VL families, up to -5.8%)
are roughly cancelled by ~+1-3% regressions on the K=4096 AC=32
fast-paths recently tuned in 51b596e / d862a50. Those AC=32
kernels are VALU-dense (16 pk_dot per thread per k1 iter); the LDS
bpermute used to overlap with VALU on the DS unit, and replacing it
with an extra VALU op adds to the critical schedule.

Gating the swap on `A_CHUNK != 32` keeps ds_bpermute for the AC=32
family and uses permlanex16 elsewhere — best of both.

Changes:
- skinny_gemms.cu: add shfl_xor16_f32 helper (gfx10+/RDNA wave32),
  route the three wvSplitK_hf_{sml,_,big} reduction tails through it
  via if constexpr (A_CHUNK == 32) ? __shfl_xor : shfl_xor16_f32.
- skinny_gemms.cu: extend wvSplitK_sweep to support achunk=32 (used
  to confirm the dispatcher's AC=32 entries remain optimal — they do;
  best AC=16 candidate trails AC=32 by 0-5%).
- sweep_bf16_kernel.py: add 32 to the ACHUNKS sweep grid.

Verified on Strix Halo (gfx1151, 20 CU) with
bench_rocm_skinny_gemm_bf16.py over 19 shapes x N=1..4 = 76 cases:
ungated permlanex16 had median +0.09% (20 wins / 17 losses); gated
has median -0.16% (25 wins / 9 losses). AC=32 family median moves
from +1.13% (ungated) to +0.00% (gated); AC=8 family median moves
from -0.03% to -0.35% (Qwen3-4B lm_head -5.0 to -5.8%).

Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant