[kernels] wvSplitK: gated v_permlanex16_b32 for half-wave reduction#982
Draft
mgehre-amd wants to merge 1 commit into
Draft
[kernels] wvSplitK: gated v_permlanex16_b32 for half-wave reduction#982mgehre-amd wants to merge 1 commit into
mgehre-amd wants to merge 1 commit into
Conversation
The wave32 cross-row reduction __shfl_xor(sum, 16) in the three wvSplitK bf16/fp16 kernels lowers to ds_bpermute_b32 — an LDS round-trip plus an s_waitcnt lgkmcnt(0) before the dependent FMA can issue. v_permlanex16_b32 does the same lane swap in one VALU op with no LDS traffic and no waitcnt. Switching unconditionally is net-neutral on bench_rocm_skinny_gemm_bf16: gains on the AC=8/16 paths (Qwen3-4B/Qwen2.5-VL families, up to -5.8%) are roughly cancelled by ~+1-3% regressions on the K=4096 AC=32 fast-paths recently tuned in 51b596e / d862a50. Those AC=32 kernels are VALU-dense (16 pk_dot per thread per k1 iter); the LDS bpermute used to overlap with VALU on the DS unit, and replacing it with an extra VALU op adds to the critical schedule. Gating the swap on `A_CHUNK != 32` keeps ds_bpermute for the AC=32 family and uses permlanex16 elsewhere — best of both. Changes: - skinny_gemms.cu: add shfl_xor16_f32 helper (gfx10+/RDNA wave32), route the three wvSplitK_hf_{sml,_,big} reduction tails through it via if constexpr (A_CHUNK == 32) ? __shfl_xor : shfl_xor16_f32. - skinny_gemms.cu: extend wvSplitK_sweep to support achunk=32 (used to confirm the dispatcher's AC=32 entries remain optimal — they do; best AC=16 candidate trails AC=32 by 0-5%). - sweep_bf16_kernel.py: add 32 to the ACHUNKS sweep grid. Verified on Strix Halo (gfx1151, 20 CU) with bench_rocm_skinny_gemm_bf16.py over 19 shapes x N=1..4 = 76 cases: ungated permlanex16 had median +0.09% (20 wins / 17 losses); gated has median -0.16% (25 wins / 9 losses). AC=32 family median moves from +1.13% (ungated) to +0.00% (gated); AC=8 family median moves from -0.03% to -0.35% (Qwen3-4B lm_head -5.0 to -5.8%). Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
The wave32 cross-row reduction
__shfl_xor(sum, 16)in the three wvSplitKbf16/fp16 kernels (
wvSplitK_hf_{sml,_,big}) lowers tods_bpermute_b32—an LDS round-trip plus
s_waitcnt lgkmcnt(0)before the dependent FMA canissue.
v_permlanex16_b32does the same lane swap in 1 VALU op with no LDStraffic and no waitcnt.
Switching unconditionally is net-neutral on
bench_rocm_skinny_gemm_bf16.py:gains on the AC=8/16 paths (Qwen3-4B / Qwen2.5-VL families, up to -5.8 %)
are roughly cancelled by ~+1–3 % regressions on the K=4096 AC=32 fast paths
recently tuned in
51b596e533/d862a50107. Those AC=32 kernels areVALU-dense (16
pk_dotper thread perk1iter); the LDSbpermuteusedto overlap with VALU on the DS unit, so replacing it with an extra VALU op
adds to the critical schedule.
Gating the swap on
A_CHUNK != 32keepsds_bpermutefor the AC=32family and uses
permlanex16elsewhere — best of both.Changes
csrc/rocm/skinny_gemms.cu: addshfl_xor16_f32helper (gfx10+/RDNAwave32), route the three
wvSplitK_hf_{sml,_,big}reduction tailsthrough it via
if constexpr (A_CHUNK == 32) ? __shfl_xor : shfl_xor16_f32.csrc/rocm/skinny_gemms.cu: extendwvSplitK_sweepto supportachunk=32(used to confirm the dispatcher's AC=32 entries remainoptimal — they do; best AC=16 candidate trails AC=32 by 0–5 %).
benchmarks/kernels/sweep_bf16_kernel.py: add32to theACHUNKSsweep grid.
ISA
Per-A_CHUNK tally from the rebuilt
_rocm_C.abi3.so:v_permlanex16_b32ds_bpermute_b32Test plan
tests/kernels/quantization/bench_rocm_skinny_gemm_bf16.pyon Strix Halo(gfx1151, 20 CU): 19 shapes × N=1..4 = 76 cases, target SE 0.1 %,
TODO: remeasure
End-to-end (Qwen3-4B bf16, vLLM serving)
Two paired reps each side, max-concurrency=1, input 128 / output 256, 1 prompts:
TODO: remeasure
Reproduce
AI assistance was used to author this change.