Skip to content

ck w4a16#963

Draft
mgehre-amd wants to merge 11 commits into
gfx11from
matthias.ck-w4a16-aiter-bench
Draft

ck w4a16#963
mgehre-amd wants to merge 11 commits into
gfx11from
matthias.ck-w4a16-aiter-bench

Conversation

@mgehre-amd
Copy link
Copy Markdown

No description provided.

marcusr-amd and others added 11 commits May 11, 2026 09:47
Adds csrc/rocm/ck_w4a16.cu wrapping CK's DeviceGemm_BScale_Wmma_CShuffleV3
as torch.ops._rocm_C.ck_w4a16_b_scale_gemm. Targets the Qwen3-4B gate_up_proj
prefill shape (M=3968, N=19456, K=2560, group=128, fp16) on Strix Halo, where
the Triton path leaves significant performance on the table.

Build is gated on -DVLLM_CK_INCLUDE_DIR + -DVLLM_CK_BUILD_INCLUDE_DIR (CK's
ck/config.h is generated by CK's own configure step, so both source and build
include dirs are needed). Without those flags csrc/rocm/ck_w4a16.cu is skipped
and the dispatcher falls through to Triton.

Dispatch lives inside the existing hybrid_w4a16_apply custom op (extended with
optional w_q_ck + ck_target_m args), keeping it opaque to dynamo. CK-format
weights are precomputed once in process_weights_after_loading; the runtime M
check is a plain int compare against a per-layer cached value. Set
VLLM_DISABLE_CK_W4A16=1 to force fall-through for A/B testing.

The vLLM ExLlama [N, K//8] int32 weight layout maps to CK's [K0, N, K1//2]
int8 via a single reshape + axis swap (no nibble re-packing). Scales pass
through unchanged.

Symmetric (uint4b8) only -- the kernel deliberately skips when zero-points
are present, so asymmetric AWQ checkpoints fall through to Triton until
follow-up work adds zero-point support.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

Signed-off-by: Marcus Rosen <marcus.rosen@amd.com>
Adds csrc/rocm/ck_w4a16.cu::ck_w4a16_b_scale_zp_gemm and wires the dispatch
in HybridW4A16LinearKernel so AWQ checkpoints with per-group zero points
(e.g. Qwen/Qwen3-4B-AWQ) reach the CK kernel. Symmetric callers are unchanged.

Implementation reuses the existing symmetric kernel via the identity
  (nibble - zp) * scale = (nibble - 8) * scale - (zp - 8) * scale
The caller precomputes scaled_zp = (zp - 8) * scale per group at weight load
and passes it to the new op; the CK kernel subtracts it inline during
dequant. All tile sizing, scheduler, and threadmap config is shared with
the symmetric path so future tuning benefits both kernels in one place.

Requires the matching CK header changes in
  include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
  include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
  include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp
  include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp
  include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_ab_scale.hpp
  include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_b_scale.hpp
which add an optional zero-point pointer threaded through to the dequant
inner loop. Build is gated on -DVLLM_CK_INCLUDE_DIR/-DVLLM_CK_BUILD_INCLUDE_DIR
exactly as before; without those flags the new op is skipped and asymmetric
callers stay on the Triton path.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

Signed-off-by: Marcus Rosen <marcus.rosen@amd.com>
… default)

Standalone sweep on the same EXP1_FINAL kernel config shows it holds
~30 TFLOPS on M=2048 (default chunked-prefill chunk size on the Qwen3-4B
gate_up_proj N=19456 K=2560 column), within ~3% of the M=3968 number it
was tuned for. Adds M=2048 to the per-layer target-M list so users no
longer need --max-num-batched-tokens 4096 to hit the CK path on that shape.

Generalizes the dispatch table from a single per-layer M to a small list,
threaded through the hybrid_w4a16_apply custom op as SymInt[]?. Membership
test stays inside the opaque custom op; runtime check is still a plain
Python int compare against a 1-2 element list per layer.

Signed-off-by: Marcus Rosen <marcus.rosen@amd.com>
…hreshold

Standalone sweep on the gate_up shape (N=19456, K=2560) shows the same
EXP1_FINAL kernel config holds 28-31 TFLOPS uniformly across M=256-16384,
so a min-M threshold is more accurate than enumerating measured values.
Below ~256 the kernel's fixed launch overhead (~0.4 ms) dominates and
Triton is comparable; the threshold avoids that range.

In particular this dispatches the M=1920 second-chunk case for
chunk=2048+prompt=3968 (which the discrete list missed and fell back
to Triton). E2E on Qwen/Qwen3-4B-AWQ at default chunk=2048:

  before: 1987 ms TTFT (discrete list, M=1920 -> Triton fallback)
  after:  1922 ms TTFT (min_m=256, M=1920 -> CK)

  Per-layer kernel time at M=1920:
    Triton 8.48 ms -> CK 6.14 ms (-2.3 ms/layer * 36 layers = -84 ms;
    realized -65 ms after chunked-prefill overhead).

