Marcusr/aiesw 32176 w4a16 ck wmma#930
Draft
marcusr-amd wants to merge 5 commits into
Draft
Conversation
Adds csrc/rocm/ck_w4a16.cu wrapping CK's DeviceGemm_BScale_Wmma_CShuffleV3 as torch.ops._rocm_C.ck_w4a16_b_scale_gemm. Targets the Qwen3-4B gate_up_proj prefill shape (M=3968, N=19456, K=2560, group=128, fp16) on Strix Halo, where the Triton path leaves significant performance on the table. Build is gated on -DVLLM_CK_INCLUDE_DIR + -DVLLM_CK_BUILD_INCLUDE_DIR (CK's ck/config.h is generated by CK's own configure step, so both source and build include dirs are needed). Without those flags csrc/rocm/ck_w4a16.cu is skipped and the dispatcher falls through to Triton. Dispatch lives inside the existing hybrid_w4a16_apply custom op (extended with optional w_q_ck + ck_target_m args), keeping it opaque to dynamo. CK-format weights are precomputed once in process_weights_after_loading; the runtime M check is a plain int compare against a per-layer cached value. Set VLLM_DISABLE_CK_W4A16=1 to force fall-through for A/B testing. The vLLM ExLlama [N, K//8] int32 weight layout maps to CK's [K0, N, K1//2] int8 via a single reshape + axis swap (no nibble re-packing). Scales pass through unchanged. Symmetric (uint4b8) only -- the kernel deliberately skips when zero-points are present, so asymmetric AWQ checkpoints fall through to Triton until follow-up work adds zero-point support. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Marcus Rosen <marcus.rosen@amd.com>
Adds csrc/rocm/ck_w4a16.cu::ck_w4a16_b_scale_zp_gemm and wires the dispatch in HybridW4A16LinearKernel so AWQ checkpoints with per-group zero points (e.g. Qwen/Qwen3-4B-AWQ) reach the CK kernel. Symmetric callers are unchanged. Implementation reuses the existing symmetric kernel via the identity (nibble - zp) * scale = (nibble - 8) * scale - (zp - 8) * scale The caller precomputes scaled_zp = (zp - 8) * scale per group at weight load and passes it to the new op; the CK kernel subtracts it inline during dequant. All tile sizing, scheduler, and threadmap config is shared with the symmetric path so future tuning benefits both kernels in one place. Requires the matching CK header changes in include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_ab_scale.hpp include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_b_scale.hpp which add an optional zero-point pointer threaded through to the dequant inner loop. Build is gated on -DVLLM_CK_INCLUDE_DIR/-DVLLM_CK_BUILD_INCLUDE_DIR exactly as before; without those flags the new op is skipped and asymmetric callers stay on the Triton path. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Marcus Rosen <marcus.rosen@amd.com>
… default) Standalone sweep on the same EXP1_FINAL kernel config shows it holds ~30 TFLOPS on M=2048 (default chunked-prefill chunk size on the Qwen3-4B gate_up_proj N=19456 K=2560 column), within ~3% of the M=3968 number it was tuned for. Adds M=2048 to the per-layer target-M list so users no longer need --max-num-batched-tokens 4096 to hit the CK path on that shape. Generalizes the dispatch table from a single per-layer M to a small list, threaded through the hybrid_w4a16_apply custom op as SymInt[]?. Membership test stays inside the opaque custom op; runtime check is still a plain Python int compare against a 1-2 element list per layer. Signed-off-by: Marcus Rosen <marcus.rosen@amd.com>
…hreshold
Standalone sweep on the gate_up shape (N=19456, K=2560) shows the same
EXP1_FINAL kernel config holds 28-31 TFLOPS uniformly across M=256-16384,
so a min-M threshold is more accurate than enumerating measured values.
Below ~256 the kernel's fixed launch overhead (~0.4 ms) dominates and
Triton is comparable; the threshold avoids that range.
In particular this dispatches the M=1920 second-chunk case for
chunk=2048+prompt=3968 (which the discrete list missed and fell back
to Triton). E2E on Qwen/Qwen3-4B-AWQ at default chunk=2048:
before: 1987 ms TTFT (discrete list, M=1920 -> Triton fallback)
after: 1922 ms TTFT (min_m=256, M=1920 -> CK)
Per-layer kernel time at M=1920:
Triton 8.48 ms -> CK 6.14 ms (-2.3 ms/layer * 36 layers = -84 ms;
realized -65 ms after chunked-prefill overhead).
Also generalizes to arbitrary chunk sizes -- users no longer need to
match a specific M for CK to fire.
Custom op signature: SymInt[]? ck_target_ms -> SymInt ck_min_m=0.
Signed-off-by: Marcus Rosen <marcus.rosen@amd.com>
The same EXP1_FINAL kernel binary handles all four Qwen3-4B prefill linear columns -- M/N/K are runtime args, only KPerBlock=32 is templated, and all three additional K values (2560, 4096, 9728) are multiples of both KPerBlock and Scale_Block_K=128. Standalone CK at the relevant M (1920, 2048) and CPU verify pass: Layer M N K CK ms Triton ms (profile) delta/layer qkv 1920 6144 2560 1.96 2.70 -0.74 qkv 2048 6144 2560 2.17 2.87 -0.70 o_proj 1920 2560 4096 1.73 2.17 -0.44 o_proj 2048 2560 4096 1.95 2.14 -0.19 down_proj 1920 2560 9728 3.32 5.16 -1.84 down_proj 2048 2560 9728 3.76 5.27 -1.51 E2E on Qwen/Qwen3-4B-AWQ at default chunk=2048 (5 reps each, num_prompts=10): gate_up only: TTFT mean 1966 ms, median 1954 ms all four: TTFT mean 1894 ms, median 1868 ms -72 ms mean (-86 ms median); run-to-run noise ~+/-40 ms. Each wired layer adds one CK-format weight copy (~0.92 GB total for the four Qwen3-4B columns; roughly +6% of available memory on a 16 GB iGPU). Falls through to Triton if VLLM_DISABLE_CK_W4A16=1. Signed-off-by: Marcus Rosen <marcus.rosen@amd.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Purpose
Closes AIESW-32176.
Adds a CK WMMA W4A16 b_scale GEMM kernel and dispatches it from
HybridW4A16LinearKernelfor the Qwen3-4Bgate_up_projprefill shape onStrix Halo (gfx1151). This is the largest GEMM by FLOPs in Qwen3-4B prefill
(M=3968, N=19456, K=2560, group_size=128) — roughly 2x the next largest — so
optimizing it gives the biggest TTFT win for AWQ / W4A16 Qwen3-4B.
Two variants land in this PR:
torch.ops._rocm_C.ck_w4a16_b_scale_gemm) forcompressed_tensors uint4b8 checkpoints (e.g.
RedHatAI/Qwen3-4B-quantized.w4a16).torch.ops._rocm_C.ck_w4a16_b_scale_zp_gemm) for AWQcheckpoints with per-group zero points (e.g.
Qwen/Qwen3-4B-AWQ— theAC#3 model). Reuses the symmetric kernel via the identity
(nibble - zp) * scale = (nibble - 8) * scale - (zp - 8) * scale, withscaled_zp = (zp - 8) * scaleprecomputed once at weight load. Adds onefp16 fma per dequant pack — both variants share all tile sizing,
scheduler, and threadmap config.
Build is gated on
-DVLLM_CK_INCLUDE_DIR+-DVLLM_CK_BUILD_INCLUDE_DIR(CK's
ck/config.his generated by CK's own configure step). Without thoseflags
csrc/rocm/ck_w4a16.cuis skipped and the dispatcher falls through toTriton.
VLLM_DISABLE_CK_W4A16=1forces fall-through at runtime for A/Bbenchmarking.
Dispatch decision lives inside the existing
hybrid_w4a16_applycustom op(extended with optional
w_q_ck,ck_target_m,w_scaled_zp_ckargs),keeping it opaque to dynamo. CK-format weights are precomputed once in
process_weights_after_loading; the runtime M check is a plain int compareagainst a per-layer cached value.
Caveats
Depends on six matching CK header changes that are NOT in this PR.
vLLM consumes CK via include paths at build time; the new
csrc/rocm/ck_w4a16.cureferences aDequantPack8WithZpelement-opand an optional
p_b_zero_pointarg onDeviceGemm_BScale_Wmma_CShuffleV3::MakeArgumentthat don't exist inupstream CK yet. The CK changes are additive and
if constexpr-gated(zero impact on existing symmetric CK callers, verified bit-identical),
but they need to land somewhere this build can find them before this PR
will compile.
Options being discussed: (a) open a companion PR against
ROCm/composable_kernel; (b) vendor the modified CK headers under
csrc/rocm/external/. Either way, this PR is draft/WIP awaiting theCK side.
Build behavior in the meantime:
VLLM_CK_INCLUDE_DIRunset (default):csrc/rocm/ck_w4a16.cuisskipped, ops are unregistered, dispatcher falls through to Triton.
Build succeeds; no behavior change vs
gfx11baseline.VLLM_CK_INCLUDE_DIRset, pointing at a CK with the matchingchanges applied: kernel compiles, ops register, dispatch fires for
the target shape.
VLLM_CK_INCLUDE_DIRset, pointing at upstream CK without thechanges: compile error in
csrc/rocm/ck_w4a16.cu.Single-shape (column) specialization: only the Qwen3-4B
gate_up_projcolumn (N=19456, K=2560) on gfx1151 routes to CK; the M dimension is
validated at M=2048 (default chunked-prefill chunk size) and M=3968 (full
prompt with
--max-num-batched-tokens 4096), both ~30 TFLOPS standalone.Other columns (qkv_proj, o_proj, down_proj, other Qwen variants) stay on
Triton — left for follow-up tuning. Per-layer dispatch carries a small
list of validated M values (
ck_target_ms) so adding shapes is aone-line change to
_CK_W4A16_TARGET_SHAPES.Test Plan
Triton W4A16 (via
benchmarks/kernels/benchmark_hybrid_w4a16_gemm.py,edited to add
batch_size=3968+ bf16 providers) and the hipBLASLt fp16TN roofline (
hipblaslt-bench).pack → repack to CK layout → call new op → compare to torch fp16
reference using the dequantized weights. Both symmetric and asymmetric
variants.
VLLM_DISABLE_CK_W4A16env-var toggle on the samevllm bench serveinvocation:Qwen/Qwen3-4B-AWQ(asymmetric, AC#3 model)RedHatAI/Qwen3-4B-quantized.w4a16(symmetric reference)per AC#1 — TBD.
"produces the same output as with the triton kernel") — currently
confirmed via the per-layer fp16 tolerance smoke tests; full-model
output diff TBD.
Test Result
Measured on Strix Halo (gfx1151, Radeon 8060S) against ROCm 7.13.
Standalone GEMM (N=19456, K=2560, group=128, fp16, gfx1151)
At the target shape M=3968:
The same kernel config holds across the M dimension on this column —
both M=2048 (chunked-prefill default chunk size) and M=3968 dispatch
to CK and hit ~30 TFLOPS:
Asymmetric kernel is ~3.7% slower than symmetric per call — exactly the
design promise of one extra fp16 fma per dequant pack.
E2E (
vllm bench serve, max_num_seqs=1, fp16, num_prompts=10)Qwen/Qwen3-4B-AWQ(asymmetric, AC#3 model). The CK dispatch iskeyed per-layer by a min-M threshold (
ck_min_m=256for the gate_upcolumn, where standalone sweep shows the same kernel config holds
28-31 TFLOPS uniformly across M=256-16384). The default
chunked-prefill chunk (max_num_batched_tokens=2048) now hits CK
on both prefill chunks (M=2048 and M=1920) — the prior
--max-num-batched-tokens 4096workaround is not required:The default chunk + CK ON config (1922 ms) wins both A/Bs:
VLLM_DISABLE_CK_W4A16=1at the same chunk: −217 ms (−10%)chunk=4096 + CK ONconfig: −244 ms (−11%) —splitting the prefill into two chunks (both CK) is faster than one
big CK call. Per-dispatch profile shows Triton W4A16's qkv/o_proj/
down_proj scale super-linearly with M, so chunked prefill saves
~270 ms of downstream Triton GEMM time per prompt by feeding those
ops smaller M dimensions.
RedHatAI/Qwen3-4B-quantized.w4a16(symmetric reference, priorchunk=4096 measurement):
Decode (TPOT) is unchanged in all A/Bs — CK only fires on prefill
chunks at the registered M values, leaving the decode path on the
existing skinny / Triton kernels.
Numerical correctness smoke tests
Both well within the fp16 GEMM tolerance class used by
tests/kernels/quantization/.MMLU regression check (AC#1)
TBD — will post
lm_evalMMLU before/after onQwen/Qwen3-4B-AWQ.Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.