diff --git a/programming_examples/llama32_1b/awq_repacker.py b/programming_examples/llama32_1b/awq_repacker.py new file mode 100644 index 000000000..6cd5cc83b --- /dev/null +++ b/programming_examples/llama32_1b/awq_repacker.py @@ -0,0 +1,259 @@ +# Copyright (C) 2026, Advanced Micro Devices, Inc. +# SPDX-License-Identifier: MIT + +"""HuggingFace AutoAWQ checkpoint -> per-tile packed BO layout used by +mlir-air's int4-AWQ GEMV kernels. + +AutoAWQ stores each Linear's quantized weights as three tensors: + qweight: [in_features=K, out_features // 8] int32 + (8 uint4 nibbles packed along N per int32, interleaved by AWQ_PACK_ORDER) + qzeros: [K // group_size, out_features // 8] int32 (same packing) + scales: [K // group_size, out_features] fp16 + +mlir-air's `matvec_int4_packed.pack_inputs` expects: + A_q[M=out, K/2] uint8 (col 2i = low nibble, col 2i+1 = high nibble) + A_s[n_groups, M] bf16 + A_z[n_groups, M] uint8 + +This module bridges the two formats. A built-in self-test (run via +`python3 awq_repacker.py`) generates synthetic AWQ tensors, repacks, and +verifies that the repacked form dequantizes to exactly the same bf16 +weights as a direct dense dequant. +""" + +import argparse +import os +import sys + +import numpy as np +from ml_dtypes import bfloat16 + +# AutoAWQ packs 8 uint4 nibbles into each int32, but the nibble at bit +# position 4*i within the int32 holds the *unpacked* output column +# `8*col_block + AWQ_PACK_ORDER[i]`. See autoawq.utils.packing_utils.pack. +AWQ_PACK_ORDER = np.array([0, 2, 4, 6, 1, 3, 5, 7], dtype=np.int64) +# Inverse: AWQ_UNPACK_PERM[k] == bit position holding output column k. +AWQ_UNPACK_PERM = np.argsort(AWQ_PACK_ORDER) # = [0, 4, 1, 5, 2, 6, 3, 7] + + +def unpack_awq_int32(packed: np.ndarray) -> np.ndarray: + """Unpack AutoAWQ int32 -> uint8 nibbles along the last axis. + + Args: + packed: int32 array, last axis is the packed-N axis of size N//8. + + Returns: + uint8 array, last axis size N, values in [0, 16). + """ + packed64 = packed.astype(np.int64) + shifts = np.arange(8, dtype=np.int64) * 4 + # nibs[..., k, i] = bits [4i : 4i+4] of int32 at packed position k. + nibs = ((packed64[..., :, None] >> shifts) & 0xF).astype(np.uint8) + # Reorder: nibble at bit position i corresponds to output column + # AWQ_PACK_ORDER[i], so to get natural column order we index by + # AWQ_UNPACK_PERM (column k -> bit position AWQ_UNPACK_PERM[k]). + nibs_reordered = nibs[..., AWQ_UNPACK_PERM] + return nibs_reordered.reshape(*packed.shape[:-1], packed.shape[-1] * 8) + + +def dequant_to_bf16(qweight, qzeros, scales, group_size): + """HF AutoAWQ tensors -> dense bf16 weight matrix [in_features, out_features]. + + Matches transformer_block's `wq[in, out]` storage convention so the result + can be assigned directly to LayerWeights.wq / wk / wv / wo / w_gate / + w_up / w_down for the CPU-prefill placeholder path. + + Dequant formula: w[k, n] = (qweight_u[k, n] - qzeros_u[k//gs, n]) * scales[k//gs, n]. + """ + qweight_u = unpack_awq_int32(qweight) # [K, N] uint8 + qzeros_u = unpack_awq_int32(qzeros) # [K/gs, N] uint8 + K, N = qweight_u.shape + n_groups = K // group_size + if qzeros_u.shape != (n_groups, N): + raise ValueError(f"qzeros shape {qzeros_u.shape} vs expected ({n_groups}, {N})") + if scales.shape != (n_groups, N): + raise ValueError(f"scales shape {scales.shape} vs expected ({n_groups}, {N})") + # Round scales to bf16 first so this matches what the NPU kernel actually + # sees (the packed BO carries bf16 scales). fp16->bf16 loses 3 mantissa + # bits, which is real and intentional precision drift relative to the + # canonical AWQ fp16 dequant. + s_bf16_as_f32 = scales.astype(bfloat16).astype(np.float32) + z_per_k = np.repeat(qzeros_u.astype(np.int32), group_size, axis=0) # [K, N] + s_per_k = np.repeat(s_bf16_as_f32, group_size, axis=0) # [K, N] + w_f32 = (qweight_u.astype(np.int32) - z_per_k) * s_per_k + return w_f32.astype(bfloat16) + + +def repack_hf_awq_linear(qweight, qzeros, scales, group_size): + """HF AutoAWQ tensors -> (A_q, A_s, A_z) in mlir-air `pack_inputs` format. + + Returns: + A_q: uint8 [M=out_features, K/2], packed nibble pairs (col 2i = low, + col 2i+1 = high). + A_s: bf16 [n_groups, M] (lossy fp16->bf16 cast on AWQ's smooth scales). + A_z: uint8 [n_groups, M], values in [0, 16). + """ + qweight_u = unpack_awq_int32(qweight) # [K, N] + qzeros_u = unpack_awq_int32(qzeros) # [K/gs, N] + K, N = qweight_u.shape + n_groups = K // group_size + if qzeros_u.shape != (n_groups, N): + raise ValueError(f"qzeros shape {qzeros_u.shape} vs expected ({n_groups}, {N})") + if scales.shape != (n_groups, N): + raise ValueError(f"scales shape {scales.shape} vs expected ({n_groups}, {N})") + # Transpose K-major (HF) -> M-major: weight[m=n, k] = qweight_u[k, n]. + q_mn = np.ascontiguousarray(qweight_u.T) # [M, K] + low = q_mn[:, 0::2] & 0x0F + high = (q_mn[:, 1::2] & 0x0F) << 4 + A_q = (low | high).astype(np.uint8) # [M, K/2] + A_s = np.ascontiguousarray(scales).astype(bfloat16) # [n_groups, M] + A_z = np.ascontiguousarray(qzeros_u).astype(np.uint8) # [n_groups, M] + return A_q, A_s, A_z + + +def repack_for_gemv( + qweight, qzeros, scales, group_size, M_TILE=8, K_CHUNK=2048, N_CORES=8 +): + """HF AutoAWQ -> [total_tiles, tile_bytes] uint8 BO ready for mlir-air decode. + + Calls `matvec_int4_packed.pack_inputs` under the hood. Single-launch + layout (`M_PER_LAUNCH = M`) — matches the int4 decode ELF builders. + """ + sys.path.insert( + 0, + os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "..", + "matrix_vector_multiplication", + "int4_awq", + ), + ) + from matvec_int4_packed import pack_inputs # type: ignore + + A_q, A_s, A_z = repack_hf_awq_linear(qweight, qzeros, scales, group_size) + M, K_half = A_q.shape + K = K_half * 2 + return pack_inputs(A_q, A_s, A_z, M, K, group_size, M_TILE, K_CHUNK, N_CORES, M) + + +# --------------------------------------------------------------------------- +# Synthetic AWQ generator + self-test +# --------------------------------------------------------------------------- + + +def _gen_synthetic_awq(K, N, group_size, seed=42): + """Produce HF-AutoAWQ-shaped tensors from a random nibble matrix. + + Returns (qweight[K, N//8] int32, qzeros[K/gs, N//8] int32, + scales[K/gs, N] fp16, dense_ref[K, N] uint8) where dense_ref + is the un-packed nibble matrix used to verify the unpack path. + """ + rng = np.random.default_rng(seed) + n_groups = K // group_size + # Random uint4 nibbles for both weights and zeros. + nibs = rng.integers(0, 16, size=(K, N), dtype=np.uint8) + z_nibs = rng.integers(0, 16, size=(n_groups, N), dtype=np.uint8) + scales = rng.uniform(0.005, 0.02, size=(n_groups, N)).astype(np.float16) + + # Pack along N axis: bit position i holds column AWQ_PACK_ORDER[i]. + n_blocks = N // 8 + qweight = np.zeros((K, n_blocks), dtype=np.int32) + qzeros = np.zeros((n_groups, n_blocks), dtype=np.int32) + for i in range(8): + col = AWQ_PACK_ORDER[i] + qweight |= nibs[:, col::8].astype(np.int32) << (4 * i) + qzeros |= z_nibs[:, col::8].astype(np.int32) << (4 * i) + return qweight, qzeros, scales, nibs, z_nibs + + +def self_test(K=512, N=128, group_size=128, seed=42, verbose=True): + """Round-trip check: pack synthetic AWQ -> repack -> dequant matches + direct dense dequant. Algebraically identical up to bf16 rounding. + """ + qweight, qzeros, scales, nibs_ref, z_nibs_ref = _gen_synthetic_awq( + K, N, group_size, seed=seed + ) + + # (a) unpack round-trip: confirms AWQ_PACK_ORDER handling. + nibs_unpacked = unpack_awq_int32(qweight) + if not np.array_equal(nibs_unpacked, nibs_ref): + wrong = (nibs_unpacked != nibs_ref).sum() + raise AssertionError( + f"unpack_awq_int32 mismatch on {wrong} / {nibs_ref.size} nibbles" + ) + z_unpacked = unpack_awq_int32(qzeros) + if not np.array_equal(z_unpacked, z_nibs_ref): + raise AssertionError("qzeros unpack mismatch") + if verbose: + print(f" [a] AWQ_PACK_ORDER unpack: PASS ({K}x{N} nibbles)") + + # (b) dense dequant and our repack agree on every (k, n). + w_dense = dequant_to_bf16(qweight, qzeros, scales, group_size) + A_q, A_s, A_z = repack_hf_awq_linear(qweight, qzeros, scales, group_size) + # Reconstruct dense weight from (A_q, A_s, A_z): w'[k, n] = (nib - z) * s. + M = A_q.shape[0] + K2 = A_q.shape[1] * 2 + if (M, K2) != (N, K): + raise AssertionError(f"repack shape mismatch: ({M}, {K2}) vs ({N}, {K})") + nibs_from_packed = np.zeros((M, K2), dtype=np.uint8) + nibs_from_packed[:, 0::2] = A_q & 0x0F + nibs_from_packed[:, 1::2] = (A_q >> 4) & 0x0F + z_per_k = np.repeat(A_z.astype(np.int32), group_size, axis=0) # [K, M] + s_per_k = np.repeat(A_s.astype(np.float32), group_size, axis=0) # [K, M] + # w_repacked[m, k] = (nib[m, k] - z[k//gs, m]) * s[k//gs, m] + w_repacked_f32 = (nibs_from_packed.astype(np.int32) - z_per_k.T) * s_per_k.T + w_repacked = w_repacked_f32.astype(bfloat16) + # Compare in [in, out] orientation: w_dense is [K, N]; w_repacked is [M=N, K]. + if not np.array_equal(w_dense, w_repacked.T): + diff = w_dense.astype(np.float32) - w_repacked.T.astype(np.float32) + mx = np.max(np.abs(diff)) + raise AssertionError(f"dense vs repacked dequant mismatch: max |Δ| = {mx}") + if verbose: + print(f" [b] dense vs repacked dequant: PASS") + + # (c) end-to-end vs matvec_int4_packed.cpu_reference on a random input. + sys.path.insert( + 0, + os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "..", + "matrix_vector_multiplication", + "int4_awq", + ), + ) + from matvec_int4_packed import cpu_reference # type: ignore + + rng = np.random.default_rng(seed + 1) + x = rng.standard_normal(K).astype(bfloat16) + # mlir-air cpu_reference applies dequant + matmul in the same order as + # the NPU kernel; result should match w_dense.T @ x within bf16 rounding. + y_repacked = cpu_reference(A_q, A_s, A_z, x) + y_dense = (w_dense.astype(np.float32).T @ x.astype(np.float32)).astype(bfloat16) + corr = np.corrcoef( + y_repacked.astype(np.float32).flatten(), + y_dense.astype(np.float32).flatten(), + )[0, 1] + if not (corr >= 0.9999): + raise AssertionError( + f"end-to-end correlation {corr:.6f} below 0.9999 threshold" + ) + if verbose: + print(f" [c] end-to-end (cpu_reference vs dense): PASS (corr={corr:.6f})") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + prog="awq_repacker.py", + description="HF AutoAWQ -> mlir-air packed-BO repacker + self-test.", + ) + parser.add_argument("--k", type=int, default=512) + parser.add_argument("--n", type=int, default=128) + parser.add_argument("--gs", type=int, default=128) + parser.add_argument("--seed", type=int, default=42) + args = parser.parse_args() + print( + f"AWQ repacker self-test: K={args.k}, N={args.n}, GS={args.gs}, " + f"seed={args.seed}" + ) + self_test(K=args.k, N=args.n, group_size=args.gs, seed=args.seed) + print("All self-tests PASSED.") diff --git a/programming_examples/llama32_1b/cpu_prefill.py b/programming_examples/llama32_1b/cpu_prefill.py new file mode 100644 index 000000000..7c2f1ab7c --- /dev/null +++ b/programming_examples/llama32_1b/cpu_prefill.py @@ -0,0 +1,104 @@ +# Copyright (C) 2026, Advanced Micro Devices, Inc. +# SPDX-License-Identifier: MIT + +"""CPU prefill placeholder for the int4-AWQ pipeline. + +Wraps `llama32_1b_reference.transformer_block` into a drop-in replacement +for `llama32_1b_inference.run_npu_prefill` so the int4-AWQ end-to-end path +can bootstrap a KV cache without needing int4 prefill ELFs yet. + +Per-layer K (post-RoPE) and V are harvested from each `transformer_block` +call's intermediates dict and written into the same KV cache layout the +NPU decode loop reads from. The final norm + LM head runs on the last +prompt-position hidden state to produce the first generated token, matching +`run_npu_prefill`'s return contract. + +Runtime: numpy bf16 dequant + matmul; ~2 s for a 16-token prompt at 16 +layers, scales linearly. For validation and short interactive prompts only; +production int4 prefill will land later as a separate project and replace +the import in `inference.py`. +""" + +import time + +import numpy as np +from ml_dtypes import bfloat16 + + +def run_cpu_prefill( + token_ids, + weights, + config, + rope_lut_bf16, + max_seq, + tokenizer=None, + quiet=False, +): + """CPU prefill that mirrors `run_npu_prefill`'s return signature. + + Args: + token_ids: list[int] of prompt token IDs. + weights: LlamaWeights with bf16 dequant fields populated (set up by + load_weights_awq via dequant_to_bf16). Packed BO attributes are + ignored here. + config: LlamaConfig. + rope_lut_bf16: (max_seq, head_dim) RoPE LUT in bf16; converted to + f32 internally for the reference math. + max_seq: KV cache stride along the sequence dim. + tokenizer: optional, used only for logging. + quiet: suppress timing prints. + + Returns: + prefill_token: int -- first predicted token ID (greedy argmax) + k_cache: ndarray (n_layers, n_kv_heads, max_seq, head_dim) bfloat16 + v_cache: ndarray (n_layers, n_kv_heads, max_seq, head_dim) bfloat16 + prompt_len: int -- len(token_ids) + """ + from llama32_1b_reference import rms_norm as _rms_norm + from llama32_1b_reference import transformer_block as _transformer_block + + seq_len = len(token_ids) + n_layers = config.n_layers + n_kv_heads = config.n_kv_heads + head_dim = config.head_dim + + if not quiet: + print(f"Running CPU prefill ({n_layers} layers, seq_len={seq_len})...") + t0 = time.time() + + rope_lut_f32 = np.asarray(rope_lut_bf16, dtype=np.float32) + + # Token embedding -> initial hidden states. + embed = np.asarray(weights.embed_table, dtype=np.float32) + x = embed[np.asarray(token_ids)] # (seq_len, emb_dim) + + k_cache = np.zeros((n_layers, n_kv_heads, max_seq, head_dim), dtype=bfloat16) + v_cache = np.zeros((n_layers, n_kv_heads, max_seq, head_dim), dtype=bfloat16) + + for layer_idx in range(n_layers): + lw = weights.layers[layer_idx] + x, inters = _transformer_block(x, lw, rope_lut_f32, config) + # k_roped, v: (seq_len, n_kv_heads * head_dim) -> (n_kv_heads, seq_len, head_dim) + k_roped = inters["k_roped"].reshape(seq_len, n_kv_heads, head_dim) + v = inters["v"].reshape(seq_len, n_kv_heads, head_dim) + k_cache[layer_idx, :, :seq_len, :] = k_roped.transpose(1, 0, 2).astype(bfloat16) + v_cache[layer_idx, :, :seq_len, :] = v.transpose(1, 0, 2).astype(bfloat16) + + # Final norm + LM head on the LAST prompt position only. + final_norm = np.asarray(weights.final_norm, dtype=np.float32) + h_last = _rms_norm(x[-1:], final_norm) # (1, emb_dim) f32 + lm_head = np.asarray(weights.lm_head, dtype=np.float32) + logits_row = (h_last @ lm_head.T).reshape(-1) # (vocab_size,) + prefill_token = int(np.argmax(logits_row)) + + t_prefill = time.time() - t0 + if not quiet: + msg = f"CPU prefill done in {t_prefill:.2f}s. First token: {prefill_token}" + if tokenizer is not None: + try: + msg += f" ({tokenizer.decode([prefill_token])!r})" + except Exception: + pass + print(msg) + + return prefill_token, k_cache, v_cache, seq_len diff --git a/programming_examples/llama32_1b/kernel_builder/backend_presets.py b/programming_examples/llama32_1b/kernel_builder/backend_presets.py index ef396f860..c1dd8be67 100644 --- a/programming_examples/llama32_1b/kernel_builder/backend_presets.py +++ b/programming_examples/llama32_1b/kernel_builder/backend_presets.py @@ -63,3 +63,23 @@ "instance_name": "lm_head_gemv", **GEMV_K2048_BACKEND, } + +# --------------------------------------------------------------------------- +# Decode (int4-AWQ ELFs — same kwarg shape, distinct instance names so the +# kernel cache files don't collide with the bf16 ones) +# --------------------------------------------------------------------------- + +RGR_INT4_BACKEND = { + "output_format": "elf", + "instance_name": "rms_qkv_int4_rope", + "stack_size": 4096, + **GEMV_K2048_BACKEND, +} + +OGF_INT4_BACKEND = { + "output_format": "elf", + "instance_name": "o_gemv_ffn_int4", + "omit_pingpong": "all", + "stack_size": 4096, + **{k: v for k, v in GEMV_K2048_BACKEND.items() if k != "omit_pingpong"}, +} diff --git a/programming_examples/llama32_1b/kernel_builder/cache.py b/programming_examples/llama32_1b/kernel_builder/cache.py index bb4291a7e..b21b2ae49 100644 --- a/programming_examples/llama32_1b/kernel_builder/cache.py +++ b/programming_examples/llama32_1b/kernel_builder/cache.py @@ -12,13 +12,17 @@ from ml_dtypes import bfloat16 -def prepare_air_project(): +def prepare_air_project(quant: str = "bf16"): """Clean and prepare the air_project/ directory for a fresh compilation. aircc defaults to 'air_project/' as its working directory. Sequential compilations leave stale artifacts that corrupt subsequent kernels. This method wipes the directory, compiles all external C++ kernels from source, and copies them to air_project/. + + Args: + quant: "bf16" or "awq". When "awq", also compiles + stages + `mv_int4_bf16.o` so the int4 decode ELFs can link it. """ air_proj = Path("air_project") if air_proj.exists(): @@ -28,7 +32,7 @@ def prepare_air_project(): # Compile external kernels from source (not stale .o copies) from kernel_builder.external_kernels import compile_all_external_kernels - compile_all_external_kernels() + compile_all_external_kernels(quant=quant) # Copy compiled .o files to air_project/ for aiecc to find. Must include # every external symbol referenced by `link_with` in the kernel modules: @@ -38,7 +42,8 @@ def prepare_air_project(): # - silu_and_mul.o : SwiGLU (prefill o_ffn, decode o_gemv_ffn) # - attn.o : flash attention (prefill, when --cpu-attn=False) # - attn_npu2.o : flash attention NPU2 variant alias - for obj_name in [ + # - mv_int4_bf16.o : int4-AWQ GEMV micro-kernel (decode int4 ELFs only) + obj_names = [ "silu_and_mul.o", "rope.o", "attn.o", @@ -46,7 +51,10 @@ def prepare_air_project(): "mv.o", "mv_bf16.o", "attn_decode_npu2.o", - ]: + ] + if quant == "awq": + obj_names.append("mv_int4_bf16.o") + for obj_name in obj_names: src = Path(obj_name) if src.exists(): shutil.copy2(src, air_proj / obj_name) @@ -263,7 +271,11 @@ def compile_and_cache( from air.backend.xrt import XRTBackend self._log(f"Compiling {name}...") - prepare_air_project() + # Int4 ELFs need the AWQ GEMV micro-kernel staged alongside the bf16 + # objects — detect from the kernel name so callers don't have to pass + # an extra flag through every compile_and_cache invocation. + quant = "awq" if "int4" in name else "bf16" + prepare_air_project(quant=quant) backend = XRTBackend(**backend_kwargs) t0 = time.time() diff --git a/programming_examples/llama32_1b/kernel_builder/external_kernels.py b/programming_examples/llama32_1b/kernel_builder/external_kernels.py index a5e0e9d8c..881aa9434 100644 --- a/programming_examples/llama32_1b/kernel_builder/external_kernels.py +++ b/programming_examples/llama32_1b/kernel_builder/external_kernels.py @@ -151,6 +151,20 @@ def compile_attn_npu2(head_dim=64): shutil.copy2("attn_npu2.o", "attn.o") +def compile_mv_k8192(): + """Compile mv_k8192.o with renamed GEMV symbols for K=8192 decode merge.""" + src = _PROJ_ROOT / "matrix_vector_multiplication" / "bf16" / "mv.cc" + _compile_kernel( + src, + "mv_k8192.o", + extra_flags=[ + "-DDIM_M_OUTPUT=2", + "-Dmatvec_vectorized_bf16_bf16=dg_matvec_vectorized_bf16_bf16", + "-Dlinalg_fill_bf16=dg_linalg_fill_bf16", + ], + ) + + def compile_mv(tile_m=8): """Compile mv.o (standard GEMV kernel) from source.""" src = _PROJ_ROOT / "matrix_vector_multiplication" / "bf16" / "mv.cc" @@ -171,13 +185,6 @@ def compile_mv_int4_bf16(m_tile=8, k_chunk=2048, gs=128): ) -def compile_mv_bf16(): - """Compile mv_bf16.o for the 2-tile matvec+add primitive used by - o_gemv_ffn stages 1 and 3.""" - src = _PROJ_ROOT / "matrix_vector_multiplication" / "bf16_cascade" / "mv_bf16.cc" - _compile_kernel(src, "mv_bf16.o") - - def compile_attn_decode_npu2(head_dim=64): """Compile attn_decode_npu2.o (RoPE helpers for the fused decode kernel).""" src = _PROJ_ROOT / "attention_decode" / "attn_decode_npu2.cc" @@ -192,16 +199,25 @@ def compile_attn_decode_npu2(head_dim=64): ) -def compile_all_external_kernels(head_dim=64): +def compile_all_external_kernels(head_dim=64, quant="bf16"): """Compile all external C++ kernels from source. Call this before kernel compilation to ensure all .o files are fresh. Each kernel is only compiled if its .o doesn't already exist. Delete build_peano/*.o to force recompilation. + + Args: + head_dim: attention head dimension (RoPE / attn kernel macros). + quant: "bf16" (default) or "awq". When "awq" the int4 GEMV micro-kernel + (`mv_int4_bf16.o`) is built so the int4 decode ELFs can link it. + bf16-specific GEMV objects (mv.o, mv_k8192.o) are still built so + mixed paths (e.g. bf16 prefill alongside int4 decode) keep working. """ compile_silu_and_mul() compile_rope() compile_attn_npu2(head_dim=head_dim) compile_attn_decode_npu2(head_dim=head_dim) compile_mv() - compile_mv_bf16() + compile_mv_k8192() + if quant == "awq": + compile_mv_int4_bf16() diff --git a/programming_examples/llama32_1b/llama32_1b_decode.py b/programming_examples/llama32_1b/llama32_1b_decode.py index 9292b387e..f84923776 100644 --- a/programming_examples/llama32_1b/llama32_1b_decode.py +++ b/programming_examples/llama32_1b/llama32_1b_decode.py @@ -28,6 +28,8 @@ RGR_BACKEND, OGF_BACKEND, LM_GEMV_BACKEND, + RGR_INT4_BACKEND, + OGF_INT4_BACKEND, ) # --------------------------------------------------------------------------- @@ -35,11 +37,18 @@ # --------------------------------------------------------------------------- -def compile_decode_kernels(cache, config): - """Compile the 3 merged decode kernels.""" +def compile_decode_kernels(cache, config, quant: str = "bf16"): + """Compile the 3 merged decode kernels. + + Args: + cache: KernelCache. + config: LlamaConfig. + quant: "bf16" (default, existing behavior) or "awq" (int4-AWQ ELFs: + rms_qkv_int4_rope + o_gemv_ffn_int4 from PR #1633 / #1637). + """ from kernel_builder.external_kernels import compile_all_external_kernels - compile_all_external_kernels(head_dim=config.head_dim) + compile_all_external_kernels(head_dim=config.head_dim, quant=quant) emb_dim = config.emb_dim n_heads = config.n_heads @@ -49,31 +58,62 @@ def compile_decode_kernels(cache, config): kv_dim = n_kv_heads * head_dim print(f"\n{'='*60}") - print(f"Compiling decode kernels (2-call merged pipeline)...") + print(f"Compiling decode kernels (quant={quant})...") print(f"{'='*60}\n") - # 1. rms_gemv_rope: RMSNorm + QKV GEMV + RoPE Q+K (6 launches, 13 args) - from multi_launch_builder.rms_gemv_rope_multi import ( - build_rms_gemv_rope_module, - ) + if quant == "awq": + # 1. rms_qkv_int4_rope: RMSNorm + int4 QKV GEMV + RoPE Q+K (6 launches, 13 args) + from multi_launch_builder.rms_qkv_int4_rope_multi import ( + build_rms_qkv_int4_rope_module, + ) - cache.compile_and_cache( - "rms_gemv_rope", - build_rms_gemv_rope_module(emb_dim, kv_dim, n_heads, n_kv_heads, head_dim), - {"verbose": cache.verbose, **RGR_BACKEND}, - ) + cache.compile_and_cache( + "rms_qkv_int4_rope", + build_rms_qkv_int4_rope_module( + emb_dim=emb_dim, + kv_dim=kv_dim, + n_heads=n_heads, + n_kv_heads=n_kv_heads, + head_dim=head_dim, + ), + {"verbose": cache.verbose, **RGR_INT4_BACKEND}, + ) + + # 2. o_gemv_ffn_int4: 3-launch full-int4 (matvec_int4_packed_add + + # matvec_int4_swiglu_rms + matvec_int4_packed_add). Same arg6-row0 + # residual routing as the bf16 baseline. + from multi_launch_builder.o_gemv_ffn_int4_multi import ( + build_o_gemv_ffn_int4_module, + ) - # 2. o_gemv_ffn: 3-launch (matvec_2tile_add + matvec_swiglu_rms + - # matvec_2tile_add). Post-attention residual is routed - # through a row-0 subview of arg6 (the packed RMSNorm - # input buffer); see o_gemv_ffn_multi.py for the ABI. - from multi_launch_builder.o_gemv_ffn_multi import build_o_gemv_ffn_module + cache.compile_and_cache( + "o_gemv_ffn_int4", + build_o_gemv_ffn_int4_module(emb_dim=emb_dim, hidden_dim=hidden_dim), + {"verbose": cache.verbose, **OGF_INT4_BACKEND}, + ) + else: + # 1. rms_gemv_rope: RMSNorm + QKV GEMV + RoPE Q+K (6 launches, 13 args) + from multi_launch_builder.rms_gemv_rope_multi import ( + build_rms_gemv_rope_module, + ) - cache.compile_and_cache( - "o_gemv_ffn", - build_o_gemv_ffn_module(emb_dim, hidden_dim), - {"verbose": cache.verbose, **OGF_BACKEND}, - ) + cache.compile_and_cache( + "rms_gemv_rope", + build_rms_gemv_rope_module(emb_dim, kv_dim, n_heads, n_kv_heads, head_dim), + {"verbose": cache.verbose, **RGR_BACKEND}, + ) + + # 2. o_gemv_ffn: 3-launch (matvec_2tile_add + matvec_swiglu_rms + + # matvec_2tile_add). Post-attention residual is routed + # through a row-0 subview of arg6 (the packed RMSNorm + # input buffer); see o_gemv_ffn_multi.py for the ABI. + from multi_launch_builder.o_gemv_ffn_multi import build_o_gemv_ffn_module + + cache.compile_and_cache( + "o_gemv_ffn", + build_o_gemv_ffn_module(emb_dim, hidden_dim), + {"verbose": cache.verbose, **OGF_BACKEND}, + ) # 3. LM Head GEMV multi-launch: 8-partition GEMV in one ELF from multi_launch_builder.lm_head_gemv_multi import ( @@ -145,6 +185,7 @@ def run_decode_block( v_cache_layer, current_pos, rope_lut_bf16, + quant: str = "bf16", ): """Run one transformer block for a single decode token. @@ -188,14 +229,21 @@ def _run(name, backend, *inputs, static_indices=None, **kwargs): # --- Call 1: rms_gemv_rope (6 launches, 13 args) --- # RMSNorm + Q/K/V GEMV + RoPE Q + RoPE K + # Int4 path uses the same arg slots; only the weight types change + # (slots 3/5/7: packed-int4 uint8 BO instead of bf16 matrix). x_in = x_bf16.flatten().astype(bfloat16) w_norm = layer_weights.attn_norm.reshape(emb_dim).astype(bfloat16) normed_buf = np.zeros(emb_dim, dtype=bfloat16) - wq = layer_weights._wq_t + if quant == "awq": + wq = layer_weights._wq_packed + wk = layer_weights._wk_packed + wv = layer_weights._wv_packed + else: + wq = layer_weights._wq_t + wk = layer_weights._wk_t + wv = layer_weights._wv_t q_buf = np.zeros(emb_dim, dtype=bfloat16) - wk = layer_weights._wk_t k_buf = np.zeros(kv_dim, dtype=bfloat16) - wv = layer_weights._wv_t v_buf = np.zeros(kv_dim, dtype=bfloat16) # RoPE LUT for current position @@ -205,17 +253,19 @@ def _run(name, backend, *inputs, static_indices=None, **kwargs): q_roped_buf = np.zeros(emb_dim, dtype=bfloat16) k_roped_buf = np.zeros(kv_dim, dtype=bfloat16) + rgr_name = "rms_qkv_int4_rope" if quant == "awq" else "rms_gemv_rope" + rgr_backend = RGR_INT4_BACKEND if quant == "awq" else RGR_BACKEND results = _run( - "rms_gemv_rope", - RGR_BACKEND, + rgr_name, + rgr_backend, x_in, # arg0 w_norm, # arg1 normed_buf, # arg2 (intermediate) - wq, # arg3 (static) + wq, # arg3 (static, packed-i8 in int4 mode) q_buf, # arg4 (intermediate) - wk, # arg5 (static) + wk, # arg5 (static, packed-i8 in int4 mode) k_buf, # arg6 (intermediate) - wv, # arg7 (static) + wv, # arg7 (static, packed-i8 in int4 mode) v_buf, # arg8 (intermediate/output) lut_q, # arg9 lut_k, # arg10 @@ -249,34 +299,42 @@ def _run(name, backend, *inputs, static_indices=None, **kwargs): # stage 1 in-kernel, row 1 = ffn_norm_w pre-loaded by host). # arg7 = interleaved w_gateup [2*hidden_dim, emb_dim]. arg2/4/5/8/9/10/13 # are dead ABI placeholders; pass small zero buffers. - wo = layer_weights._wo_t + # Int4 path: slots 0/7/12 hold packed-i8 BOs; bf16 dead-slot ABI unchanged. x_residual = x_bf16.flatten().astype(bfloat16) swiglu_buf = np.zeros(hidden_dim, dtype=bfloat16) - w_down = layer_weights._wdown_t output_buf = np.zeros(emb_dim, dtype=bfloat16) + if quant == "awq": + wo = layer_weights._wo_packed + w_gateup = layer_weights._wgateup_packed + w_down = layer_weights._wdown_packed + else: + wo = layer_weights._wo_t + w_gateup = layer_weights._wgateup_t + w_down = layer_weights._wdown_t arg6 = layer_weights._packed_rms_buf # [2, emb_dim] - arg7 = layer_weights._wgateup_t # [2*hidden, emb_dim] z_emb = np.zeros(emb_dim, dtype=bfloat16) z_hidden = np.zeros(hidden_dim, dtype=bfloat16) z_hidden_emb = np.zeros((hidden_dim, emb_dim), dtype=bfloat16) + ogf_name = "o_gemv_ffn_int4" if quant == "awq" else "o_gemv_ffn" + ogf_backend = OGF_INT4_BACKEND if quant == "awq" else OGF_BACKEND results = _run( - "o_gemv_ffn", - OGF_BACKEND, - wo, # arg0 wo (static) + ogf_name, + ogf_backend, + wo, # arg0 wo (static, packed-i8 in int4) attn_out, # arg1 attn_out (input) z_emb, # arg2 (dead) x_residual, # arg3 x_residual (input) z_emb, # arg4 (dead — was res1 bus) z_emb, # arg5 (dead — ffn_norm_w now in arg6[1]) arg6, # arg6 packed RMS input (static) - arg7, # arg7 w_gateup (static) + w_gateup, # arg7 w_gateup (static, packed-i8 in int4) z_hidden, # arg8 (dead) z_hidden_emb, # arg9 (dead — wup folded into arg7) z_hidden, # arg10 (dead) swiglu_buf, # arg11 swiglu (intermediate) - w_down, # arg12 wdown (static) + w_down, # arg12 wdown (static, packed-i8 in int4) z_emb, # arg13 (dead) output_buf, # arg14 output (output) output_indices=[14], diff --git a/programming_examples/llama32_1b/llama32_1b_inference.py b/programming_examples/llama32_1b/llama32_1b_inference.py index 20aff3e51..9576dab61 100644 --- a/programming_examples/llama32_1b/llama32_1b_inference.py +++ b/programming_examples/llama32_1b/llama32_1b_inference.py @@ -37,6 +37,7 @@ from llama32_1b_weights import ( LlamaConfig, load_weights, + load_weights_awq, synthetic_weights, generate_rope_lut, ) @@ -46,6 +47,8 @@ LM_GEMV_BACKEND, RGR_BACKEND, OGF_BACKEND, + RGR_INT4_BACKEND, + OGF_INT4_BACKEND, ) from llama32_1b_prefill import ( compile_all_kernels, @@ -110,10 +113,11 @@ class Session: seq_len: int # padded prompt length (today: 2048) weights: Any # LlamaWeights, mutated by prepare_runtime() tokenizer: Any # transformers AutoTokenizer - prefill_cache: Any # KernelCache + prefill_cache: Any # KernelCache (None in awq mode — CPU prefill placeholder) decode_cache: Any # KernelCache rope_lut_bf16: np.ndarray # (max_seq, head_dim) bfloat16 model_variant: str # "base" | "instruct" + quant: str = "bf16" # "bf16" or "awq" # Decode LM Head constants @@ -133,6 +137,7 @@ def prepare_runtime( config, seq_len, rope_lut_bf16, + quant: str = "bf16", ): """One-time runtime initialization. Called before any timed inference. @@ -166,11 +171,13 @@ def prepare_runtime( kv_dim = n_kv_heads * head_dim # 1. Compile external C++ kernels from source - compile_all_external_kernels(head_dim=head_dim) + compile_all_external_kernels(head_dim=head_dim, quant=quant) # 2. Pre-transpose all decode GEMV weights # GEMV kernel expects A[M,K] but HuggingFace stores (out_features, in_features) - if not hasattr(weights, "_decode_weights_transposed"): + # In awq mode the decode ELFs consume packed BOs (already on LayerWeights + # as _wq_packed etc.) so the bf16 transpose isn't needed. + if quant == "bf16" and not hasattr(weights, "_decode_weights_transposed"): print(" Pre-transposing weights for GEMV...") for lw in weights.layers: lw._wq_t = np.ascontiguousarray( @@ -200,13 +207,17 @@ def prepare_runtime( for i, lw in enumerate(weights.layers): lw._layer_idx = i - # 4. Pre-load prefill weights into per-layer BOs - preload_prefill_weights(weights, config, prefill_cache, seq_len, rope_lut_bf16) + # 4. Pre-load prefill weights into per-layer BOs. + # AWQ path skips this entirely — prefill runs on CPU as a placeholder + # until the int4 prefill ELFs land. Saves several seconds of compile + # time and ~110 MB of bf16 prefill weights on the device. + if quant == "bf16": + preload_prefill_weights(weights, config, prefill_cache, seq_len, rope_lut_bf16) # 5. Pre-load decode weights into per-layer BOs # (lm_head_gemv 8-partition weights here are also reused by prefill's # last-token projection — refactored from full-seq GEMM for ~150 ms savings) - _preload_decode_weights(decode_cache, weights, config) + _preload_decode_weights(decode_cache, weights, config, quant=quant) # Note: NPU warmup pass not needed here — the NPU prefill keeps # the NPU active. Only needed in llama32_1b_decode.py where CPU prefill @@ -216,12 +227,15 @@ def prepare_runtime( print(f" Runtime prepared in {t_prep:.1f}s") -def _preload_decode_weights(decode_cache, weights, config): +def _preload_decode_weights(decode_cache, weights, config, quant: str = "bf16"): """Pre-load all decode transformer block weights into per-layer BOs. Mirrors the preloading pattern from llama32_1b_decode.py: writes all weight data once before timing starts. During inference, static_input_indices skips weight re-writes. + + In quant="awq" mode, uses the int4 decode ELFs (rms_qkv_int4_rope + + o_gemv_ffn_int4) and packed-BO weights from LayerWeights._wq_packed etc. """ if hasattr(weights, "_decode_weights_preloaded_to_bos"): return @@ -242,18 +256,33 @@ def _preload_decode_weights(decode_cache, weights, config): for layer_idx in range(config.n_layers): lw = weights.layers[layer_idx] - # rms_gemv_rope: allocate + write weights + # rms_gemv_rope / rms_qkv_int4_rope: allocate + write weights. + # ABI is identical (13 args, same arg slots, same output_indices and + # static set) — only slots 3/5/7's *type* changes (bf16 [in,out] vs + # packed-i8 [total_tiles, tile_bytes]). + if quant == "awq": + rgr_name = "rms_qkv_int4_rope" + rgr_backend = RGR_INT4_BACKEND + wq_static, wk_static, wv_static = ( + lw._wq_packed, + lw._wk_packed, + lw._wv_packed, + ) + else: + rgr_name = "rms_gemv_rope" + rgr_backend = RGR_BACKEND + wq_static, wk_static, wv_static = lw._wq_t, lw._wk_t, lw._wv_t decode_cache.load_and_run( - "rms_gemv_rope", - RGR_BACKEND, + rgr_name, + rgr_backend, np.zeros(emb_dim, dtype=bfloat16), # x_in lw.attn_norm.reshape(emb_dim).astype(bfloat16), # norm_w np.zeros(emb_dim, dtype=bfloat16), # normed - lw._wq_t, # wq + wq_static, # wq (bf16 [in,out] or packed-i8) np.zeros(emb_dim, dtype=bfloat16), # q - lw._wk_t, # wk + wk_static, # wk np.zeros(kv_dim, dtype=bfloat16), # k - lw._wv_t, # wv + wv_static, # wv np.zeros(kv_dim, dtype=bfloat16), # v rope_lut_q_dummy, # lut_q rope_lut_k_dummy, # lut_k @@ -262,23 +291,24 @@ def _preload_decode_weights(decode_cache, weights, config): output_indices=[8, 11, 12], static_input_indices={1, 3, 5, 7}, intermediate_indices={2, 4, 6, 8, 11, 12}, - bo_key=f"rms_gemv_rope_L{layer_idx}", + bo_key=f"{rgr_name}_L{layer_idx}", ) # o_gemv_ffn (3-stage): build the interleaved w_gateup [2*hidden, emb] - # and the packed [2, emb] RMSNorm-input buffer (row 1 = ffn_norm_w, - # row 0 left zero for stage 1 to overwrite per token). Stashed on - # LayerWeights for reuse across all decode tokens. Frees the original - # _wgate_t/_wup_t once the interleaved copy is in place — they're - # otherwise unused after this preload (~1 GB host RAM saved). - wgate = lw._wgate_t - wup = lw._wup_t - wgateup = np.empty((2 * hidden_dim, emb_dim), dtype=bfloat16) - wgateup[0::2] = wgate - wgateup[1::2] = wup - lw._wgateup_t = wgateup - del lw._wgate_t - del lw._wup_t + # (bf16 only — int4 path already produced it in load_weights_awq) and + # the packed [2, emb] RMSNorm-input buffer (row 1 = ffn_norm_w, row 0 + # left zero for stage 1 to overwrite per token). Stashed on LayerWeights + # for reuse across all decode tokens. + if quant == "bf16": + wgate = lw._wgate_t + wup = lw._wup_t + wgateup = np.empty((2 * hidden_dim, emb_dim), dtype=bfloat16) + wgateup[0::2] = wgate + wgateup[1::2] = wup + lw._wgateup_t = wgateup + # ~1 GB host RAM saved across the 16 layers. + del lw._wgate_t + del lw._wup_t packed = np.empty((2, emb_dim), dtype=bfloat16) packed[0] = 0.0 @@ -289,28 +319,40 @@ def _preload_decode_weights(decode_cache, weights, config): z_hidden = np.zeros(hidden_dim, dtype=bfloat16) z_hidden_emb = np.zeros((hidden_dim, emb_dim), dtype=bfloat16) + if quant == "awq": + ogf_name = "o_gemv_ffn_int4" + ogf_backend = OGF_INT4_BACKEND + wo_static, wgu_static, wd_static = ( + lw._wo_packed, + lw._wgateup_packed, + lw._wdown_packed, + ) + else: + ogf_name = "o_gemv_ffn" + ogf_backend = OGF_BACKEND + wo_static, wgu_static, wd_static = lw._wo_t, lw._wgateup_t, lw._wdown_t decode_cache.load_and_run( - "o_gemv_ffn", - OGF_BACKEND, - lw._wo_t, # arg0 wo (static) - z_emb, # arg1 attn_out - z_emb, # arg2 (dead) - z_emb, # arg3 x_residual - z_emb, # arg4 (dead) - z_emb, # arg5 (dead) + ogf_name, + ogf_backend, + wo_static, # arg0 wo (static) + z_emb, # arg1 attn_out + z_emb, # arg2 (dead) + z_emb, # arg3 x_residual + z_emb, # arg4 (dead) + z_emb, # arg5 (dead) lw._packed_rms_buf, # arg6 packed (static) - lw._wgateup_t, # arg7 w_gateup (static) - z_hidden, # arg8 (dead) - z_hidden_emb, # arg9 (dead) + wgu_static, # arg7 w_gateup (static) + z_hidden, # arg8 (dead) + z_hidden_emb, # arg9 (dead) z_hidden, # arg10 (dead) z_hidden, # arg11 swiglu - lw._wdown_t, # arg12 wdown (static) + wd_static, # arg12 wdown (static) z_emb, # arg13 (dead) z_emb, # arg14 output output_indices=[14], static_input_indices={0, 6, 7, 12}, intermediate_indices={2, 4, 5, 8, 9, 10, 11, 13, 14}, - bo_key=f"o_gemv_ffn_L{layer_idx}", + bo_key=f"{ogf_name}_L{layer_idx}", ) # LM Head GEMV weights (8 partitions) @@ -556,6 +598,7 @@ def generate( verify=False, cpu_attn=True, on_token=None, + quant: str = "bf16", ): """Run NPU prefill + NPU decode generation. @@ -575,21 +618,36 @@ def generate( print(f"LLAMA Inference: prompt_len={seq_len}, n_tokens={n_tokens}") print(f"{'='*60}\n") - # --- Phase 1: NPU Prefill --- - prefill_token, k_cache, v_cache, prompt_len = run_npu_prefill( - prompt_tokens, - weights, - config, - prefill_cache, - decode_cache, - rope_lut_bf16, - max_seq, - tokenizer=tokenizer, - cpu_attn=cpu_attn, - profile=profile, - verify=verify, - quiet=streaming, - ) + # --- Phase 1: Prefill --- + # bf16: NPU prefill ELFs. awq: CPU prefill placeholder (no int4 prefill + # ELFs yet — see cpu_prefill.run_cpu_prefill). + if quant == "awq": + from cpu_prefill import run_cpu_prefill + + prefill_token, k_cache, v_cache, prompt_len = run_cpu_prefill( + prompt_tokens, + weights, + config, + rope_lut_bf16, + max_seq, + tokenizer=tokenizer, + quiet=streaming, + ) + else: + prefill_token, k_cache, v_cache, prompt_len = run_npu_prefill( + prompt_tokens, + weights, + config, + prefill_cache, + decode_cache, + rope_lut_bf16, + max_seq, + tokenizer=tokenizer, + cpu_attn=cpu_attn, + profile=profile, + verify=verify, + quiet=streaming, + ) # --- Phase 2: NPU Decode --- generated_tokens = [prefill_token] # Token 0 = from prefill @@ -620,6 +678,7 @@ def generate( v_cache[layer_idx], current_pos, rope_lut_bf16, + quant=quant, ) # Final RMSNorm (CPU) @@ -697,26 +756,51 @@ def build_session(args) -> Session: config = LlamaConfig() seq_len = 2048 - prefill_cache = KernelCache("prefill_kernel_cache", verbose=args.verbose) - decode_cache = KernelCache("decode_kernel_cache", verbose=args.verbose) + quant = getattr(args, "quant", "bf16") + # bf16: NPU prefill + NPU decode. awq: CPU prefill placeholder + NPU int4 + # decode (no prefill cache needed; the int4 prefill ELFs land in a future + # PR — see cpu_prefill.py). + prefill_cache = ( + KernelCache("prefill_kernel_cache", verbose=args.verbose) + if quant == "bf16" + else None + ) + decode_cache_dir = ( + "decode_kernel_cache" if quant == "bf16" else "decode_kernel_cache_int4" + ) + decode_cache = KernelCache(decode_cache_dir, verbose=args.verbose) if not args.run_only: - print("Compiling prefill kernels...") - compile_all_kernels(prefill_cache, config, seq_len, cpu_attn=args.cpu_attn) + if quant == "bf16": + print("Compiling prefill kernels...") + compile_all_kernels(prefill_cache, config, seq_len, cpu_attn=args.cpu_attn) + else: + print("AWQ mode: prefill runs on CPU, skipping NPU prefill compile.") print("\nCompiling decode kernels...") - compile_decode_kernels(decode_cache, config) + compile_decode_kernels(decode_cache, config, quant=quant) if args.compile_only: sys.exit(0) if args.run_only: - prefill_cache.load_manifest() + if prefill_cache is not None: + prefill_cache.load_manifest() decode_cache.load_manifest() if args.synthetic_weights: print("\nUsing synthetic random weights (skipping HuggingFace download).") weights = synthetic_weights(config) tokenizer = _SyntheticTokenizer() + elif quant == "awq": + model_id = args.model_path or ( + "amd/Llama-3.2-1B-Instruct-awq-uint4-asym-g128-bf16-lmhead" + ) + print(f"\nLoading AWQ weights ({model_id})...") + weights = load_weights_awq(model_id, config=config) + + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained(model_id) else: model_id = ( "meta-llama/Llama-3.2-1B-Instruct" @@ -736,7 +820,13 @@ def build_session(args) -> Session: ).astype(bfloat16) prepare_runtime( - prefill_cache, decode_cache, weights, config, seq_len, rope_lut_bf16 + prefill_cache, + decode_cache, + weights, + config, + seq_len, + rope_lut_bf16, + quant=quant, ) return Session( @@ -748,6 +838,7 @@ def build_session(args) -> Session: decode_cache=decode_cache, rope_lut_bf16=rope_lut_bf16, model_variant=args.model, + quant=quant, ) @@ -776,7 +867,11 @@ def run_once( (generated_token_ids, prompt_len_actual).""" tokens = _tokenize_prompt(session, prompt_text) prompt_len_actual = len(tokens) - if len(tokens) < session.seq_len: + # bf16 NPU prefill is shape-static at session.seq_len, so the prompt must + # be EOS-padded out to that length. CPU prefill (awq mode) runs at the + # raw prompt length — skipping the pad collapses a 2048² CPU-side + # attention from minutes to ~2 s for short prompts. + if session.quant == "bf16" and len(tokens) < session.seq_len: tokens = tokens + [session.tokenizer.eos_token_id] * ( session.seq_len - len(tokens) ) @@ -794,6 +889,7 @@ def run_once( verify=verify, cpu_attn=cpu_attn, on_token=on_token, + quant=session.quant, ) return generated, prompt_len_actual @@ -923,6 +1019,21 @@ def _stream_cb(_token_id: int, delta: str) -> None: default="instruct", help="Model variant: instruct (Q&A, default) or base (completion)", ) + parser.add_argument( + "--quant", + type=str, + choices=["bf16", "awq"], + default="bf16", + help="Weight precision. 'awq': int4-AWQ NPU decode with CPU prefill " + "placeholder; 'bf16' (default): bf16 NPU prefill + decode.", + ) + parser.add_argument( + "--model-path", + type=str, + default=None, + help="AWQ-mode override for the HF model id / local dir. Default: " + "amd/Llama-3.2-1B-Instruct-awq-uint4-asym-g128-bf16-lmhead.", + ) parser.add_argument( "--interactive", action="store_true", @@ -938,6 +1049,11 @@ def _stream_cb(_token_id: int, delta: str) -> None: if args.synthetic_weights and args.interactive: parser.error("--synthetic-weights cannot be combined with --interactive") + if args.synthetic_weights and args.quant == "awq": + parser.error( + "--synthetic-weights only generates bf16 weights; rerun with " + "--quant=bf16 or drop --synthetic-weights for an AWQ checkpoint." + ) if args.interactive: if args.compile_only: diff --git a/programming_examples/llama32_1b/llama32_1b_weights.py b/programming_examples/llama32_1b/llama32_1b_weights.py index 0f156f8ef..3840560b9 100644 --- a/programming_examples/llama32_1b/llama32_1b_weights.py +++ b/programming_examples/llama32_1b/llama32_1b_weights.py @@ -20,6 +20,7 @@ """ import os +import sys import glob as glob_module from dataclasses import dataclass, field from typing import List, Optional @@ -124,6 +125,18 @@ class LlamaWeights: "mlp.down_proj.weight": ("w_down", True), } +# AutoAWQ stores each Linear as three tensors. Field name = the dataclass +# field that owns the bf16 dequant (used by the CPU prefill placeholder). +_HF_AWQ_LINEARS = { + "self_attn.q_proj": "wq", + "self_attn.k_proj": "wk", + "self_attn.v_proj": "wv", + "self_attn.o_proj": "wo", + "mlp.gate_proj": "w_gate", + "mlp.up_proj": "w_up", + "mlp.down_proj": "w_down", +} + # --------------------------------------------------------------------------- # Safetensors loading helpers @@ -327,6 +340,219 @@ def load_weights( ) +# --------------------------------------------------------------------------- +# HuggingFace AutoAWQ loader +# --------------------------------------------------------------------------- + +# Mapping LayerWeights field name -> per-layer packed-BO attribute name. +# Packed BOs are dynamically attached to each LayerWeights instance (same +# pattern inference.py uses for ._wq_t etc.), keyed by the dataclass field +# to avoid extending the dataclass schema. The decode-side runtime reads +# `layer._wq_packed`, `layer._wo_packed`, etc. The fused FFN ELF +# (o_gemv_ffn_int4) wants gate+up as ONE row-interleaved packed BO and +# is exposed separately as `_wgateup_packed`. +_AWQ_PACKED_ATTR = { + "wq": "_wq_packed", + "wk": "_wk_packed", + "wv": "_wv_packed", + "wo": "_wo_packed", + "w_down": "_wdown_packed", +} +# w_gate and w_up are NOT in _AWQ_PACKED_ATTR — they go through a row-level +# interleave first (gate[i] -> row 2i, up[i] -> row 2i+1) before packing, +# so the int4 FFN ELF can consume them in one BO. + + +def load_weights_awq( + model_name_or_path: str, + config: Optional[LlamaConfig] = None, + group_size: int = 128, + m_tile: int = 8, + k_chunk: int = 2048, + n_cores: int = 8, +) -> LlamaWeights: + """Load a HuggingFace AutoAWQ Llama checkpoint. + + For each Linear, stashes BOTH: + - the bf16 dequant on the existing LayerWeights field (wq/wk/.../w_down), + for the CPU prefill placeholder via reference.transformer_block; + - the per-tile packed uint8 BO on `__packed`, for the NPU int4 + decode ELFs (rms_qkv_int4_rope and o_gemv_ffn_int4). + + Args: + model_name_or_path: local dir or HF model id of an AutoAWQ checkpoint. + config: model hyperparameters (defaults to Llama-3.2-1B). + group_size: AWQ group size (typical 128). Must match the checkpoint. + m_tile, k_chunk, n_cores: GEMV packed-BO tiling parameters; defaults + match the int4 decode ELF builders. + + Returns: + LlamaWeights with packed-BO attributes attached. + """ + from safetensors import safe_open + from awq_repacker import dequant_to_bf16, repack_for_gemv, repack_hf_awq_linear + + sys.path.insert( + 0, + os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "..", + "matrix_vector_multiplication", + "int4_awq", + ), + ) + from matvec_int4_packed import pack_inputs as _pack_inputs + + if config is None: + config = LlamaConfig() + + safetensor_files = _resolve_safetensor_files(model_name_or_path) + + key_to_file = {} + for filepath in safetensor_files: + with safe_open(filepath, framework="numpy") as f: + for key in f.keys(): + key_to_file[key] = filepath + + # --- Embedding table --- + embed_key = "model.embed_tokens.weight" + if embed_key not in key_to_file: + raise KeyError(f"Missing weight: {embed_key}") + with safe_open(key_to_file[embed_key], framework="numpy") as f: + embed_table = _load_tensor(f, embed_key, bfloat16) + + # --- Per-layer --- + layers: List[LayerWeights] = [] + for layer_idx in range(config.n_layers): + # Non-quantized: layernorms. + attn_norm_key = f"model.layers.{layer_idx}.input_layernorm.weight" + ffn_norm_key = f"model.layers.{layer_idx}.post_attention_layernorm.weight" + with safe_open(key_to_file[attn_norm_key], framework="numpy") as f: + attn_norm = _load_tensor(f, attn_norm_key, bfloat16) + with safe_open(key_to_file[ffn_norm_key], framework="numpy") as f: + ffn_norm = _load_tensor(f, ffn_norm_key, bfloat16) + + # Each quantized Linear: load qweight/qzeros/scales, then both + # (a) dequant to bf16 -> existing LayerWeights field, for CPU prefill, + # (b) repack-and-pack -> packed BO, for NPU int4 decode. + # gate/up are special: their NPU packed BO must be ONE interleaved + # matrix (gate row 0, up row 0, gate row 1, ...) so the int4 ELF2 + # (matvec_int4_swiglu_rms) can consume them in one input slot. + linear_bf16 = {} + linear_packed = {} + gate_quants = up_quants = None + for hf_prefix, field_name in _HF_AWQ_LINEARS.items(): + base = f"model.layers.{layer_idx}.{hf_prefix}" + qw_key, qz_key, s_key = ( + f"{base}.qweight", + f"{base}.qzeros", + f"{base}.scales", + ) + for k in (qw_key, qz_key, s_key): + if k not in key_to_file: + raise KeyError( + f"Missing AWQ tensor: {k} (is this an AutoAWQ checkpoint?)" + ) + with safe_open(key_to_file[qw_key], framework="numpy") as f: + qw = f.get_tensor(qw_key) + with safe_open(key_to_file[qz_key], framework="numpy") as f: + qz = f.get_tensor(qz_key) + with safe_open(key_to_file[s_key], framework="numpy") as f: + sc = f.get_tensor(s_key) + if qw.dtype != np.int32: + qw = qw.astype(np.int32) + if qz.dtype != np.int32: + qz = qz.astype(np.int32) + # (a) bf16 dequant for CPU prefill: shape [in, out] (matches + # transformer_block's wq[in, out] convention — no transpose). + linear_bf16[field_name] = dequant_to_bf16(qw, qz, sc, group_size) + # (b) packed BO for NPU decode; gate/up are deferred until both + # are loaded so we can interleave them at the nibble level. + if field_name == "w_gate": + gate_quants = repack_hf_awq_linear(qw, qz, sc, group_size) + elif field_name == "w_up": + up_quants = repack_hf_awq_linear(qw, qz, sc, group_size) + else: + linear_packed[field_name] = repack_for_gemv( + qw, + qz, + sc, + group_size, + M_TILE=m_tile, + K_CHUNK=k_chunk, + N_CORES=n_cores, + ) + + # Interleave gate/up at the (A_q, A_s, A_z) level: row 2i = gate[i], + # row 2i+1 = up[i]. Then pack into one BO sized [2*hidden, emb] for + # arg7 of o_gemv_ffn_int4. + if gate_quants is None or up_quants is None: + raise RuntimeError( + "Could not find both mlp.gate_proj and mlp.up_proj AWQ tensors" + ) + g_q, g_s, g_z = gate_quants + u_q, u_s, u_z = up_quants + h_out, k_half = g_q.shape # h_out = hidden_dim, k_half = K/2 + if u_q.shape != (h_out, k_half): + raise RuntimeError("gate_proj and up_proj have different shapes") + gu_q = np.empty((2 * h_out, k_half), dtype=np.uint8) + gu_q[0::2] = g_q + gu_q[1::2] = u_q + n_groups = g_s.shape[0] + gu_s = np.empty((n_groups, 2 * h_out), dtype=g_s.dtype) + gu_s[:, 0::2] = g_s + gu_s[:, 1::2] = u_s + gu_z = np.empty((n_groups, 2 * h_out), dtype=np.uint8) + gu_z[:, 0::2] = g_z + gu_z[:, 1::2] = u_z + M_gateup = 2 * h_out + K_full = k_half * 2 + gateup_packed = _pack_inputs( + gu_q, + gu_s, + gu_z, + M_gateup, + K_full, + group_size, + m_tile, + k_chunk, + n_cores, + M_gateup, + ) + + layer = LayerWeights( + attn_norm=attn_norm, + ffn_norm=ffn_norm, + **linear_bf16, + ) + for field_name, packed in linear_packed.items(): + setattr(layer, _AWQ_PACKED_ATTR[field_name], packed) + layer._wgateup_packed = gateup_packed + layers.append(layer) + + if (layer_idx + 1) % 4 == 0 or layer_idx == 0: + print(f" AWQ layer {layer_idx + 1}/{config.n_layers} loaded") + + # --- Final norm and lm_head (both bf16 in this checkpoint) --- + norm_key = "model.norm.weight" + with safe_open(key_to_file[norm_key], framework="numpy") as f: + final_norm = _load_tensor(f, norm_key, bfloat16) + lm_head_key = "lm_head.weight" + if lm_head_key in key_to_file: + with safe_open(key_to_file[lm_head_key], framework="numpy") as f: + lm_head = _load_tensor(f, lm_head_key, bfloat16) + else: + print(" Tied embeddings: reusing embed_table as lm_head.") + lm_head = embed_table + + return LlamaWeights( + embed_table=embed_table, + layers=layers, + final_norm=final_norm, + lm_head=lm_head, + ) + + # --------------------------------------------------------------------------- # Synthetic-weights builder (CI smoke / verify without HuggingFace download) # ---------------------------------------------------------------------------