Also generalizes to arbitrary chunk sizes -- users no longer need to
match a specific M for CK to fire.

Custom op signature: SymInt[]? ck_target_ms -> SymInt ck_min_m=0.

Signed-off-by: Marcus Rosen <marcus.rosen@amd.com>
The same EXP1_FINAL kernel binary handles all four Qwen3-4B prefill
linear columns -- M/N/K are runtime args, only KPerBlock=32 is templated,
and all three additional K values (2560, 4096, 9728) are multiples of
both KPerBlock and Scale_Block_K=128. Standalone CK at the relevant
M (1920, 2048) and CPU verify pass:

  Layer       M     N     K     CK ms   Triton ms (profile)   delta/layer
  qkv         1920  6144  2560  1.96    2.70                  -0.74
  qkv         2048  6144  2560  2.17    2.87                  -0.70
  o_proj      1920  2560  4096  1.73    2.17                  -0.44
  o_proj      2048  2560  4096  1.95    2.14                  -0.19
  down_proj   1920  2560  9728  3.32    5.16                  -1.84
  down_proj   2048  2560  9728  3.76    5.27                  -1.51

E2E on Qwen/Qwen3-4B-AWQ at default chunk=2048 (5 reps each, num_prompts=10):

  gate_up only: TTFT mean 1966 ms, median 1954 ms
  all four:     TTFT mean 1894 ms, median 1868 ms

  -72 ms mean (-86 ms median); run-to-run noise ~+/-40 ms.

Each wired layer adds one CK-format weight copy (~0.92 GB total for
the four Qwen3-4B columns; roughly +6% of available memory on a
16 GB iGPU). Falls through to Triton if VLLM_DISABLE_CK_W4A16=1.

Signed-off-by: Marcus Rosen <marcus.rosen@amd.com>
…aiter)

Removes the CK W4A16 b_scale GEMM kernel sources and build wiring that were
added in 5f4261b (symmetric) and dac4df1 (asymmetric / AWQ). On this
branch the dispatcher will instead soft-import aiter's gemm_w4a16 op as a
single entry point covering both symmetric and asymmetric routes — no
in-tree CK build, no per-build VLLM_CK_INCLUDE_DIR / VLLM_CK_BUILD_INCLUDE_DIR
flags, no aiter-side header patches.

Removed:
  - csrc/rocm/ck_w4a16.cu
  - csrc/rocm/ops.h: ck_w4a16_b_scale_gemm + ck_w4a16_b_scale_zp_gemm decls
  - csrc/rocm/torch_bindings.cpp: VLLM_HAVE_CK_W4A16 op registrations
  - CMakeLists.txt: VLLM_CK_INCLUDE_DIR / VLLM_CK_BUILD_INCLUDE_DIR block
    and the matching target_include_directories / VLLM_HAVE_CK_W4A16 define

The original in-tree build is preserved on branch
marcusr/aiesw-32176-w4a16-ck-wmma for reference. The dispatcher Python
update (calling aiter instead of torch.ops._rocm_C) is the next commit.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

Signed-off-by: Marcus Rosen <marcus.rosen@amd.com>
Replaces the in-tree torch.ops._rocm_C.ck_w4a16_b_scale_gemm /
_b_scale_zp_gemm calls in HybridW4A16LinearKernel with a soft-imported
aiter.ops.gemm_w4a16 — a single function that covers both symmetric
(scaled_zp=None) and asymmetric / AWQ (scaled_zp set) routes, with fp16
and bf16 selected from the caller-allocated output dtype.

Detection follows the find_spec("aiter") pattern from vllm/_aiter_ops.py:
a new _has_aiter_w4a16_op() helper checks both that the aiter package is
importable AND that aiter.ops.gemm_w4a16 exists, and the result is cached
once at module import as _HAS_AITER_W4A16_OP so the lookup doesn't happen
in the hot path and torch.compile sees a Python constant.

If aiter (or its gemm_w4a16 op) isn't available, the dispatcher falls
through to Triton — the same fallback path that exists today when
VLLM_DISABLE_CK_W4A16=1 is set or when no shape matches a registered
target. The env var name is preserved so existing benchmark scripts and
A/B harnesses keep working.

Per-layer min-M threshold, _CK_W4A16_TARGET_SHAPES table, weight repack
helper _repack_vllm_to_ck_b_scale, and the load-time precompute of
_hybrid_w_q_ck / _hybrid_w_scaled_zp_ck are unchanged. The custom op
torch.ops.vllm.hybrid_w4a16_apply still wraps the dispatch so it stays
opaque to dynamo.

