Skip to content

Marcusr/aiesw 32176 w4a16 via aiter#935

Draft
marcusr-amd wants to merge 8 commits into
gfx11from
marcusr/aiesw-32176-w4a16-via-aiter
Draft

Marcusr/aiesw 32176 w4a16 via aiter#935
marcusr-amd wants to merge 8 commits into
gfx11from
marcusr/aiesw-32176-w4a16-via-aiter

Conversation

@marcusr-amd
Copy link
Copy Markdown

@marcusr-amd marcusr-amd commented May 14, 2026

Purpose

Test Plan

Test Result


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Adds csrc/rocm/ck_w4a16.cu wrapping CK's DeviceGemm_BScale_Wmma_CShuffleV3
as torch.ops._rocm_C.ck_w4a16_b_scale_gemm. Targets the Qwen3-4B gate_up_proj
prefill shape (M=3968, N=19456, K=2560, group=128, fp16) on Strix Halo, where
the Triton path leaves significant performance on the table.

Build is gated on -DVLLM_CK_INCLUDE_DIR + -DVLLM_CK_BUILD_INCLUDE_DIR (CK's
ck/config.h is generated by CK's own configure step, so both source and build
include dirs are needed). Without those flags csrc/rocm/ck_w4a16.cu is skipped
and the dispatcher falls through to Triton.

Dispatch lives inside the existing hybrid_w4a16_apply custom op (extended with
optional w_q_ck + ck_target_m args), keeping it opaque to dynamo. CK-format
weights are precomputed once in process_weights_after_loading; the runtime M
check is a plain int compare against a per-layer cached value. Set
VLLM_DISABLE_CK_W4A16=1 to force fall-through for A/B testing.

The vLLM ExLlama [N, K//8] int32 weight layout maps to CK's [K0, N, K1//2]
int8 via a single reshape + axis swap (no nibble re-packing). Scales pass
through unchanged.

Symmetric (uint4b8) only -- the kernel deliberately skips when zero-points
are present, so asymmetric AWQ checkpoints fall through to Triton until
follow-up work adds zero-point support.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

Signed-off-by: Marcus Rosen <marcus.rosen@amd.com>
Adds csrc/rocm/ck_w4a16.cu::ck_w4a16_b_scale_zp_gemm and wires the dispatch
in HybridW4A16LinearKernel so AWQ checkpoints with per-group zero points
(e.g. Qwen/Qwen3-4B-AWQ) reach the CK kernel. Symmetric callers are unchanged.

Implementation reuses the existing symmetric kernel via the identity
  (nibble - zp) * scale = (nibble - 8) * scale - (zp - 8) * scale
The caller precomputes scaled_zp = (zp - 8) * scale per group at weight load
and passes it to the new op; the CK kernel subtracts it inline during
dequant. All tile sizing, scheduler, and threadmap config is shared with
the symmetric path so future tuning benefits both kernels in one place.

Requires the matching CK header changes in
  include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
  include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
  include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp
  include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp
  include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_ab_scale.hpp
  include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_b_scale.hpp
which add an optional zero-point pointer threaded through to the dequant
inner loop. Build is gated on -DVLLM_CK_INCLUDE_DIR/-DVLLM_CK_BUILD_INCLUDE_DIR
exactly as before; without those flags the new op is skipped and asymmetric
callers stay on the Triton path.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

Signed-off-by: Marcus Rosen <marcus.rosen@amd.com>
… default)

Standalone sweep on the same EXP1_FINAL kernel config shows it holds
~30 TFLOPS on M=2048 (default chunked-prefill chunk size on the Qwen3-4B
gate_up_proj N=19456 K=2560 column), within ~3% of the M=3968 number it
was tuned for. Adds M=2048 to the per-layer target-M list so users no
longer need --max-num-batched-tokens 4096 to hit the CK path on that shape.

Generalizes the dispatch table from a single per-layer M to a small list,
threaded through the hybrid_w4a16_apply custom op as SymInt[]?. Membership
test stays inside the opaque custom op; runtime check is still a plain
Python int compare against a 1-2 element list per layer.

Signed-off-by: Marcus Rosen <marcus.rosen@amd.com>
…hreshold

Standalone sweep on the gate_up shape (N=19456, K=2560) shows the same
EXP1_FINAL kernel config holds 28-31 TFLOPS uniformly across M=256-16384,
so a min-M threshold is more accurate than enumerating measured values.
Below ~256 the kernel's fixed launch overhead (~0.4 ms) dominates and
Triton is comparable; the threshold avoids that range.

In particular this dispatches the M=1920 second-chunk case for
chunk=2048+prompt=3968 (which the discrete list missed and fell back
to Triton). E2E on Qwen/Qwen3-4B-AWQ at default chunk=2048:

  before: 1987 ms TTFT (discrete list, M=1920 -> Triton fallback)
  after:  1922 ms TTFT (min_m=256, M=1920 -> CK)

  Per-layer kernel time at M=1920:
    Triton 8.48 ms -> CK 6.14 ms (-2.3 ms/layer * 36 layers = -84 ms;
    realized -65 ms after chunked-prefill overhead).

Also generalizes to arbitrary chunk sizes -- users no longer need to
match a specific M for CK to fire.

Custom op signature: SymInt[]? ck_target_ms -> SymInt ck_min_m=0.

Signed-off-by: Marcus Rosen <marcus.rosen@amd.com>
The same EXP1_FINAL kernel binary handles all four Qwen3-4B prefill
linear columns -- M/N/K are runtime args, only KPerBlock=32 is templated,
and all three additional K values (2560, 4096, 9728) are multiples of
both KPerBlock and Scale_Block_K=128. Standalone CK at the relevant
M (1920, 2048) and CPU verify pass:

  Layer       M     N     K     CK ms   Triton ms (profile)   delta/layer
  qkv         1920  6144  2560  1.96    2.70                  -0.74
  qkv         2048  6144  2560  2.17    2.87                  -0.70
  o_proj      1920  2560  4096  1.73    2.17                  -0.44
  o_proj      2048  2560  4096  1.95    2.14                  -0.19
  down_proj   1920  2560  9728  3.32    5.16                  -1.84
  down_proj   2048  2560  9728  3.76    5.27                  -1.51

E2E on Qwen/Qwen3-4B-AWQ at default chunk=2048 (5 reps each, num_prompts=10):

  gate_up only: TTFT mean 1966 ms, median 1954 ms
  all four:     TTFT mean 1894 ms, median 1868 ms

  -72 ms mean (-86 ms median); run-to-run noise ~+/-40 ms.

Each wired layer adds one CK-format weight copy (~0.92 GB total for
the four Qwen3-4B columns; roughly +6% of available memory on a
16 GB iGPU). Falls through to Triton if VLLM_DISABLE_CK_W4A16=1.

Signed-off-by: Marcus Rosen <marcus.rosen@amd.com>
…aiter)

Removes the CK W4A16 b_scale GEMM kernel sources and build wiring that were
added in 5f4261b (symmetric) and dac4df1 (asymmetric / AWQ). On this
branch the dispatcher will instead soft-import aiter's gemm_w4a16 op as a
single entry point covering both symmetric and asymmetric routes — no
in-tree CK build, no per-build VLLM_CK_INCLUDE_DIR / VLLM_CK_BUILD_INCLUDE_DIR
flags, no aiter-side header patches.

Removed:
  - csrc/rocm/ck_w4a16.cu
  - csrc/rocm/ops.h: ck_w4a16_b_scale_gemm + ck_w4a16_b_scale_zp_gemm decls
  - csrc/rocm/torch_bindings.cpp: VLLM_HAVE_CK_W4A16 op registrations
  - CMakeLists.txt: VLLM_CK_INCLUDE_DIR / VLLM_CK_BUILD_INCLUDE_DIR block
    and the matching target_include_directories / VLLM_HAVE_CK_W4A16 define

The original in-tree build is preserved on branch
marcusr/aiesw-32176-w4a16-ck-wmma for reference. The dispatcher Python
update (calling aiter instead of torch.ops._rocm_C) is the next commit.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

Signed-off-by: Marcus Rosen <marcus.rosen@amd.com>
Replaces the in-tree torch.ops._rocm_C.ck_w4a16_b_scale_gemm /
_b_scale_zp_gemm calls in HybridW4A16LinearKernel with a soft-imported
aiter.ops.gemm_w4a16 — a single function that covers both symmetric
(scaled_zp=None) and asymmetric / AWQ (scaled_zp set) routes, with fp16
and bf16 selected from the caller-allocated output dtype.

Detection follows the find_spec("aiter") pattern from vllm/_aiter_ops.py:
a new _has_aiter_w4a16_op() helper checks both that the aiter package is
importable AND that aiter.ops.gemm_w4a16 exists, and the result is cached
once at module import as _HAS_AITER_W4A16_OP so the lookup doesn't happen
in the hot path and torch.compile sees a Python constant.

If aiter (or its gemm_w4a16 op) isn't available, the dispatcher falls
through to Triton — the same fallback path that exists today when
VLLM_DISABLE_CK_W4A16=1 is set or when no shape matches a registered
target. The env var name is preserved so existing benchmark scripts and
A/B harnesses keep working.

Per-layer min-M threshold, _CK_W4A16_TARGET_SHAPES table, weight repack
helper _repack_vllm_to_ck_b_scale, and the load-time precompute of
_hybrid_w_q_ck / _hybrid_w_scaled_zp_ck are unchanged. The custom op
torch.ops.vllm.hybrid_w4a16_apply still wraps the dispatch so it stays
opaque to dynamo.

Verified: importing
vllm.model_executor.kernels.linear.mixed_precision.hybrid_w4a16
in the project venv succeeds, _CK_W4A16_TARGET_SHAPES is populated, and
torch.ops.vllm.hybrid_w4a16_apply is registered. With aiter not installed
in this venv, _HAS_AITER_W4A16_OP=False and dispatch falls through to
Triton as designed — the in-tree CK build artifacts (kept on branch
marcusr/aiesw-32176-w4a16-ck-wmma) are no longer needed on this branch.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

Signed-off-by: Marcus Rosen <marcus.rosen@amd.com>
…on bf16

The CK kernel builds both fp16 and bf16 instantiations (aiter convention
requires both F16 and B16 paths in any new GEMM), but only fp16 entries
are registered in the dispatch table. bf16 dispatch is intentionally
omitted because Triton W4A16 wins on bf16 at our target hardware.

Measured at the gate_up_proj shape (M=2048 N=19456 K=2560 G=128) on
Strix Halo (gfx1151):

  fp16: CK 29.1 TFLOPS vs Triton 21.3 TFLOPS  (CK 1.36x faster)
  bf16: CK 19.4 TFLOPS vs Triton 24.4 TFLOPS  (Triton 1.26x faster)

Root cause of the bf16 gap is hardware: RDNA3 (gfx11) has no packed bf16
multiply instruction (V_PK_FMA_BF16 doesn't exist in the gfx11 ISA), so
CK's bf16 dequant falls back to scalar fp32 conversion. Triton's compiler
appears to schedule fp32 dual-issue more aggressively for this workload.

Comment-only change. The dispatch table is unchanged (no bf16 entries
were ever added); this commit documents WHY they aren't there.

Signed-off-by: Marcus Rosen <marcus.rosen@amd.com>
# gate_up M=2048 N=19456 K=2560 G=128 on gfx1151:
# fp16: CK 29.1 vs Triton 21.3 TFLOPS (CK 1.36x faster) -> dispatch CK
# bf16: CK 19.4 vs Triton 24.4 TFLOPS (Triton 1.26x faster) -> Triton
# Root cause of the bf16 gap: RDNA3 (gfx11) lacks a packed bf16 multiply
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe ideas from #909 helps with the bf16 issue?

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any maybe llvm/llvm-project#186179 can help once it's fixed.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leaving a link to #953 (comment) to review.

Also from Matthias

For the ck side, I wound a way for bf16 with zero points that makes perf close to fp16; this does the conversion from fp32 to bf16 by truncation (upper 16 bit) instead of round-to-nearest.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants