From 4b5f9596787709a1a42bd75b44e7d706190e3d5e Mon Sep 17 00:00:00 2001 From: erweiw Date: Sat, 30 May 2026 11:27:46 -0700 Subject: [PATCH] [llama32_1b] int4-AWQ end-to-end decode with HF AutoAWQ checkpoint MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Wires the int4-AWQ decode ELFs from PR #1633 (rms_qkv_int4_rope) and PR #1637 (o_gemv_ffn_int4) into the full inference pipeline so the existing chat_repl / llama32_1b_inference can drive them against a real HuggingFace AutoAWQ-quantized Llama-3.2-1B checkpoint. python3 llama32_1b_inference.py \ --quant awq --run-only --n-tokens 30 \ --model-path amd/Llama-3.2-1B-Instruct-awq-uint4-asym-g128-bf16-lmhead \ --prompt "Once upon a time" # -> "Once upon a time, in a small village nestled in the rolling hills # of a far-off land, there lived a young girl named Sophia. Sophia was a" # 12.4 tok/s decode (~81 ms/tok), coherent continuation. Prefill stays on CPU as a placeholder for this PR (no int4 GEMM kernel / prefill ELFs yet — that's a separate project). The placeholder runs the existing llama32_1b_reference.transformer_block over dequantized-to-bf16 AWQ weights to populate the KV cache, then hands off to the NPU int4 decode loop. Replacing it with int4 NPU prefill later doesn't touch any of the decode wiring landed here. New files: awq_repacker.py - Unpacks AutoAWQ int32-packed nibbles via AWQ_PACK_ORDER [0,2,4,6,1,3,5,7], composes with matvec_int4_packed.pack_inputs to produce the per-tile packed-BO layout the int4 decode ELFs consume. - dequant_to_bf16 (fp16->bf16 scales, asymmetric uint4) for CPU prefill. - Built-in synthetic round-trip self-test (>= 0.9999 correlation vs dense dequant); passes at K/N up to 2048. cpu_prefill.py - Drop-in replacement for run_npu_prefill signature, harvests per-layer k_roped/v from transformer_block intermediates into the KV cache layout expected by run_decode_block. ~165 s for a 40-token prompt; fine for validation, not for production. Modified: llama32_1b_weights.py - load_weights_awq(model_id, config): loads HF AutoAWQ tensors, attaches both bf16 dequant (existing LayerWeights fields, for CPU prefill) and per-tile packed BOs (_wq_packed / .../ _wgateup_packed / _wdown_packed, for NPU decode). Gate+up are interleaved at the nibble level so the int4 FFN ELF consumes them in one arg slot. kernel_builder/backend_presets.py - RGR_INT4_BACKEND, OGF_INT4_BACKEND (same shape as the bf16 ones; distinct instance_name so kernel-cache files don't collide). kernel_builder/external_kernels.py - compile_all_external_kernels(quant=) builds mv_int4_bf16.o when quant="awq" (object already had compile_mv_int4_bf16 from #1633). kernel_builder/cache.py - prepare_air_project(quant=) stages mv_int4_bf16.o into air_project/. compile_and_cache detects int4 ELFs from the name and pipes the right quant through, so existing call sites don't need to change. llama32_1b_decode.py - compile_decode_kernels(cache, config, quant=) builds either the bf16 ELFs or the int4 ELFs. - run_decode_block(..., quant=) reads either bf16 transposed weights or packed-i8 BOs from the same arg slots. llama32_1b_inference.py - --quant {bf16,awq} flag, --model-path AWQ checkpoint override. - awq mode: no prefill compile, no bf16 transpose, no bf16 prefill preload; CPU prefill replaces run_npu_prefill. - --quant=awq is incompatible with --synthetic-weights. Verification on NPU2 with amd/Llama-3.2-1B-Instruct-awq-uint4-asym-g128-bf16-lmhead: - compile-only: rms_qkv_int4_rope (3 s) + o_gemv_ffn_int4 (35 s) + lm_head_gemv (10 s) - 30-token greedy generation produces coherent English, 12.4 tok/s decode - Decode latency (~81 ms/tok) tracks the standalone-ELF PR #1637 result Co-Authored-By: Claude Opus 4.7 (1M context) --- .../llama32_1b/awq_repacker.py | 259 ++++++++++++++++++ .../llama32_1b/cpu_prefill.py | 104 +++++++ .../kernel_builder/backend_presets.py | 20 ++ .../llama32_1b/kernel_builder/cache.py | 22 +- .../kernel_builder/external_kernels.py | 34 ++- .../llama32_1b/llama32_1b_decode.py | 136 ++++++--- .../llama32_1b/llama32_1b_inference.py | 242 +++++++++++----- .../llama32_1b/llama32_1b_weights.py | 226 +++++++++++++++ 8 files changed, 927 insertions(+), 116 deletions(-) create mode 100644 programming_examples/llama32_1b/awq_repacker.py create mode 100644 programming_examples/llama32_1b/cpu_prefill.py diff --git a/programming_examples/llama32_1b/awq_repacker.py b/programming_examples/llama32_1b/awq_repacker.py new file mode 100644 index 000000000..6cd5cc83b --- /dev/null +++ b/programming_examples/llama32_1b/awq_repacker.py @@ -0,0 +1,259 @@ +# Copyright (C) 2026, Advanced Micro Devices, Inc. +# SPDX-License-Identifier: MIT + +"""HuggingFace AutoAWQ checkpoint -> per-tile packed BO layout used by +mlir-air's int4-AWQ GEMV kernels. + +AutoAWQ stores each Linear's quantized weights as three tensors: + qweight: [in_features=K, out_features // 8] int32 + (8 uint4 nibbles packed along N per int32, interleaved by AWQ_PACK_ORDER) + qzeros: [K // group_size, out_features // 8] int32 (same packing) + scales: [K // group_size, out_features] fp16 + +mlir-air's `matvec_int4_packed.pack_inputs` expects: + A_q[M=out, K/2] uint8 (col 2i = low nibble, col 2i+1 = high nibble) + A_s[n_groups, M] bf16 + A_z[n_groups, M] uint8 + +This module bridges the two formats. A built-in self-test (run via +`python3 awq_repacker.py`) generates synthetic AWQ tensors, repacks, and +verifies that the repacked form dequantizes to exactly the same bf16 +weights as a direct dense dequant. +""" + +import argparse +import os +import sys + +import numpy as np +from ml_dtypes import bfloat16 + +# AutoAWQ packs 8 uint4 nibbles into each int32, but the nibble at bit +# position 4*i within the int32 holds the *unpacked* output column +# `8*col_block + AWQ_PACK_ORDER[i]`. See autoawq.utils.packing_utils.pack. +AWQ_PACK_ORDER = np.array([0, 2, 4, 6, 1, 3, 5, 7], dtype=np.int64) +# Inverse: AWQ_UNPACK_PERM[k] == bit position holding output column k. +AWQ_UNPACK_PERM = np.argsort(AWQ_PACK_ORDER) # = [0, 4, 1, 5, 2, 6, 3, 7] + + +def unpack_awq_int32(packed: np.ndarray) -> np.ndarray: + """Unpack AutoAWQ int32 -> uint8 nibbles along the last axis. + + Args: + packed: int32 array, last axis is the packed-N axis of size N//8. + + Returns: + uint8 array, last axis size N, values in [0, 16). + """ + packed64 = packed.astype(np.int64) + shifts = np.arange(8, dtype=np.int64) * 4 + # nibs[..., k, i] = bits [4i : 4i+4] of int32 at packed position k. + nibs = ((packed64[..., :, None] >> shifts) & 0xF).astype(np.uint8) + # Reorder: nibble at bit position i corresponds to output column + # AWQ_PACK_ORDER[i], so to get natural column order we index by + # AWQ_UNPACK_PERM (column k -> bit position AWQ_UNPACK_PERM[k]). + nibs_reordered = nibs[..., AWQ_UNPACK_PERM] + return nibs_reordered.reshape(*packed.shape[:-1], packed.shape[-1] * 8) + + +def dequant_to_bf16(qweight, qzeros, scales, group_size): + """HF AutoAWQ tensors -> dense bf16 weight matrix [in_features, out_features]. + + Matches transformer_block's `wq[in, out]` storage convention so the result + can be assigned directly to LayerWeights.wq / wk / wv / wo / w_gate / + w_up / w_down for the CPU-prefill placeholder path. + + Dequant formula: w[k, n] = (qweight_u[k, n] - qzeros_u[k//gs, n]) * scales[k//gs, n]. + """ + qweight_u = unpack_awq_int32(qweight) # [K, N] uint8 + qzeros_u = unpack_awq_int32(qzeros) # [K/gs, N] uint8 + K, N = qweight_u.shape + n_groups = K // group_size + if qzeros_u.shape != (n_groups, N): + raise ValueError(f"qzeros shape {qzeros_u.shape} vs expected ({n_groups}, {N})") + if scales.shape != (n_groups, N): + raise ValueError(f"scales shape {scales.shape} vs expected ({n_groups}, {N})") + # Round scales to bf16 first so this matches what the NPU kernel actually + # sees (the packed BO carries bf16 scales). fp16->bf16 loses 3 mantissa + # bits, which is real and intentional precision drift relative to the + # canonical AWQ fp16 dequant. + s_bf16_as_f32 = scales.astype(bfloat16).astype(np.float32) + z_per_k = np.repeat(qzeros_u.astype(np.int32), group_size, axis=0) # [K, N] + s_per_k = np.repeat(s_bf16_as_f32, group_size, axis=0) # [K, N] + w_f32 = (qweight_u.astype(np.int32) - z_per_k) * s_per_k + return w_f32.astype(bfloat16) + + +def repack_hf_awq_linear(qweight, qzeros, scales, group_size): + """HF AutoAWQ tensors -> (A_q, A_s, A_z) in mlir-air `pack_inputs` format. + + Returns: + A_q: uint8 [M=out_features, K/2], packed nibble pairs (col 2i = low, + col 2i+1 = high). + A_s: bf16 [n_groups, M] (lossy fp16->bf16 cast on AWQ's smooth scales). + A_z: uint8 [n_groups, M], values in [0, 16). + """ + qweight_u = unpack_awq_int32(qweight) # [K, N] + qzeros_u = unpack_awq_int32(qzeros) # [K/gs, N] + K, N = qweight_u.shape + n_groups = K // group_size + if qzeros_u.shape != (n_groups, N): + raise ValueError(f"qzeros shape {qzeros_u.shape} vs expected ({n_groups}, {N})") + if scales.shape != (n_groups, N): + raise ValueError(f"scales shape {scales.shape} vs expected ({n_groups}, {N})") + # Transpose K-major (HF) -> M-major: weight[m=n, k] = qweight_u[k, n]. + q_mn = np.ascontiguousarray(qweight_u.T) # [M, K] + low = q_mn[:, 0::2] & 0x0F + high = (q_mn[:, 1::2] & 0x0F) << 4 + A_q = (low | high).astype(np.uint8) # [M, K/2] + A_s = np.ascontiguousarray(scales).astype(bfloat16) # [n_groups, M] + A_z = np.ascontiguousarray(qzeros_u).astype(np.uint8) # [n_groups, M] + return A_q, A_s, A_z + + +def repack_for_gemv( + qweight, qzeros, scales, group_size, M_TILE=8, K_CHUNK=2048, N_CORES=8 +): + """HF AutoAWQ -> [total_tiles, tile_bytes] uint8 BO ready for mlir-air decode. + + Calls `matvec_int4_packed.pack_inputs` under the hood. Single-launch + layout (`M_PER_LAUNCH = M`) — matches the int4 decode ELF builders. + """ + sys.path.insert( + 0, + os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "..", + "matrix_vector_multiplication", + "int4_awq", + ), + ) + from matvec_int4_packed import pack_inputs # type: ignore + + A_q, A_s, A_z = repack_hf_awq_linear(qweight, qzeros, scales, group_size) + M, K_half = A_q.shape + K = K_half * 2 + return pack_inputs(A_q, A_s, A_z, M, K, group_size, M_TILE, K_CHUNK, N_CORES, M) + + +# --------------------------------------------------------------------------- +# Synthetic AWQ generator + self-test +# --------------------------------------------------------------------------- + + +def _gen_synthetic_awq(K, N, group_size, seed=42): + """Produce HF-AutoAWQ-shaped tensors from a random nibble matrix. + + Returns (qweight[K, N//8] int32, qzeros[K/gs, N//8] int32, + scales[K/gs, N] fp16, dense_ref[K, N] uint8) where dense_ref + is the un-packed nibble matrix used to verify the unpack path. + """ + rng = np.random.default_rng(seed) + n_groups = K // group_size + # Random uint4 nibbles for both weights and zeros. + nibs = rng.integers(0, 16, size=(K, N), dtype=np.uint8) + z_nibs = rng.integers(0, 16, size=(n_groups, N), dtype=np.uint8) + scales = rng.uniform(0.005, 0.02, size=(n_groups, N)).astype(np.float16) + + # Pack along N axis: bit position i holds column AWQ_PACK_ORDER[i]. + n_blocks = N // 8 + qweight = np.zeros((K, n_blocks), dtype=np.int32) + qzeros = np.zeros((n_groups, n_blocks), dtype=np.int32) + for i in range(8): + col = AWQ_PACK_ORDER[i] + qweight |= nibs[:, col::8].astype(np.int32) << (4 * i) + qzeros |= z_nibs[:, col::8].astype(np.int32) << (4 * i) + return qweight, qzeros, scales, nibs, z_nibs + + +def self_test(K=512, N=128, group_size=128, seed=42, verbose=True): + """Round-trip check: pack synthetic AWQ -> repack -> dequant matches + direct dense dequant. Algebraically identical up to bf16 rounding. + """ + qweight, qzeros, scales, nibs_ref, z_nibs_ref = _gen_synthetic_awq( + K, N, group_size, seed=seed + ) + + # (a) unpack round-trip: confirms AWQ_PACK_ORDER handling. + nibs_unpacked = unpack_awq_int32(qweight) + if not np.array_equal(nibs_unpacked, nibs_ref): + wrong = (nibs_unpacked != nibs_ref).sum() + raise AssertionError( + f"unpack_awq_int32 mismatch on {wrong} / {nibs_ref.size} nibbles" + ) + z_unpacked = unpack_awq_int32(qzeros) + if not np.array_equal(z_unpacked, z_nibs_ref): + raise AssertionError("qzeros unpack mismatch") + if verbose: + print(f" [a] AWQ_PACK_ORDER unpack: PASS ({K}x{N} nibbles)") + + # (b) dense dequant and our repack agree on every (k, n). + w_dense = dequant_to_bf16(qweight, qzeros, scales, group_size) + A_q, A_s, A_z = repack_hf_awq_linear(qweight, qzeros, scales, group_size) + # Reconstruct dense weight from (A_q, A_s, A_z): w'[k, n] = (nib - z) * s. + M = A_q.shape[0] + K2 = A_q.shape[1] * 2 + if (M, K2) != (N, K): + raise AssertionError(f"repack shape mismatch: ({M}, {K2}) vs ({N}, {K})") + nibs_from_packed = np.zeros((M, K2), dtype=np.uint8) + nibs_from_packed[:, 0::2] = A_q & 0x0F + nibs_from_packed[:, 1::2] = (A_q >> 4) & 0x0F + z_per_k = np.repeat(A_z.astype(np.int32), group_size, axis=0) # [K, M] + s_per_k = np.repeat(A_s.astype(np.float32), group_size, axis=0) # [K, M] + # w_repacked[m, k] = (nib[m, k] - z[k//gs, m]) * s[k//gs, m] + w_repacked_f32 = (nibs_from_packed.astype(np.int32) - z_per_k.T) * s_per_k.T + w_repacked = w_repacked_f32.astype(bfloat16) + # Compare in [in, out] orientation: w_dense is [K, N]; w_repacked is [M=N, K]. + if not np.array_equal(w_dense, w_repacked.T): + diff = w_dense.astype(np.float32) - w_repacked.T.astype(np.float32) + mx = np.max(np.abs(diff)) + raise AssertionError(f"dense vs repacked dequant mismatch: max |Δ| = {mx}") + if verbose: + print(f" [b] dense vs repacked dequant: PASS") + + # (c) end-to-end vs matvec_int4_packed.cpu_reference on a random input. + sys.path.insert( + 0, + os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "..", + "matrix_vector_multiplication", + "int4_awq", + ), + ) + from matvec_int4_packed import cpu_reference # type: ignore + + rng = np.random.default_rng(seed + 1) + x = rng.standard_normal(K).astype(bfloat16) + # mlir-air cpu_reference applies dequant + matmul in the same order as + # the NPU kernel; result should match w_dense.T @ x within bf16 rounding. + y_repacked = cpu_reference(A_q, A_s, A_z, x) + y_dense = (w_dense.astype(np.float32).T @ x.astype(np.float32)).astype(bfloat16) + corr = np.corrcoef( + y_repacked.astype(np.float32).flatten(), + y_dense.astype(np.float32).flatten(), + )[0, 1] + if not (corr >= 0.9999): + raise AssertionError( + f"end-to-end correlation {corr:.6f} below 0.9999 threshold" + ) + if verbose: + print(f" [c] end-to-end (cpu_reference vs dense): PASS (corr={corr:.6f})") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + prog="awq_repacker.py", + description="HF AutoAWQ -> mlir-air packed-BO repacker + self-test.", + ) + parser.add_argument("--k", type=int, default=512) + parser.add_argument("--n", type=int, default=128) + parser.add_argument("--gs", type=int, default=128) + parser.add_argument("--seed", type=int, default=42) + args = parser.parse_args() + print( + f"AWQ repacker self-test: K={args.k}, N={args.n}, GS={args.gs}, " + f"seed={args.seed}" + ) + self_test(K=args.k, N=args.n, group_size=args.gs, seed=args.seed) + print("All self-tests PASSED.") diff --git a/programming_examples/llama32_1b/cpu_prefill.py b/programming_examples/llama32_1b/cpu_prefill.py new file mode 100644 index 000000000..7c2f1ab7c --- /dev/null +++ b/programming_examples/llama32_1b/cpu_prefill.py @@ -0,0 +1,104 @@ +# Copyright (C) 2026, Advanced Micro Devices, Inc. +# SPDX-License-Identifier: MIT + +"""CPU prefill placeholder for the int4-AWQ pipeline. + +Wraps `llama32_1b_reference.transformer_block` into a drop-in replacement +for `llama32_1b_inference.run_npu_prefill` so the int4-AWQ end-to-end path +can bootstrap a KV cache without needing int4 prefill ELFs yet. + +Per-layer K (post-RoPE) and V are harvested from each `transformer_block` +call's intermediates dict and written into the same KV cache layout the +NPU decode loop reads from. The final norm + LM head runs on the last +prompt-position hidden state to produce the first generated token, matching +`run_npu_prefill`'s return contract. + +Runtime: numpy bf16 dequant + matmul; ~2 s for a 16-token prompt at 16 +layers, scales linearly. For validation and short interactive prompts only; +production int4 prefill will land later as a separate project and replace +the import in `inference.py`. +""" + +import time + +import numpy as np +from ml_dtypes import bfloat16 + + +def run_cpu_prefill( + token_ids, + weights, + config, + rope_lut_bf16, + max_seq, + tokenizer=None, + quiet=False, +): + """CPU prefill that mirrors `run_npu_prefill`'s return signature. + + Args: + token_ids: list[int] of prompt token IDs. + weights: LlamaWeights with bf16 dequant fields populated (set up by + load_weights_awq via dequant_to_bf16). Packed BO attributes are + ignored here. + config: LlamaConfig. + rope_lut_bf16: (max_seq, head_dim) RoPE LUT in bf16; converted to + f32 internally for the reference math. + max_seq: KV cache stride along the sequence dim. + tokenizer: optional, used only for logging. + quiet: suppress timing prints. + + Returns: + prefill_token: int -- first predicted token ID (greedy argmax) + k_cache: ndarray (n_layers, n_kv_heads, max_seq, head_dim) bfloat16 + v_cache: ndarray (n_layers, n_kv_heads, max_seq, head_dim) bfloat16 + prompt_len: int -- len(token_ids) + """ + from llama32_1b_reference import rms_norm as _rms_norm + from llama32_1b_reference import transformer_block as _transformer_block + + seq_len = len(token_ids) + n_layers = config.n_layers + n_kv_heads = config.n_kv_heads + head_dim = config.head_dim + + if not quiet: + print(f"Running CPU prefill ({n_layers} layers, seq_len={seq_len})...") + t0 = time.time() + + rope_lut_f32 = np.asarray(rope_lut_bf16, dtype=np.float32) + + # Token embedding -> initial hidden states. + embed = np.asarray(weights.embed_table, dtype=np.float32) + x = embed[np.asarray(token_ids)] # (seq_len, emb_dim) + + k_cache = np.zeros((n_layers, n_kv_heads, max_seq, head_dim), dtype=bfloat16) + v_cache = np.zeros((n_layers, n_kv_heads, max_seq, head_dim), dtype=bfloat16) + + for layer_idx in range(n_layers): + lw = weights.layers[layer_idx] + x, inters = _transformer_block(x, lw, rope_lut_f32, config) + # k_roped, v: (seq_len, n_kv_heads * head_dim) -> (n_kv_heads, seq_len, head_dim) + k_roped = inters["k_roped"].reshape(seq_len, n_kv_heads, head_dim) + v = inters["v"].reshape(seq_len, n_kv_heads, head_dim) + k_cache[layer_idx, :, :seq_len, :] = k_roped.transpose(1, 0, 2).astype(bfloat16) + v_cache[layer_idx, :, :seq_len, :] = v.transpose(1, 0, 2).astype(bfloat16) + + # Final norm + LM head on the LAST prompt position only. + final_norm = np.asarray(weights.final_norm, dtype=np.float32) + h_last = _rms_norm(x[-1:], final_norm) # (1, emb_dim) f32 + lm_head = np.asarray(weights.lm_head, dtype=np.float32) + logits_row = (h_last @ lm_head.T).reshape(-1) # (vocab_size,) + prefill_token = int(np.argmax(logits_row)) + + t_prefill = time.time() - t0 + if not quiet: + msg = f"CPU prefill done in {t_prefill:.2f}s. First token: {prefill_token}" + if tokenizer is not None: + try: + msg += f" ({tokenizer.decode([prefill_token])!r})" + except Exception: + pass + print(msg) + + return prefill_token, k_cache, v_cache, seq_len diff --git a/programming_examples/llama32_1b/kernel_builder/backend_presets.py b/programming_examples/llama32_1b/kernel_builder/backend_presets.py index ef396f860..c1dd8be67 100644 --- a/programming_examples/llama32_1b/kernel_builder/backend_presets.py +++ b/programming_examples/llama32_1b/kernel_builder/backend_presets.py @@ -63,3 +63,23 @@ "instance_name": "lm_head_gemv", **GEMV_K2048_BACKEND, } + +# --------------------------------------------------------------------------- +# Decode (int4-AWQ ELFs — same kwarg shape, distinct instance names so the +# kernel cache files don't collide with the bf16 ones) +# --------------------------------------------------------------------------- + +RGR_INT4_BACKEND = { + "output_format": "elf", + "instance_name": "rms_qkv_int4_rope", + "stack_size": 4096, + **GEMV_K2048_BACKEND, +} + +OGF_INT4_BACKEND = { + "output_format": "elf", + "instance_name": "o_gemv_ffn_int4", + "omit_pingpong": "all", + "stack_size": 4096, + **{k: v for k, v in GEMV_K2048_BACKEND.items() if k != "omit_pingpong"}, +} diff --git a/programming_examples/llama32_1b/kernel_builder/cache.py b/programming_examples/llama32_1b/kernel_builder/cache.py index bb4291a7e..b21b2ae49 100644 --- a/programming_examples/llama32_1b/kernel_builder/cache.py +++ b/programming_examples/llama32_1b/kernel_builder/cache.py @@ -12,13 +12,17 @@ from ml_dtypes import bfloat16 -def prepare_air_project(): +def prepare_air_project(quant: str = "bf16"): """Clean and prepare the air_project/ directory for a fresh compilation. aircc defaults to 'air_project/' as its working directory. Sequential compilations leave stale artifacts that corrupt subsequent kernels. This method wipes the directory, compiles all external C++ kernels from source, and copies them to air_project/. + + Args: + quant: "bf16" or "awq". When "awq", also compiles + stages + `mv_int4_bf16.o` so the int4 decode ELFs can link it. """ air_proj = Path("air_project") if air_proj.exists(): @@ -28,7 +32,7 @@ def prepare_air_project(): # Compile external kernels from source (not stale .o copies) from kernel_builder.external_kernels import compile_all_external_kernels - compile_all_external_kernels() + compile_all_external_kernels(quant=quant) # Copy compiled .o files to air_project/ for aiecc to find. Must include # every external symbol referenced by `link_with` in the kernel modules: @@ -38,7 +42,8 @@ def prepare_air_project(): # - silu_and_mul.o : SwiGLU (prefill o_ffn, decode o_gemv_ffn) # - attn.o : flash attention (prefill, when --cpu-attn=False) # - attn_npu2.o : flash attention NPU2 variant alias - for obj_name in [ + # - mv_int4_bf16.o : int4-AWQ GEMV micro-kernel (decode int4 ELFs only) + obj_names = [ "silu_and_mul.o", "rope.o", "attn.o", @@ -46,7 +51,10 @@ def prepare_air_project(): "mv.o", "mv_bf16.o", "attn_decode_npu2.o", - ]: + ] + if quant == "awq": + obj_names.append("mv_int4_bf16.o") + for obj_name in obj_names: src = Path(obj_name) if src.exists(): shutil.copy2(src, air_proj / obj_name) @@ -263,7 +271,11 @@ def compile_and_cache( from air.backend.xrt import XRTBackend self._log(f"Compiling {name}...") - prepare_air_project() + # Int4 ELFs need the AWQ GEMV micro-kernel staged alongside the bf16 + # objects — detect from the kernel name so callers don't have to pass + # an extra flag through every compile_and_cache invocation. + quant = "awq" if "int4" in name else "bf16" + prepare_air_project(quant=quant) backend = XRTBackend(**backend_kwargs) t0 = time.time() diff --git a/programming_examples/llama32_1b/kernel_builder/external_kernels.py b/programming_examples/llama32_1b/kernel_builder/external_kernels.py index a5e0e9d8c..881aa9434 100644 --- a/programming_examples/llama32_1b/kernel_builder/external_kernels.py +++ b/programming_examples/llama32_1b/kernel_builder/external_kernels.py @@ -151,6 +151,20 @@ def compile_attn_npu2(head_dim=64): shutil.copy2("attn_npu2.o", "attn.o") +def compile_mv_k8192(): + """Compile mv_k8192.o with renamed GEMV symbols for K=8192 decode merge.""" + src = _PROJ_ROOT / "matrix_vector_multiplication" / "bf16" / "mv.cc" + _compile_kernel( + src, + "mv_k8192.o", + extra_flags=[ + "-DDIM_M_OUTPUT=2", + "-Dmatvec_vectorized_bf16_bf16=dg_matvec_vectorized_bf16_bf16", + "-Dlinalg_fill_bf16=dg_linalg_fill_bf16", + ], + ) + + def compile_mv(tile_m=8): """Compile mv.o (standard GEMV kernel) from source.""" src = _PROJ_ROOT / "matrix_vector_multiplication" / "bf16" / "mv.cc" @@ -171,13 +185,6 @@ def compile_mv_int4_bf16(m_tile=8, k_chunk=2048, gs=128): ) -def compile_mv_bf16(): - """Compile mv_bf16.o for the 2-tile matvec+add primitive used by - o_gemv_ffn stages 1 and 3.""" - src = _PROJ_ROOT / "matrix_vector_multiplication" / "bf16_cascade" / "mv_bf16.cc" - _compile_kernel(src, "mv_bf16.o") - - def compile_attn_decode_npu2(head_dim=64): """Compile attn_decode_npu2.o (RoPE helpers for the fused decode kernel).""" src = _PROJ_ROOT / "attention_decode" / "attn_decode_npu2.cc" @@ -192,16 +199,25 @@ def compile_attn_decode_npu2(head_dim=64): ) -def compile_all_external_kernels(head_dim=64): +def compile_all_external_kernels(head_dim=64, quant="bf16"): """Compile all external C++ kernels from source. Call this before kernel compilation to ensure all .o files are fresh. Each kernel is only compiled if its .o doesn't already exist. Delete build_peano/*.o to force recompilation. + + Args: + head_dim: attention head dimension (RoPE / attn kernel macros). + quant: "bf16" (default) or "awq". When "awq" the int4 GEMV micro-kernel + (`mv_int4_bf16.o`) is built so the int4 decode ELFs can link it. + bf16-specific GEMV objects (mv.o, mv_k8192.o) are still built so + mixed paths (e.g. bf16 prefill alongside int4 decode) keep working. """ compile_silu_and_mul() compile_rope() compile_attn_npu2(head_dim=head_dim) compile_attn_decode_npu2(head_dim=head_dim) compile_mv() - compile_mv_bf16() + compile_mv_k8192() + if quant == "awq": + compile_mv_int4_bf16() diff --git a/programming_examples/llama32_1b/llama32_1b_decode.py b/programming_examples/llama32_1b/llama32_1b_decode.py index 9292b387e..f84923776 100644 --- a/programming_examples/llama32_1b/llama32_1b_decode.py +++ b/programming_examples/llama32_1b/llama32_1b_decode.py @@ -28,6 +28,8 @@ RGR_BACKEND, OGF_BACKEND, LM_GEMV_BACKEND, + RGR_INT4_BACKEND, + OGF_INT4_BACKEND, ) # --------------------------------------------------------------------------- @@ -35,11 +37,18 @@ # --------------------------------------------------------------------------- -def compile_decode_kernels(cache, config): - """Compile the 3 merged decode kernels.""" +def compile_decode_kernels(cache, config, quant: str = "bf16"): + """Compile the 3 merged decode kernels. + + Args: + cache: KernelCache. + config: LlamaConfig. + quant: "bf16" (default, existing behavior) or "awq" (int4-AWQ ELFs: + rms_qkv_int4_rope + o_gemv_ffn_int4 from PR #1633 / #1637). + """ from kernel_builder.external_kernels import compile_all_external_kernels - compile_all_external_kernels(head_dim=config.head_dim) + compile_all_external_kernels(head_dim=config.head_dim, quant=quant) emb_dim = config.emb_dim n_heads = config.n_heads @@ -49,31 +58,62 @@ def compile_decode_kernels(cache, config): kv_dim = n_kv_heads * head_dim print(f"\n{'='*60}") - print(f"Compiling decode kernels (2-call merged pipeline)...") + print(f"Compiling decode kernels (quant={quant})...") print(f"{'='*60}\n") - # 1. rms_gemv_rope: RMSNorm + QKV GEMV + RoPE Q+K (6 launches, 13 args) - from multi_launch_builder.rms_gemv_rope_multi import ( - build_rms_gemv_rope_module, - ) + if quant == "awq": + # 1. rms_qkv_int4_rope: RMSNorm + int4 QKV GEMV + RoPE Q+K (6 launches, 13 args) + from multi_launch_builder.rms_qkv_int4_rope_multi import ( + build_rms_qkv_int4_rope_module, + ) - cache.compile_and_cache( - "rms_gemv_rope", - build_rms_gemv_rope_module(emb_dim, kv_dim, n_heads, n_kv_heads, head_dim), - {"verbose": cache.verbose, **RGR_BACKEND}, - ) + cache.compile_and_cache( + "rms_qkv_int4_rope", + build_rms_qkv_int4_rope_module( + emb_dim=emb_dim, + kv_dim=kv_dim, + n_heads=n_heads, + n_kv_heads=n_kv_heads, + head_dim=head_dim, + ), + {"verbose": cache.verbose, **RGR_INT4_BACKEND}, + ) + + # 2. o_gemv_ffn_int4: 3-launch full-int4 (matvec_int4_packed_add + + # matvec_int4_swiglu_rms + matvec_int4_packed_add). Same arg6-row0 + # residual routing as the bf16 baseline. + from multi_launch_builder.o_gemv_ffn_int4_multi import ( + build_o_gemv_ffn_int4_module, + ) - # 2. o_gemv_ffn: 3-launch (matvec_2tile_add + matvec_swiglu_rms + - # matvec_2tile_add). Post-attention residual is routed - # through a row-0 subview of arg6 (the packed RMSNorm - # input buffer); see o_gemv_ffn_multi.py for the ABI. - from multi_launch_builder.o_gemv_ffn_multi import build_o_gemv_ffn_module + cache.compile_and_cache( + "o_gemv_ffn_int4", + build_o_gemv_ffn_int4_module(emb_dim=emb_dim, hidden_dim=hidden_dim), + {"verbose": cache.verbose, **OGF_INT4_BACKEND}, + ) + else: + # 1. rms_gemv_rope: RMSNorm + QKV GEMV + RoPE Q+K (6 launches, 13 args) + from multi_launch_builder.rms_gemv_rope_multi import ( + build_rms_gemv_rope_module, + ) - cache.compile_and_cache( - "o_gemv_ffn", - build_o_gemv_ffn_module(emb_dim, hidden_dim), - {"verbose": cache.verbose, **OGF_BACKEND}, - ) + cache.compile_and_cache( + "rms_gemv_rope", + build_rms_gemv_rope_module(emb_dim, kv_dim, n_heads, n_kv_heads, head_dim), + {"verbose": cache.verbose, **RGR_BACKEND}, + ) + + # 2. o_gemv_ffn: 3-launch (matvec_2tile_add + matvec_swiglu_rms + + # matvec_2tile_add). Post-attention residual is routed + # through a row-0 subview of arg6 (the packed RMSNorm + # input buffer); see o_gemv_ffn_multi.py for the ABI. + from multi_launch_builder.o_gemv_ffn_multi import build_o_gemv_ffn_module + + cache.compile_and_cache( + "o_gemv_ffn", + build_o_gemv_ffn_module(emb_dim, hidden_dim), + {"verbose": cache.verbose, **OGF_BACKEND}, + ) # 3. LM Head GEMV multi-launch: 8-partition GEMV in one ELF from multi_launch_builder.lm_head_gemv_multi import ( @@ -145,6 +185,7 @@ def run_decode_block( v_cache_layer, current_pos, rope_lut_bf16, + quant: str = "bf16", ): """Run one transformer block for a single decode token. @@ -188,14 +229,21 @@ def _run(name, backend, *inputs, static_indices=None, **kwargs): # --- Call 1: rms_gemv_rope (6 launches, 13 args) --- # RMSNorm + Q/K/V GEMV + RoPE Q + RoPE K + # Int4 path uses the same arg slots; only the weight types change + # (slots 3/5/7: packed-int4 uint8 BO instead of bf16 matrix). x_in = x_bf16.flatten().astype(bfloat16) w_norm = layer_weights.attn_norm.reshape(emb_dim).astype(bfloat16) normed_buf = np.zeros(emb_dim, dtype=bfloat16) - wq = layer_weights._wq_t + if quant == "awq": + wq = layer_weights._wq_packed + wk = layer_weights._wk_packed + wv = layer_weights._wv_packed + else: + wq = layer_weights._wq_t + wk = layer_weights._wk_t + wv = layer_weights._wv_t q_buf = np.zeros(emb_dim, dtype=bfloat16) - wk = layer_weights._wk_t k_buf = np.zeros(kv_dim, dtype=bfloat16) - wv = layer_weights._wv_t v_buf = np.zeros(kv_dim, dtype=bfloat16) # RoPE LUT for current position @@ -205,17 +253,19 @@ def _run(name, backend, *inputs, static_indices=None, **kwargs): q_roped_buf = np.zeros(emb_dim, dtype=bfloat16) k_roped_buf = np.zeros(kv_dim, dtype=bfloat16) + rgr_name = "rms_qkv_int4_rope" if quant == "awq" else "rms_gemv_rope" + rgr_backend = RGR_INT4_BACKEND if quant == "awq" else RGR_BACKEND results = _run( - "rms_gemv_rope", - RGR_BACKEND, + rgr_name, + rgr_backend, x_in, # arg0 w_norm, # arg1 normed_buf, # arg2 (intermediate) - wq, # arg3 (static) + wq, # arg3 (static, packed-i8 in int4 mode) q_buf, # arg4 (intermediate) - wk, # arg5 (static) + wk, # arg5 (static, packed-i8 in int4 mode) k_buf, # arg6 (intermediate) - wv, # arg7 (static) + wv, # arg7 (static, packed-i8 in int4 mode) v_buf, # arg8 (intermediate/output) lut_q, # arg9 lut_k, # arg10 @@ -249,34 +299,42 @@ def _run(name, backend, *inputs, static_indices=None, **kwargs): # stage 1 in-kernel, row 1 = ffn_norm_w pre-loaded by host). # arg7 = interleaved w_gateup [2*hidden_dim, emb_dim]. arg2/4/5/8/9/10/13 # are dead ABI placeholders; pass small zero buffers. - wo = layer_weights._wo_t + # Int4 path: slots 0/7/12 hold packed-i8 BOs; bf16 dead-slot ABI unchanged. x_residual = x_bf16.flatten().astype(bfloat16) swiglu_buf = np.zeros(hidden_dim, dtype=bfloat16) - w_down = layer_weights._wdown_t output_buf = np.zeros(emb_dim, dtype=bfloat16) + if quant == "awq": + wo = layer_weights._wo_packed + w_gateup = layer_weights._wgateup_packed + w_down = layer_weights._wdown_packed + else: + wo = layer_weights._wo_t + w_gateup = layer_weights._wgateup_t + w_down = layer_weights._wdown_t arg6 = layer_weights._packed_rms_buf # [2, emb_dim] - arg7 = layer_weights._wgateup_t # [2*hidden, emb_dim] z_emb = np.zeros(emb_dim, dtype=bfloat16) z_hidden = np.zeros(hidden_dim, dtype=bfloat16) z_hidden_emb = np.zeros((hidden_dim, emb_dim), dtype=bfloat16) + ogf_name = "o_gemv_ffn_int4" if quant == "awq" else "o_gemv_ffn" + ogf_backend = OGF_INT4_BACKEND if quant == "awq" else OGF_BACKEND results = _run( - "o_gemv_ffn", - OGF_BACKEND, - wo, # arg0 wo (static) + ogf_name, + ogf_backend, + wo, # arg0 wo (static, packed-i8 in int4) attn_out, # arg1 attn_out (input) z_emb, # arg2 (dead) x_residual, # arg3 x_residual (input) z_emb, # arg4 (dead — was res1 bus) z_emb, # arg5 (dead — ffn_norm_w now in arg6[1]) arg6, # arg6 packed RMS input (static) - arg7, # arg7 w_gateup (static) + w_gateup, # arg7 w_gateup (static, packed-i8 in int4) z_hidden, # arg8 (dead) z_hidden_emb, # arg9 (dead — wup folded into arg7) z_hidden, # arg10 (dead) swiglu_buf, # arg11 swiglu (intermediate) - w_down, # arg12 wdown (static) + w_down, # arg12 wdown (static, packed-i8 in int4) z_emb, # arg13 (dead) output_buf, # arg14 output (output) output_indices=[14], diff --git a/programming_examples/llama32_1b/llama32_1b_inference.py b/programming_examples/llama32_1b/llama32_1b_inference.py index 20aff3e51..9576dab61 100644 --- a/programming_examples/llama32_1b/llama32_1b_inference.py +++ b/programming_examples/llama32_1b/llama32_1b_inference.py @@ -37,6 +37,7 @@ from llama32_1b_weights import ( LlamaConfig, load_weights, + load_weights_awq, synthetic_weights, generate_rope_lut, ) @@ -46,6 +47,8 @@ LM_GEMV_BACKEND, RGR_BACKEND, OGF_BACKEND, + RGR_INT4_BACKEND, + OGF_INT4_BACKEND, ) from llama32_1b_prefill import ( compile_all_kernels, @@ -110,10 +113,11 @@ class Session: seq_len: int # padded prompt length (today: 2048) weights: Any # LlamaWeights, mutated by prepare_runtime() tokenizer: Any # transformers AutoTokenizer - prefill_cache: Any # KernelCache + prefill_cache: Any # KernelCache (None in awq mode — CPU prefill placeholder) decode_cache: Any # KernelCache rope_lut_bf16: np.ndarray # (max_seq, head_dim) bfloat16 model_variant: str # "base" | "instruct" + quant: str = "bf16" # "bf16" or "awq" # Decode LM Head constants @@ -133,6 +137,7 @@ def prepare_runtime( config, seq_len, rope_lut_bf16, + quant: str = "bf16", ): """One-time runtime initialization. Called before any timed inference. @@ -166,11 +171,13 @@ def prepare_runtime( kv_dim = n_kv_heads * head_dim # 1. Compile external C++ kernels from source - compile_all_external_kernels(head_dim=head_dim) + compile_all_external_kernels(head_dim=head_dim, quant=quant) # 2. Pre-transpose all decode GEMV weights # GEMV kernel expects A[M,K] but HuggingFace stores (out_features, in_features) - if not hasattr(weights, "_decode_weights_transposed"): + # In awq mode the decode ELFs consume packed BOs (already on LayerWeights + # as _wq_packed etc.) so the bf16 transpose isn't needed. + if quant == "bf16" and not hasattr(weights, "_decode_weights_transposed"): print(" Pre-transposing weights for GEMV...") for lw in weights.layers: lw._wq_t = np.ascontiguousarray( @@ -200,13 +207,17 @@ def prepare_runtime( for i, lw in enumerate(weights.layers): lw._layer_idx = i - # 4. Pre-load prefill weights into per-layer BOs - preload_prefill_weights(weights, config, prefill_cache, seq_len, rope_lut_bf16) + # 4. Pre-load prefill weights into per-layer BOs. + # AWQ path skips this entirely — prefill runs on CPU as a placeholder + # until the int4 prefill ELFs land. Saves several seconds of compile + # time and ~110 MB of bf16 prefill weights on the device. + if quant == "bf16": + preload_prefill_weights(weights, config, prefill_cache, seq_len, rope_lut_bf16) # 5. Pre-load decode weights into per-layer BOs # (lm_head_gemv 8-partition weights here are also reused by prefill's # last-token projection — refactored from full-seq GEMM for ~150 ms savings) - _preload_decode_weights(decode_cache, weights, config) + _preload_decode_weights(decode_cache, weights, config, quant=quant) # Note: NPU warmup pass not needed here — the NPU prefill keeps # the NPU active. Only needed in llama32_1b_decode.py where CPU prefill @@ -216,12 +227,15 @@ def prepare_runtime( print(f" Runtime prepared in {t_prep:.1f}s") -def _preload_decode_weights(decode_cache, weights, config): +def _preload_decode_weights(decode_cache, weights, config, quant: str = "bf16"): """Pre-load all decode transformer block weights into per-layer BOs. Mirrors the preloading pattern from llama32_1b_decode.py: writes all weight data once before timing starts. During inference, static_input_indices skips weight re-writes. + + In quant="awq" mode, uses the int4 decode ELFs (rms_qkv_int4_rope + + o_gemv_ffn_int4) and packed-BO weights from LayerWeights._wq_packed etc. """ if hasattr(weights, "_decode_weights_preloaded_to_bos"): return @@ -242,18 +256,33 @@ def _preload_decode_weights(decode_cache, weights, config): for layer_idx in range(config.n_layers): lw = weights.layers[layer_idx] - # rms_gemv_rope: allocate + write weights + # rms_gemv_rope / rms_qkv_int4_rope: allocate + write weights. + # ABI is identical (13 args, same arg slots, same output_indices and + # static set) — only slots 3/5/7's *type* changes (bf16 [in,out] vs + # packed-i8 [total_tiles, tile_bytes]). + if quant == "awq": + rgr_name = "rms_qkv_int4_rope" + rgr_backend = RGR_INT4_BACKEND + wq_static, wk_static, wv_static = ( + lw._wq_packed, + lw._wk_packed, + lw._wv_packed, + ) + else: + rgr_name = "rms_gemv_rope" + rgr_backend = RGR_BACKEND + wq_static, wk_static, wv_static = lw._wq_t, lw._wk_t, lw._wv_t decode_cache.load_and_run( - "rms_gemv_rope", - RGR_BACKEND, + rgr_name, + rgr_backend, np.zeros(emb_dim, dtype=bfloat16), # x_in lw.attn_norm.reshape(emb_dim).astype(bfloat16), # norm_w np.zeros(emb_dim, dtype=bfloat16), # normed - lw._wq_t, # wq + wq_static, # wq (bf16 [in,out] or packed-i8) np.zeros(emb_dim, dtype=bfloat16), # q - lw._wk_t, # wk + wk_static, # wk np.zeros(kv_dim, dtype=bfloat16), # k - lw._wv_t, # wv + wv_static, # wv np.zeros(kv_dim, dtype=bfloat16), # v rope_lut_q_dummy, # lut_q rope_lut_k_dummy, # lut_k @@ -262,23 +291,24 @@ def _preload_decode_weights(decode_cache, weights, config): output_indices=[8, 11, 12], static_input_indices={1, 3, 5, 7}, intermediate_indices={2, 4, 6, 8, 11, 12}, - bo_key=f"rms_gemv_rope_L{layer_idx}", + bo_key=f"{rgr_name}_L{layer_idx}", ) # o_gemv_ffn (3-stage): build the interleaved w_gateup [2*hidden, emb] - # and the packed [2, emb] RMSNorm-input buffer (row 1 = ffn_norm_w, - # row 0 left zero for stage 1 to overwrite per token). Stashed on - # LayerWeights for reuse across all decode tokens. Frees the original - # _wgate_t/_wup_t once the interleaved copy is in place — they're - # otherwise unused after this preload (~1 GB host RAM saved). - wgate = lw._wgate_t - wup = lw._wup_t - wgateup = np.empty((2 * hidden_dim, emb_dim), dtype=bfloat16) - wgateup[0::2] = wgate - wgateup[1::2] = wup - lw._wgateup_t = wgateup - del lw._wgate_t - del lw._wup_t + # (bf16 only — int4 path already produced it in load_weights_awq) and + # the packed [2, emb] RMSNorm-input buffer (row 1 = ffn_norm_w, row 0 + # left zero for stage 1 to overwrite per token). Stashed on LayerWeights + # for reuse across all decode tokens. + if quant == "bf16": + wgate = lw._wgate_t + wup = lw._wup_t + wgateup = np.empty((2 * hidden_dim, emb_dim), dtype=bfloat16) + wgateup[0::2] = wgate + wgateup[1::2] = wup + lw._wgateup_t = wgateup + # ~1 GB host RAM saved across the 16 layers. + del lw._wgate_t + del lw._wup_t packed = np.empty((2, emb_dim), dtype=bfloat16) packed[0] = 0.0 @@ -289,28 +319,40 @@ def _preload_decode_weights(decode_cache, weights, config): z_hidden = np.zeros(hidden_dim, dtype=bfloat16) z_hidden_emb = np.zeros((hidden_dim, emb_dim), dtype=bfloat16) + if quant == "awq": + ogf_name = "o_gemv_ffn_int4" + ogf_backend = OGF_INT4_BACKEND + wo_static, wgu_static, wd_static = ( + lw._wo_packed, + lw._wgateup_packed, + lw._wdown_packed, + ) + else: + ogf_name = "o_gemv_ffn" + ogf_backend = OGF_BACKEND + wo_static, wgu_static, wd_static = lw._wo_t, lw._wgateup_t, lw._wdown_t decode_cache.load_and_run( - "o_gemv_ffn", - OGF_BACKEND, - lw._wo_t, # arg0 wo (static) - z_emb, # arg1 attn_out - z_emb, # arg2 (dead) - z_emb, # arg3 x_residual - z_emb, # arg4 (dead) - z_emb, # arg5 (dead) + ogf_name, + ogf_backend, + wo_static, # arg0 wo (static) + z_emb, # arg1 attn_out + z_emb, # arg2 (dead) + z_emb, # arg3 x_residual + z_emb, # arg4 (dead) + z_emb, # arg5 (dead) lw._packed_rms_buf, # arg6 packed (static) - lw._wgateup_t, # arg7 w_gateup (static) - z_hidden, # arg8 (dead) - z_hidden_emb, # arg9 (dead) + wgu_static, # arg7 w_gateup (static) + z_hidden, # arg8 (dead) + z_hidden_emb, # arg9 (dead) z_hidden, # arg10 (dead) z_hidden, # arg11 swiglu - lw._wdown_t, # arg12 wdown (static) + wd_static, # arg12 wdown (static) z_emb, # arg13 (dead) z_emb, # arg14 output output_indices=[14], static_input_indices={0, 6, 7, 12}, intermediate_indices={2, 4, 5, 8, 9, 10, 11, 13, 14}, - bo_key=f"o_gemv_ffn_L{layer_idx}", + bo_key=f"{ogf_name}_L{layer_idx}", ) # LM Head GEMV weights (8 partitions) @@ -556,6 +598,7 @@ def generate( verify=False, cpu_attn=True, on_token=None, + quant: str = "bf16", ): """Run NPU prefill + NPU decode generation. @@ -575,21 +618,36 @@ def generate( print(f"LLAMA Inference: prompt_len={seq_len}, n_tokens={n_tokens}") print(f"{'='*60}\n") - # --- Phase 1: NPU Prefill --- - prefill_token, k_cache, v_cache, prompt_len = run_npu_prefill( - prompt_tokens, - weights, - config, - prefill_cache, - decode_cache, - rope_lut_bf16, - max_seq, - tokenizer=tokenizer, - cpu_attn=cpu_attn, - profile=profile, - verify=verify, - quiet=streaming, - ) + # --- Phase 1: Prefill --- + # bf16: NPU prefill ELFs. awq: CPU prefill placeholder (no int4 prefill + # ELFs yet — see cpu_prefill.run_cpu_prefill). + if quant == "awq": + from cpu_prefill import run_cpu_prefill + + prefill_token, k_cache, v_cache, prompt_len = run_cpu_prefill( + prompt_tokens, + weights, + config, + rope_lut_bf16, + max_seq, + tokenizer=tokenizer, + quiet=streaming, + ) + else: + prefill_token, k_cache, v_cache, prompt_len = run_npu_prefill( + prompt_tokens, + weights, + config, + prefill_cache, + decode_cache, + rope_lut_bf16, + max_seq, + tokenizer=tokenizer, + cpu_attn=cpu_attn, + profile=profile, + verify=verify, + quiet=streaming, + ) # --- Phase 2: NPU Decode --- generated_tokens = [prefill_token] # Token 0 = from prefill @@ -620,6 +678,7 @@ def generate( v_cache[layer_idx], current_pos, rope_lut_bf16, + quant=quant, ) # Final RMSNorm (CPU) @@ -697,26 +756,51 @@ def build_session(args) -> Session: config = LlamaConfig() seq_len = 2048 - prefill_cache = KernelCache("prefill_kernel_cache", verbose=args.verbose) - decode_cache = KernelCache("decode_kernel_cache", verbose=args.verbose) + quant = getattr(args, "quant", "bf16") + # bf16: NPU prefill + NPU decode. awq: CPU prefill placeholder + NPU int4 + # decode (no prefill cache needed; the int4 prefill ELFs land in a future + # PR — see cpu_prefill.py). + prefill_cache = ( + KernelCache("prefill_kernel_cache", verbose=args.verbose) + if quant == "bf16" + else None + ) + decode_cache_dir = ( + "decode_kernel_cache" if quant == "bf16" else "decode_kernel_cache_int4" + ) + decode_cache = KernelCache(decode_cache_dir, verbose=args.verbose) if not args.run_only: - print("Compiling prefill kernels...") - compile_all_kernels(prefill_cache, config, seq_len, cpu_attn=args.cpu_attn) + if quant == "bf16": + print("Compiling prefill kernels...") + compile_all_kernels(prefill_cache, config, seq_len, cpu_attn=args.cpu_attn) + else: + print("AWQ mode: prefill runs on CPU, skipping NPU prefill compile.") print("\nCompiling decode kernels...") - compile_decode_kernels(decode_cache, config) + compile_decode_kernels(decode_cache, config, quant=quant) if args.compile_only: sys.exit(0) if args.run_only: - prefill_cache.load_manifest() + if prefill_cache is not None: + prefill_cache.load_manifest() decode_cache.load_manifest() if args.synthetic_weights: print("\nUsing synthetic random weights (skipping HuggingFace download).") weights = synthetic_weights(config) tokenizer = _SyntheticTokenizer() + elif quant == "awq": + model_id = args.model_path or ( + "amd/Llama-3.2-1B-Instruct-awq-uint4-asym-g128-bf16-lmhead" + ) + print(f"\nLoading AWQ weights ({model_id})...") + weights = load_weights_awq(model_id, config=config) + + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained(model_id) else: model_id = ( "meta-llama/Llama-3.2-1B-Instruct" @@ -736,7 +820,13 @@ def build_session(args) -> Session: ).astype(bfloat16) prepare_runtime( - prefill_cache, decode_cache, weights, config, seq_len, rope_lut_bf16 + prefill_cache, + decode_cache, + weights, + config, + seq_len, + rope_lut_bf16, + quant=quant, ) return Session( @@ -748,6 +838,7 @@ def build_session(args) -> Session: decode_cache=decode_cache, rope_lut_bf16=rope_lut_bf16, model_variant=args.model, + quant=quant, ) @@ -776,7 +867,11 @@ def run_once( (generated_token_ids, prompt_len_actual).""" tokens = _tokenize_prompt(session, prompt_text) prompt_len_actual = len(tokens) - if len(tokens) < session.seq_len: + # bf16 NPU prefill is shape-static at session.seq_len, so the prompt must + # be EOS-padded out to that length. CPU prefill (awq mode) runs at the + # raw prompt length — skipping the pad collapses a 2048² CPU-side + # attention from minutes to ~2 s for short prompts. + if session.quant == "bf16" and len(tokens) < session.seq_len: tokens = tokens + [session.tokenizer.eos_token_id] * ( session.seq_len - len(tokens) ) @@ -794,6 +889,7 @@ def run_once( verify=verify, cpu_attn=cpu_attn, on_token=on_token, + quant=session.quant, ) return generated, prompt_len_actual @@ -923,6 +1019,21 @@ def _stream_cb(_token_id: int, delta: str) -> None: default="instruct", help="Model variant: instruct (Q&A, default) or base (completion)", ) + parser.add_argument( + "--quant", + type=str, + choices=["bf16", "awq"], + default="bf16", + help="Weight precision. 'awq': int4-AWQ NPU decode with CPU prefill " + "placeholder; 'bf16' (default): bf16 NPU prefill + decode.", + ) + parser.add_argument( + "--model-path", + type=str, + default=None, + help="AWQ-mode override for the HF model id / local dir. Default: " + "amd/Llama-3.2-1B-Instruct-awq-uint4-asym-g128-bf16-lmhead.", + ) parser.add_argument( "--interactive", action="store_true", @@ -938,6 +1049,11 @@ def _stream_cb(_token_id: int, delta: str) -> None: if args.synthetic_weights and args.interactive: parser.error("--synthetic-weights cannot be combined with --interactive") + if args.synthetic_weights and args.quant == "awq": + parser.error( + "--synthetic-weights only generates bf16 weights; rerun with " + "--quant=bf16 or drop --synthetic-weights for an AWQ checkpoint." + ) if args.interactive: if args.compile_only: diff --git a/programming_examples/llama32_1b/llama32_1b_weights.py b/programming_examples/llama32_1b/llama32_1b_weights.py index 0f156f8ef..3840560b9 100644 --- a/programming_examples/llama32_1b/llama32_1b_weights.py +++ b/programming_examples/llama32_1b/llama32_1b_weights.py @@ -20,6 +20,7 @@ """ import os +import sys import glob as glob_module from dataclasses import dataclass, field from typing import List, Optional @@ -124,6 +125,18 @@ class LlamaWeights: "mlp.down_proj.weight": ("w_down", True), } +# AutoAWQ stores each Linear as three tensors. Field name = the dataclass +# field that owns the bf16 dequant (used by the CPU prefill placeholder). +_HF_AWQ_LINEARS = { + "self_attn.q_proj": "wq", + "self_attn.k_proj": "wk", + "self_attn.v_proj": "wv", + "self_attn.o_proj": "wo", + "mlp.gate_proj": "w_gate", + "mlp.up_proj": "w_up", + "mlp.down_proj": "w_down", +} + # --------------------------------------------------------------------------- # Safetensors loading helpers @@ -327,6 +340,219 @@ def load_weights( ) +# --------------------------------------------------------------------------- +# HuggingFace AutoAWQ loader +# --------------------------------------------------------------------------- + +# Mapping LayerWeights field name -> per-layer packed-BO attribute name. +# Packed BOs are dynamically attached to each LayerWeights instance (same +# pattern inference.py uses for ._wq_t etc.), keyed by the dataclass field +# to avoid extending the dataclass schema. The decode-side runtime reads +# `layer._wq_packed`, `layer._wo_packed`, etc. The fused FFN ELF +# (o_gemv_ffn_int4) wants gate+up as ONE row-interleaved packed BO and +# is exposed separately as `_wgateup_packed`. +_AWQ_PACKED_ATTR = { + "wq": "_wq_packed", + "wk": "_wk_packed", + "wv": "_wv_packed", + "wo": "_wo_packed", + "w_down": "_wdown_packed", +} +# w_gate and w_up are NOT in _AWQ_PACKED_ATTR — they go through a row-level +# interleave first (gate[i] -> row 2i, up[i] -> row 2i+1) before packing, +# so the int4 FFN ELF can consume them in one BO. + + +def load_weights_awq( + model_name_or_path: str, + config: Optional[LlamaConfig] = None, + group_size: int = 128, + m_tile: int = 8, + k_chunk: int = 2048, + n_cores: int = 8, +) -> LlamaWeights: + """Load a HuggingFace AutoAWQ Llama checkpoint. + + For each Linear, stashes BOTH: + - the bf16 dequant on the existing LayerWeights field (wq/wk/.../w_down), + for the CPU prefill placeholder via reference.transformer_block; + - the per-tile packed uint8 BO on `__packed`, for the NPU int4 + decode ELFs (rms_qkv_int4_rope and o_gemv_ffn_int4). + + Args: + model_name_or_path: local dir or HF model id of an AutoAWQ checkpoint. + config: model hyperparameters (defaults to Llama-3.2-1B). + group_size: AWQ group size (typical 128). Must match the checkpoint. + m_tile, k_chunk, n_cores: GEMV packed-BO tiling parameters; defaults + match the int4 decode ELF builders. + + Returns: + LlamaWeights with packed-BO attributes attached. + """ + from safetensors import safe_open + from awq_repacker import dequant_to_bf16, repack_for_gemv, repack_hf_awq_linear + + sys.path.insert( + 0, + os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "..", + "matrix_vector_multiplication", + "int4_awq", + ), + ) + from matvec_int4_packed import pack_inputs as _pack_inputs + + if config is None: + config = LlamaConfig() + + safetensor_files = _resolve_safetensor_files(model_name_or_path) + + key_to_file = {} + for filepath in safetensor_files: + with safe_open(filepath, framework="numpy") as f: + for key in f.keys(): + key_to_file[key] = filepath + + # --- Embedding table --- + embed_key = "model.embed_tokens.weight" + if embed_key not in key_to_file: + raise KeyError(f"Missing weight: {embed_key}") + with safe_open(key_to_file[embed_key], framework="numpy") as f: + embed_table = _load_tensor(f, embed_key, bfloat16) + + # --- Per-layer --- + layers: List[LayerWeights] = [] + for layer_idx in range(config.n_layers): + # Non-quantized: layernorms. + attn_norm_key = f"model.layers.{layer_idx}.input_layernorm.weight" + ffn_norm_key = f"model.layers.{layer_idx}.post_attention_layernorm.weight" + with safe_open(key_to_file[attn_norm_key], framework="numpy") as f: + attn_norm = _load_tensor(f, attn_norm_key, bfloat16) + with safe_open(key_to_file[ffn_norm_key], framework="numpy") as f: + ffn_norm = _load_tensor(f, ffn_norm_key, bfloat16) + + # Each quantized Linear: load qweight/qzeros/scales, then both + # (a) dequant to bf16 -> existing LayerWeights field, for CPU prefill, + # (b) repack-and-pack -> packed BO, for NPU int4 decode. + # gate/up are special: their NPU packed BO must be ONE interleaved + # matrix (gate row 0, up row 0, gate row 1, ...) so the int4 ELF2 + # (matvec_int4_swiglu_rms) can consume them in one input slot. + linear_bf16 = {} + linear_packed = {} + gate_quants = up_quants = None + for hf_prefix, field_name in _HF_AWQ_LINEARS.items(): + base = f"model.layers.{layer_idx}.{hf_prefix}" + qw_key, qz_key, s_key = ( + f"{base}.qweight", + f"{base}.qzeros", + f"{base}.scales", + ) + for k in (qw_key, qz_key, s_key): + if k not in key_to_file: + raise KeyError( + f"Missing AWQ tensor: {k} (is this an AutoAWQ checkpoint?)" + ) + with safe_open(key_to_file[qw_key], framework="numpy") as f: + qw = f.get_tensor(qw_key) + with safe_open(key_to_file[qz_key], framework="numpy") as f: + qz = f.get_tensor(qz_key) + with safe_open(key_to_file[s_key], framework="numpy") as f: + sc = f.get_tensor(s_key) + if qw.dtype != np.int32: + qw = qw.astype(np.int32) + if qz.dtype != np.int32: + qz = qz.astype(np.int32) + # (a) bf16 dequant for CPU prefill: shape [in, out] (matches + # transformer_block's wq[in, out] convention — no transpose). + linear_bf16[field_name] = dequant_to_bf16(qw, qz, sc, group_size) + # (b) packed BO for NPU decode; gate/up are deferred until both + # are loaded so we can interleave them at the nibble level. + if field_name == "w_gate": + gate_quants = repack_hf_awq_linear(qw, qz, sc, group_size) + elif field_name == "w_up": + up_quants = repack_hf_awq_linear(qw, qz, sc, group_size) + else: + linear_packed[field_name] = repack_for_gemv( + qw, + qz, + sc, + group_size, + M_TILE=m_tile, + K_CHUNK=k_chunk, + N_CORES=n_cores, + ) + + # Interleave gate/up at the (A_q, A_s, A_z) level: row 2i = gate[i], + # row 2i+1 = up[i]. Then pack into one BO sized [2*hidden, emb] for + # arg7 of o_gemv_ffn_int4. + if gate_quants is None or up_quants is None: + raise RuntimeError( + "Could not find both mlp.gate_proj and mlp.up_proj AWQ tensors" + ) + g_q, g_s, g_z = gate_quants + u_q, u_s, u_z = up_quants + h_out, k_half = g_q.shape # h_out = hidden_dim, k_half = K/2 + if u_q.shape != (h_out, k_half): + raise RuntimeError("gate_proj and up_proj have different shapes") + gu_q = np.empty((2 * h_out, k_half), dtype=np.uint8) + gu_q[0::2] = g_q + gu_q[1::2] = u_q + n_groups = g_s.shape[0] + gu_s = np.empty((n_groups, 2 * h_out), dtype=g_s.dtype) + gu_s[:, 0::2] = g_s + gu_s[:, 1::2] = u_s + gu_z = np.empty((n_groups, 2 * h_out), dtype=np.uint8) + gu_z[:, 0::2] = g_z + gu_z[:, 1::2] = u_z + M_gateup = 2 * h_out + K_full = k_half * 2 + gateup_packed = _pack_inputs( + gu_q, + gu_s, + gu_z, + M_gateup, + K_full, + group_size, + m_tile, + k_chunk, + n_cores, + M_gateup, + ) + + layer = LayerWeights( + attn_norm=attn_norm, + ffn_norm=ffn_norm, + **linear_bf16, + ) + for field_name, packed in linear_packed.items(): + setattr(layer, _AWQ_PACKED_ATTR[field_name], packed) + layer._wgateup_packed = gateup_packed + layers.append(layer) + + if (layer_idx + 1) % 4 == 0 or layer_idx == 0: + print(f" AWQ layer {layer_idx + 1}/{config.n_layers} loaded") + + # --- Final norm and lm_head (both bf16 in this checkpoint) --- + norm_key = "model.norm.weight" + with safe_open(key_to_file[norm_key], framework="numpy") as f: + final_norm = _load_tensor(f, norm_key, bfloat16) + lm_head_key = "lm_head.weight" + if lm_head_key in key_to_file: + with safe_open(key_to_file[lm_head_key], framework="numpy") as f: + lm_head = _load_tensor(f, lm_head_key, bfloat16) + else: + print(" Tied embeddings: reusing embed_table as lm_head.") + lm_head = embed_table + + return LlamaWeights( + embed_table=embed_table, + layers=layers, + final_norm=final_norm, + lm_head=lm_head, + ) + + # --------------------------------------------------------------------------- # Synthetic-weights builder (CI smoke / verify without HuggingFace download) # ---------------------------------------------------------------------------