Verified: importing
vllm.model_executor.kernels.linear.mixed_precision.hybrid_w4a16
in the project venv succeeds, _CK_W4A16_TARGET_SHAPES is populated, and
torch.ops.vllm.hybrid_w4a16_apply is registered. With aiter not installed
in this venv, _HAS_AITER_W4A16_OP=False and dispatch falls through to
Triton as designed — the in-tree CK build artifacts (kept on branch
marcusr/aiesw-32176-w4a16-ck-wmma) are no longer needed on this branch.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

Signed-off-by: Marcus Rosen <marcus.rosen@amd.com>
…on bf16

The CK kernel builds both fp16 and bf16 instantiations (aiter convention
requires both F16 and B16 paths in any new GEMM), but only fp16 entries
are registered in the dispatch table. bf16 dispatch is intentionally
omitted because Triton W4A16 wins on bf16 at our target hardware.

Measured at the gate_up_proj shape (M=2048 N=19456 K=2560 G=128) on
Strix Halo (gfx1151):

  fp16: CK 29.1 TFLOPS vs Triton 21.3 TFLOPS  (CK 1.36x faster)
  bf16: CK 19.4 TFLOPS vs Triton 24.4 TFLOPS  (Triton 1.26x faster)

Root cause of the bf16 gap is hardware: RDNA3 (gfx11) has no packed bf16
multiply instruction (V_PK_FMA_BF16 doesn't exist in the gfx11 ISA), so
CK's bf16 dequant falls back to scalar fp32 conversion. Triton's compiler
appears to schedule fp32 dual-issue more aggressively for this workload.

Comment-only change. The dispatch table is unchanged (no bf16 entries
were ever added); this commit documents WHY they aren't there.

Signed-off-by: Marcus Rosen <marcus.rosen@amd.com>
… default-on CK_TRUNCATE for bf16

Makes the CK W4A16 path actually controllable at runtime, expands the
dispatch table to cover all non-MoE models in AIESW-32282, and flips
the bf16 dequant rounding mode to the new TRUNCATE variant by default.

End-to-end TTFT on RedHatAI/Qwen3-8B-quantized.w4a16 (bf16): Triton
3030 ms → CK_RTE 3278 ms (+8.2% loss) → CK_TRUNC 2796 ms (-7.7% win).
gsm8k 5-shot strict-match on Qwen3-1.7B-AWQ is 0.620 with both Triton
and CK_TRUNC (McNemar paired p=1.000 on n=500 — accuracy-safe).

Changes:

* Apply-time _ck_disabled() gate restored. The original PR's gate was
  load-time only (skipped the precompute when disabled); once the
  precompute was made always-on for the Triton-scaled_zp A/B,
  VLLM_DISABLE_CK_W4A16 became a no-op and every A/B was secretly
  CK-vs-CK. _hybrid_w4a16_apply_impl now also checks _ck_disabled()
  at the per-call CK dispatch site.

* aiter dispatch call sites fixed. PR-as-posted called
  aiter.ops.gemm_w4a16(...) which is the submodule, not the function
  (TypeError at first call). Replaced with explicit
  `from aiter.ops.gemm_w4a16 import gemm_w4a16 as _aiter_gemm_w4a16`
  at both sym and asym branches.

* _hybrid_w_scaled_zp_ck precompute always-on (independent of
  _ck_disabled()) so the Triton-scaled_zp A/B path is reachable.
  Apply-time _ck_disabled() check is the runtime CK gate.

* New VLLM_TRITON_W4A16_SCALED_ZP=1 routes the Triton fallback
  through the CK-style (nibble - 8) * scale - scaled_zp formulation.
  Algebraically equivalent to (nibble - zp_raw) * scale; measured wash
  on gfx1151 (+0.19% TTFT, within noise). Opt-in only.

* New VLLM_CK_W4A16_TRUNCATE_BF16, default "1". Routes the CK W4A16
  dispatch into the TruncateBf16Round=true template instantiation
  (aiter+CK in commits 78df4feb9 and 2cfd5509f). For bf16 dispatch
  the truncate path replaces the IEEE round-to-nearest-even chain
  (v_add3_u32 +0x7fff bias + v_cmp_o_f32 + v_cndmask_b16 0x7fc0
  NaN-quietening) with a >>16 bit-cast; ~4e-3 worst-case ULP error,
  inside the W4A16 op-test 5e-3 tolerance. CK analog of vLLM PR #953's
  Triton-side optimization. Without this flip the bf16 CK kernel
  loses ~8% to Triton; with it CK overtakes Triton by ~8%. fp16 is
  unaffected (no rounding chain to remove; the Truncate element-op
  delegates to the non-truncate variant for fp16).

* New VLLM_CK_W4A16_PRE_DEQUANT=1 routes into the PreDequantToLDS=true
  template instantiation. Currently TORCH_CHECKs at dispatch (kernel
  body is a TODO in aiter csrc) — wires the surface so a follow-up
  agent can fill in the implementation without touching the dispatcher.

* _CK_W4A16_TARGET_SHAPES expanded from 4 to 27 entries covering every
  (N, K, group, dtype) tuple across all non-MoE AIESW-32282 models:
  Qwen3-4B / Qwen3-8B-AWQ / Qwen2.5-3B / Qwen2.5-7B / Llama-2-7B /
  Gemma-2B (fp16 g=128, 21 shapes), Qwen3-VL-4B-AWQ-4bit (fp16 g=32,
  same 4 shapes), Qwen3-8B-quantized.w4a16 / Qwen3-1.7B /
  Qwen2.5-VL-7B-quantized.w4a16 (bf16 g=128), SmolLM2-1.7B /
  Gemma-4-31B-AWQ (bf16 g=32). All entries use (min_M=256,
  KPerBlock=32) per the EXP1_FINAL kernel config sweep.

* _lookup_ck_target() now also accepts any shape that meets the CK
  config's static constraints (dtype in {fp16, bf16}, group_size in
  {32, 128}, K % 32 == 0, K % group_size == 0, N % 128 == 0) and
  returns (min_M=256, KPerBlock=32) for them. The explicit table
  still wins for per-shape overrides; the fallback reduces the
  maintenance burden of enumerating every model's shapes by hand.

Cross-model TTFT (1 rep, --num-prompts 10 --output-len 1) against the
20260520_233800 regression Triton baselines:

  Qwen3-4B_3968                          -10.3%
  Qwen3-8B-AWQ_3968                      -10.7%
  Qwen3-8B-CT.w4a16_3968 (forced fp16)   -10.1%
  Qwen3-1.7B_3968 (bf16 native, TRUNC)    -5.6%
  Qwen2.5-3B-Instruct_3968               -15.5%
  Qwen2.5-7B-Instruct_3968               -11.0%
  Llama-2-7B-AWQ_1920                    -12.3%
  Gemma-2B_8000                           -4.4%
  SmolLM2-1.7B-Instruct-AWQ_8000         -23.7%
  Qwen2.5-VL-7B-w4a16-mmstar              -3.4%
  Qwen3-VL-4B-AWQ-xlam                   -19.1%
  Qwen3-VL-4B-AWQ-mmstar                  +8.8%  (single-rep, std=354ms)
  Gemma-4-31B-IT_VLM_AWQ-4bit             OOM    (also FAIL in regression)

10 wins, 1 loss, 1 OOM; mean -10.0% over the 11 PASS rows.

Operational note: vLLM's torch.compile cache at
~/.cache/vllm/torch_compile_cache/ does NOT key on
VLLM_DISABLE_CK_W4A16 or the new env vars; bench scripts must clear
the cache between every config switch.

Companion changes (already committed):
  aiter  78df4feb9  template extensions + CK submodule bump
  CK     2cfd5509f  threadwise BElementwiseOperation plumbing + bf16 truncate
…he only bf16 behavior in CK now

The aiter op no longer accepts truncate_bf16_round (the CK kernel bakes
bf16 truncate as the only rounding mode — see aiter commit
"AIESW-32282: drop TruncateBf16Round axis"). Removing the vLLM-side env
var and the ck_tbt propagation keeps the dispatcher in sync with the
aiter API.

Justification recap: lm_eval gsm8k 5-shot n=500 on Orion-zhen/Qwen3-1.7B-
AWQ shows truncate is statistically indistinguishable from Triton
(McNemar p=1.000), and TTFT on RedHatAI/Qwen3-8B-quantized.w4a16 shows
truncate is the only setting where CK beats Triton on bf16. There is no
production scenario where we'd want RTE.

Changes:
- hybrid_w4a16.py: remove _ck_truncate_bf16_round() and the
  VLLM_CK_W4A16_TRUNCATE_BF16 env var; drop ck_tbt from both
  _aiter_gemm_w4a16(...) call sites.
- Tiny pre-existing comment rewordings to satisfy ruff E501 + typos
  lints that the pre-commit hook surfaced on this file (the lints
  were already in the prior commit's version of this file — same
  file, pre-existing issues; no behavior change).

Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
VLLM_TRITON_W4A16_SCALED_ZP=1 was an A/B knob to test whether the
CK-style (nibble - 8) * scale - scaled_zp formulation would help the
Triton W4A16 fallback. Measurement on Qwen3-4B-AWQ showed the delta is
-0.19% (within stddev) -- wash on gfx1151. Removes the dead branch and
its env var, kernel template specialization, wrapper parameter, and
call-site logic.

Kept (CK aiter op still uses scaled_zp as its asym carrier):
- w_scaled_zp_ck layer parameter precompute at weight load time.
- aiter.ops.gemm_w4a16(..., scaled_zp=w_scaled_zp_ck) call sites.

Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants