ck w4a16#963
Draft
mgehre-amd wants to merge 11 commits into
Draft
Conversation
Adds csrc/rocm/ck_w4a16.cu wrapping CK's DeviceGemm_BScale_Wmma_CShuffleV3 as torch.ops._rocm_C.ck_w4a16_b_scale_gemm. Targets the Qwen3-4B gate_up_proj prefill shape (M=3968, N=19456, K=2560, group=128, fp16) on Strix Halo, where the Triton path leaves significant performance on the table. Build is gated on -DVLLM_CK_INCLUDE_DIR + -DVLLM_CK_BUILD_INCLUDE_DIR (CK's ck/config.h is generated by CK's own configure step, so both source and build include dirs are needed). Without those flags csrc/rocm/ck_w4a16.cu is skipped and the dispatcher falls through to Triton. Dispatch lives inside the existing hybrid_w4a16_apply custom op (extended with optional w_q_ck + ck_target_m args), keeping it opaque to dynamo. CK-format weights are precomputed once in process_weights_after_loading; the runtime M check is a plain int compare against a per-layer cached value. Set VLLM_DISABLE_CK_W4A16=1 to force fall-through for A/B testing. The vLLM ExLlama [N, K//8] int32 weight layout maps to CK's [K0, N, K1//2] int8 via a single reshape + axis swap (no nibble re-packing). Scales pass through unchanged. Symmetric (uint4b8) only -- the kernel deliberately skips when zero-points are present, so asymmetric AWQ checkpoints fall through to Triton until follow-up work adds zero-point support. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Marcus Rosen <marcus.rosen@amd.com>
Adds csrc/rocm/ck_w4a16.cu::ck_w4a16_b_scale_zp_gemm and wires the dispatch in HybridW4A16LinearKernel so AWQ checkpoints with per-group zero points (e.g. Qwen/Qwen3-4B-AWQ) reach the CK kernel. Symmetric callers are unchanged. Implementation reuses the existing symmetric kernel via the identity (nibble - zp) * scale = (nibble - 8) * scale - (zp - 8) * scale The caller precomputes scaled_zp = (zp - 8) * scale per group at weight load and passes it to the new op; the CK kernel subtracts it inline during dequant. All tile sizing, scheduler, and threadmap config is shared with the symmetric path so future tuning benefits both kernels in one place. Requires the matching CK header changes in include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_ab_scale.hpp include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_b_scale.hpp which add an optional zero-point pointer threaded through to the dequant inner loop. Build is gated on -DVLLM_CK_INCLUDE_DIR/-DVLLM_CK_BUILD_INCLUDE_DIR exactly as before; without those flags the new op is skipped and asymmetric callers stay on the Triton path. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Marcus Rosen <marcus.rosen@amd.com>
… default) Standalone sweep on the same EXP1_FINAL kernel config shows it holds ~30 TFLOPS on M=2048 (default chunked-prefill chunk size on the Qwen3-4B gate_up_proj N=19456 K=2560 column), within ~3% of the M=3968 number it was tuned for. Adds M=2048 to the per-layer target-M list so users no longer need --max-num-batched-tokens 4096 to hit the CK path on that shape. Generalizes the dispatch table from a single per-layer M to a small list, threaded through the hybrid_w4a16_apply custom op as SymInt[]?. Membership test stays inside the opaque custom op; runtime check is still a plain Python int compare against a 1-2 element list per layer. Signed-off-by: Marcus Rosen <marcus.rosen@amd.com>
…hreshold
Standalone sweep on the gate_up shape (N=19456, K=2560) shows the same
EXP1_FINAL kernel config holds 28-31 TFLOPS uniformly across M=256-16384,
so a min-M threshold is more accurate than enumerating measured values.
Below ~256 the kernel's fixed launch overhead (~0.4 ms) dominates and
Triton is comparable; the threshold avoids that range.
In particular this dispatches the M=1920 second-chunk case for
chunk=2048+prompt=3968 (which the discrete list missed and fell back
to Triton). E2E on Qwen/Qwen3-4B-AWQ at default chunk=2048:
before: 1987 ms TTFT (discrete list, M=1920 -> Triton fallback)
after: 1922 ms TTFT (min_m=256, M=1920 -> CK)
Per-layer kernel time at M=1920:
Triton 8.48 ms -> CK 6.14 ms (-2.3 ms/layer * 36 layers = -84 ms;
realized -65 ms after chunked-prefill overhead).
Also generalizes to arbitrary chunk sizes -- users no longer need to
match a specific M for CK to fire.
Custom op signature: SymInt[]? ck_target_ms -> SymInt ck_min_m=0.
Signed-off-by: Marcus Rosen <marcus.rosen@amd.com>
The same EXP1_FINAL kernel binary handles all four Qwen3-4B prefill linear columns -- M/N/K are runtime args, only KPerBlock=32 is templated, and all three additional K values (2560, 4096, 9728) are multiples of both KPerBlock and Scale_Block_K=128. Standalone CK at the relevant M (1920, 2048) and CPU verify pass: Layer M N K CK ms Triton ms (profile) delta/layer qkv 1920 6144 2560 1.96 2.70 -0.74 qkv 2048 6144 2560 2.17 2.87 -0.70 o_proj 1920 2560 4096 1.73 2.17 -0.44 o_proj 2048 2560 4096 1.95 2.14 -0.19 down_proj 1920 2560 9728 3.32 5.16 -1.84 down_proj 2048 2560 9728 3.76 5.27 -1.51 E2E on Qwen/Qwen3-4B-AWQ at default chunk=2048 (5 reps each, num_prompts=10): gate_up only: TTFT mean 1966 ms, median 1954 ms all four: TTFT mean 1894 ms, median 1868 ms -72 ms mean (-86 ms median); run-to-run noise ~+/-40 ms. Each wired layer adds one CK-format weight copy (~0.92 GB total for the four Qwen3-4B columns; roughly +6% of available memory on a 16 GB iGPU). Falls through to Triton if VLLM_DISABLE_CK_W4A16=1. Signed-off-by: Marcus Rosen <marcus.rosen@amd.com>
…aiter) Removes the CK W4A16 b_scale GEMM kernel sources and build wiring that were added in 5f4261b (symmetric) and dac4df1 (asymmetric / AWQ). On this branch the dispatcher will instead soft-import aiter's gemm_w4a16 op as a single entry point covering both symmetric and asymmetric routes — no in-tree CK build, no per-build VLLM_CK_INCLUDE_DIR / VLLM_CK_BUILD_INCLUDE_DIR flags, no aiter-side header patches. Removed: - csrc/rocm/ck_w4a16.cu - csrc/rocm/ops.h: ck_w4a16_b_scale_gemm + ck_w4a16_b_scale_zp_gemm decls - csrc/rocm/torch_bindings.cpp: VLLM_HAVE_CK_W4A16 op registrations - CMakeLists.txt: VLLM_CK_INCLUDE_DIR / VLLM_CK_BUILD_INCLUDE_DIR block and the matching target_include_directories / VLLM_HAVE_CK_W4A16 define The original in-tree build is preserved on branch marcusr/aiesw-32176-w4a16-ck-wmma for reference. The dispatcher Python update (calling aiter instead of torch.ops._rocm_C) is the next commit. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Marcus Rosen <marcus.rosen@amd.com>
Replaces the in-tree torch.ops._rocm_C.ck_w4a16_b_scale_gemm /
_b_scale_zp_gemm calls in HybridW4A16LinearKernel with a soft-imported
aiter.ops.gemm_w4a16 — a single function that covers both symmetric
(scaled_zp=None) and asymmetric / AWQ (scaled_zp set) routes, with fp16
and bf16 selected from the caller-allocated output dtype.
Detection follows the find_spec("aiter") pattern from vllm/_aiter_ops.py:
a new _has_aiter_w4a16_op() helper checks both that the aiter package is
importable AND that aiter.ops.gemm_w4a16 exists, and the result is cached
once at module import as _HAS_AITER_W4A16_OP so the lookup doesn't happen
in the hot path and torch.compile sees a Python constant.
If aiter (or its gemm_w4a16 op) isn't available, the dispatcher falls
through to Triton — the same fallback path that exists today when
VLLM_DISABLE_CK_W4A16=1 is set or when no shape matches a registered
target. The env var name is preserved so existing benchmark scripts and
A/B harnesses keep working.
Per-layer min-M threshold, _CK_W4A16_TARGET_SHAPES table, weight repack
helper _repack_vllm_to_ck_b_scale, and the load-time precompute of
_hybrid_w_q_ck / _hybrid_w_scaled_zp_ck are unchanged. The custom op
torch.ops.vllm.hybrid_w4a16_apply still wraps the dispatch so it stays
opaque to dynamo.
Verified: importing
vllm.model_executor.kernels.linear.mixed_precision.hybrid_w4a16
in the project venv succeeds, _CK_W4A16_TARGET_SHAPES is populated, and
torch.ops.vllm.hybrid_w4a16_apply is registered. With aiter not installed
in this venv, _HAS_AITER_W4A16_OP=False and dispatch falls through to
Triton as designed — the in-tree CK build artifacts (kept on branch
marcusr/aiesw-32176-w4a16-ck-wmma) are no longer needed on this branch.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Marcus Rosen <marcus.rosen@amd.com>
…on bf16 The CK kernel builds both fp16 and bf16 instantiations (aiter convention requires both F16 and B16 paths in any new GEMM), but only fp16 entries are registered in the dispatch table. bf16 dispatch is intentionally omitted because Triton W4A16 wins on bf16 at our target hardware. Measured at the gate_up_proj shape (M=2048 N=19456 K=2560 G=128) on Strix Halo (gfx1151): fp16: CK 29.1 TFLOPS vs Triton 21.3 TFLOPS (CK 1.36x faster) bf16: CK 19.4 TFLOPS vs Triton 24.4 TFLOPS (Triton 1.26x faster) Root cause of the bf16 gap is hardware: RDNA3 (gfx11) has no packed bf16 multiply instruction (V_PK_FMA_BF16 doesn't exist in the gfx11 ISA), so CK's bf16 dequant falls back to scalar fp32 conversion. Triton's compiler appears to schedule fp32 dual-issue more aggressively for this workload. Comment-only change. The dispatch table is unchanged (no bf16 entries were ever added); this commit documents WHY they aren't there. Signed-off-by: Marcus Rosen <marcus.rosen@amd.com>
… default-on CK_TRUNCATE for bf16 Makes the CK W4A16 path actually controllable at runtime, expands the dispatch table to cover all non-MoE models in AIESW-32282, and flips the bf16 dequant rounding mode to the new TRUNCATE variant by default. End-to-end TTFT on RedHatAI/Qwen3-8B-quantized.w4a16 (bf16): Triton 3030 ms → CK_RTE 3278 ms (+8.2% loss) → CK_TRUNC 2796 ms (-7.7% win). gsm8k 5-shot strict-match on Qwen3-1.7B-AWQ is 0.620 with both Triton and CK_TRUNC (McNemar paired p=1.000 on n=500 — accuracy-safe). Changes: * Apply-time _ck_disabled() gate restored. The original PR's gate was load-time only (skipped the precompute when disabled); once the precompute was made always-on for the Triton-scaled_zp A/B, VLLM_DISABLE_CK_W4A16 became a no-op and every A/B was secretly CK-vs-CK. _hybrid_w4a16_apply_impl now also checks _ck_disabled() at the per-call CK dispatch site. * aiter dispatch call sites fixed. PR-as-posted called aiter.ops.gemm_w4a16(...) which is the submodule, not the function (TypeError at first call). Replaced with explicit `from aiter.ops.gemm_w4a16 import gemm_w4a16 as _aiter_gemm_w4a16` at both sym and asym branches. * _hybrid_w_scaled_zp_ck precompute always-on (independent of _ck_disabled()) so the Triton-scaled_zp A/B path is reachable. Apply-time _ck_disabled() check is the runtime CK gate. * New VLLM_TRITON_W4A16_SCALED_ZP=1 routes the Triton fallback through the CK-style (nibble - 8) * scale - scaled_zp formulation. Algebraically equivalent to (nibble - zp_raw) * scale; measured wash on gfx1151 (+0.19% TTFT, within noise). Opt-in only. * New VLLM_CK_W4A16_TRUNCATE_BF16, default "1". Routes the CK W4A16 dispatch into the TruncateBf16Round=true template instantiation (aiter+CK in commits 78df4feb9 and 2cfd5509f). For bf16 dispatch the truncate path replaces the IEEE round-to-nearest-even chain (v_add3_u32 +0x7fff bias + v_cmp_o_f32 + v_cndmask_b16 0x7fc0 NaN-quietening) with a >>16 bit-cast; ~4e-3 worst-case ULP error, inside the W4A16 op-test 5e-3 tolerance. CK analog of vLLM PR #953's Triton-side optimization. Without this flip the bf16 CK kernel loses ~8% to Triton; with it CK overtakes Triton by ~8%. fp16 is unaffected (no rounding chain to remove; the Truncate element-op delegates to the non-truncate variant for fp16). * New VLLM_CK_W4A16_PRE_DEQUANT=1 routes into the PreDequantToLDS=true template instantiation. Currently TORCH_CHECKs at dispatch (kernel body is a TODO in aiter csrc) — wires the surface so a follow-up agent can fill in the implementation without touching the dispatcher. * _CK_W4A16_TARGET_SHAPES expanded from 4 to 27 entries covering every (N, K, group, dtype) tuple across all non-MoE AIESW-32282 models: Qwen3-4B / Qwen3-8B-AWQ / Qwen2.5-3B / Qwen2.5-7B / Llama-2-7B / Gemma-2B (fp16 g=128, 21 shapes), Qwen3-VL-4B-AWQ-4bit (fp16 g=32, same 4 shapes), Qwen3-8B-quantized.w4a16 / Qwen3-1.7B / Qwen2.5-VL-7B-quantized.w4a16 (bf16 g=128), SmolLM2-1.7B / Gemma-4-31B-AWQ (bf16 g=32). All entries use (min_M=256, KPerBlock=32) per the EXP1_FINAL kernel config sweep. * _lookup_ck_target() now also accepts any shape that meets the CK config's static constraints (dtype in {fp16, bf16}, group_size in {32, 128}, K % 32 == 0, K % group_size == 0, N % 128 == 0) and returns (min_M=256, KPerBlock=32) for them. The explicit table still wins for per-shape overrides; the fallback reduces the maintenance burden of enumerating every model's shapes by hand. Cross-model TTFT (1 rep, --num-prompts 10 --output-len 1) against the 20260520_233800 regression Triton baselines: Qwen3-4B_3968 -10.3% Qwen3-8B-AWQ_3968 -10.7% Qwen3-8B-CT.w4a16_3968 (forced fp16) -10.1% Qwen3-1.7B_3968 (bf16 native, TRUNC) -5.6% Qwen2.5-3B-Instruct_3968 -15.5% Qwen2.5-7B-Instruct_3968 -11.0% Llama-2-7B-AWQ_1920 -12.3% Gemma-2B_8000 -4.4% SmolLM2-1.7B-Instruct-AWQ_8000 -23.7% Qwen2.5-VL-7B-w4a16-mmstar -3.4% Qwen3-VL-4B-AWQ-xlam -19.1% Qwen3-VL-4B-AWQ-mmstar +8.8% (single-rep, std=354ms) Gemma-4-31B-IT_VLM_AWQ-4bit OOM (also FAIL in regression) 10 wins, 1 loss, 1 OOM; mean -10.0% over the 11 PASS rows. Operational note: vLLM's torch.compile cache at ~/.cache/vllm/torch_compile_cache/ does NOT key on VLLM_DISABLE_CK_W4A16 or the new env vars; bench scripts must clear the cache between every config switch. Companion changes (already committed): aiter 78df4feb9 template extensions + CK submodule bump CK 2cfd5509f threadwise BElementwiseOperation plumbing + bf16 truncate
…he only bf16 behavior in CK now The aiter op no longer accepts truncate_bf16_round (the CK kernel bakes bf16 truncate as the only rounding mode — see aiter commit "AIESW-32282: drop TruncateBf16Round axis"). Removing the vLLM-side env var and the ck_tbt propagation keeps the dispatcher in sync with the aiter API. Justification recap: lm_eval gsm8k 5-shot n=500 on Orion-zhen/Qwen3-1.7B- AWQ shows truncate is statistically indistinguishable from Triton (McNemar p=1.000), and TTFT on RedHatAI/Qwen3-8B-quantized.w4a16 shows truncate is the only setting where CK beats Triton on bf16. There is no production scenario where we'd want RTE. Changes: - hybrid_w4a16.py: remove _ck_truncate_bf16_round() and the VLLM_CK_W4A16_TRUNCATE_BF16 env var; drop ck_tbt from both _aiter_gemm_w4a16(...) call sites. - Tiny pre-existing comment rewordings to satisfy ruff E501 + typos lints that the pre-commit hook surfaced on this file (the lints were already in the prior commit's version of this file — same file, pre-existing issues; no behavior change). Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
VLLM_TRITON_W4A16_SCALED_ZP=1 was an A/B knob to test whether the CK-style (nibble - 8) * scale - scaled_zp formulation would help the Triton W4A16 fallback. Measurement on Qwen3-4B-AWQ showed the delta is -0.19% (within stddev) -- wash on gfx1151. Removes the dead branch and its env var, kernel template specialization, wrapper parameter, and call-site logic. Kept (CK aiter op still uses scaled_zp as its asym carrier): - w_scaled_zp_ck layer parameter precompute at weight load time. - aiter.ops.gemm_w4a16(..., scaled_zp=w_scaled_zp_ck) call sites. Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
No description provided.